xref: /aosp_15_r20/external/pytorch/test/dynamo/test_repros.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
3with test_rewrite_assert_with_msg and test_rewrite_assert_without_msg)
4"""
5
6# Owner(s): ["module: dynamo"]
7import collections
8import contextlib
9import copy
10import dataclasses
11import functools
12import gc
13import inspect
14import itertools
15import os
16import random
17import unittest
18import warnings
19import weakref
20from abc import ABC
21from collections import namedtuple
22from copy import deepcopy
23from enum import Enum
24from functools import wraps
25from typing import Any, Dict, Iterator, List, Tuple
26from unittest import mock
27
28import numpy as np
29
30import torch
31import torch._dynamo.test_case
32import torch._dynamo.testing
33import torch._dynamo.utils
34import torch._functorch.config
35import torch.library
36import torch.utils._pytree as pytree
37from torch import nn
38from torch._dynamo.debug_utils import same_two_models
39from torch._dynamo.testing import CompileCounter, rand_strided, same
40from torch._inductor.utils import fresh_inductor_cache
41from torch.nn import functional as F
42from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
43from torch.testing._internal.common_utils import (
44    disable_translation_validation_if_dynamic_shapes,
45    instantiate_parametrized_tests,
46    parametrize,
47    skipIfWindows,
48    TEST_WITH_ROCM,
49)
50from torch.testing._internal.two_tensor import TwoTensor
51
52
53_orig_module_call = torch.nn.Module.__call__
54
55# Custom operator that only supports CPU and Meta
56lib = torch.library.Library("test_sample", "DEF")  # noqa: TOR901
57lib.define("foo(Tensor self) -> Tensor")
58lib.impl("foo", torch.sin, "CPU")
59
60
61requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
62
63
64_GLOBAL_CPU_TENSOR = torch.randn(3)
65
66
67def exists(val):
68    return val is not None
69
70
71def maybe(fn):
72    @wraps(fn)
73    def inner(x, *args, **kwargs):
74        if not exists(x):
75            return x
76        return fn(x, *args, **kwargs)
77
78    return inner
79
80
81def is_fx_tracing_test() -> bool:
82    """
83    Copied from the hpc trainer codebase
84    """
85    return torch.nn.Module.__call__ is not _orig_module_call
86
87
88def has_detectron2():
89    try:
90        from detectron2.layers.mask_ops import _paste_masks_tensor_shape
91
92        return _paste_masks_tensor_shape is not None
93    except ImportError:
94        return False
95
96
97def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True):
98    # from detectron2 mask_ops.py
99
100    device = masks.device
101
102    if skip_empty and not torch.jit.is_scripting():
103        x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
104            dtype=torch.int32
105        )
106        x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(
107            dtype=torch.int32
108        )
109        y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(
110            dtype=torch.int32
111        )
112    else:
113        x0_int, y0_int = 0, 0
114        x1_int, y1_int = img_w, img_h
115    x0, y0, x1, y1 = torch.split(boxes, 1, dim=1)  # each is Nx1
116
117    N = masks.shape[0]
118
119    img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
120    img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
121    img_y = (img_y - y0) / (y1 - y0) * 2 - 1
122    img_x = (img_x - x0) / (x1 - x0) * 2 - 1
123    # img_x, img_y have shapes (N, w), (N, h)
124
125    gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
126    gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
127    grid = torch.stack([gx, gy], dim=3)
128
129    if not torch.jit.is_scripting():
130        if not masks.dtype.is_floating_point:
131            masks = masks.float()
132    img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
133
134    if skip_empty and not torch.jit.is_scripting():
135        return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
136    else:
137        return img_masks[:, 0], ()
138
139
140def global_fn(x):
141    return torch.sin(x)
142
143
144def cat(tensors, dim=0):
145    # from detectron2 wrappers.py
146    assert isinstance(tensors, (list, tuple))
147    if len(tensors) == 1:
148        return tensors[0]
149    return torch.cat(tensors, dim)
150
151
152def shapes_to_tensor(x, device=None):
153    # from detectron2 wrappers.py
154    if torch.jit.is_scripting():
155        return torch.as_tensor(x, device=device)
156    if torch.jit.is_tracing():
157        assert all(
158            isinstance(t, torch.Tensor) for t in x
159        ), "Shape should be tensor during tracing!"
160        # as_tensor should not be used in tracing because it records a constant
161        ret = torch.stack(x)
162        if ret.device != device:  # avoid recording a hard-coded device if not necessary
163            ret = ret.to(device=device)
164        return ret
165    return torch.as_tensor(x, device=device)
166
167
168fw_graph = [None]
169bw_graph = [None]
170
171
172def aot_graph_capture_backend(gm, args):
173    from functorch.compile import min_cut_rematerialization_partition
174    from torch._functorch.aot_autograd import aot_module_simplified
175
176    def fw_compiler(gm, _):
177        fw_graph[0] = gm
178        return gm
179
180    def bw_compiler(gm, _):
181        bw_graph[0] = gm
182        return gm
183
184    return aot_module_simplified(
185        gm,
186        args,
187        fw_compiler,
188        bw_compiler,
189        partition_fn=min_cut_rematerialization_partition,
190        keep_inference_input_mutations=True,
191    )
192
193
194class Boxes:
195    # from detectron2 poolers.py
196    def __init__(self, tensor: torch.Tensor):
197        """
198        Args:
199            tensor (Tensor[float]): a Nx4 matrix.  Each row is (x1, y1, x2, y2).
200        """
201        device = (
202            tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
203        )
204        tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
205        if tensor.numel() == 0:
206            # Use reshape, so we don't end up creating a new tensor that does not depend on
207            # the inputs (and consequently confuses jit)
208            tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32, device=device)
209        assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
210        self.tensor = tensor
211
212    def __len__(self) -> int:
213        return self.tensor.shape[0]
214
215    @property
216    def device(self):
217        return self.tensor.device
218
219
220def convert_boxes_to_pooler_format(box_lists):
221    # from detectron2 structures.py
222    boxes = torch.cat([x.tensor for x in box_lists], dim=0)
223    # __len__ returns Tensor in tracing.
224    sizes = shapes_to_tensor([x.__len__() for x in box_lists], device=boxes.device)
225    indices = torch.repeat_interleave(
226        torch.arange(len(box_lists), dtype=boxes.dtype, device=boxes.device), sizes
227    )
228    return cat([indices[:, None], boxes], dim=1)
229
230
231ReformerBackwardOutput = namedtuple(
232    "ReformerBackwardOutput",
233    ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"],
234)
235ReformerEncoderOutput = namedtuple(
236    "ReformerEncoderOutput",
237    ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"],
238)
239
240
241class _ReversibleFunction(torch.autograd.Function):
242    # taken from modeling_reformer.py in huggingface
243    @staticmethod
244    def forward(
245        ctx,
246        hidden_states,
247        layers,
248        attention_mask,
249        head_mask,
250        num_hashes,
251        all_hidden_states,
252        all_attentions,
253        past_buckets_states,
254        use_cache,
255        orig_sequence_length,
256        output_hidden_states,
257        output_attentions,
258    ):
259        all_buckets = ()
260
261        # split duplicated tensor
262        hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1)
263
264        for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)):
265            if output_hidden_states is True:
266                all_hidden_states.append(hidden_states)
267
268            attn_output = layer(attn_output)
269            all_buckets = all_buckets + (attn_output,)
270
271        # Add last layer
272        if output_hidden_states is True:
273            all_hidden_states.append(hidden_states)
274
275        # attach params to ctx for backward
276        ctx.save_for_backward(attn_output.detach(), hidden_states.detach())
277        ctx.layers = layers
278        ctx.all_buckets = all_buckets
279        ctx.head_mask = head_mask
280        ctx.attention_mask = attention_mask
281
282        # Concatenate 2 RevNet outputs
283        return torch.cat([attn_output, hidden_states], dim=-1)
284
285    @staticmethod
286    def backward(ctx, grad_hidden_states):
287        grad_attn_output, grad_hidden_states = torch.chunk(
288            grad_hidden_states, 2, dim=-1
289        )
290
291        # free memory
292        del grad_attn_output
293
294        # num of return vars has to match num of forward() args
295        # return gradient for hidden_states arg and None for other args
296        return (
297            grad_hidden_states,
298            None,
299            None,
300            None,
301            None,
302            None,
303            None,
304            None,
305            None,
306            None,
307            None,
308            None,
309        )
310
311
312class ReformerEncoder(torch.nn.Module):
313    def __init__(self) -> None:
314        super().__init__()
315        self.dropout = 0.5
316        self.layer_norm = torch.nn.LayerNorm(512, eps=1.0e-12)
317        self.layers = [torch.nn.Linear(256, 256)]
318
319    def forward(
320        self,
321        hidden_states,
322        attention_mask=None,
323        head_mask=[None] * 6,
324        num_hashes=None,
325        use_cache=False,
326        orig_sequence_length=64,
327        output_hidden_states=False,
328        output_attentions=False,
329    ):
330        # hidden_states and attention lists to be filled if wished
331        all_hidden_states = []
332        all_attentions = []
333        past_buckets_states = [((None), (None)) for i in range(len(self.layers))]
334
335        # concat same tensor for reversible ResNet
336        hidden_states = torch.cat([hidden_states, hidden_states], dim=-1)
337        hidden_states = _ReversibleFunction.apply(
338            hidden_states,
339            self.layers,
340            attention_mask,
341            head_mask,
342            num_hashes,
343            all_hidden_states,
344            all_attentions,
345            past_buckets_states,
346            use_cache,
347            orig_sequence_length,
348            output_hidden_states,
349            output_attentions,
350        )
351
352        # Apply layer norm to concatenated hidden states
353        hidden_states = self.layer_norm(hidden_states)
354
355        # Apply dropout
356        hidden_states = torch.nn.functional.dropout(
357            hidden_states, p=self.dropout, training=self.training
358        )
359
360        return ReformerEncoderOutput(
361            hidden_states=hidden_states,
362            all_hidden_states=all_hidden_states,
363            all_attentions=all_attentions,
364            past_buckets_states=past_buckets_states,
365        )
366
367
368class ListConfig:
369    class ValueNode:
370        def __init__(self, value):
371            self.value = value
372
373        def _dereference_node(self):
374            return self
375
376        def _is_missing(self):
377            return False
378
379        def _value(self):
380            return self.value
381
382    # Based on an example from omegaconfig.listconfig
383    class ListIterator(Iterator[Any]):
384        def __init__(self, lst: Any, resolve: bool) -> None:
385            self.resolve = resolve
386            self.iterator = iter(lst.__dict__["_content"])
387            self.index = 0
388
389        def __next__(self) -> Any:
390            x = next(self.iterator)
391            if self.resolve:
392                x = x._dereference_node()
393                if x._is_missing():
394                    raise AssertionError
395
396            self.index = self.index + 1
397            if isinstance(x, ListConfig.ValueNode):
398                return x._value()
399            raise AssertionError
400
401    def __iter__(self):
402        return self._iter_ex(True)
403
404    def _iter_ex(self, resolve: bool) -> Iterator[Any]:
405        try:
406            return ListConfig.ListIterator(self, resolve)
407        except Exception:
408            raise AssertionError from None
409
410    def __init__(self) -> None:
411        self._content = [
412            ListConfig.ValueNode(1),
413            ListConfig.ValueNode(3),
414            ListConfig.ValueNode(torch.tensor([7.0])),
415        ]
416
417
418def longformer_chunk(hidden_states, window_overlap=256):
419    """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
420
421    # non-overlapping chunks of size = 2w
422    hidden_states = hidden_states.view(
423        hidden_states.size(0),
424        hidden_states.size(1) // (window_overlap * 2),
425        window_overlap * 2,
426        hidden_states.size(2),
427    )
428
429    # use `as_strided` to make the chunks overlap with an overlap size = window_overlap
430    chunk_size = list(hidden_states.size())
431    chunk_size[1] = chunk_size[1] * 2 - 1
432
433    chunk_stride = list(hidden_states.stride())
434    chunk_stride[1] = chunk_stride[1] // 2
435    return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
436
437
438class PartialT5(torch.nn.Module):
439    # Highly simplified T5Attention prefix
440    def __init__(self) -> None:
441        super().__init__()
442        self.q = torch.nn.Linear(512, 512)
443        self.k = torch.nn.Linear(512, 512)
444        self.v = torch.nn.Linear(512, 512)
445
446    def forward(
447        self,
448        hidden_states,
449        key_value_states=None,
450        past_key_value=None,
451        query_length=None,
452    ):
453        batch_size, seq_length = hidden_states.shape[:2]
454
455        real_seq_length = seq_length
456
457        if past_key_value is not None:
458            assert (
459                len(past_key_value) == 2
460            ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
461            real_seq_length += (
462                past_key_value[0].shape[2] if query_length is None else query_length
463            )
464
465        def shape(states):
466            """projection"""
467            return states.view(batch_size, -1, 8, 64).transpose(1, 2)
468
469        def project(hidden_states, proj_layer, key_value_states, past_key_value):
470            """projects hidden states correctly to key/query states"""
471            if key_value_states is None:
472                # self-attn
473                # (batch_size, n_heads, seq_length, dim_per_head)
474                hidden_states = shape(proj_layer(hidden_states))
475            elif past_key_value is None:
476                # cross-attn
477                # (batch_size, n_heads, seq_length, dim_per_head)
478                hidden_states = shape(proj_layer(key_value_states))
479
480            if past_key_value is not None:
481                if key_value_states is None:
482                    # self-attn
483                    # (batch_size, n_heads, key_length, dim_per_head)
484                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
485                else:
486                    # cross-attn
487                    hidden_states = past_key_value
488            return hidden_states
489
490        # get query states
491        query_states = shape(
492            self.q(hidden_states)
493        )  # (batch_size, n_heads, seq_length, dim_per_head)
494
495        # get key/value states
496        key_states = project(
497            hidden_states,
498            self.k,
499            key_value_states,
500            past_key_value[0] if past_key_value is not None else None,
501        )
502        value_states = project(
503            hidden_states,
504            self.v,
505            key_value_states,
506            past_key_value[1] if past_key_value is not None else None,
507        )
508
509        # compute scores
510        scores = torch.matmul(query_states, key_states.transpose(3, 2))
511
512        # (truncated here )
513        return scores, value_states
514
515
516class ChunkReformerFeedForward(torch.nn.Module):
517    # simplified from HF modeling_reformer.py
518    def __init__(self) -> None:
519        super().__init__()
520        self.layer_norm = torch.nn.LayerNorm(256, eps=1e-12)
521        self.dense = torch.nn.Linear(256, 256)
522        self.output = torch.nn.Linear(256, 256)
523
524    def forward(self, attention_output):
525        return apply_chunking_to_forward(
526            self.forward_chunk,
527            attention_output + 1,
528        )
529
530    def forward_chunk(self, hidden_states):
531        hidden_states = self.layer_norm(hidden_states)
532        hidden_states = self.dense(hidden_states)
533        return self.output(hidden_states)
534
535
536def apply_chunking_to_forward(forward_fn, *input_tensors):
537    # simplified from HF model_utils.py
538    assert len(input_tensors) > 0
539    tensor_shape = input_tensors[0].shape[1]
540    assert all(input_tensor.shape[1] == tensor_shape for input_tensor in input_tensors)
541    num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
542    if num_args_in_forward_chunk_fn != len(input_tensors):
543        raise ValueError
544
545    return forward_fn(*input_tensors)
546
547
548def _validate_model_kwargs(fn, model_kwargs):
549    # simplified from transformers.generation.utils._validate_model_kwargs
550    unused_model_args = []
551    model_args = set(inspect.signature(fn).parameters)
552    for key, value in model_kwargs.items():
553        if value is not None and key not in model_args:
554            unused_model_args.append(key)
555    if unused_model_args:
556        raise ValueError(
557            f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
558            " generate arguments will also show up in this list)"
559        )
560
561
562class FakeMamlInner(torch.nn.Module):
563    def __init__(self) -> None:
564        super().__init__()
565        self.linear = torch.nn.Linear(784, 5)
566
567    def forward(self, x, ignored=None, bn_training=False):
568        return self.linear(x.view(x.shape[0], -1))
569
570
571class PartialMaml(torch.nn.Module):
572    # Highly simplified version of maml.meta.Meta.finetuning
573    def __init__(self) -> None:
574        super().__init__()
575        self.net = FakeMamlInner()
576        self.update_step_test = 10
577        self.update_lr = 0.4
578
579    def forward(self, x_spt, y_spt, x_qry, y_qry):
580        querysz = x_qry.size(0)
581
582        corrects = [0 for _ in range(self.update_step_test + 1)]
583
584        # in order to not ruin the state of running_mean/variance and bn_weight/bias
585        # we finetuning on the copied model instead of self.net
586        net = deepcopy(self.net)
587
588        # 1. run the i-th task and compute loss for k=0
589        logits = net(x_spt)
590        loss = F.cross_entropy(logits, y_spt)
591        grad = torch.autograd.grad(loss, net.parameters())
592        fast_weights = [
593            p[1] - self.update_lr * p[0] for p in zip(grad, net.parameters())
594        ]
595
596        # this is the loss and accuracy before first update
597        with torch.no_grad():
598            # [setsz, nway]
599            logits_q = net(x_qry, net.parameters(), bn_training=True)
600            # [setsz]
601            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
602            # scalar
603            correct = torch.eq(pred_q, y_qry).sum().item()
604            corrects[0] = corrects[0] + correct
605
606        # this is the loss and accuracy after the first update
607        with torch.no_grad():
608            # [setsz, nway]
609            logits_q = net(x_qry, fast_weights, bn_training=True)
610            # [setsz]
611            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
612            # scalar
613            correct = torch.eq(pred_q, y_qry).sum().item()
614            corrects[1] = corrects[1] + correct
615
616        del net
617
618        accs = torch.tensor(corrects) / querysz
619
620        return accs
621
622
623def softmax_backward_data(parent, grad_output, output, dim, self):
624    from torch import _softmax_backward_data
625
626    return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
627
628
629class XSoftmax(torch.autograd.Function):
630    # transformers.models.deberta.modeling_deberta.XSoftmax
631    @staticmethod
632    def forward(self, input, mask, dim):
633        self.dim = dim
634        rmask = ~(mask.to(torch.bool))
635        output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
636        output = torch.softmax(output, self.dim)
637        output.masked_fill_(rmask, 0)
638        self.save_for_backward(output, rmask)
639        return output
640
641    @staticmethod
642    def backward(self, grad_output):
643        (output, rmask) = self.saved_tensors
644        inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
645        return inputGrad, None, None
646
647
648class ModelOutput(collections.OrderedDict):
649    """based on file_utils.py in HuggingFace"""
650
651    def __getitem__(self, k):
652        if isinstance(k, str):
653            inner_dict = dict(self.items())
654            return inner_dict[k]
655        else:
656            return self.to_tuple()[k]
657
658    def __setattr__(self, name, value):
659        if name in self.keys() and value is not None:
660            # Don't call self.__setitem__ to avoid recursion errors
661            super().__setitem__(name, value)
662        super().__setattr__(name, value)
663
664    def __setitem__(self, key, value):
665        # Will raise a KeyException if needed
666        super().__setitem__(key, value)
667        # Don't call self.__setattr__ to avoid recursion errors
668        super().__setattr__(key, value)
669
670    def to_tuple(self):
671        return tuple(self[k] for k in self.keys())
672
673
674def create_rand_mask_from_inputs(
675    from_blocked_mask,
676    to_blocked_mask,
677    rand_attn,
678    num_attention_heads,
679    num_rand_blocks,
680    batch_size,
681    from_seq_length,
682    from_block_size,
683):
684    """taken from HF modeling_big_bird.py"""
685    num_windows = from_seq_length // from_block_size - 2
686    rand_mask = torch.stack(
687        [p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)]
688    )
689    rand_mask = rand_mask.view(
690        batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size
691    )
692    rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask)
693    return rand_mask
694
695
696class SequentialAppendList(torch.nn.Sequential):
697    """from timm/models/vovnet.py"""
698
699    def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
700        for i, module in enumerate(self):
701            if i == 0:
702                concat_list.append(module(x))
703            else:
704                concat_list.append(module(concat_list[-1]))
705        x = torch.cat(concat_list, dim=1)
706        return x, concat_list
707
708
709class BatchNormAct2d(torch.nn.BatchNorm2d):
710    """Taken from timm"""
711
712    def __init__(
713        self,
714        num_features,
715        eps=1e-5,
716        momentum=0.1,
717        affine=True,
718        track_running_stats=True,
719        act_layer=torch.nn.ReLU,
720        inplace=True,
721    ):
722        super().__init__(
723            num_features,
724            eps=eps,
725            momentum=momentum,
726            affine=affine,
727            track_running_stats=track_running_stats,
728        )
729        self.act = act_layer(inplace=inplace)
730
731    @torch.jit.ignore
732    def _forward_python(self, x):
733        return super().forward(x)
734
735    def forward(self, x):
736        if torch.jit.is_scripting():
737            x = self._forward_jit(x)
738        else:
739            x = self._forward_python(x)
740        x = self.act(x)
741        return x
742
743
744def get_parameter_dtype(parameter):
745    """from huggingface model_utils.py"""
746    try:
747        return next(parameter.parameters()).dtype
748    except StopIteration:
749        # For nn.DataParallel compatibility in PyTorch 1.5
750
751        def find_tensor_attributes(module):
752            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
753            return tuples
754
755        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
756        first_tuple = next(gen)
757        return first_tuple[1].dtype
758
759
760class DummyConfig:
761    attn_layers = ["local", "lsh", "local", "lsh", "local", "lsh"]
762    lsh_attn_chunk_length = 64
763    local_attn_chunk_length = 64
764
765
766def _get_min_chunk_len(config):
767    """from hf_Reformer"""
768    attn_types = config.attn_layers
769    attn_types_set = set(attn_types)
770    if len(attn_types_set) == 1 and attn_types[0] == "lsh":
771        return config.lsh_attn_chunk_length
772    elif len(attn_types_set) == 1 and attn_types[0] == "local":
773        return config.local_attn_chunk_length
774    elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}:
775        return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
776    else:
777        raise NotImplementedError(
778            f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select "
779            "attn layer types from ['lsh', 'local'] only."
780        )
781
782
783def _stable_argsort(vector, dim):
784    """from hf_Reformer"""
785    # this function scales the vector so that torch.argsort is stable.
786    # torch.argsort is not stable on its own
787    scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1)
788    scale_offset = scale_offset.expand(vector.shape)
789    scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim])
790    return torch.argsort(scaled_vector, dim=dim)
791
792
793def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(buckets):
794    """from hf_Reformer"""
795    # no gradients are needed
796    with torch.no_grad():
797        # hash-based sort
798        sorted_bucket_idx = _stable_argsort(buckets, dim=-1)
799
800        # create simple indices to scatter to, to have undo sort
801        indices = (
802            torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device)
803            .view(1, 1, -1)
804            .expand(sorted_bucket_idx.shape)
805        )
806
807        # get undo sort
808        undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size())
809        undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices)
810
811    return sorted_bucket_idx, undo_sorted_bucket_idx
812
813
814class CustomList1(list):
815    def __call__(self, x):
816        for processor in self:
817            x = processor(x)
818        return x
819
820    def clear(self):
821        pass  # this prevents RestrictedListSubclassVariable from kicking in
822
823
824class CustomList2(list):
825    def __call__(self, x):
826        for processor in self:
827            x = processor(x)
828        return x
829
830    def length_times_10(self):
831        return len(self) * 10
832
833    def append_twice(self, x):
834        self.extend([x, x])
835
836
837def _merge_criteria_processor_list(default_list, custom_list):
838    # simplified transformers/generation/utils.py
839    if len(custom_list) == 0:
840        return default_list
841    for default in default_list:
842        for custom in custom_list:
843            if type(custom) is type(default):
844                raise ValueError
845    default_list.extend(custom_list)
846    return default_list
847
848
849class FeedForwardLayer(nn.Module):
850    def __init__(self, d_model, dim_feedforward, activation, dropout) -> None:
851        super().__init__()
852        self.linear1 = nn.Linear(d_model, dim_feedforward)
853        self.activation = activation
854        self.dropout1 = nn.Dropout(dropout)
855        self.linear2 = nn.Linear(dim_feedforward, d_model)
856        self.dropout2 = nn.Dropout(dropout)
857
858    def forward(self, x):
859        return self.dropout2(
860            self.linear2(self.dropout1(self.activation(self.linear1(x))))
861        )
862
863
864class TransformerEncoderLayer(nn.Module):
865    def __init__(
866        self,
867        d_model,
868        nhead,
869        dim_feedforward=2048,
870        dropout=0.1,
871        activation=nn.ReLU(),
872        layer_norm_eps=1e-5,
873    ):
874        super().__init__()
875        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
876        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
877        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
878        self.dropout = nn.Dropout(dropout)
879        self.ff_block = FeedForwardLayer(d_model, dim_feedforward, activation, dropout)
880
881    def forward(self, src, src_mask=None, src_key_padding_mask=None):
882        x = src
883        x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
884        x = self.norm2(x + self._ff_block(x))
885        return x
886
887    # self-attention block
888    def _sa_block(self, x, attn_mask, key_padding_mask):
889        x = self.self_attn(
890            x,
891            x,
892            x,
893            attn_mask=attn_mask,
894            key_padding_mask=key_padding_mask,
895            need_weights=False,
896        )[0]
897        return self.dropout(x)
898
899    # feed forward block
900    def _ff_block(self, x):
901        return self.ff_block(x)
902
903
904class MockModule(torch.nn.Module):
905    def inner_fn(self, left, right):
906        return tuple(left) == tuple(right)
907
908    def fn(self, tensor):
909        if type(tensor) is int:
910            return False
911
912        torch.add(tensor, tensor)
913        return self.inner_fn(tensor.shape, (1, 2, 3))
914
915
916class IncByOne:
917    def __init__(self, x):
918        self.x = x + 1
919
920
921class IncByTwo:
922    def __init__(self, x):
923        self.x = x + 2
924
925
926class ReproTests(torch._dynamo.test_case.TestCase):
927    def test_do_paste_mask(self):
928        torch._dynamo.utils.counters.clear()
929        cnt = torch._dynamo.testing.CompileCounter()
930        opt__do_paste_mask = torch.compile(_do_paste_mask, backend=cnt)
931        opt__do_paste_mask(
932            torch.randn(1, 1, 28, 28),
933            torch.tensor([[0.0, 1, 2, 4]]) * 1,
934            427,
935            640,
936            True,
937        )
938        opt__do_paste_mask(
939            torch.randn(1, 1, 28, 28),
940            torch.tensor([[0.0, 1, 2, 4]]) * 2,
941            427,
942            640,
943            True,
944        )
945        opt__do_paste_mask(
946            torch.randn(1, 1, 28, 28),
947            torch.tensor([[0.0, 1, 2, 4]]) * 3,
948            612,
949            612,
950            True,
951        )
952        opt__do_paste_mask(
953            torch.randn(1, 1, 28, 28),
954            torch.tensor([[0.0, 1, 2, 4]]) * 4,
955            612,
956            612,
957            True,
958        )
959        opt__do_paste_mask(
960            torch.randn(1, 1, 28, 28),
961            torch.tensor([[0.0, 1, 2, 4]]) * 2,
962            427,
963            640,
964            False,
965        )
966        # (dynamic shapes, static shapes)
967        self.assertIn(cnt.frame_count, (5, 7))
968        self.assertIn(cnt.op_count, (92, 106, 119))
969
970    def test_convert_boxes_to_pooler_format(self):
971        boxes1 = [
972            Boxes(torch.arange(0, 8).reshape((2, 4))),
973            Boxes(torch.arange(8, 16).reshape((2, 4))),
974        ]
975        boxes2 = [
976            Boxes(torch.arange(16, 20).reshape((1, 4))),
977            Boxes(torch.arange(20, 24).reshape((1, 4))),
978        ]
979        correct1 = convert_boxes_to_pooler_format(boxes1)
980        correct2 = convert_boxes_to_pooler_format(boxes2)
981        fn = convert_boxes_to_pooler_format
982        cnt = torch._dynamo.testing.CompileCounter()
983        opt_fn = torch._dynamo.optimize(cnt)(fn)
984        self.assertTrue(same(opt_fn(boxes1), correct1))
985        self.assertTrue(same(opt_fn(boxes2), correct2))
986
987        # repeat_interleave is a dynamic shape operator we do not execute/
988        # In the future, we could reduce the frame_count down to 1
989        # by guarding on the exact values of `Tensor repeats` arg
990        if torch._dynamo.config.assume_static_by_default:
991            self.assertExpectedInline(cnt.frame_count, """4""")
992            self.assertExpectedInline(cnt.op_count, """10""")
993        else:
994            self.assertExpectedInline(cnt.frame_count, """4""")
995            self.assertExpectedInline(cnt.op_count, """14""")
996
997    def test_boxes_len(self):
998        def fn(boxes):
999            return len(boxes) + boxes.__len__() + boxes.tensor
1000
1001        boxes1 = Boxes(torch.arange(0, 8).reshape((2, 4)))
1002        cnt = torch._dynamo.testing.CompileCounter()
1003        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1004        self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
1005
1006        if torch._dynamo.config.assume_static_by_default:
1007            self.assertExpectedInline(cnt.frame_count, """1""")
1008            self.assertExpectedInline(cnt.op_count, """1""")
1009        else:
1010            self.assertExpectedInline(cnt.frame_count, """1""")
1011            self.assertExpectedInline(cnt.op_count, """2""")
1012
1013    def _reformer(self, nopython):
1014        input = torch.randn([1, 64, 256])
1015        model = ReformerEncoder()
1016        torch.manual_seed(1337)
1017        correct = copy.deepcopy(model)(input)
1018        cnt = torch._dynamo.testing.CompileCounter()
1019        torch.manual_seed(1337)
1020        opt_model = torch._dynamo.optimize(cnt, nopython=nopython)(model)
1021        self.assertTrue(same(opt_model(input), correct))
1022        return cnt
1023
1024    @requires_cuda
1025    def test_sub_alpha_scalar_repro(self):
1026        @torch.compile(backend="aot_eager")
1027        def f(x):
1028            return x.sub(1, alpha=2)
1029
1030        f(torch.ones(2, device="cuda", dtype=torch.float64))
1031
1032    # https://github.com/pytorch/pytorch/issues/113010
1033    def test_out_overload_non_contiguous(self):
1034        def f(x, y):
1035            return torch.abs(x, out=y.T)
1036
1037        f_compiled = torch.compile(f, backend="aot_eager")
1038
1039        x_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2)
1040        y_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2)
1041        x_test = torch.arange(4, dtype=torch.float32).reshape(2, 2)
1042        y_test = torch.arange(4, dtype=torch.float32).reshape(2, 2)
1043
1044        out_ref = f(x_ref, y_ref)
1045        out_test = f_compiled(x_test, y_test)
1046        self.assertEqual(out_ref, out_test)
1047        self.assertEqual(y_ref, y_test)
1048
1049    # https://github.com/pytorch/pytorch/issues/109053
1050    def test_view_dtype_overload(self):
1051        def f(x):
1052            return x.view(torch.int32)
1053
1054        f_compiled = torch.compile(f, backend="aot_eager")
1055
1056        x1 = torch.ones(4, requires_grad=True)
1057        out_ref = f(x1)
1058        out_test = f_compiled(x1)
1059        self.assertEqual(out_ref, out_test)
1060
1061        x2 = torch.ones(4, requires_grad=False)
1062        out_ref = f(x2)
1063        out_test = f_compiled(x2)
1064        self.assertEqual(out_ref, out_test)
1065
1066    # https://github.com/pytorch/pytorch/issues/90552
1067    def test_intermediate_leaf_requires_grad(self):
1068        def f(x):
1069            leaf = torch.ones(2, requires_grad=True)
1070            return leaf, leaf * 2
1071
1072        f_compiled = torch.compile(f, backend="aot_eager")
1073        x = torch.arange(4, dtype=torch.float32).reshape(2, 2)
1074
1075        leaf, out = f(x)
1076        leaf_test, out_test = f_compiled(x)
1077        out.sum().backward()
1078        out_test.sum().backward()
1079        self.assertEqual(leaf.grad, leaf_test.grad)
1080
1081    # https://github.com/pytorch/pytorch/issues/113263
1082    def test_unpack_hooks_dont_run_during_tracing(self):
1083        def f(x, y):
1084            return x * y
1085
1086        f_compiled = torch.compile(f, backend="aot_eager")
1087
1088        pack_count = 0
1089        unpack_count = 0
1090
1091        def pack_hook(x):
1092            nonlocal pack_count
1093            pack_count += 1
1094            return x
1095
1096        # unpack hook shouldn't run during compilation, while we trace the forward
1097        def unpack_hook(x):
1098            nonlocal unpack_count
1099            unpack_count += 1
1100            return x
1101
1102        x = torch.ones(4, requires_grad=True)
1103        y = torch.ones(4, requires_grad=False)
1104        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
1105            out_test = f_compiled(x, y)
1106            self.assertEqual(pack_count, 1)
1107            self.assertEqual(unpack_count, 0)
1108            out_test.sum().backward()
1109            self.assertEqual(pack_count, 1)
1110            self.assertEqual(unpack_count, 1)
1111
1112    # https://github.com/pytorch/pytorch/issues/113263
1113    def test_unpack_hooks_can_be_disabled(self):
1114        def f(x, y):
1115            return x * y
1116
1117        f_compiled = torch.compile(f, backend="aot_eager")
1118
1119        x = torch.ones(4, requires_grad=True)
1120        y = torch.ones(4, requires_grad=False)
1121        with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"):
1122            out_test = f_compiled(x, y)
1123            out_test.sum().backward()
1124
1125    # https://github.com/pytorch/pytorch/issues/113263
1126    def test_disabling_unpack_hooks_within_compiled_region(self):
1127        def g(z):
1128            with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"):
1129                return z + 5
1130
1131        def f(x, y):
1132            z = x * y
1133            return g(z)
1134
1135        f_compiled = torch.compile(f, backend="aot_eager")
1136
1137        x = torch.ones(4, requires_grad=True)
1138        y = torch.ones(4, requires_grad=False)
1139        out_test = f_compiled(x, y)
1140        out_test.sum().backward()
1141
1142    # See https://github.com/pytorch/pytorch/issues/97745
1143    def test_gan_repro_trying_to_backward_through_the_graph_a_second_time(self):
1144        def f(a, b):
1145            c = torch.ones(2, 2)
1146            d = torch.ones(2, 2)
1147            e = torch.matmul(a, c)
1148            g_loss = torch.abs(e - d).mean()
1149            g_loss.backward()
1150            fake_d_pred = torch.matmul(b, e.detach())
1151            d_loss = fake_d_pred.mean()
1152            d_loss.backward()
1153
1154        a_ref = torch.randn(2, 2, requires_grad=True)
1155        b_ref = torch.randn(2, 2, requires_grad=True)
1156        out_ref = f(a_ref, b_ref)
1157
1158        a_test = a_ref.clone().detach().requires_grad_(True)
1159        b_test = b_ref.clone().detach().requires_grad_(True)
1160        out_test = torch.compile(f, backend="aot_eager")(a_test, b_test)
1161
1162        self.assertEqual(out_ref, out_test)
1163        self.assertEqual(a_ref.grad, a_test.grad)
1164        self.assertEqual(b_ref.grad, b_test.grad)
1165
1166    # https://github.com/pytorch/pytorch/issues/111603
1167    def test_tuple_enum_as_key_dict(self):
1168        class MyEnum(Enum):
1169            A = "a"
1170
1171        class SomeModel(torch.nn.Module):
1172            def __init__(self) -> None:
1173                super().__init__()
1174                self.linear = torch.nn.Linear(1, 1)
1175
1176            def forward(self, x) -> torch.Tensor:
1177                return self.linear(x[MyEnum.A])
1178
1179        x = {MyEnum.A: torch.rand(8, 1)}
1180        model_pytorch = SomeModel()
1181        model = torch.compile(model_pytorch)
1182        # Executing twice works
1183        model(x)
1184        y = model(x)
1185        self.assertEqual(y, model_pytorch(x))
1186
1187    def test_embedding_backward_broadcasting_decomp(self):
1188        def f(grad_output, indices):
1189            num_weights = 10
1190            padding_idx = 1
1191            scale_grad_by_freq = True
1192            return torch.ops.aten.embedding_dense_backward(
1193                grad_output, indices, num_weights, padding_idx, scale_grad_by_freq
1194            )
1195
1196        f_compiled = torch.compile(f, backend="aot_eager")
1197
1198        grad_output = torch.ones(2, 4, 3, dtype=torch.float16)
1199        indices = torch.ones(2, 4, dtype=torch.int64)
1200
1201        out_ref = f(grad_output, indices)
1202        out_test = f_compiled(grad_output, indices)
1203
1204        self.assertEqual(out_ref, out_test)
1205
1206    def test_reformer_eval(self):
1207        with torch.no_grad():
1208            cnt = self._reformer(nopython=True)
1209        self.assertEqual(cnt.frame_count, 1)
1210        self.assertEqual(cnt.op_count, 11)
1211
1212    def test_reformer_train(self):
1213        with torch.enable_grad():
1214            cnt = self._reformer(nopython=False)
1215        expected_op_count = (
1216            """11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5"""
1217        )
1218
1219        self.assertExpectedInline(cnt.frame_count, """1""")
1220        self.assertExpectedInline(cnt.op_count, expected_op_count)
1221
1222    @disable_translation_validation_if_dynamic_shapes
1223    def test_longformer_chunk(self):
1224        input1 = torch.randn([1, 4096, 1])
1225        input2 = torch.randn([12, 4096, 64])
1226        correct1 = longformer_chunk(input1)
1227        correct2 = longformer_chunk(input2)
1228        fn = longformer_chunk
1229        cnt = torch._dynamo.testing.CompileCounter()
1230        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1231        self.assertTrue(same(opt_fn(input1), correct1))
1232        self.assertTrue(same(opt_fn(input2), correct2))
1233        self.assertTrue(same(opt_fn(input1), correct1))
1234        self.assertTrue(same(opt_fn(input2), correct2))
1235
1236        if torch._dynamo.config.assume_static_by_default:
1237            if torch._dynamo.config.automatic_dynamic_shapes:
1238                self.assertExpectedInline(cnt.frame_count, """2""")
1239                self.assertExpectedInline(cnt.op_count, """8""")
1240            else:
1241                self.assertExpectedInline(cnt.frame_count, """2""")
1242                self.assertExpectedInline(cnt.op_count, """4""")
1243        else:
1244            self.assertExpectedInline(cnt.frame_count, """2""")
1245            self.assertExpectedInline(cnt.op_count, """19""")
1246
1247    def test_hf_t5_forward(self):
1248        input = torch.randn([1, 2048, 512])
1249        model = PartialT5()
1250        correct = model(input)
1251        cnt = torch._dynamo.testing.CompileCounter()
1252        opt_model = torch._dynamo.optimize_assert(cnt)(model)
1253        self.assertTrue(same(opt_model(input), correct))
1254
1255        if torch._dynamo.config.assume_static_by_default:
1256            self.assertExpectedInline(cnt.frame_count, """1""")
1257            self.assertExpectedInline(cnt.op_count, """11""")
1258        else:
1259            self.assertExpectedInline(cnt.frame_count, """1""")
1260            self.assertExpectedInline(cnt.op_count, """11""")
1261
1262    def test_module_in_skipfiles(self):
1263        model = nn.Linear(10, 10)
1264        cnt = torch._dynamo.testing.CompileCounter()
1265        torch.compile(model, backend=cnt, fullgraph=True)(torch.randn([5, 10]))
1266        self.assertEqual(cnt.frame_count, 1)
1267        self.assertEqual(cnt.op_count, 1)
1268
1269    def test_function_in_skipfiles(self):
1270        cnt = torch._dynamo.testing.CompileCounter()
1271        torch.compile(torch.sin, backend=cnt, fullgraph=True)(torch.randn([5, 10]))
1272        self.assertEqual(cnt.frame_count, 1)
1273        self.assertEqual(cnt.op_count, 1)
1274
1275    def test_slicing_dynamic_shape(self):
1276        def fn(y):
1277            x = torch.ones(8)
1278            idx = y[0]
1279            out = x[idx:]
1280            return (out + 3) * 5
1281
1282        counter = torch._dynamo.testing.CompileCounter()
1283        opt_fn = torch._dynamo.optimize(counter)(fn)
1284        out = opt_fn(torch.ones(10, dtype=torch.long))
1285        # idx should be 1 -> slicing off [1:] of 8 elem tensor
1286        self.assertEqual(list(out.shape), [7])
1287
1288        self.assertEqual(counter.op_count, 2)
1289        self.assertEqual(counter.frame_count, 1)
1290
1291        self.assertEqual(list(opt_fn(torch.tensor([4])).shape), [4])
1292
1293    def test_slicing_dynamic_shape_setitem(self):
1294        def fn(input_lengths: torch.Tensor, new_ones_1):
1295            getitem_13 = input_lengths[3]
1296            new_ones_1[(3, slice(getitem_13, None, None))] = 0
1297            setitem_13 = new_ones_1
1298            return (setitem_13,)
1299
1300        x = torch.randn(10).to(dtype=torch.int64)
1301        y = torch.randn(10, 204)
1302        ref = fn(x, y)
1303        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
1304        res = opt_fn(x, y)
1305        self.assertTrue(same(ref, res))
1306
1307    @torch._dynamo.config.patch(error_on_recompile=True)
1308    @torch.fx.experimental._config.patch(use_duck_shape=False)
1309    def test_dynamic_shape_disable_duck_size(self):
1310        class TestModel(nn.Module):
1311            def __init__(
1312                self,
1313            ):
1314                super().__init__()
1315
1316            def forward(self, x: torch.Tensor, val: int) -> torch.Tensor:
1317                return x + val
1318
1319        main_model = TestModel().to(memory_format=torch.channels_last)
1320        opt_model = torch.compile(main_model, backend="eager", dynamic=True)
1321
1322        x1 = torch.rand(2, 5, 10, 10).to(memory_format=torch.channels_last)
1323        x2 = torch.rand(2, 5, 4, 8).to(memory_format=torch.channels_last)
1324
1325        o1_ref = main_model(x1, 4)
1326        o1 = opt_model(x1, 4)
1327
1328        o2_ref = main_model(x2, 20)
1329        o2 = opt_model(x2, 20)
1330
1331    def test_chunk_reformer_ff(self):
1332        input = torch.randn([1, 4096, 256])
1333        model = ChunkReformerFeedForward()
1334        correct = model(input)
1335        cnt = torch._dynamo.testing.CompileCounter()
1336        opt_model = torch._dynamo.optimize_assert(cnt)(model)
1337        self.assertTrue(same(opt_model(input), correct))
1338
1339        self.assertEqual(cnt.frame_count, 1)
1340        self.assertLessEqual(cnt.op_count, 10)
1341
1342    # see: https://github.com/pytorch/pytorch/issues/80067
1343    # NB: When you remove the expectedFailure, don't forget to
1344    # uncomment/adjust the assertEqual below
1345    @unittest.expectedFailure
1346    @torch._dynamo.config.patch(
1347        fake_tensor_propagation=True, capture_scalar_outputs=True
1348    )
1349    def test_maml_item_capture(self):
1350        a = torch.randn(5, 1, 28, 28)
1351        b = torch.zeros(5, dtype=torch.int64)
1352        c = torch.randn(75, 1, 28, 28)
1353        d = torch.zeros(75, dtype=torch.int64)
1354        model = PartialMaml()
1355        correct = model(a, b, c, d)
1356        cnt = torch._dynamo.testing.CompileCounter()
1357        opt_model = torch._dynamo.optimize(cnt)(model)
1358        for _ in range(10):
1359            self.assertTrue(same(opt_model(a, b, c, d), correct))
1360
1361        # if torch._dynamo.config.assume_static_by_default:
1362        #     self.assertExpectedInline(cnt.frame_count, """2""")
1363        # else:
1364        #     self.assertExpectedInline(cnt.frame_count, """3""")
1365        # TODO(jansel): figure out why op count depends on imports
1366        self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27))
1367
1368    # see: https://github.com/pytorch/pytorch/issues/80067
1369    @torch._dynamo.config.patch(capture_scalar_outputs=False)
1370    def test_maml_no_item_capture(self):
1371        a = torch.randn(5, 1, 28, 28)
1372        b = torch.zeros(5, dtype=torch.int64)
1373        c = torch.randn(75, 1, 28, 28)
1374        d = torch.zeros(75, dtype=torch.int64)
1375        model = PartialMaml()
1376        correct = model(a, b, c, d)
1377        cnt = torch._dynamo.testing.CompileCounter()
1378        opt_model = torch._dynamo.optimize(cnt)(model)
1379        for _ in range(10):
1380            self.assertTrue(same(opt_model(a, b, c, d), correct))
1381
1382        if torch._dynamo.config.assume_static_by_default:
1383            self.assertExpectedInline(cnt.frame_count, """4""")
1384        else:
1385            self.assertExpectedInline(cnt.frame_count, """5""")
1386
1387    def test_hf_model_output(self):
1388        ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10))
1389
1390        def fn1(x):
1391            return x["a"] + 1
1392
1393        def fn2(x):
1394            return x.a + 1
1395
1396        def fn3(x):
1397            return x.to_tuple()[0] + 1
1398
1399        def fn4(x):
1400            return x[0] + 1
1401
1402        cnt = torch._dynamo.testing.CompileCounter()
1403        for fn in (fn1, fn2, fn3, fn4):
1404            cnt.clear()
1405            opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1406            self.assertTrue(same(opt_fn(ex), ex.a + 1))
1407            self.assertEqual(cnt.frame_count, 1)
1408            self.assertEqual(cnt.op_count, 1)
1409
1410    @disable_translation_validation_if_dynamic_shapes
1411    def test_create_rand_mask_from_inputs(self):
1412        args = [
1413            torch.randn([1, 64, 64]),
1414            torch.randn([1, 64, 64]),
1415            torch.zeros([1, 12, 62, 3], dtype=torch.int64),
1416            12,
1417            3,
1418            1,
1419            4096,
1420            64,
1421        ]
1422        correct = create_rand_mask_from_inputs(*args)
1423        fn = create_rand_mask_from_inputs
1424
1425        cnt = torch._dynamo.testing.CompileCounter()
1426        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1427        self.assertTrue(same(opt_fn(*args), correct))
1428        if torch._dynamo.config.assume_static_by_default:
1429            self.assertExpectedInline(cnt.frame_count, """1""")
1430            self.assertExpectedInline(cnt.op_count, """8""")
1431        else:
1432            self.assertExpectedInline(cnt.frame_count, """1""")
1433            self.assertExpectedInline(cnt.op_count, """11""")
1434
1435    def test_rng_state(self):
1436        def fn():
1437            state = torch.get_rng_state()
1438            before = torch.rand(1000)
1439            torch.set_rng_state(state)
1440            after = torch.rand(1000)
1441            return before, after
1442
1443        cnt = torch._dynamo.testing.CompileCounter()
1444        opt_fn = torch._dynamo.optimize(cnt)(fn)
1445
1446        before, after = opt_fn()
1447        self.assertTrue(same(before, after))
1448        self.assertEqual(cnt.frame_count, 2)
1449        self.assertEqual(cnt.op_count, 2)  # rand, rand
1450        try:
1451            graph, _ = torch._dynamo.export(fn)()
1452            # See https://github.com/pytorch/pytorch/pull/87490
1453            self.fail("unexpected export success")
1454        except torch._dynamo.exc.Unsupported:
1455            pass
1456
1457    def test_threading_local(self):
1458        import threading
1459
1460        foo = threading.local()
1461        foo.x = torch.rand(1)
1462
1463        def f(x):
1464            return torch.cat([x, foo.x])
1465
1466        cnt = torch._dynamo.testing.CompileCounter()
1467        opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
1468
1469        inp = torch.ones(1)
1470        out = f(inp)
1471        opt_out = opt_f(inp)
1472        self.assertEqual(opt_out, out)
1473        self.assertEqual(cnt.frame_count, 1)
1474
1475    def test_seq_append_list(self):
1476        x = torch.randn(4, 10)
1477        model = SequentialAppendList(
1478            torch.nn.Linear(10, 10),
1479            torch.nn.ReLU(),
1480            torch.nn.Linear(10, 10),
1481            torch.nn.ReLU(),
1482        )
1483        # this one is tricky because it mutates the list provided as an input
1484        l1 = [x]
1485        l2 = [x]
1486        correct, _ = model(x, l1)
1487        cnt = torch._dynamo.testing.CompileCounter()
1488        opt_model = torch._dynamo.optimize_assert(cnt)(model)
1489        result, l3 = opt_model(x, l2)
1490        self.assertTrue(same(result, correct))
1491        self.assertTrue(same(l1, l2))
1492        self.assertIs(l2, l3)
1493        self.assertEqual(cnt.frame_count, 1)
1494        self.assertEqual(cnt.op_count, 5)
1495
1496    def test_batch_norm_act(self):
1497        a = torch.randn(5, 1, 28, 28)
1498        model = BatchNormAct2d(1).eval()
1499        correct = model(a)
1500        cnt = torch._dynamo.testing.CompileCounter()
1501        if not torch._dynamo.config.specialize_int:
1502            # _local_scalar_dense causes graph break w 0-dim tensor
1503            opt_model = torch._dynamo.optimize(cnt)(model)
1504            self.assertTrue(same(opt_model(a), correct))
1505            return
1506
1507        opt_model = torch._dynamo.optimize_assert(cnt)(model)
1508        self.assertTrue(same(opt_model(a), correct))
1509        self.assertEqual(cnt.frame_count, 1)
1510        self.assertEqual(cnt.op_count, 2)
1511
1512    def test_get_parameter_dtype(self):
1513        model = SequentialAppendList(
1514            torch.nn.Linear(10, 10),
1515            torch.nn.ReLU(),
1516        )
1517
1518        def fn(model, x):
1519            return x + torch.randn(10, dtype=get_parameter_dtype(model))
1520
1521        cnt = torch._dynamo.testing.CompileCounter()
1522        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1523        self.assertEqual(opt_fn(model, torch.randn(10)).dtype, torch.float32)
1524        self.assertEqual(cnt.frame_count, 1)
1525        self.assertEqual(cnt.op_count, 2)
1526
1527    def test_nn_parameter(self):
1528        def test_fn():
1529            a = torch.nn.Parameter(torch.randn(5, 5))
1530            # Checks that TensorVariable stores the type information correctly
1531            self.assertTrue(isinstance(a, torch.nn.Parameter))
1532            return a
1533
1534        cnt = torch._dynamo.testing.CompileCounter()
1535        opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
1536        out = opt_test_fn()
1537        self.assertTrue(isinstance(out, torch.nn.Parameter))
1538
1539    def test_Size(self):
1540        def test_fn():
1541            a = torch.randn(4)
1542            x = torch.Size([1, 2, 3])
1543            # Checks that SizeVariable return torch.Size object
1544            assert isinstance(x, torch.Size)
1545            # Causes graph breaks and checks reconstruction of SizeVariable
1546            # object
1547            self.assertIsInstance(x, torch.Size)
1548            return a
1549
1550        cnt = torch._dynamo.testing.CompileCounter()
1551        opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
1552        opt_test_fn()
1553
1554    # See https://github.com/pytorch/pytorch/issues/100067
1555    def test_copy_weird_strides(self):
1556        # This test requires inductor's copy() decomp to preserve strides properly.
1557        def test_fn(a):
1558            b = torch.zeros(48, 4, 256, 513)
1559            b[:, 0, 1:256, 1:256] = a
1560            c = b.view(4, 12, 1024, 513)
1561            d = c.transpose(2, 1)
1562            d.add_(1)
1563            return d
1564
1565        sh, st, dt, dev, rg = (
1566            (48, 255, 255),
1567            (787968, 513, 1),
1568            torch.float16,
1569            "cpu",
1570            True,
1571        )
1572        a = rand_strided(sh, st, dt, dev).requires_grad_(rg)
1573        compiled_f = torch.compile(test_fn, backend="aot_eager_decomp_partition")
1574        out1 = test_fn(a)
1575        out2 = compiled_f(a)
1576        self.assertEqual(out1, out2)
1577
1578    def test_indexing_with_list(self):
1579        def test_fn():
1580            def run_test(tensor, *idx):
1581                npt = tensor.numpy()
1582                assert npt[idx].shape == tensor[idx].shape
1583
1584            x = torch.arange(0, 10)
1585            cases = [
1586                [None, None],
1587                [1, None],
1588            ]
1589
1590            for case in cases:
1591                run_test(x, *case)
1592
1593            return torch.randn(4)
1594
1595        cnt = torch._dynamo.testing.CompileCounter()
1596        opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
1597        opt_test_fn()
1598
1599    def test_reformer_min_chunk_len(self):
1600        def fn(cfg):
1601            t = torch.empty(10)
1602            t.fill_(_get_min_chunk_len(cfg))
1603            return t[0]
1604
1605        cfg = DummyConfig()
1606        cnt = torch._dynamo.testing.CompileCounter()
1607        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1608        self.assertEqual(opt_fn(cfg), 64)
1609        # With unspec int, maximum computation is preserved
1610        self.assertExpectedInline(cnt.frame_count, """1""")
1611        self.assertExpectedInline(cnt.op_count, """3""")
1612
1613    def test_reformer_sorting(self):
1614        x = torch.zeros([1, 12, 4096], dtype=torch.int64)
1615        correct = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(x)
1616        fn = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx
1617
1618        cnt = torch._dynamo.testing.CompileCounter()
1619        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1620        self.assertTrue(same(opt_fn(x), correct))
1621        if torch._dynamo.config.assume_static_by_default:
1622            self.assertExpectedInline(cnt.frame_count, """1""")
1623            self.assertExpectedInline(cnt.op_count, """14""")
1624        else:
1625            self.assertExpectedInline(cnt.frame_count, """1""")
1626            self.assertExpectedInline(cnt.op_count, """16""")
1627
1628    def test_recursive_map(self):
1629        # https://github.com/pytorch/torchdynamo/issues/132
1630        def _recursive_map(struct, batch_dim=0):
1631            for k, v in struct.items():
1632                if v is not None:
1633                    if isinstance(v, dict):
1634                        _recursive_map(v)
1635                    else:
1636                        struct[k] = v
1637
1638        def toy_example(a, b, v):
1639            x = a / (torch.abs(a) + 1)
1640            if v is not None:
1641                _recursive_map(v)
1642            return x * b
1643
1644        cnt = torch._dynamo.testing.CompileCounter()
1645        opt_toy_example = torch._dynamo.optimize(cnt)(toy_example)
1646        opt_toy_example(
1647            torch.randn(10),
1648            torch.randn(10),
1649            {"layer0": {"memory_keys": torch.randn(10)}},
1650        )
1651        self.assertEqual(cnt.frame_count, 1)
1652        self.assertEqual(cnt.op_count, 4)
1653
1654    def test_issue114171(self):
1655        device = torch.device("cpu")
1656
1657        def fcnn(in_dim, out_dim, hidden_dim, activation=torch.nn.GELU):
1658            layers = [
1659                torch.nn.Linear(in_dim, hidden_dim, device=device),
1660                activation(),
1661                torch.nn.Linear(hidden_dim, out_dim, device=device),
1662            ]
1663            return torch.nn.Sequential(*layers)
1664
1665        class testmodel(torch.nn.Module):
1666            def __init__(self) -> None:
1667                super().__init__()
1668                self.interaction_networks = torch.nn.ModuleList(
1669                    [fcnn(262, 1174, 400) for _ in range(4)]
1670                )
1671
1672            def interact(self, x, cycle):
1673                return self.interaction_networks[cycle](x)
1674
1675        model = testmodel()
1676        forward_aot = torch.compile(
1677            model.interact, fullgraph=True, dynamic=True, backend="eager"
1678        )
1679
1680        x = torch.rand([111, 262], device=device)
1681        y2 = forward_aot(x, 2)  # previously failed
1682
1683    def test_issue175(self):
1684        n_heads = 2
1685        d_model = 64
1686        model = TransformerEncoderLayer(d_model, n_heads)
1687        inp = torch.randn(1, d_model)
1688        cnt = torch._dynamo.testing.CompileCounter()
1689        opt_model = torch._dynamo.optimize(cnt, nopython=True)(model)
1690        opt_model(inp)
1691        opt_model(inp)
1692        self.assertEqual(cnt.frame_count, 1)
1693
1694        self.assertEqual(
1695            15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count
1696        )
1697
1698    def test_exec_import(self):
1699        def fn1():
1700            exec("import math")
1701
1702        def fn2():
1703            try:
1704                math.sqrt(4)
1705                return False
1706            except NameError:
1707                return True
1708
1709        def fn3():
1710            fn1()
1711            return fn2()
1712
1713        self.assertTrue(fn3())
1714        opt_fn3 = torch._dynamo.optimize("eager")(fn3)
1715        self.assertTrue(opt_fn3())
1716
1717    def test_exec_wildcard_import(self):
1718        # Test that globals are not carried over from frame to frame
1719        def fn1():
1720            exec("from torch import *")
1721
1722        def fn2():
1723            x = torch.zeros(4)
1724            for i in range(5):
1725                x = x + i
1726            return x
1727
1728        def fn3():
1729            fn1()
1730            return fn2()
1731
1732        ref = fn3()
1733        opt_fn3 = torch._dynamo.optimize("eager")(fn3)
1734        res = opt_fn3()
1735        self.assertTrue(same(ref, res))
1736
1737    def test_with_on_graph_break_inst(self):
1738        def reversible(x):
1739            print("Hello world")  # Cause graph break so inline fails
1740            return torch.sin(torch.cos(x))
1741
1742        def fn(x):
1743            with torch.enable_grad():
1744                a = torch.sin(x)
1745                b = reversible(a)
1746                c = torch.sigmoid(b)
1747                c.sum().backward()
1748                return x.grad
1749
1750        x = torch.randn(3, requires_grad=True)
1751        x.grad = None
1752        with torch.no_grad():
1753            ref = fn(x)
1754
1755        x.grad = None
1756        opt_fn = torch._dynamo.optimize("eager")(fn)
1757        with torch.no_grad():
1758            res = opt_fn(x)
1759        self.assertTrue(same(ref, res))
1760
1761    def test_with_on_graph_break_nested(self):
1762        def reversible(x):
1763            torch._dynamo.graph_break()  # Cause graph break so inline fails
1764            return torch.sin(torch.cos(x))
1765
1766        def fn(x):
1767            # nested context manager failed previously
1768            with torch.no_grad():
1769                with torch.enable_grad():
1770                    a = torch.sin(x)
1771                    b = reversible(a)
1772                    c = torch.sigmoid(b)
1773                    c.sum().backward()
1774                    return x.grad
1775
1776        x = torch.randn(3, requires_grad=True)
1777        x.grad = None
1778        with torch.no_grad():
1779            ref = fn(x)
1780
1781        x.grad = None
1782        opt_fn = torch._dynamo.optimize("eager")(fn)
1783        with torch.no_grad():
1784            res = opt_fn(x)
1785        self.assertTrue(same(ref, res))
1786
1787    # https://github.com/pytorch/torchdynamo/issues/1446
1788    def test_grad_mode_carrying_correct_state_after_graph_break(self):
1789        def fn(x):
1790            with torch.no_grad():
1791                y = x * 3
1792                print("Break")
1793                z = x + 2
1794            return y, z
1795
1796        x = torch.randn(3, requires_grad=True)
1797        opt_fn = torch._dynamo.optimize("eager")(fn)
1798        y, z = opt_fn(x)
1799        self.assertFalse(y.requires_grad)
1800        self.assertFalse(z.requires_grad)
1801
1802    def test_abc_setattr(self):
1803        # tests that we correctly bail out of __setattr__ calls
1804
1805        # TODO: does not ensure ABC classes are correctly inferred as ClassVariables
1806        # (doesn't test the fix for 'super()')
1807
1808        class BaseModule(torch.nn.Module, ABC):
1809            def blah(self, x):
1810                return x + 1
1811
1812        class Derived(BaseModule):
1813            def __setattr__(self, name, value) -> None:
1814                super().__setattr__(name, value)
1815
1816            def forward(self, x):
1817                # expect a graph break on __setattr__
1818                self.foo = 0
1819                return self.blah(x)
1820
1821            def blah(self, x):
1822                return super().blah(x)
1823
1824        x = torch.randn(3, requires_grad=True)
1825        mod = Derived()
1826        opt_mod = torch._dynamo.optimize("eager")(mod)
1827        opt_mod(x)
1828
1829        # Not sure what this test is testing. It was earlier graph breaking on
1830        # __dict__, so the counter >= 2. With __dict__ support, there is no
1831        # graph break.
1832        self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
1833        self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["total"], 1)
1834
1835    @torch._dynamo.config.patch("suppress_errors", True)
1836    def test_guard_fail_tensor_bool(self):
1837        @torch._dynamo.disable(recursive=False)
1838        def fn():
1839            condition_shape = (5, 5)
1840            dtypes = (torch.bool,)
1841            shapes = (
1842                (),
1843                (5,),
1844                (1, 5),
1845            )
1846
1847            tensors = [
1848                torch.empty(shape, dtype=dtype).fill_(17)
1849                for shape, dtype in itertools.product(shapes, dtypes)
1850            ]
1851
1852            x_vals = (5.0, *tensors)
1853            y_vals = (6.0, *tensors)
1854
1855            @torch._dynamo.disable
1856            def get_expected(condition, x, y):
1857                x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x
1858                y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y
1859                return torch.from_numpy(
1860                    np.where(condition.cpu().numpy(), x_np, y_np)
1861                ).to(common_dtype)
1862
1863            for x, y in zip(x_vals, y_vals):
1864                condition = torch.empty(*condition_shape, dtype=torch.bool).bernoulli_()
1865                common_dtype = torch.result_type(x, y)
1866
1867                def check_equal(condition, x, y):
1868                    # NumPy aggressively promotes to double, hence cast to output to correct dtype
1869                    expected = get_expected(condition, x, y)
1870                    result = torch.where(condition, x, y)
1871                    assert torch.allclose(expected, result)
1872
1873                check_equal(condition, x, y)
1874                check_equal(condition, y, x)
1875
1876        fn()
1877        opt_fn = torch._dynamo.optimize("eager")(fn)
1878        opt_fn()
1879
1880    def test_guard_fail_nested_tuple(self):
1881        def fn(args):
1882            return torch.ones(()), args[0] * 2
1883
1884        # This adds a tensor check on args[1][0] and args[1][1]
1885        args1 = (torch.ones(1), (torch.ones(1), torch.ones(1)))
1886        args2 = (torch.ones(1), torch.ones(1))
1887        opt_fn = torch._dynamo.optimize("eager")(fn)
1888        ref = opt_fn(args1)
1889        res = opt_fn(args2)
1890
1891        self.assertTrue(same(ref, res))
1892
1893    def test_nullcontext1(self):
1894        @torch.compile(fullgraph=True, backend="eager")
1895        def fn(x, ctx):
1896            x = x.sin()
1897            with ctx:
1898                x = x.cos()
1899            x = x.sin()
1900            return x
1901
1902        y = torch.randn(10)
1903        self.assertTrue(same(fn(y, contextlib.nullcontext()), y.sin().cos().sin()))
1904
1905    def test_nullcontext2(self):
1906        @torch.compile(fullgraph=True, backend="eager")
1907        def fn(x, ctx):
1908            x = x.sin()
1909            with ctx():
1910                x = x.cos()
1911            x = x.sin()
1912            return x
1913
1914        y = torch.randn(10)
1915        self.assertTrue(same(fn(y, contextlib.nullcontext), y.sin().cos().sin()))
1916
1917    def test_no_grad_inline(self):
1918        @torch.no_grad()
1919        def a(x):
1920            return x.sin()
1921
1922        @torch.compile(backend="eager", fullgraph=True)
1923        def b(x):
1924            return a(x).cos()
1925
1926        y = torch.randn(10)
1927        self.assertTrue(same(b(y), y.sin().cos()))
1928
1929    @skipIfWindows(
1930        msg="torch._dynamo.exc.TorchRuntimeError: Failed running call_function <class 'torch.LongTensor'>(*(FakeTensor(..., size=(10,), dtype=torch.int32),), **{}):"  # noqa: B950
1931    )
1932    def test_longtensor_list(self):
1933        for partition in [0, 5, 10]:
1934
1935            @torch._dynamo.disable
1936            def rand_gen():
1937                rand_vals = [random.randint(5, 10) for _ in range(10)]
1938                # List of tensors mixed with np.arrays
1939                return list(np.array(rand_vals[:partition])) + [
1940                    torch.tensor(val) for val in rand_vals[partition:]
1941                ]
1942
1943            def fn(x):
1944                random_list = rand_gen()
1945                z = torch.LongTensor(random_list)
1946                return x * z
1947
1948            x = torch.ones(10) * 2
1949
1950            random.seed(0)
1951            ref0 = fn(x)
1952            ref1 = fn(x)
1953
1954            random.seed(0)
1955            opt_fn = torch._dynamo.optimize("eager")(fn)
1956            res0 = opt_fn(x)
1957            res1 = opt_fn(x)
1958
1959            self.assertTrue(same(ref0, res0))
1960            self.assertTrue(same(ref1, res1))
1961
1962    def test_primtorch(self):
1963        @torch._dynamo.optimize("eager")
1964        def fn(x):
1965            torch._refs.abs(x)
1966
1967        fn(torch.randn(3))
1968
1969    @unittest.expectedFailure
1970    # inline_call [('inline in skipfiles: bind ...python3.10/inspect.py', 1)]
1971    def test_primtorch_no_graph_break(self):
1972        @torch._dynamo.optimize("eager", nopython=True)
1973        def fn(x):
1974            torch._refs.abs(x)
1975
1976        fn(torch.randn(3))
1977
1978    def test_torch_tensor_ops_no_graph_break(self):
1979        @torch._dynamo.optimize("eager", nopython=True)
1980        def fn(x):
1981            torch.Tensor.abs_(x)
1982
1983        fn(torch.randn(3))
1984
1985    @unittest.skipIf(
1986        not isinstance(torch.ops.aten.abs, torch._ops.OpOverloadPacket),
1987        "old pt doesn't work",
1988    )
1989    def test_torch_ops_aten(self):
1990        # Picked an op that doesn't show up in the default list
1991        @torch._dynamo.optimize("eager", nopython=True)
1992        def fn(x):
1993            return torch.ops.aten.absolute(x)
1994
1995        fn(torch.randn(3))
1996
1997    def test_hf_gelu_inline(self):
1998        class GELUActivation(nn.Module):
1999            def __init__(self) -> None:
2000                super().__init__()
2001                self.act = nn.functional.gelu
2002
2003            def forward(self, input):
2004                return self.act(input)
2005
2006        @torch._dynamo.optimize("eager", nopython=True)
2007        def fn(x):
2008            return GELUActivation()(x)
2009
2010        y = torch.randn(10)
2011        self.assertTrue(same(fn(y), nn.functional.gelu(y)))
2012
2013        @torch._dynamo.optimize("eager", nopython=True)
2014        def fn_returns(x):
2015            return GELUActivation(), x + 1
2016
2017        act, _ = fn_returns(y)
2018        self.assertIsInstance(act, GELUActivation)
2019        self.assertIs(act.act, nn.functional.gelu)
2020        self.assertTrue(hasattr(act, "_buffers"))  # check that __init__ got called
2021
2022    def test_dropout_inline(self):
2023        @torch._dynamo.optimize("eager")
2024        def fn(x):
2025            return torch.nn.Dropout(0.1)(x)
2026
2027        y = torch.randn(10)
2028        torch.manual_seed(1337)
2029        ref = nn.functional.dropout(y, 0.1)
2030        torch.manual_seed(1337)
2031        res = fn(y)
2032        self.assertTrue(same(ref, res))
2033
2034    def test_setitem_boolean_mask_diff(self):
2035        def fn(x, b, y):
2036            x = x.clone()
2037            x[b] = y
2038            return x
2039
2040        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
2041        x = torch.randn(4, requires_grad=True)
2042        b = torch.tensor([True, False, True, False])
2043        y = torch.randn(2, requires_grad=True)
2044        opt_fn(x, b, y)
2045
2046    def test_setitem_tuple_boolean_mask_diff(self):
2047        def fn(x, b, y):
2048            x = x.clone()
2049            x[:, b] = y
2050            return x
2051
2052        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
2053        x = torch.randn(8, 4, requires_grad=True)
2054        b = torch.tensor([True, False, True, False])
2055        y = torch.randn(2, requires_grad=True)
2056        opt_fn(x, b, y)
2057
2058    def test_torch_tensor_ops(self):
2059        def fn(x):
2060            return torch.Tensor.abs_(x)
2061
2062        x = torch.randn(3)
2063        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
2064        y = fn(x)
2065        y_ = opt_fn(x)
2066        self.assertTrue(same(y, y_))
2067
2068    def test_guard_ordering_shape_fail(self):
2069        # If a function which takes a tensor has an inner function which
2070        # is compiled and generates a guard on its shape,
2071        # they are evaluated in the wrong order. So if on a subsequent call
2072        # an int is passed instead of a tensor, guard evaluation will crash
2073        # with a "no attribute: shape" error
2074        m = MockModule()
2075        opt_m = torch._dynamo.optimize("eager")(m)
2076        opt_m.fn(torch.ones((5, 5)))
2077        opt_m.fn(-3)
2078
2079    def test_tensor_isinstance_tuple(self):
2080        @torch._dynamo.optimize("eager")
2081        def fn():
2082            t = torch.ones(5, 5)
2083            if not isinstance(t, (int, torch.Tensor)):
2084                msg = str.format(
2085                    "{0} is not an instance of {1}",
2086                    type(t),
2087                    (int, torch.Tensor),
2088                )
2089                raise ValueError(msg)
2090            return True
2091
2092        fn()
2093
2094    def test_isinstance_dtype(self):
2095        @torch._dynamo.optimize("eager", nopython=True)
2096        def fn(x):
2097            isinstance(torch.bfloat16, torch.dtype)
2098            return x
2099
2100        fn(torch.randn(3))
2101
2102    def test_isinstance_storage(self):
2103        @torch._dynamo.optimize("eager")
2104        def fn(x):
2105            f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40])
2106            bools = torch.BoolStorage.from_buffer(f, "big")
2107            assert isinstance(bools, torch.BoolStorage)
2108            return x
2109
2110        fn(torch.randn(3))
2111
2112    def test_issue111522(self):
2113        @torch.compile(backend="eager", fullgraph=True)
2114        def f(x, y):
2115            return x + y.a
2116
2117        class A:
2118            a = 2
2119
2120        self.assertEqual(f(torch.zeros(2), A()), torch.full([2], 2.0))
2121
2122        del A.a
2123
2124        # graph break on missing attr
2125        with self.assertRaises(torch._dynamo.exc.Unsupported):
2126            f(torch.zeros(2), A())
2127
2128    def test_dict_list_values(self):
2129        def inner_fn(args):
2130            return [x[1].shape for x in args]
2131
2132        @torch._dynamo.optimize("eager")
2133        def fn(tensors):
2134            return inner_fn(zip(itertools.count(), tensors["args"]))
2135
2136        fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]})
2137        fn({"args": [torch.ones(5, 5)]})
2138
2139    def test_dict_iter(self):
2140        class MyMod(torch.nn.Module):
2141            def forward(self, x):
2142                z = {"my": 1, "const": 2, "dict": 3, "variable": 4}
2143                tot = 0
2144                for key in z:
2145                    tot += z[key]
2146
2147                return tot
2148
2149        x = torch.tensor([0])
2150        model = MyMod()
2151        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
2152        y = opt_model(x)
2153
2154        self.assertEqual(y, 10)
2155
2156    def test_sort_out(self):
2157        dtype = torch.float32
2158        device = "cpu"
2159
2160        def fn():
2161            tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0]
2162            values1 = torch.tensor(0, dtype=dtype, device=device)
2163            indices1 = torch.tensor(0, dtype=torch.long, device=device)
2164            torch.sort(tensor, out=(values1, indices1))
2165            self.assertEqual(values1.stride(), (1,))
2166            self.assertEqual(indices1.stride(), (1,))
2167
2168        fn()
2169        opt_fn = torch._dynamo.optimize("eager")(fn)
2170        opt_fn()
2171
2172    def test_sort_out2(self):
2173        class MyModule(torch.nn.Module):
2174            def __init__(self) -> None:
2175                super().__init__()
2176                self.sorted = torch.nn.Buffer(torch.ones(4, 4))
2177                self.indices = torch.nn.Buffer(torch.ones(4, 4, dtype=torch.long))
2178
2179            def forward(self, x):
2180                torch.sort(x, out=(self.sorted, self.indices))
2181                return (x + 1, self.sorted, self.indices)
2182
2183        x = torch.randn(4, 4)
2184        m = MyModule()
2185        ref = m(x)
2186        opt_m = torch._dynamo.optimize("eager")(m)
2187        res = opt_m(x)
2188        self.assertTrue(same(ref, res))
2189
2190    def test_sigmoid_out(self):
2191        dtype = torch.float32
2192        device = "cpu"
2193
2194        def fn():
2195            inp = torch.randn((3, 5), dtype=dtype, device=device)
2196            out1 = torch.tensor(0, dtype=dtype, device=device)
2197            torch.sigmoid(inp, out=out1)
2198            self.assertEqual(out1.numel(), 15)
2199
2200        fn()
2201        opt_fn = torch._dynamo.optimize("eager")(fn)
2202        opt_fn()
2203
2204    def test_sigmoid_out2(self):
2205        class MyModule(torch.nn.Module):
2206            def __init__(self) -> None:
2207                super().__init__()
2208                self.base = torch.nn.Buffer(torch.ones(4, 4))
2209
2210            def forward(self, x):
2211                torch.sigmoid(x, out=self.base)
2212                return x + self.base
2213
2214        x = torch.randn(4, 4)
2215        m = MyModule()
2216        ref = m(x)
2217        opt_m = torch._dynamo.optimize("eager")(m)
2218        res = opt_m(x)
2219        self.assertTrue(same(ref, res))
2220
2221    def test_slice_into_list_mutable(self):
2222        class Mod(torch.nn.Module):
2223            def forward(self, listy):
2224                x = listy[3:5]
2225                for i in range(10):
2226                    z = torch.abs(torch.randn(10)) + 1
2227                    x[0] = z
2228                return x
2229
2230        m = Mod()
2231        listy = [torch.randn(10)] * 10
2232
2233        cnt = torch._dynamo.testing.CompileCounter()
2234        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
2235        opt_m.forward(listy)
2236
2237        self.assertEqual(cnt.frame_count, 1)
2238
2239    @torch._dynamo.config.patch(capture_scalar_outputs=True)
2240    def test_issue111918(self):
2241        cnt = CompileCounter()
2242
2243        @torch.compile(backend=cnt, dynamic=True)
2244        def fn(x):
2245            x = x + 1
2246            y = x.item()
2247            if y > 2:
2248                return x * 2
2249            else:
2250                return x * 3
2251
2252        x = torch.tensor([3.0])
2253        fn(x)
2254        self.assertEqual(cnt.frame_count, 2)
2255        self.assertEqual(cnt.op_count, 4)
2256
2257        torch._dynamo.reset()
2258        fn = torch.compile(fn, fullgraph=True, backend="eager")
2259        with self.assertRaises(torch._dynamo.exc.UserError):
2260            fn(x)
2261
2262    def test_vdd_duplicate_error(self):
2263        def fn(a, dt):
2264            keys = list(dt._jt_dict.keys())
2265            p = torch.cos(dt._jt_dict[keys[0]]._value)
2266            q = torch.sin(a)
2267            r = torch.sigmoid(dt._jt_dict[keys[0]]._value)
2268            return p + q + r
2269
2270        class Value:
2271            def __init__(self) -> None:
2272                self._value = torch.randn(4)
2273
2274        class Sample:
2275            def __init__(self) -> None:
2276                self._jt_dict = {}
2277                self._jt_dict["POSITION_ID"] = Value()
2278
2279        a = torch.randn(4)
2280        sample = Sample()
2281
2282        ref = fn(a, sample)
2283
2284        optimized_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
2285        res = optimized_fn(a, sample)
2286
2287        self.assertTrue(same(ref, res))
2288
2289    def test_specialized_stride(self):
2290        def f():
2291            e = torch.empty(4)
2292            x = e[::2]
2293            return x.stride()
2294
2295        self.assertEqual(f(), torch._dynamo.optimize("eager")(f)())
2296
2297    def test_out_none(self):
2298        # https://github.com/pytorch/pytorch/issues/92814
2299        def fn(input):
2300            return torch.nn.functional.normalize(input, dim=0, out=None)
2301
2302        x = torch.rand([1])
2303        self.assertEqual(fn(x), torch._dynamo.optimize("eager")(fn)(x))
2304
2305    def test_multi_import(self):
2306        if not has_detectron2():
2307            raise unittest.SkipTest("requires detectron2")
2308
2309        @torch._dynamo.optimize("eager", nopython=True)
2310        def to_bitmasks(boxes):
2311            from detectron2.layers.mask_ops import (
2312                _paste_masks_tensor_shape,
2313                paste_masks_in_image,
2314            )
2315
2316            if (
2317                paste_masks_in_image is not None
2318                and _paste_masks_tensor_shape is not None
2319            ):
2320                return boxes + 1
2321
2322        self.assertTrue((to_bitmasks(torch.zeros(10)) == torch.ones(10)).all())
2323
2324    def test_multi_dot_import(self):
2325        def fn1(x):
2326            return torch.sin(x)
2327
2328        def fn(x):
2329            import torch.fx
2330
2331            _ = torch.fx.symbolic_trace(fn1)
2332            return x * 2
2333
2334        x = torch.randn(10)
2335        fn(x)
2336        cnt = torch._dynamo.testing.CompileCounter()
2337        opt_fn = torch._dynamo.optimize(cnt)(fn)
2338        opt_fn(x)
2339        self.assertEqual(cnt.frame_count, 1)
2340
2341    def test_relative_import(self):
2342        try:
2343            from . import utils as _  # noqa: F401
2344
2345            def fn(x):
2346                from .utils import tensor_for_import_testing
2347
2348                return x * 2 * tensor_for_import_testing
2349
2350        except ImportError:
2351
2352            def fn(x):
2353                from utils import tensor_for_import_testing
2354
2355                return x * 2 * tensor_for_import_testing
2356
2357        x = torch.randn(10)
2358        fn(x)
2359        cnt = torch._dynamo.testing.CompileCounter()
2360        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
2361        opt_fn(x)
2362        self.assertEqual(cnt.frame_count, 1)
2363
2364    def test_relative_import_no_modulename(self):
2365        try:
2366            from . import utils as _  # noqa: F401
2367
2368            def fn(x):
2369                from . import utils
2370
2371                return x * 2 * utils.tensor_for_import_testing
2372
2373        except ImportError:
2374
2375            def fn(x):
2376                import utils
2377
2378                return x * 2 * utils.tensor_for_import_testing
2379
2380        x = torch.randn(10)
2381        fn(x)
2382        cnt = torch._dynamo.testing.CompileCounter()
2383        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
2384        opt_fn(x)
2385        self.assertEqual(cnt.frame_count, 1)
2386
2387    def test_bigbird_unsqueeze_inplace(self):
2388        def fn(reshape_2):
2389            view_2 = reshape_2.clone()
2390            view_2.unsqueeze_(2)
2391            cat_11 = torch.cat([view_2], dim=2)
2392            view_13 = cat_11.view((2, 12, 64, -1))
2393            return (view_13,)
2394
2395        x = torch.randn(2, 12, 64, 64, requires_grad=True)
2396        ref = fn(x)
2397        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
2398        res = opt_fn(x)
2399        self.assertTrue(same(ref, res))
2400
2401    def test_issue1466_size_aot_autograd(self):
2402        def fn(x):
2403            # do a tensor op and a size compute
2404            y = x * 2
2405            x_size = x.size()
2406            # trigger a graph break
2407            print("arf")
2408            # use the tensor op and size compute
2409            z = y.view(x_size) + 1
2410            return z
2411
2412        x = torch.randn(2, 3, requires_grad=True)
2413        ref = fn(x)
2414        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
2415        res = opt_fn(x)
2416        self.assertTrue(same(ref, res))
2417
2418    def test_ellipsis(self):
2419        class Repro(torch.nn.Module):
2420            def __init__(self) -> None:
2421                super().__init__()
2422                self.lnorm = torch.nn.LayerNorm(
2423                    (256,), eps=1e-06, elementwise_affine=True
2424                )
2425                self.linear = torch.nn.Linear(
2426                    in_features=256, out_features=256, bias=True
2427                )
2428
2429            def forward(self, cat_10):
2430                lnorm = self.lnorm(cat_10)
2431                getitem_64 = lnorm[
2432                    (slice(None, None, None), slice(0, 1, None), Ellipsis)
2433                ]
2434                linear = self.linear(getitem_64)
2435                return (linear,)
2436
2437        args = [torch.randn(2, 197, 256)]
2438
2439        mod = Repro()
2440        opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
2441
2442        self.assertTrue(same(mod(*args), opt_mod(*args)))
2443
2444    def test_reinplacing(self):
2445        class MockModule(torch.nn.Module):
2446            def __init__(self) -> None:
2447                super().__init__()
2448                self.self_layoutlm_embeddings_x_position_embeddings = (
2449                    torch.nn.Embedding(1024, 768)
2450                )
2451                self.self_layoutlm_embeddings_y_position_embeddings = (
2452                    torch.nn.Embedding(1024, 768)
2453                )
2454
2455            def forward(self, getitem_1, getitem_2, add):
2456                self_layoutlm_embeddings_x_position_embeddings = (
2457                    self.self_layoutlm_embeddings_x_position_embeddings(getitem_1)
2458                )
2459                self_layoutlm_embeddings_y_position_embeddings = (
2460                    self.self_layoutlm_embeddings_y_position_embeddings(getitem_2)
2461                )
2462                add_1 = add + self_layoutlm_embeddings_x_position_embeddings
2463                add_2 = add_1 + self_layoutlm_embeddings_y_position_embeddings
2464                return (add_2,)
2465
2466        mod = MockModule()
2467        opt_mod = torch._dynamo.optimize("aot_eager_decomp_partition")(mod)
2468
2469        args = [
2470            ((2, 512), (2048, 4), torch.int64, "cpu", False),
2471            ((2, 512), (2048, 4), torch.int64, "cpu", False),
2472            ((2, 512, 768), (393216, 768, 1), torch.float32, "cpu", True),
2473        ]
2474        args = [
2475            rand_strided(sh, st, dt, dev).requires_grad_(rg)
2476            for (sh, st, dt, dev, rg) in args
2477        ]
2478        self.assertTrue(same_two_models(mod, opt_mod, args))
2479
2480    def test_optimized_deepcopy(self):
2481        # See https://github.com/pytorch/pytorch/pull/88629
2482        class Foo(torch.nn.Module):
2483            def __init__(self) -> None:
2484                super().__init__()
2485                self.fc = torch.nn.Linear(in_features=2, out_features=3, bias=True)
2486
2487            def forward(self, x):
2488                return self.fc(x)
2489
2490        mod = Foo()
2491        opt_mod = torch._dynamo.optimize("eager")(mod)
2492        args = [torch.randn(1, 2)]
2493        self.assertTrue(same_two_models(mod, opt_mod, args))
2494
2495    def test_class_member(self):
2496        class Foo(torch.nn.Module):
2497            a = 4
2498            b = torch.ones(3, 4)
2499
2500            def __init__(self) -> None:
2501                super().__init__()
2502                self.c = 4
2503
2504            def forward(self, x):
2505                return x.cos() + self.a + self.b + self.c
2506
2507        mod = Foo()
2508        opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
2509        args = (torch.randn(3, 4),)
2510        self.assertTrue(same(mod(*args), opt_mod(*args)))
2511
2512    def test_named_buffers(self):
2513        class Foo(torch.nn.Module):
2514            def __init__(self) -> None:
2515                super().__init__()
2516                self.x = torch.nn.Buffer(torch.ones(3))
2517                self.y = torch.nn.Buffer(torch.ones(3))
2518
2519            def forward(self, inp):
2520                res = 0
2521                for name, buffer in self.named_buffers():
2522                    res += buffer.sum()
2523
2524                return inp.cos() + res
2525
2526        mod = Foo()
2527        opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
2528        args = (torch.randn(3, 4),)
2529        self.assertTrue(same(mod(*args), opt_mod(*args)))
2530
2531    def test_requires_grad_guards_with_grad_mode1(self):
2532        def f(x):
2533            if x.requires_grad:
2534                return x + 1
2535            else:
2536                return x + 2
2537
2538        x = torch.ones(2, requires_grad=True)
2539
2540        f_compiled = torch.compile(f)
2541        with torch.no_grad():
2542            # compile an inference graph
2543            f_compiled(x)
2544
2545        # Test: we should fail guards and recompile (even though it's still an inference graph)
2546        out_ref = f(x.detach())
2547        out = f_compiled(x.detach())
2548
2549        self.assertEqual(out_ref, out)
2550        self.assertEqual(out_ref.requires_grad, out.requires_grad)
2551
2552    def test_requires_grad_guards_with_grad_mode2(self):
2553        x = torch.ones(2, requires_grad=True)
2554        x_ref = x.clone().detach().requires_grad_(True)
2555
2556        m = torch.nn.Linear(2, 2)
2557        m_compiled = torch.compile(m)
2558
2559        with torch.no_grad():
2560            # compile an inference graph
2561            m_compiled(x)
2562
2563        # Test: we should fail guards and recompile a training graph
2564        out_ref = m(x_ref)
2565        out = m_compiled(x)
2566        self.assertEqual(out_ref, out)
2567        self.assertEqual(out_ref.requires_grad, out.requires_grad)
2568
2569    def test_is_symbolic_tracing(self):
2570        # Ensure no graph break here
2571        def fn(x):
2572            if is_fx_tracing_test():
2573                return x * 2
2574            return x * 4
2575
2576        a = torch.randn(4)
2577        ref = fn(a)
2578        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
2579        res = opt_fn(a)
2580        self.assertTrue(same(ref, res))
2581
2582    def test_tokenization(self):
2583        from collections import UserDict
2584
2585        class BatchEncoding(UserDict):
2586            """
2587            Copied from tokenization
2588            """
2589
2590            def __init__(
2591                self,
2592                data,
2593            ):
2594                super().__init__(data)
2595
2596            def __getattr__(self, item: str):
2597                try:
2598                    return self.data[item]
2599                except KeyError as e:
2600                    raise AttributeError from e
2601
2602        def tokenization(x):
2603            encoding = BatchEncoding({"key": x})
2604            return encoding["key"]
2605
2606        opt_fn = torch._dynamo.optimize("eager")(tokenization)
2607        x = torch.rand((1, 4))
2608        ref = tokenization(x)
2609        res = opt_fn(x)
2610        self.assertTrue(same(ref, res))
2611
2612    def test_modules(self):
2613        class Foo(torch.nn.Module):
2614            def __init__(self) -> None:
2615                super().__init__()
2616                self.fc = torch.nn.Linear(4, 3)
2617
2618            def forward(self, inp):
2619                res = torch.zeros(3, 3)
2620                for mod in self.modules():
2621                    res += self.fc(inp)
2622                return res
2623
2624        mod = Foo()
2625        args = (torch.ones(3, 4),)
2626        cnt = torch._dynamo.testing.CompileCounter()
2627        opt_mod = torch._dynamo.optimize(cnt, nopython=True)(mod)
2628        self.assertTrue(same(mod(*args), opt_mod(*args)))
2629        self.assertEqual(cnt.op_count, 5)
2630        self.assertEqual(cnt.frame_count, 1)
2631
2632    def test_omegaconf_listconfig_iter(self):
2633        obj = ListConfig()
2634        x = torch.zeros(2)
2635
2636        def fn():
2637            y = x
2638            for i in obj:
2639                y += i
2640            return y
2641
2642        expected = fn()
2643        actual = torch.compile(fn, fullgraph=True, backend="eager")()
2644        self.assertEqual(actual, expected)
2645
2646    def test_user_defined_iter(self):
2647        class MyIter:
2648            def __init__(self) -> None:
2649                self.i = 0
2650
2651            def __iter__(self):
2652                return self
2653
2654            def __next__(self):
2655                if self.i < 3:
2656                    self.i += 1
2657                    return self.i
2658                raise StopIteration
2659
2660        @torch.compile(backend="eager", fullgraph=True)
2661        def fn(x):
2662            for i in MyIter():
2663                x += i
2664            return x
2665
2666        self.assertEqual(fn(torch.zeros(1)), torch.full([1], 6.0))
2667
2668    def test_stop_iteration_reconstruct(self):
2669        @torch.compile(backend="eager", fullgraph=True)
2670        def fn(x):
2671            return x.sin(), StopIteration(1, 2, 3)
2672
2673        _, res = fn(torch.ones(1))
2674        self.assertEqual(str(res), str(StopIteration(1, 2, 3)))
2675
2676    def test_tensor_data_kwarg(self):
2677        # https://github.com/pytorch/pytorch/issues/96278
2678        def f():
2679            return torch.tensor(data=[[1.0, -1.0]])
2680
2681        cnt = torch._dynamo.testing.CompileCounter()
2682        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f)
2683        self.assertTrue(same(f(), opt_fn()))
2684        self.assertEqual(cnt.frame_count, 1)
2685
2686    @requires_cuda
2687    def test_norm_dtype(self):
2688        def foo(_stack0):
2689            getitem = _stack0[(slice(None, None, None), -1)]
2690            _stack0 = None
2691            normalize = torch.nn.functional.normalize(getitem, p=2, dim=1)
2692            getitem = None
2693            return (normalize,)
2694
2695        args = [((2, 50, 256), (1, 256, 1), torch.float16, "cuda", False)]
2696        args = [
2697            rand_strided(sh, st, dt, dev).requires_grad_(rg)
2698            for (sh, st, dt, dev, rg) in args
2699        ]
2700
2701        opt_foo = torch._dynamo.optimize("aot_eager_decomp_partition")(foo)
2702        with torch.cuda.amp.autocast(enabled=True):
2703            ref = foo(*args)[0]
2704            res = foo(*args)[0]
2705            self.assertEqual(ref.dtype, res.dtype)
2706
2707            self.assertTrue(same(res, ref))
2708
2709    def test_for_loop_graph_break(self):
2710        def inner(x):
2711            return torch.sin(x)
2712
2713        def fn(x):
2714            for _ in range(100):
2715                inner(x)
2716                torch._dynamo.graph_break()
2717            return x
2718
2719        cnt = torch._dynamo.testing.CompileCounter()
2720        opt_fn = torch._dynamo.optimize(cnt)(fn)
2721        x = torch.randn(4)
2722        opt_fn(x)
2723        self.assertEqual(cnt.frame_count, 1)
2724        self.assertEqual(cnt.op_count, 1)
2725
2726    def test_for_loop_graph_break_before(self):
2727        # Checks that the backedge is calculated correctly
2728        def inner(x):
2729            return torch.sin(x)
2730
2731        def fn(x):
2732            torch._dynamo.graph_break()
2733            for _ in range(100):
2734                inner(x)
2735            return x
2736
2737        cnt = torch._dynamo.testing.CompileCounter()
2738        opt_fn = torch._dynamo.optimize(cnt)(fn)
2739        x = torch.randn(4)
2740        opt_fn(x)
2741        self.assertEqual(cnt.frame_count, 1)
2742        self.assertEqual(cnt.op_count, 100)
2743
2744    def test_avoid_dupe_specialization(self):
2745        def f(x, y):
2746            return (x + y) * 1
2747
2748        opt_f = torch._dynamo.optimize("aot_eager")(f)
2749
2750        for b in [True, False]:
2751            x = torch.randn(4, requires_grad=b)
2752            y = torch.randn(4, requires_grad=b)
2753            self.assertEqual(f(x, x), opt_f(x, x))
2754            self.assertEqual(f(x, y), opt_f(x, y))
2755
2756    def test_validate_model_kwargs(self):
2757        cnt = CompileCounter()
2758
2759        def f1(a, b):
2760            return torch.sin(a) + torch.cos(b)
2761
2762        @torch.compile(backend=cnt, fullgraph=True)
2763        def f2(**kwargs):
2764            _validate_model_kwargs(f1, kwargs)
2765            return f1(**kwargs)
2766
2767        x = torch.randn(10)
2768        y = torch.randn(10)
2769
2770        self.assertEqual(f2(a=x, b=y), f1(x, y))
2771        self.assertEqual(cnt.frame_count, 1)
2772        self.assertEqual(cnt.op_count, 3)
2773
2774    def test_swin_base_tensor_attr(self):
2775        class Foo(torch.nn.Module):
2776            def __init__(self) -> None:
2777                super().__init__()
2778                # NB: not a parameter or buffer
2779                self.t = torch.randn(3)
2780
2781            def forward(self, x):
2782                return x + torch.cat((self.t, self.t))
2783
2784        mod = Foo()
2785        opt_mod = torch._dynamo.optimize("eager")(mod)
2786        args = [torch.randn(6)]
2787        self.assertTrue(same_two_models(mod, opt_mod, args))
2788        opt_mod(*args)
2789
2790    def test_pointless_graph_removal(self):
2791        cnt = torch._dynamo.testing.CompileCounter()
2792
2793        @torch.compile(backend=cnt)
2794        def fn(x):
2795            with torch.no_grad():
2796                torch._dynamo.graph_break()
2797                return x + 1
2798
2799        fn(torch.randn(4))
2800        self.assertEqual(cnt.frame_count, 1)
2801        self.assertEqual(cnt.op_count, 3)
2802
2803    def test_output_aliases_intermediate(self):
2804        def f(x):
2805            intermediate = x.mul(2)
2806            return intermediate.view(-1), intermediate
2807
2808        opt_f = torch._dynamo.optimize("aot_eager")(f)
2809
2810        for b in [True, False]:
2811            x = torch.randn(4, requires_grad=b)
2812            out = f(x)
2813            out_test = opt_f(x)
2814            self.assertEqual(out[0], out_test[0])
2815            self.assertEqual(out[1], out_test[1])
2816            self.assertEqual(out[0].requires_grad, out_test[0].requires_grad)
2817            self.assertEqual(out[1].requires_grad, out_test[1].requires_grad)
2818            # test that the aliasing relationship of outputs is preserved
2819            out[0].mul_(2)
2820            out_test[0].mul_(2)
2821            self.assertEqual(out[0], out_test[0])
2822            self.assertEqual(out[1], out_test[1])
2823
2824    def test_while_loop_graph_break(self):
2825        # Repro of tacotron2 cache_size_recompilation
2826        def inner(x):
2827            return torch.sin(x)
2828
2829        def fn(x):
2830            i = 20
2831            while i > 10:
2832                x = inner(x)
2833                i -= 1
2834                torch._dynamo.graph_break()
2835            return x
2836
2837        cnt = torch._dynamo.testing.CompileCounter()
2838        opt_fn = torch._dynamo.optimize(cnt)(fn)
2839        x = torch.randn(4)
2840        opt_fn(x)
2841        self.assertEqual(cnt.frame_count, 1)
2842        self.assertEqual(cnt.op_count, 1)
2843
2844    def test_nested_while_loop_graph_break(self):
2845        def inner_loop(x):
2846            i = 3
2847            while i > 0:
2848                i -= 1
2849                x += 1
2850                torch._dynamo.graph_break()
2851            return x
2852
2853        def inner(x):
2854            inner_loop(x)
2855            return torch.sin(x)
2856
2857        def fn(x):
2858            i = 20
2859            while i > 10:
2860                x = inner(x)
2861                i -= 1
2862                torch._dynamo.graph_break()
2863            return x
2864
2865        cnt = torch._dynamo.testing.CompileCounter()
2866        opt_fn = torch._dynamo.optimize(cnt)(fn)
2867        x = torch.randn(4)
2868        opt_fn(x)
2869        self.assertEqual(cnt.frame_count, 1)
2870        self.assertEqual(cnt.op_count, 1)
2871
2872    def test_while_loop_graph_break_inside_call_function(self):
2873        # Repro of huggingface graph break inside loop in `get_parameter_dtype`.
2874        # Skip only the inner frame that has loop that contains graph break.
2875        def inner(x):
2876            for i in range(3):
2877                x += 1
2878                torch._dynamo.graph_break()
2879            return x
2880
2881        def fn(x):
2882            x += 2
2883            inner(x)
2884            x += 3
2885            return x
2886
2887        cnt = torch._dynamo.testing.CompileCounter()
2888        opt_fn = torch._dynamo.optimize(cnt)(fn)
2889        x = torch.randn(4)
2890        opt_fn(x)
2891        self.assertEqual(cnt.frame_count, 2)
2892        self.assertEqual(cnt.op_count, 2)
2893
2894    def test_exception_in_dynamo_handling(self):
2895        hit_handler = False
2896
2897        # See https://github.com/pytorch/pytorch/pull/96488
2898        @contextlib.contextmanager
2899        def ctx():
2900            try:
2901                yield
2902            except RuntimeError:
2903                nonlocal hit_handler
2904                hit_handler = True
2905
2906        @torch._dynamo.optimize("eager")
2907        def f():
2908            with ctx():
2909                h()
2910
2911        def h():
2912            raise RuntimeError("boof")
2913
2914        # Should not error
2915        f()
2916        self.assertTrue(hit_handler)
2917
2918    def test_generator_dealloc(self):
2919        # See https://github.com/pytorch/pytorch/pull/96488
2920        #
2921        # NB: yes, [(...)] is intentional, this is a list containing a
2922        # generator
2923        generator_box = [(x for x in [1, 2, 3])]
2924
2925        counter = torch._dynamo.testing.CompileCounter()
2926
2927        def g(x):
2928            return x + 2
2929
2930        # TODO: This test is pretty delicate.  To test if it's actually doing
2931        # anything, rebuild eval_frame.c with '#define TORCHDYNAMO_DEBUG 1'
2932        # and then look at the logs for:
2933        #
2934        # TRACE[_custom_eval_frame:650] begin <genexpr> test_repros.py 2276 -1 0 0
2935        # TRACE[_custom_eval_frame:664] throw <genexpr>
2936        #
2937        # This means we're actually hitting the relevant codepath
2938
2939        # NB: Make sure we don't actually Dynamo this frame; if we do Dynamo
2940        # this frame, Dynamo actually DOES understand list.clear and will
2941        # arrange for the generator deallocation to happen when the eval frame
2942        # handler is disabled, which will prevent the bug from happening (we
2943        # specifically want to trigger the generator deallocation WHILE the
2944        # dynamo eval frame handler is active), as that will cause the
2945        # generator to become exhausted and trigger the throw_flag == TRUE
2946        # case.
2947        @torch._dynamo.disable(recursive=False)
2948        def f(x):
2949            generator_box.clear()
2950            return g(x)
2951
2952        self.assertNoUnraisable(
2953            lambda: torch._dynamo.optimize(counter)(f)(torch.randn(3))
2954        )
2955
2956        # Make sure the x + 2 is captured (a previous incorrect implementation
2957        # of this fix would have disabled the eval frame callback, which means
2958        # g wouldn't get traced
2959        self.assertEqual(counter.op_count, 1)
2960
2961    def test_error_return_without_exception_set(self):
2962        # https://github.com/pytorch/pytorch/issues/93781
2963        @torch.compile
2964        def f():
2965            _generator_type = type(_ for _ in ())
2966
2967        self.assertNoUnraisable(f)
2968
2969    def common_merge_criteria_processor_list(self, list_cls, fullgraph):
2970        cnt = CompileCounter()
2971
2972        @torch.compile(backend=cnt, fullgraph=fullgraph)
2973        def f(x, left, right):
2974            combined = _merge_criteria_processor_list(left, right)
2975            return combined(x)
2976
2977        l1 = list_cls([torch.nn.ReLU(), torch.nn.Sigmoid()])
2978        l2 = list_cls([])
2979        input = torch.randn(16)
2980        result = f(input, l1, l2)
2981        self.assertEqual(result, l1(input))
2982        self.assertEqual(cnt.frame_count, 1)
2983        self.assertEqual(cnt.op_count, 2)
2984
2985        cnt.clear()
2986        l3 = list_cls([torch.nn.SiLU()])
2987        expected = l3(l1(input))
2988        result = f(input, l1, l3)
2989        self.assertEqual(len(l1), 3)
2990        self.assertEqual(result, expected)
2991        self.assertEqual(cnt.frame_count, 1)
2992        self.assertEqual(cnt.op_count, 3)
2993
2994    def test_merge_criteria_processor_list1(self):
2995        self.common_merge_criteria_processor_list(CustomList1, False)
2996
2997    def test_merge_criteria_processor_list2(self):
2998        self.common_merge_criteria_processor_list(CustomList2, True)
2999
3000    def test_restricted_list_subclass1(self):
3001        cnt = CompileCounter()
3002
3003        @torch.compile(backend=cnt, fullgraph=True)
3004        def fn(a, b):
3005            l = CustomList2()
3006            l.extend([True])
3007            l.append(a)
3008            l.extend([b])
3009            l.pop(0)
3010            l.append(l.length_times_10())
3011            return sum(l)
3012
3013        x = torch.randn(10)
3014        y = torch.randn(10)
3015        self.assertEqual(fn(x, y), x + y + 20)
3016        self.assertEqual(cnt.op_count, 3)
3017
3018    def test_restricted_list_subclass2(self):
3019        cnt = CompileCounter()
3020
3021        @torch.compile(backend=cnt, fullgraph=True)
3022        def fn(a, b):
3023            l1 = CustomList2([a + 1])
3024            l2 = CustomList2([b + 2])
3025            l1.extend(l2)
3026            return l1
3027
3028        x = torch.randn(10)
3029        y = torch.randn(10)
3030        z = fn(x, y)
3031        self.assertEqual(type(z), CustomList2)
3032        self.assertEqual(len(z), 2)
3033        self.assertEqual(z.length_times_10(), 20)
3034        self.assertEqual(list(z), [x + 1, y + 2])
3035
3036    def test_restricted_list_subclass3(self):
3037        cnt = CompileCounter()
3038
3039        @torch.compile(backend=cnt, fullgraph=True)
3040        def fn(a: CustomList2, b: CustomList2):
3041            a.extend(b)
3042            a.append_twice(b[2] + 1)
3043            a.append(b[3] + 2)
3044            return b
3045
3046        x = torch.randn(10)
3047        y = torch.randn(10)
3048        l = CustomList2([x, y])
3049        self.assertIs(fn(l, l), l)
3050        self.assertEqual(len(l), 7)
3051        self.assertIs(l[0], x)
3052        self.assertIs(l[1], y)
3053        self.assertIs(l[2], x)
3054        self.assertIs(l[3], y)
3055        self.assertEqual(l[4], x + 1)
3056        self.assertIs(l[5], l[4])
3057        self.assertEqual(l[6], y + 2)
3058
3059    def test_rewrite_assert_with_msg(self):
3060        def f(x):
3061            b = x.sin()
3062            assert x[0] == 3, "First dim need to be 3"
3063            return x.cos() + b
3064
3065        args = (torch.Tensor([3, 4, 5]),)
3066        cnt = torch._dynamo.testing.CompileCounter()
3067
3068        opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
3069        self.assertTrue(same(f(*args), opt_f(*args)))
3070        self.assertEqual(cnt.op_count, 6)
3071        self.assertEqual(cnt.frame_count, 1)
3072
3073        exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
3074        self.assertTrue(same(exported(*args), f(*args)))
3075
3076    def test_list_aliasing(self):
3077        cnt = CompileCounter()
3078
3079        @torch.compile(backend=cnt, fullgraph=True)
3080        def fn(a):
3081            a.append(torch.sin(a[0]))
3082            return a
3083
3084        x = torch.randn(10)
3085        l = [x]
3086        self.assertIs(fn(l), l)
3087        self.assertEqual(len(l), 2)
3088        self.assertIs(l[0], x)
3089        self.assertEqual(l[1], torch.sin(x))
3090        self.assertEqual(cnt.frame_count, 1)
3091        self.assertEqual(cnt.op_count, 1)
3092
3093    def test_not_rewrite_assert_for_other_errors(self):
3094        def f(x):
3095            b = x.sin()
3096            if not x.sum() <= 3:
3097                raise ValueError("input sum needs to be 3")
3098            return x.cos() + b
3099
3100        args = (torch.Tensor([3, 4, 5]),)
3101        opt_fn = torch._dynamo.optimize("eager")(f)
3102        with self.assertRaisesRegex(ValueError, "input sum needs to be 3"):
3103            opt_fn(*args)
3104
3105    def test_rewrite_assert_dont_change_bytecode(self):
3106        def fn(x):
3107            with torch.no_grad():
3108                assert x.max() < 5, f"invalid max {x.max()}"
3109                x = torch.sin(x)
3110            return x
3111
3112        x = torch.ones(4)
3113        opt_fn = torch._dynamo.optimize("eager")(fn)
3114        self.assertTrue(same(fn(x), opt_fn(x)))
3115
3116    def test_rewrite_assert_without_msg(self):
3117        def f(x):
3118            b = x.sin()
3119            assert x[0] == 3
3120            return x.cos() + b
3121
3122        args = (torch.Tensor([3, 4, 5]),)
3123        exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
3124        self.assertTrue(same(exported(*args), f(*args)))
3125
3126        with self.assertRaisesRegex(RuntimeError, "assertion error"):
3127            exported(torch.Tensor([5, 6, 7]))
3128
3129    def test_rewrite_assert_with_non_string_msg(self):
3130        def f(x):
3131            b = x.sin()
3132            assert x[0] == 2, x.size()
3133            return x.cos() + b
3134
3135        torch._dynamo.utils.counters.clear()
3136        args = torch.Tensor([3, 4, 5])
3137        opt_f = torch._dynamo.optimize("eager")(f)
3138        with self.assertRaisesRegex(AssertionError, "torch.Size"):
3139            opt_f(args)
3140        self.assertEqual(
3141            torch._dynamo.utils.counters["graph_break"][
3142                "assert with non-string message"
3143            ],
3144            1,
3145        )
3146
3147    def test_rewrite_assert_noop(self):
3148        def f(x):
3149            b = x.sin()
3150            assert True
3151            assert x.dtype == torch.float32
3152            return x.cos() + b
3153
3154        args = (torch.Tensor([3, 4, 5]),)
3155        exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
3156        self.assertTrue(same(exported(*args), f(*args)))
3157
3158        cnt = torch._dynamo.testing.CompileCounter()
3159        opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
3160        self.assertTrue(same(f(*args), opt_f(*args)))
3161        # torch._assert shouldn't be in the graph
3162        self.assertEqual(cnt.op_count, 3)
3163        self.assertEqual(cnt.frame_count, 1)
3164
3165        exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5]))
3166        self.assertTrue(same(exported(*args), f(*args)))
3167
3168    def test_size_typematch(self):
3169        def f(x, y):
3170            if isinstance(x, torch.Size):
3171                return y + 1
3172            else:
3173                return y + 2
3174
3175        y = torch.zeros(1)
3176        x1 = torch.Size((3,))
3177        x2 = (3,)
3178
3179        cnt = torch._dynamo.testing.CompileCounter()
3180        opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
3181        self.assertTrue(same(f(x1, y), opt_f(x1, y)))
3182        self.assertTrue(same(f(x2, y), opt_f(x2, y)))
3183        self.assertEqual(cnt.frame_count, 2)
3184
3185    def test_dict_subclass_contains(self):
3186        # pattern from huggingface
3187        class ClassInstantier(collections.OrderedDict):
3188            pass
3189
3190        @torch.compile(fullgraph=True, backend="eager")
3191        def f(x, d):
3192            if "key1" in d:
3193                x = x + 2
3194            if "key2" in d:
3195                x = x + 4
3196            x = x + 8
3197            return x
3198
3199        result = f(torch.ones(8), ClassInstantier({"key1": torch.ones(8)}))
3200        self.assertTrue(same(result, torch.full([8], 11.0)))
3201
3202        result = f(torch.ones(8), ClassInstantier({"key2": torch.ones(8)}))
3203        self.assertTrue(same(result, torch.full([8], 13.0)))
3204
3205    def test_hf_classinstantier(self):
3206        # hf activations.py
3207        class ClassInstantier(collections.OrderedDict):
3208            def __getitem__(self, key):
3209                content = super().__getitem__(key)
3210                cls, kwargs = content if isinstance(content, tuple) else (content, {})
3211                return cls(**kwargs)
3212
3213        ACT2CLS = ClassInstantier(
3214            {
3215                "relu": (nn.ReLU, {"inplace": False}),
3216                "tanh": nn.Tanh,
3217            }
3218        )
3219
3220        @torch.compile(fullgraph=True, backend="eager")
3221        def f(x, act):
3222            return ACT2CLS[act](x)
3223
3224        y = torch.randn(10)
3225        self.assertTrue(same(f(y, "tanh"), torch.tanh(y)))
3226        self.assertTrue(same(f(y, "relu"), torch.relu(y)))
3227
3228    def test_ephemeral_module(self):
3229        # hf activations.py
3230        class ReLUSquaredActivation(nn.Module):
3231            def forward(self, input):
3232                relu_applied = torch.nn.functional.relu(input)
3233                squared = torch.square(relu_applied)
3234                return squared
3235
3236        @torch.compile(fullgraph=True, backend="eager")
3237        def f(x):
3238            x = x + 0.2
3239            x = ReLUSquaredActivation()(x)
3240            x = x + 1
3241            return x
3242
3243        y = torch.randn(10)
3244        self.assertTrue(same(f(y), ReLUSquaredActivation()(y + 0.2) + 1))
3245
3246    def test_inplace_unsqueeze_input(self):
3247        def backend(gm, example_inputs):
3248            self.assertEqual(example_inputs[-1].size(), torch.Size([1, 3, 4]))
3249            return gm
3250
3251        @torch.compile(backend=backend)
3252        def fn(x):
3253            x.unsqueeze_(0)
3254            return x + 1
3255
3256        inputs = [torch.randn(3, 4)]
3257        self.assertEqual(fn(*inputs).size(), torch.Size([1, 3, 4]))
3258        self.assertEqual(inputs[0].size(), torch.Size([1, 3, 4]))
3259
3260    def test_batchnorm_e2e(self):
3261        class Repro(torch.nn.Module):
3262            def __init__(self) -> None:
3263                super().__init__()
3264                self.bn = torch.nn.BatchNorm2d(
3265                    64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
3266                )
3267                self.conv1 = torch.nn.Conv2d(
3268                    64,
3269                    64,
3270                    kernel_size=(3, 3),
3271                    stride=(1, 1),
3272                    padding=(1, 1),
3273                    bias=False,
3274                )
3275
3276            def forward(self, x):
3277                x1 = self.bn(x)
3278                x2 = self.conv1(x1)
3279                out = torch.nn.functional.relu(x2)
3280                return (out,)
3281
3282        torch.manual_seed(1337)
3283
3284        m_ref = Repro()
3285        m_test = deepcopy(m_ref)
3286
3287        @torch._dynamo.optimize("aot_eager_decomp_partition")
3288        def compiled_fn(x):
3289            return m_test(x)
3290
3291        x_ref = torch.randn(2, 64, 32, 32, requires_grad=True)
3292        x_test = x_ref.clone()
3293
3294        # Loop multiple times: each iteration the running_mean/var on batchnorm will update,
3295        # which changes the output of the next iteration
3296        for _ in range(3):
3297            ref = m_ref(x_ref)
3298            res = compiled_fn(x_test)
3299
3300            self.assertTrue(same(ref, res))
3301
3302            for r in ref:
3303                if r.requires_grad:
3304                    r.sum().backward()
3305            for r in res:
3306                if r.requires_grad:
3307                    r.sum().backward()
3308
3309            for param_ref, param_test in zip(m_ref.parameters(), m_test.parameters()):
3310                self.assertTrue(same(param_ref, param_test))
3311            # Assert running_mean/var
3312            for buffer_ref, buffer_test in zip(m_ref.buffers(), m_test.buffers()):
3313                self.assertTrue(same(buffer_ref, buffer_test))
3314
3315    @torch._dynamo.config.patch("assume_static_by_default", False)
3316    def test_dynamic_shapes_right_side(self):
3317        def f(x):
3318            return torch.ones(5 * x.shape[0])
3319
3320        inp = torch.randn(6, 5)
3321
3322        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
3323        self.assertEqual(gm(inp).shape, f(inp).shape)
3324
3325    @torch._dynamo.config.patch("specialize_int", False)
3326    def test_maybe_multiply_symint(self):
3327        # https://github.com/pytorch/pytorch/issues/97346
3328        from torch._functorch.aot_autograd import aot_module_simplified
3329
3330        def my_aot_compiler(gm, example_inputs):
3331            def my_compiler(gm, example_inputs):
3332                return gm.forward
3333
3334            # Invoke AOTAutograd
3335            return aot_module_simplified(gm, example_inputs, fw_compiler=my_compiler)
3336
3337        def my_example(t1, t2, d):
3338            out = torch.add(t1, t2, alpha=d)
3339            return out
3340
3341        compiled_fn = torch.compile(backend=my_aot_compiler, dynamic=True)(my_example)
3342
3343        t1 = torch.arange(3, dtype=torch.float32).requires_grad_(True)
3344        t2 = torch.arange(3, dtype=torch.float32).requires_grad_(True)
3345
3346        ra = compiled_fn(t1, t2, 5)
3347        self.assertEqual(ra, torch.tensor([0.0, 6.0, 12.0]))
3348
3349        ra = compiled_fn(t1, t2, 6)
3350        self.assertEqual(ra, torch.tensor([0.0, 7.0, 14.0]))
3351
3352    def test_build_map_unpack_with_call(self):
3353        def forward_with_cond_scale(x, t, cond_scale, self_cond, other1, other2):
3354            return x.sin() + t + cond_scale + self_cond + other1 + other2
3355
3356        @torch.compile(backend="eager", fullgraph=True)
3357        def fn(x):
3358            d1 = dict(other1=5)
3359            d2 = dict(other2=4)
3360            text_cond = {**d1, **d2}
3361            return forward_with_cond_scale(x, 1, cond_scale=2, self_cond=3, **text_cond)
3362
3363        self.assertTrue(same(fn(torch.ones(4)), torch.ones(4).sin() + 15))
3364
3365    @torch._dynamo.config.patch(verbose=True)
3366    def test_graph_break_unsupported_fake(self):
3367        counter = torch._dynamo.testing.CompileCounter()
3368
3369        @torch._dynamo.optimize(counter)
3370        def f(x):
3371            return torch.ops.test_sample.foo(x + 1) + 1
3372
3373        f(torch.randn(3))
3374
3375        self.assertEqual(counter.op_count, 2)
3376        self.assertEqual(counter.frame_count, 2)
3377
3378    def test_delattr(self):
3379        class MyObj:
3380            def __init__(self, a, b):
3381                self.a = a
3382                self.b = b
3383
3384        @torch.compile(backend="eager", fullgraph=True)
3385        def fn(x, obj):
3386            del obj.a
3387            obj.c = x + 1
3388            del obj.c
3389            tmp = MyObj(x + 2, x + 3)
3390            del tmp.b
3391            if hasattr(obj, "a"):
3392                return x + 1
3393            return tmp
3394
3395        x = torch.zeros([])
3396        obj1 = MyObj(x, x)
3397        obj2 = fn(x, obj1)
3398        self.assertFalse(hasattr(obj1, "a"))
3399        self.assertFalse(hasattr(obj1, "c"))
3400        self.assertFalse(hasattr(obj2, "b"))
3401        self.assertEqual(obj1.b.item(), 0)
3402        self.assertEqual(obj2.a.item(), 2)
3403
3404    def test_delattr_raises(self):
3405        class MyObj:
3406            def __init__(self, a, b):
3407                self.a = a
3408                self.b = b
3409
3410        @torch.compile(backend="eager")
3411        def fn(x, obj):
3412            del obj.a
3413            x = x + 1
3414            obj.a  # will raise
3415            return x
3416
3417        x = torch.zeros([])
3418        obj1 = MyObj(x, x)
3419        self.assertRaises(AttributeError, lambda: fn(x, obj1))
3420
3421    def test_delsubscr(self):
3422        @torch.compile(backend="eager")
3423        def fn(x):
3424            del x["a"]
3425            y = x["b"] + 1
3426            return y
3427
3428        x = {"a": torch.tensor([1]), "b": torch.tensor([1])}
3429        result = fn(x)
3430        self.assertFalse(hasattr(x, "a"))
3431        self.assertEqual(result.item(), 2)
3432
3433    def test_delsubscr_raises(self):
3434        @torch.compile(backend="eager")
3435        def fn(x):
3436            del x["a"]
3437            y = x["a"] + 1  # should raise KeyError
3438            return y
3439
3440        x = {"a": torch.tensor([1]), "b": torch.tensor([1])}
3441        self.assertRaises(KeyError, lambda: fn(x))
3442
3443    def test_attached_attribute_in_dir(self):
3444        class MyModule(torch.nn.Module):
3445            def __init__(self) -> None:
3446                super().__init__()
3447                self.linear = torch.nn.Linear(16, 16)
3448                self.relu = torch.nn.ReLU()
3449
3450            def forward(self, x):
3451                return self.relu(self.linear(x))
3452
3453        mod = torch.compile(MyModule(), backend="eager")
3454        mod.is_compiled = True
3455        self.assertTrue("is_compiled" in dir(mod))
3456
3457    @torch._dynamo.config.patch("automatic_dynamic_shapes", False)
3458    def test_dynamic_shapes_implicit_guard(self):
3459        def f(x):
3460            y = x * x.size(x.shape[0])
3461            torch.sum(y, [y.shape[0]])
3462            return y
3463
3464        cnt = torch._dynamo.testing.CompileCounter()
3465        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f)
3466        opt_fn(torch.randn(3, 1, 1, 1, 1))
3467        self.assertEqual(cnt.frame_count, 1)
3468
3469    def test_dalle2_maybe(self):
3470        def normalize(x):
3471            return x.cos()
3472
3473        @torch.compile(backend="eager", fullgraph=True)
3474        def fn(x, normalize_img):
3475            lowres_cond_img = x.sin()
3476            lowres_cond_img = maybe(normalize_img)(lowres_cond_img)
3477            return lowres_cond_img
3478
3479        self.assertEqual(fn(torch.ones([]), normalize), torch.ones([]).sin().cos())
3480
3481    def test_functools_wraps(self):
3482        def cool_name(x):
3483            return x.sin()
3484
3485        @torch.compile(backend="eager", fullgraph=True)
3486        def fn(x):
3487            y = x.cos()
3488
3489            @functools.wraps(cool_name)
3490            def uncool_name():
3491                return cool_name(y)
3492
3493            return uncool_name
3494
3495        result = fn(torch.ones([]))
3496        self.assertEqual(result.__name__, "cool_name")
3497        self.assertEqual(result(), torch.ones([]).cos().sin())
3498
3499    def test_dynamic_shapes_float_guard(self):
3500        def f(x):
3501            return torch.nn.functional.dropout(x, x.shape[0] / 6)
3502
3503        cnt = torch._dynamo.testing.CompileCounter()
3504        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f)
3505        opt_fn(torch.randn(3))
3506        self.assertEqual(cnt.frame_count, 1)
3507
3508    @torch._dynamo.config.patch(capture_scalar_outputs=True)
3509    def test_tensor_item(self):
3510        def f(x, y):
3511            val = y.item()
3512            return x.sum() + val
3513
3514        gm, _ = torch._dynamo.export(
3515            f,
3516            aten_graph=True,
3517        )(
3518            torch.zeros(6, 4),
3519            torch.tensor(1),
3520        )
3521        self.assertEqual(
3522            f(torch.zeros(6, 4), torch.tensor(1)),
3523            gm(torch.zeros(6, 4), torch.tensor(1)),
3524        )
3525        self.assertEqual(
3526            f(torch.zeros(6, 4), torch.tensor(2)),
3527            gm(torch.zeros(6, 4), torch.tensor(2)),
3528        )
3529
3530    def test_dataclass_init_with_default_factory_with_inputs(self):
3531        @dataclasses.dataclass
3532        class DClass:
3533            sharding_contexts: Any = dataclasses.field(default_factory=list)
3534            a: int = 1
3535
3536        def fn(x, inp_list):
3537            d = DClass(inp_list)
3538            d.sharding_contexts.append(x.sin() + d.a)
3539            return d
3540
3541        x = torch.randn(4)
3542        inp_list1 = [1, 2, 3]
3543        inp_list2 = [2, 3, 4]
3544        inp_list3 = [1, 2]
3545        ref1 = fn(x, inp_list1)
3546        ref2 = fn(x, inp_list2)
3547        ref3 = fn(x, inp_list3)
3548
3549        cnt = torch._dynamo.testing.CompileCounter()
3550        opt_fn = torch.compile(fn, fullgraph=True)
3551
3552        opt_ret1 = opt_fn(x, inp_list1)
3553        opt_ret2 = opt_fn(x, inp_list2)
3554        opt_ret3 = opt_fn(x, inp_list3)
3555        self.assertEqual(ref1.sharding_contexts, opt_ret1.sharding_contexts)
3556        self.assertEqual(ref2.sharding_contexts, opt_ret2.sharding_contexts)
3557        self.assertEqual(ref3.sharding_contexts, opt_ret3.sharding_contexts)
3558
3559    def test_list_index(self):
3560        for i, list_type in enumerate(
3561            (
3562                list,
3563                tuple,
3564                torch.Size,
3565                collections.deque,
3566                namedtuple("FourElems", "one two three four", defaults=[0, 0, 0, 0]),
3567            )
3568        ):
3569            torch._dynamo.reset()
3570            for index in ([], [2], [0, 3]):
3571
3572                def f(t):
3573                    if i == 4:  # namedtuple
3574                        xs = list_type(1, 2, 3, 4)
3575                    else:
3576                        xs = list_type([1, 2, 3, 4])
3577                    res = xs.index(3, *index)
3578                    return t + res
3579
3580                res = torch._dynamo.optimize(backend="eager", nopython=True)(f)(
3581                    torch.zeros(1)
3582                )
3583
3584                self.assertEqual(res, torch.tensor([2.0]))
3585
3586    def test_list_index_not_found(self):
3587        def f(t):
3588            xs = ["bar", "foo", "baz", "buzz"]
3589            res = xs.index("non-existent")
3590            return t + res
3591
3592        # Raising ValueError from item not found is unsupported
3593        with self.assertRaises(
3594            torch._dynamo.exc.Unsupported,
3595        ):
3596            torch._dynamo.optimize(backend="eager", nopython=True)(f)(torch.zeros(1))
3597
3598    def test_list_index_tensor_unsupported(self):
3599        for index in ([], [2], [0, 3]):
3600
3601            def f(t):
3602                xs = [torch.tensor([i]) for i in range(4)]
3603                res = xs.index(torch.tensor([2]), *index)
3604                return t + res
3605
3606            with self.assertRaisesRegex(
3607                torch._dynamo.exc.UserError, "Dynamic control flow is not supported"
3608            ):
3609                torch._dynamo.optimize(backend="eager", nopython=True)(f)(
3610                    torch.zeros(1)
3611                )
3612
3613    def test_hf_xsoftmax_inference(self):
3614        def fn(input, mask):
3615            return XSoftmax.apply(input + 1, mask, 1) + 2
3616
3617        fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
3618
3619        inputs = [
3620            torch.randn(4, 10),
3621            torch.randn(4, 10) < 0,
3622        ]
3623        expected = fn(*inputs)
3624        actual = fn_opt(*inputs)
3625        self.assertTrue(same(actual, expected))
3626
3627    @mock.patch("torch._dynamo.config.guard_nn_modules", True)
3628    def test_hf_xsoftmax_training(self):
3629        from torch._dynamo.utils import counters
3630
3631        counters.clear()
3632
3633        def fn(input, mask):
3634            return XSoftmax.apply(input, mask, 1)
3635
3636        cnt = torch._dynamo.testing.CompileCounter()
3637        fn_opt = torch.compile(fn, backend=cnt, fullgraph=False)
3638
3639        torch.manual_seed(1234)
3640        inputs1 = [
3641            torch.randn(4, 10, requires_grad=True),
3642            torch.randn(4, 10) < 0,
3643        ]
3644        torch.manual_seed(1234)
3645        inputs2 = [
3646            torch.randn(4, 10, requires_grad=True),
3647            torch.randn(4, 10) < 0,
3648        ]
3649
3650        expected = fn(*inputs1)
3651        actual = fn_opt(*inputs2)
3652        self.assertTrue(same(actual, expected))
3653        self.assertEqual(dict(counters["frames"]), {"total": 1, "ok": 1})
3654        self.assertEqual(cnt.op_count, 2)
3655        self.assertEqual(cnt.frame_count, 1)
3656        cnt.clear()
3657        counters.clear()
3658
3659        expected.sum().backward()
3660        actual.sum().backward()
3661        self.assertTrue(same(inputs1[0].grad, inputs2[0].grad))
3662
3663        # currently we don't capture the backwards frame
3664        self.assertEqual(cnt.frame_count, 0)
3665        self.assertEqual(cnt.op_count, 0)
3666        self.assertEqual(dict(counters["frames"]), {})
3667        self.assertEqual(dict(counters["graph_break"]), {})
3668
3669    def test_autograd_function_graph_break(self):
3670        class MySin(torch.autograd.Function):
3671            @staticmethod
3672            def forward(ctx, x):
3673                torch._dynamo.graph_break()
3674                ctx.save_for_backward(x)
3675                return x.sin()
3676
3677            @staticmethod
3678            def backward(ctx, gx):
3679                (x,) = ctx.saved_tensors
3680                return gx * x.cos()
3681
3682        x = torch.randn([], requires_grad=True)
3683
3684        @torch.compile(backend="eager")
3685        def fn(x):
3686            return MySin.apply(x)
3687
3688        y = fn(x)
3689        self.assertEqual(y, x.sin())
3690
3691        (gx,) = torch.autograd.grad(y, x)
3692        self.assertEqual(gx, x.cos())
3693
3694    def test_jit_trace_errors(self):
3695        @torch.compile(backend="eager", dynamic=True)
3696        def f(x):
3697            return x + 1
3698
3699        with self.assertRaises(RuntimeError):
3700            torch.jit.trace(f, torch.randn(3))
3701
3702        with torch._dynamo.config.patch(error_on_nested_jit_trace=False):
3703            torch.jit.trace(f, torch.randn(3))
3704
3705    @torch._dynamo.config.patch("assume_static_by_default", False)
3706    def test_tensor_split(self):
3707        def f(x):
3708            return torch.split(x, x.shape[0] // 2, dim=0)[0]
3709
3710        gm, _ = torch._dynamo.export(
3711            f,
3712            aten_graph=True,
3713        )(
3714            torch.zeros(6, 4),
3715        )
3716
3717        self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4)))
3718
3719    def test_optim_state_references_cleared(self):
3720        model = torch.nn.Linear(2048, 2048, bias=False)
3721        x = torch.ones(2048)
3722        state_ref = 0
3723
3724        optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)
3725
3726        def opt_step():
3727            optimizer.step()
3728
3729        compiled_opt_step = torch._dynamo.optimize("eager")(opt_step)
3730
3731        def compiled_model_step(x):
3732            optimizer.zero_grad()
3733            y = model(x)
3734            torch.sum(y).backward()
3735            compiled_opt_step()
3736
3737        compiled_model_step(x)
3738
3739        # Picked "square_avg" arbitrarily to check that
3740        # optimizer state tensors are deallocated
3741        state_ref = weakref.ref(
3742            optimizer.state[optimizer.param_groups[0]["params"][0]]["square_avg"]
3743        )
3744        optimizer = None
3745
3746        self.assertIsNone(state_ref())
3747
3748    def test_grad_references_cleared(self):
3749        model = torch.nn.Linear(2048, 2048, bias=False)
3750        x = torch.ones(2048)
3751        optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)
3752
3753        def opt_step():
3754            optimizer.step()
3755
3756        compiled_opt_step = torch._dynamo.optimize("eager")(opt_step)
3757
3758        def compiled_model_step(x):
3759            optimizer.zero_grad(True)
3760            y = model(x)
3761            torch.sum(y).backward()
3762            compiled_opt_step()
3763
3764        compiled_model_step(x)
3765        param_grad_ref = weakref.ref(next(iter(model.parameters())).grad)
3766        optimizer.zero_grad(True)
3767        self.assertIsNone(param_grad_ref())
3768
3769    def test_batch_encoding_clone_inputs(self):
3770        class BatchEncoding(dict):
3771            """
3772            Copied from test_tokenization
3773            """
3774
3775            def __init__(
3776                self,
3777                data,
3778            ):
3779                super().__init__(data)
3780
3781            def __getattr__(self, item: str):
3782                try:
3783                    return self.data[item]
3784                except KeyError as e:
3785                    raise AttributeError from e
3786
3787        encoding = BatchEncoding({"key": torch.rand((1, 4))})
3788        cloned_encoding = torch._dynamo.utils.clone_inputs(encoding)
3789        self.assertTrue(type(cloned_encoding) is not dict)
3790
3791    def test_iadd_graph_break(self):
3792        def fn(x):
3793            a = ()
3794            x = torch.sin(x)
3795            a += (x,)
3796            return a
3797
3798        x = torch.randn(4)
3799        ref = fn(x)
3800
3801        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
3802        res = opt_fn(x)
3803        self.assertTrue(same(ref, res))
3804
3805    def test_odict_get_item_index_name(self):
3806        d = {float: torch.float32, np.float16: torch.float16}
3807
3808        @torch.compile(backend="eager")
3809        def f(x, y1, y2):
3810            return torch.zeros(5, dtype=d[y1]), torch.zeros(5, dtype=d[y2])
3811
3812        f(torch.zeros(4), float, np.float16)
3813
3814    def test_dedup_global(self):
3815        @torch.compile()
3816        def f():
3817            return _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR
3818
3819        self.assertEqual(f(), _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR)
3820
3821    def test_randint_out_dynamic(self):
3822        def randint_fn(high, size, out):
3823            return torch.randint(high, size, out=out)
3824
3825        opt_model = torch.compile(randint_fn)
3826
3827        out1 = torch.empty(10, dtype=torch.int32)
3828        opt_model(17, (10,), out1)
3829
3830        out2 = torch.empty(12, dtype=torch.int32)
3831        opt_model(17, (12,), out2)
3832
3833    @requires_cuda
3834    def test_guard_default_device(self):
3835        try:
3836            torch.set_default_device("cuda")
3837
3838            counter = torch._dynamo.testing.CompileCounter()
3839
3840            @torch._dynamo.optimize(counter)
3841            def f():
3842                x = torch.randn(3)
3843                return x * 2
3844
3845            self.assertEqual(f().device.type, "cuda")
3846            self.assertEqual(counter.frame_count, 1)
3847
3848            torch.set_default_device("cpu")
3849
3850            self.assertEqual(f().device.type, "cpu")
3851            self.assertEqual(counter.frame_count, 2)
3852
3853        finally:
3854            torch.set_default_device(None)
3855
3856    def test_list_self_reference(self):
3857        # Issue - https://github.com/pytorch/pytorch/issues/100150
3858        root = []
3859        root[:] = [root, root, None, None]
3860
3861        @torch._dynamo.optimize("eager")
3862        def test_bug():
3863            return root
3864
3865        test_bug()
3866
3867    def test_hf_bigbird_unsqueeze(self):
3868        def torch_bmm_nd(inp_1, inp_2, ndim=None):
3869            torch._dynamo.graph_break()
3870            return torch.bmm(inp1, inp2)
3871
3872        def fn(inp1, inp2, inp3, inp4, c):
3873            a = torch_bmm_nd(inp1, inp2, 4)
3874            a.unsqueeze_(2)
3875            a = a * 2
3876
3877            b = torch_bmm_nd(inp3, inp4, 4)
3878            b.unsqueeze_(2)
3879            l = a + b
3880
3881            out = torch.cat([a, b, c], dim=2)
3882            return out, l
3883
3884        inp1 = torch.rand(1, 64, 448)
3885        inp2 = torch.rand(1, 448, 64)
3886        inp3 = torch.rand(1, 64, 448)
3887        inp4 = torch.rand(1, 448, 64)
3888        c = torch.rand(1, 64, 1, 64)
3889
3890        cnt = torch._dynamo.testing.CompileCounter()
3891        opt_fn = torch._dynamo.optimize(cnt)(fn)
3892        opt_fn(inp1, inp2, inp3, inp4, c)
3893        self.assertEqual(cnt.frame_count, 3)
3894
3895    def test_torch_variable_type(self):
3896        # from torchvision
3897        def check_type(obj, types_or_checks):
3898            for type_or_check in types_or_checks:
3899                if (
3900                    isinstance(obj, type_or_check)
3901                    if isinstance(type_or_check, type)
3902                    else type_or_check(obj)
3903                ):
3904                    return True
3905            return False
3906
3907        opt_check_type = torch._dynamo.optimize("eager")(check_type)
3908        ref = check_type(torch.randn(4), [torch.Tensor])
3909        res = opt_check_type(torch.randn(4), [torch.Tensor])
3910        self.assertEqual(ref, res)
3911
3912    # Test for https://github.com/pytorch/pytorch/issues/103132
3913    @torch._dynamo.config.patch("assume_static_by_default", False)
3914    def test_inference_mode_dynamic_shapes(self):
3915        class Repro(torch.nn.Module):
3916            def __init__(self) -> None:
3917                super().__init__()
3918
3919            def forward(self, param):
3920                z = torch.matmul(param, param)
3921                return z
3922
3923        model = Repro()
3924        # Need a 3d tensor to actually cause the error:
3925        # we go down a path of the C++ matmul decomp that calls sizes().
3926        inp = torch.randn(4, 4, 4, requires_grad=True)
3927        model = torch.compile(model, backend="aot_eager", dynamic=True)
3928        with torch.inference_mode():
3929            model(inp)
3930
3931    def test_kwargs_out_list_variable(self):
3932        class Repro(torch.nn.Module):
3933            def __init__(self) -> None:
3934                super().__init__()
3935
3936            def forward(self, param):
3937                z = torch.frexp(**param)
3938                return z
3939
3940        model = Repro()
3941        params = {"input": torch.tensor([[0.0, 1, 2, 4]])}
3942        params["out"] = [
3943            torch.empty(0, dtype=torch.float32),  # mantissa
3944            torch.empty(0, dtype=torch.int32),  # exponent
3945        ]
3946
3947        model = torch.compile(model, backend="eager")
3948        mantissa, exponent = model(params)
3949        ref_mantissa = torch.tensor([[0.0000, 0.5000, 0.5000, 0.5000]])
3950        ref_exponent = torch.tensor([[0, 1, 2, 3]], dtype=torch.int32)
3951        self.assertEqual(ref_mantissa, mantissa)
3952        self.assertEqual(ref_exponent, exponent)
3953
3954    @torch._dynamo.config.patch(capture_scalar_outputs=True)
3955    def test_split_with_sizes_aot_autograd(self):
3956        def fn(result, split_sizes):
3957            rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist())
3958            return rs
3959
3960        example_inputs = (
3961            torch.randn(32, requires_grad=True),
3962            torch.tensor((7, 16, 9)),
3963        )
3964        actual = torch.compile(fn, fullgraph=True, backend="aot_eager")(*example_inputs)
3965        expected = fn(*example_inputs)
3966        self.assertEqual(actual, expected)
3967
3968    def test_unspecialized_nn_module_with_torch_variable_attribute(self):
3969        """
3970        In this case self.fn = something that should be a TorchVariable.
3971        When it's not a TorchVariable, dynamo tries to trace through and fails.
3972        This makes sure that the self.fn is handled as a TorchVariable.
3973        """
3974
3975        class UserModule(torch.nn.Module):
3976            torchdynamo_force_dynamic = True  # forced to be a UnspecializedNNModule
3977
3978            def __init__(self, fn):
3979                super().__init__()
3980                self.fn = fn
3981
3982            def forward(self, **inp):
3983                return self.fn(**inp)
3984
3985        inputs = {
3986            "input": torch.randn([2, 9]).uniform_(0, 1),
3987            "target": torch.randn([2, 9]).uniform_(0, 1),
3988            "reduction": "mean",
3989        }
3990
3991        mod = UserModule(torch.nn.functional.binary_cross_entropy)
3992        ref = mod(**inputs)
3993        res = torch._dynamo.optimize("eager", nopython=True)(mod)(**inputs)
3994        self.assertEqual(ref, res)
3995
3996    def test_call_finally_python_3_8(self):
3997        # Issue - https://github.com/pytorch/pytorch/issues/97811
3998        def make_fn(g):
3999            def fn():
4000                while True:
4001                    try:
4002                        print(g)
4003                        break
4004                    except Exception as _:
4005                        break
4006
4007            return torch.compile(fn, backend="eager")
4008
4009        make_fn(None)()
4010
4011    def test_call_finally_python_3_8_2(self):
4012        def f(x):
4013            while x:
4014                try:
4015                    pass
4016                except Exception as _:
4017                    continue
4018
4019        torch.compile(f, backend="eager")(0)
4020
4021    def test_call_finally_opcode_python_3_8(self):
4022        def fn():
4023            try:
4024                return torch.zeros(4)
4025            finally:
4026                return torch.ones(4)  # noqa: SIM107, B012
4027
4028        result = torch.compile(fn, backend="aot_eager")()
4029        self.assertEqual(result, torch.ones(4))
4030
4031    def test_string_format(self):
4032        s = "temp{i}"
4033
4034        @torch.compile(backend="eager", fullgraph=True)
4035        def fn(x):
4036            if s.format(i=4) == "temp4":
4037                return torch.sin(x)
4038            return torch.cos(x)
4039
4040        x = torch.randn(4)
4041        self.assertEqual(fn(x), torch.sin(x))
4042
4043    # Repro of torch._dynamo.exc.InternalTorchDynamoError: 'NoneType' object has no attribute 'guards'
4044    # due to bad empty list handling
4045    def test_empty_list_contains_with_jump(self):
4046        def fn(x, l):
4047            if x in l:
4048                return x.cos()
4049            return x.sin()
4050
4051        counter = CompileCounter()
4052        compiled_fn = torch._dynamo.optimize(counter)(fn)(torch.randn([2, 2]), [])
4053        self.assertEqual(counter.frame_count, 1)
4054
4055    def test_graph_break_on_jit_isinstance(self):
4056        @torch.compile(backend="eager")
4057        def fn(x):
4058            if torch.jit.isinstance(x, List[str]):
4059                return x * 2
4060            return x
4061
4062        opt_fn = torch.compile(fn, backend="eager")
4063        x = torch.rand(4)
4064        self.assertTrue(same(fn(x), opt_fn(x)))
4065
4066    def test_add_sub_alpha_out(self):
4067        inp = torch.randn(2, 3, 4)
4068        other = 1
4069        alpha = 2
4070        for op in [torch.add, torch.sub]:
4071            out = torch.zeros(2, 3, 4)
4072            compile_out = torch.zeros(2, 3, 4)
4073            op(inp, other, alpha=alpha, out=out)
4074            compiled_fn = torch.compile(op, dynamic=True)
4075            compiled_fn(inp, other, alpha=alpha, out=compile_out)
4076            self.assertTrue(same(out, compile_out))
4077
4078    def test_negative_shape_guard(self):
4079        def fn(x):
4080            if x.size() != (5, 1, 2, 3):
4081                return x.cos()
4082            return x.sin()
4083
4084        counter = torch._dynamo.testing.CompileCounter()
4085        opt_fn = torch.compile(fn, backend=counter, dynamic=True)
4086
4087        x = torch.ones(5, 1, 3, 4)
4088        x2 = torch.ones(5, 1, 2, 3)
4089        self.assertEqual(fn(x), opt_fn(x))
4090        self.assertEqual(fn(x2), opt_fn(x2))
4091        self.assertEqual(counter.frame_count, 2)
4092
4093    @torch._dynamo.config.patch(capture_scalar_outputs=True)
4094    def test_deferred_runtime_asserts(self):
4095        @torch.compile(fullgraph=True)
4096        def f(x):
4097            y = x.item()
4098            torch._check_is_size(y)
4099            if y >= 0:
4100                return x * 2
4101            else:
4102                return x * 3
4103
4104        f(torch.tensor([3]))
4105        self.assertRaises(RuntimeError, lambda: f(torch.tensor([-2])))
4106
4107    def test_addr_alpha_beta_out(self):
4108        inp = torch.randn(2, 3)
4109        vec1 = torch.randn(2)
4110        vec2 = torch.randn(3)
4111        alpha = 2
4112        beta = 5
4113
4114        out = torch.zeros(2, 3)
4115        compile_out = torch.zeros(2, 3)
4116
4117        torch.addr(inp, vec1, vec2, alpha=alpha, beta=beta, out=out)
4118        compiled_fn = torch.compile(torch.addr, dynamic=True)
4119        compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out)
4120        self.assertTrue(same(out, compile_out))
4121
4122    def test_setattr_requires_grad_graph_breaks(self):
4123        def fn(x):
4124            z = x + 4
4125            x.requires_grad = True
4126            y = x * z
4127            return y
4128
4129        for backend in ["count", "eager", "aot_eager"]:
4130            if backend == "count":
4131                backend = CompileCounter()
4132            opt_fn = torch.compile(fn, backend=backend)
4133
4134            eager = torch.zeros(5)
4135            compiled = eager.clone()
4136
4137            out_eager = fn(eager)
4138            out_opt = opt_fn(compiled)
4139
4140            self.assertEqual(out_eager, out_opt)
4141
4142            out_eager.sum().backward()
4143            out_opt.sum().backward()
4144
4145            self.assertEqual(eager, compiled)
4146            if isinstance(backend, CompileCounter):
4147                self.assertEqual(backend.frame_count, 2)  # graph breaks
4148
4149    def test_dynamic_shapes_double_not_equal(self):
4150        # https://github.com/pytorch/pytorch/issues/113393
4151        def fn(x):
4152            if x.size() != (5, 1, 2, 3):
4153                return x.cos()
4154            return x.sin()
4155
4156        opt_fn = torch.compile(fn, backend="eager")
4157
4158        x = torch.ones(5, 1, 2, 3)
4159        x2 = torch.ones(5, 1, 3, 4)
4160        self.assertEqual(fn(x), opt_fn(x))
4161        self.assertEqual(fn(x2), opt_fn(x2))
4162
4163    def test_inductor_no_recursionerror_on_for_loops(self):
4164        def forward(x):
4165            for _ in range(1000):
4166                x = 1.0 * x
4167            return x
4168
4169        self.assertTrue(
4170            same(torch.compile(forward)(torch.tensor([1.0])), torch.tensor([1.0]))
4171        )
4172
4173    def test_user_defined_object_callable(self):
4174        # https://github.com/pytorch/pytorch/issues/114019
4175        class MyCallable:
4176            def __call__(self, x):
4177                return x + 1
4178
4179        def fn(x):
4180            # Create in graph - will not have source
4181            return MyCallable()(x)
4182
4183        fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
4184        self.assertEqual(fn_opt(torch.zeros(1)), fn(torch.zeros(1)))
4185
4186    @torch._dynamo.config.patch(log_compilation_metrics=True)
4187    def test_many_views_with_mutation(self):
4188        # When symbolic storage offsets were added in #113734, tensors_definitely_do_not_overlap
4189        # began adding shape guards - a quadratic amount relative to the number of inputs.
4190        # Test this configuration, and test that a reasonable number of guards are added.
4191        # Note, when dynamic shapes are turned on, this test fails and we still get quadratic guards.
4192        def fn(x):
4193            x[0].relu_()
4194            return torch.cat(x).sum()
4195
4196        AMT = 32
4197        src = torch.rand(16 * (AMT + 1))
4198
4199        x = [src.as_strided((4, 4), (4, 1), 3 + 16 * i) for i in range(AMT)]
4200
4201        torch._dynamo.reset()
4202        torch._dynamo.utils.clear_compilation_metrics()
4203
4204        res = torch.compile(fn, backend="aot_eager")(x)
4205
4206        all_metrics = torch._dynamo.utils.get_compilation_metrics()
4207
4208        total_guards = sum(metric.guard_count for metric in all_metrics)
4209        self.assertLess(total_guards, AMT * 8)
4210
4211        total_shape_env_guards = sum(
4212            metric.shape_env_guard_count for metric in all_metrics
4213        )
4214        self.assertLess(total_shape_env_guards, AMT * 8)
4215
4216    # https://github.com/pytorch/pytorch/issues/118799
4217    def test_subclass_graph_output_repro(self):
4218        @torch._dynamo.allow_in_graph
4219        def to_subclass(x):
4220            return TwoTensor(x.clone(), x.clone())
4221
4222        def f(x):
4223            tmp_subclass = to_subclass(x)
4224            return tmp_subclass.view(-1)
4225
4226        x = torch.ones(2)
4227        out_ref = f(x)
4228        out_test = torch.compile(f, backend="aot_eager")(x)
4229        self.assertEqual(out_ref, out_test)
4230
4231    def test_numpy_tobytes_no_error(self):
4232        def fn(x):
4233            x += 1
4234            z = x.tobytes()
4235            x += 1
4236            return z
4237
4238        cnt = torch._dynamo.testing.CompileCounter()
4239        opt_fn = torch._dynamo.optimize(cnt)(fn)
4240        opt_arg, arg = np.array([1, 2]), np.array([1, 2])
4241        self.assertEqual(opt_fn(opt_arg), fn(arg))
4242        self.assertEqual(cnt.frame_count, 2)
4243
4244    def test_numpy_not_ndarray_recompiles(self):
4245        import torch
4246
4247        def fn(x=None):
4248            if x is None:
4249                x = np.ones(3)
4250            elif isinstance(x, int):
4251                x = np.ones(6)
4252            elif isinstance(x, str):
4253                x = np.ones(9)
4254            return x**2
4255
4256        cnt = torch._dynamo.testing.CompileCounter()
4257        opt_fn = torch._dynamo.optimize(cnt)(fn)
4258
4259        x = np.zeros((2, 2))
4260
4261        self.assertEqual(opt_fn(x), fn(x))
4262        self.assertEqual(cnt.frame_count, 1)
4263        self.assertEqual(opt_fn(), fn())
4264        self.assertEqual(cnt.frame_count, 2)
4265        self.assertEqual(opt_fn(10), fn(10))
4266        self.assertEqual(cnt.frame_count, 3)
4267        self.assertEqual(opt_fn("10"), fn("10"))
4268        self.assertEqual(cnt.frame_count, 4)
4269
4270    @parametrize(
4271        "backend",
4272        ["eager", "aot_eager", "inductor"],
4273    )
4274    @parametrize(
4275        "func_name",
4276        ["func1", "func2", "func3"],
4277    )
4278    def test_tensor_set_data(self, backend, func_name):
4279        # https://github.com/pytorch/pytorch/issues/113030
4280        def func1(x, y):
4281            x.data = y
4282            x.add_(1)
4283            return x
4284
4285        def func2(x, y):
4286            x.data = y
4287            y.data = torch.zeros([0])
4288            return x
4289
4290        def func3(x, y):
4291            z = x
4292            x.data = y
4293            y.data = torch.zeros([0])
4294            return torch.tensor(x is z)
4295
4296        funcs = {"func1": func1, "func2": func2, "func3": func3}
4297        func = funcs[func_name]
4298
4299        if backend != "eager" and func is func1:
4300            # add_ not working w/ aot_autograd?
4301            return
4302
4303        torch._dynamo.reset()
4304        cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
4305
4306        compiled_fn = torch.compile(func, backend=cnt, fullgraph=True)
4307        requires_grad = func is not func1
4308        for i in range(0, 5):
4309            # Inputs
4310            eager_a = torch.ones([6], requires_grad=requires_grad)
4311            compiled_a = torch.ones([6], requires_grad=requires_grad)
4312
4313            eager_b = torch.ones([6], requires_grad=requires_grad)
4314            compiled_b = torch.ones([6], requires_grad=requires_grad)
4315
4316            # Eager
4317            out_eager = func(eager_a, eager_b)
4318            # Compiled
4319            out_compiled = compiled_fn(compiled_a, compiled_b)
4320            self.assertEqual(eager_a, compiled_a)
4321            self.assertEqual(eager_b, compiled_b)
4322            self.assertTrue(torch.equal(out_eager, out_compiled))
4323
4324            # func1 hits a leaf Variable that requires grad is being used in an in-place operation
4325            if requires_grad:
4326                bwd_inp_eager = torch.randn([6])
4327                bwd_inp_compiled = torch.clone(bwd_inp_eager)
4328                eager_a.backward(bwd_inp_eager)
4329                compiled_a.backward(bwd_inp_compiled)
4330                self.assertEqual(eager_a.grad, compiled_a.grad)
4331
4332        # Prove guarding works - we run the compiled_fn 5 times
4333        # frame_count should stay at 1.
4334        self.assertEqual(cnt.frame_count, 1)
4335
4336    @unittest.skipIf(
4337        TEST_WITH_ROCM or not PLATFORM_SUPPORTS_FLASH_ATTENTION,
4338        "flash attention not supported",
4339    )
4340    def test_flash_attn_backward_mixed_strides(self):
4341        # in this repro, "grad_out" and "value" are transposed tensors,
4342        # but "key" and "value" are contiguous
4343        def gen_inputs(device):
4344            return (
4345                torch.randn(
4346                    2, 513, 16, 64, dtype=torch.float16, device=device
4347                ).transpose(1, 2),
4348                torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
4349                torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
4350                torch.randn(
4351                    2, 513, 16, 64, dtype=torch.float16, device=device
4352                ).transpose(1, 2),
4353                torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
4354                torch.randn(2, 16, 513, device=device),
4355                None,
4356                None,
4357                513,
4358                513,
4359                0.0,
4360                False,
4361                torch.tensor(1, dtype=torch.int64),
4362                torch.tensor(1, dtype=torch.int64),
4363            )
4364
4365        inps_cuda = gen_inputs("cuda")
4366        inps_meta = gen_inputs("meta")
4367        (
4368            out1_ref,
4369            out2_ref,
4370            out3_ref,
4371        ) = torch.ops.aten._scaled_dot_product_flash_attention_backward(
4372            *inps_cuda, scale=0.125
4373        )
4374        from torch._meta_registrations import meta__scaled_dot_product_flash_backward
4375
4376        out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward(
4377            *inps_meta, scale=0.125
4378        )
4379
4380        self.assertEqual(out1_ref.shape, out1_test.shape)
4381        self.assertEqual(out1_ref.stride(), out1_test.stride())
4382        self.assertEqual(out2_ref.shape, out2_test.shape)
4383        self.assertEqual(out2_ref.stride(), out2_test.stride())
4384        self.assertEqual(out3_ref.shape, out3_test.shape)
4385        self.assertEqual(out3_ref.stride(), out3_test.stride())
4386
4387    def test_user_ctor_ctx_manager(self):
4388        class UserCtxManager:
4389            def __enter__(self):
4390                return 1
4391
4392            def __exit__(self, exc_type, exc_val, exc_tb):
4393                pass
4394
4395        def fn(x, y):
4396            ucm = UserCtxManager()
4397            return x * x
4398
4399        cnt = torch._dynamo.testing.CompileCounter()
4400        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
4401        x = torch.rand([2, 2])
4402        opt_fn(x, x)
4403        self.assertExpectedInline(cnt.frame_count, """1""")
4404
4405    @torch._dynamo.config.patch(capture_scalar_outputs=True)
4406    def test_unbacked_arange_in_bounds(self):
4407        # see https://github.com/pytorch/pytorch/issues/113002
4408        class PaddingNet(nn.Module):
4409            def __init__(self) -> None:
4410                super().__init__()
4411
4412            def forward(self, lengths):
4413                max_seq_len = lengths.max().item()
4414                row_vector = torch.arange(0, max_seq_len, 1)
4415                matrix = torch.unsqueeze(lengths, dim=-1)
4416                mask = row_vector < matrix
4417                mask = mask.type(torch.float32)
4418                mask_3d_btd = mask[:, :, None]
4419                return mask_3d_btd
4420
4421        model = PaddingNet()
4422        lengths = torch.tensor([5, 4, 4, 4], dtype=torch.int32)
4423
4424        cnt = torch._dynamo.testing.CompileCounter()
4425        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(model)
4426        opt_fn(lengths)
4427        self.assertEqual(cnt.frame_count, 1)
4428
4429    def test_overlapping_inputs_with_dynamic_shapes_error(self):
4430        @torch.compile(backend="aot_eager")
4431        def fn(a, b, c, d, e, f):
4432            a.mul_(2)
4433            b.mul_(2)
4434            c.mul_(2)
4435            d.mul_(2)
4436            e.mul_(2)
4437            f.mul_(2)
4438
4439            base = torch.ones(2, 20)
4440            a = base[:, 0:2]
4441            b = base[:, 2:4]
4442            c = base[:, 4:6]
4443            d = base[:, 6:8]
4444            e = base[:, 8:10]
4445            f = base[:, 10:12]
4446            f2 = base[:, 10:14]
4447            out = fn(a, b, c, d, e, f)
4448            with self.assertRaisesRegex(
4449                AssertionError, "is being compiled with dynamic shapes"
4450            ):
4451                out2 = fn(a, b, c, d, e, f2)
4452
4453    def test_user_ctor_ctx_manager_custom_init(self):
4454        class UserCtxManager:
4455            def __init__(self, x):
4456                x[0] = 10
4457
4458            def __enter__(self):
4459                return 1
4460
4461            def __exit__(self, exc_type, exc_val, exc_tb):
4462                pass
4463
4464        def fn(x, y):
4465            ucm = UserCtxManager(y)
4466            return x * y[0]
4467
4468        cnt = torch._dynamo.testing.CompileCounter()
4469        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
4470        x = torch.rand([2, 2])
4471        self.assertEqual(opt_fn(x, [5]), fn(x, [5]))
4472        self.assertExpectedInline(cnt.frame_count, """1""")
4473
4474    def test_user_ctor_ctx_manager_custom_init_graph_break(self):
4475        counter = [0]
4476
4477        class UserCtxManager:
4478            def __init__(self, k):
4479                k[0] += 1
4480
4481            def __enter__(self):
4482                return 1
4483
4484            def __exit__(self, exc_type, exc_val, exc_tb):
4485                pass
4486
4487        def fn(x, counter):
4488            x = x * x
4489            ucm = UserCtxManager(counter)
4490            return x * x
4491
4492        cnt = torch._dynamo.testing.CompileCounter()
4493        opt_fn = torch._dynamo.optimize(cnt)(fn)
4494        x = torch.rand([2, 2])
4495        self.assertEqual(opt_fn(x, counter), fn(x, counter))
4496        self.assertEqual(counter[0], 2)
4497        for i in range(0, 10):
4498            opt_fn(x, counter)
4499        self.assertEqual(counter[0], 12)
4500        if torch._dynamo.config.assume_static_by_default:
4501            self.assertExpectedInline(cnt.frame_count, """2""")
4502        else:
4503            self.assertExpectedInline(cnt.frame_count, """1""")
4504
4505    @unittest.expectedFailure
4506    def test_many_overlapping_inputs_does_not_explode_guards(self):
4507        from torch._dynamo.backends.common import aot_autograd
4508
4509        # Before, this was (9702, 0)
4510        num_shape_guards = None
4511        num_aot_guards = None
4512        num_compiles = 0
4513
4514        def guard_count_backend(gm, *args):
4515            nonlocal num_shape_guards
4516            nonlocal num_aot_guards
4517            nonlocal num_compiles
4518            num_shape_guards = len(
4519                torch._guards.TracingContext.try_get().fake_mode.shape_env.guards
4520            )
4521            num_aot_guards = len(
4522                torch._guards.TracingContext.try_get().guards_context.aotautograd_guards
4523            )
4524            num_compiles += 1
4525            return gm
4526
4527        aot_guard_counter = aot_autograd(fw_compiler=guard_count_backend)
4528
4529        @torch.compile(backend=aot_guard_counter, dynamic=True)
4530        def f(*args):
4531            for a in args:
4532                a.add_(1)
4533
4534        x = torch.ones(1000, requires_grad=True)
4535        args = x.split(10)
4536
4537        with torch.no_grad():
4538            f(*args)
4539        # In this example, there were 4950 guards (roughly (# tensors) ^ 2 // 2),
4540        # because every pair of aliased inputs needs a guard.
4541        self.assertTrue(num_aot_guards < 5000)
4542        # But there are no dynamic shape guards.
4543        self.assertEqual(num_shape_guards, 0)
4544        # don't recompile
4545        with torch.no_grad():
4546            f(*args)
4547        self.assertEqual(num_compiles, 1)
4548
4549    def test_invalid_seq_unpack(self):
4550        def myfn(arg):
4551            (a, b) = arg
4552
4553        def fn():
4554            return myfn((1, 2, 3))
4555
4556        try:
4557            torch.compile(fn)()
4558        except ValueError:
4559            pass
4560        else:
4561            self.fail("expected exception")
4562
4563    def test_megablocks_moe(self):
4564        try:
4565            from megablocks.layers import moe
4566            from megablocks.layers.arguments import Arguments
4567        except ImportError as e:
4568            raise unittest.SkipTest("requires megablocks") from e
4569        bs, sl, hs, num_experts, top_k = (16, 1024, 512, 1, 1)
4570        args = Arguments(
4571            hidden_size=hs,
4572            ffn_hidden_size=hs * 2,
4573            moe_num_experts=num_experts,
4574            moe_capacity_factor=1,
4575            moe_top_k=top_k,
4576        )
4577        moe_mlp = moe.MoE(args)
4578        moe_mlp.cuda(torch.cuda.current_device()).half()
4579        x = torch.randn(sl, bs, hs).cuda().half()
4580        out1, _ = moe_mlp(x)
4581        out2, _ = torch.compile(moe_mlp, backend="eager")(x)
4582        self.assertEqual(out1, out2)
4583
4584    def test_udf_classes_reconstruction(self):
4585        def fn(x):
4586            o = T(5)
4587            return o.x + x
4588
4589        opt_fn = torch.compile(fn, backend="eager")
4590        T = IncByOne
4591
4592        x = torch.randn(4)
4593        self.assertEqual(fn(x), opt_fn(x))
4594
4595        # This should recompile
4596        T = IncByTwo
4597        self.assertEqual(fn(x), opt_fn(x))
4598
4599    def test_contains_range_constprop(self):
4600        def fn(x):
4601            # dynamo should const prop to False
4602            if 3 in range(0, 10):
4603                return x + 1
4604            else:
4605                return x + 2
4606
4607        opt_fn = torch.compile(fn, backend="eager")
4608        x = torch.zeros(4)
4609        self.assertEqual(fn(x), opt_fn(x))
4610
4611    # https://github.com/pytorch/pytorch/issues/104505
4612    def test_as_strided_on_base_with_mutation_works(self):
4613        def foo(a):
4614            f = a.as_strided((2,), (1,), 0)
4615            f.add_(1.0)
4616            return a
4617
4618        a = torch.randn(2, 4)
4619        a_ref = a.clone()
4620        out_ref = foo(a_ref)
4621        f_compiled = torch.compile(foo, backend="aot_eager")
4622        out = f_compiled(a)
4623        self.assertEqual(out_ref, out)
4624        self.assertEqual(a_ref, a)
4625
4626    # https://github.com/pytorch/pytorch/issues/104505
4627    def test_as_strided_on_existing_view_banned(self):
4628        def foo(a):
4629            e = a.diagonal()
4630            f = e.as_strided((2,), (1,), 0)
4631            f.add_(1.0)
4632            return a
4633
4634        a = torch.randn(2, 4)
4635        a_ref = a.clone()
4636        out_ref = foo(a_ref)
4637        f_compiled = torch.compile(foo, backend="aot_eager")
4638        with self.assertRaisesRegex(
4639            RuntimeError,
4640            "encountered a mutation on a view chain of length 2, where view 1 was an as_strided",
4641        ):
4642            out = f_compiled(a)
4643
4644    def test_dont_aggressively_write_assert(self):
4645        record_graph = torch._dynamo.testing.EagerAndRecordGraphs()
4646
4647        @torch.compile(dynamic=True, backend=record_graph)
4648        def f(x):
4649            assert x.shape[0] > 3
4650            assert x[0].sum() > 0
4651            assert 1 % (x.shape[0] // 2) != 0
4652            assert 32 * (x.shape[0] // 2) ** 2 - 16 * (x.shape[0] // 2) != 0
4653            return x.cos()
4654
4655        f(torch.ones(6, 4))
4656        graph = record_graph.graphs[0]
4657        # It is bit annoying that we generate useless statements for
4658        # shape guards, but DCE should be able to remove them since t
4659        # there is no backed assert on them. The reason this is ok is
4660        # because dynamo will only skip the assert statement, but not
4661        # the instructions before it.
4662        self.assertExpectedInline(
4663            str(graph.code).strip(),
4664            """\
4665def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
4666    l_x_ = L_x_
4667    getitem_2 = l_x_[0]
4668    sum_1 = getitem_2.sum();  getitem_2 = None
4669    gt_1 = sum_1 > 0;  sum_1 = None
4670    _assert_async = torch._assert_async(gt_1, 'assertion error');  gt_1 = _assert_async = None
4671    cos = l_x_.cos();  l_x_ = None
4672    return (cos,)""",
4673        )
4674        for node in graph.graph.nodes:
4675            if "example_value" in node.meta and isinstance(
4676                node.meta["example_value"], torch._subclasses.fake_tensor.FakeTensor
4677            ):
4678                shape_env = node.meta["example_value"].fake_mode.shape_env
4679                lower_ranges = [val.lower for val in shape_env.var_to_range.values()]
4680                self.assertTrue(lower_ranges == [4, 2])
4681
4682        @torch.compile(dynamic=True, backend=record_graph)
4683        def f_fail(x):
4684            assert x.shape[0] < 3
4685
4686        # We graph-break here, so the failure should be eager
4687        with self.assertRaisesRegex(AssertionError, ""):
4688            f_fail(torch.ones(6, 4))
4689
4690    def test_detectron2_instances_cat(self):
4691        class Instances:
4692            def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
4693                self._image_size = image_size
4694                self._fields: Dict[str, Any] = {}
4695                for k, v in kwargs.items():
4696                    self.set(k, v)
4697
4698            @property
4699            def image_size(self) -> Tuple[int, int]:
4700                return self._image_size
4701
4702            def __setattr__(self, name: str, val: Any) -> None:
4703                if name.startswith("_"):
4704                    super().__setattr__(name, val)
4705                else:
4706                    self.set(name, val)
4707
4708            def __getattr__(self, name: str) -> Any:
4709                if name == "_fields" or name not in self._fields:
4710                    raise AttributeError(
4711                        f"Cannot find field '{name}' in the given Instances!"
4712                    )
4713                return self._fields[name]
4714
4715            def __len__(self) -> int:
4716                for v in self._fields.values():
4717                    # use __len__ because len() has to be int and is not friendly to tracing
4718                    return v.__len__()
4719                raise NotImplementedError("Empty Instances does not support __len__!")
4720
4721            def set(self, name: str, value: Any) -> None:
4722                with warnings.catch_warnings(record=True):
4723                    data_len = len(value)
4724                if len(self._fields):
4725                    assert (
4726                        len(self) == data_len
4727                    ), f"Adding a field of length {data_len} to a Instances of length {len(self)}"
4728                self._fields[name] = value
4729
4730            def get(self, name: str) -> Any:
4731                return self._fields[name]
4732
4733            @staticmethod
4734            def cat(instance_lists: List["Instances"]) -> "Instances":
4735                assert all(isinstance(i, Instances) for i in instance_lists)
4736                assert len(instance_lists) > 0
4737                if len(instance_lists) == 1:
4738                    return instance_lists[0]
4739
4740                image_size = instance_lists[0].image_size
4741                if not isinstance(
4742                    image_size, torch.Tensor
4743                ):  # could be a tensor in tracing
4744                    for i in instance_lists[1:]:
4745                        assert i.image_size == image_size
4746                ret = Instances(image_size)
4747                for k in instance_lists[0]._fields.keys():
4748                    values = [i.get(k) for i in instance_lists]
4749                    v0 = values[0]
4750                    if isinstance(v0, torch.Tensor):
4751                        values = torch.cat(values, dim=0)
4752                    elif isinstance(v0, list):
4753                        values = list(itertools.chain(*values))
4754                    elif hasattr(type(v0), "cat"):
4755                        values = type(v0).cat(values)
4756                    else:
4757                        raise ValueError(
4758                            f"Unsupported type {type(v0)} for concatenation"
4759                        )
4760                    ret.set(k, values)
4761                return ret
4762
4763        instances = [
4764            Instances((16, 16), a=torch.randn(16, 16), b=torch.randn(16, 16))
4765            for _ in range(3)
4766        ]
4767
4768        @torch.compile(backend="eager", fullgraph=True)
4769        def fn(instances):
4770            return instances[0].cat(instances)
4771
4772        actual = fn(instances)
4773        expected = instances[0].cat(instances)
4774        self.assertEqual(type(actual), type(expected))
4775        self.assertEqual(actual.__dict__, expected.__dict__)
4776
4777    def test_weakref(self):
4778        def fn(x_weak, weight, y):
4779            if x_weak is not None and x_weak() is not weight:
4780                return torch.sin(y)
4781            return torch.cos(y)
4782
4783        weight = torch.randn(4)
4784        y = torch.randn(4)
4785        x_weak = weakref.ref(weight)
4786
4787        ref = fn(x_weak, weight, y)
4788
4789        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4790        res = opt_fn(x_weak, weight, y)
4791        self.assertEqual(ref, res)
4792
4793    def test_weakref_reconstruct(self):
4794        def fn(x_weak, weight, y):
4795            y = torch.sin(y)
4796            referent = x_weak()
4797            torch._dynamo.graph_break()
4798            if referent is not weight:
4799                return torch.sin(y)
4800            return torch.cos(y)
4801
4802        weight = torch.randn(4)
4803        y = torch.randn(4)
4804        x_weak = weakref.ref(weight)
4805
4806        ref = fn(x_weak, weight, y)
4807
4808        cnt = torch._dynamo.testing.CompileCounter()
4809        opt_fn = torch.compile(fn, backend=cnt)
4810        res = opt_fn(x_weak, weight, y)
4811        self.assertEqual(ref, res)
4812        self.assertEqual(cnt.frame_count, 2)
4813
4814    def test_weakref_del(self):
4815        def fn(x_weak, y):
4816            x = x_weak()
4817            if x is not None:
4818                return torch.sin(y)
4819            return torch.cos(y)
4820
4821        weight = torch.randn(4)
4822        x_weak = weakref.ref(weight)
4823        y = torch.randn(4)
4824
4825        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4826
4827        ref = fn(x_weak, y)
4828        res = opt_fn(x_weak, y)
4829        self.assertEqual(ref, res)
4830
4831        del weight
4832        gc.collect()
4833        ref = fn(x_weak, y)
4834        res = opt_fn(x_weak, y)
4835        self.assertEqual(ref, res)
4836
4837    #     @torch._functorch.config.patch(
4838    #         recompute_views=True,
4839    #     )
4840    #     def test_storage_resize_forward_full_graph(self):
4841    #         class TestModule(torch.nn.Module):
4842    #             def __init__(self) -> None:
4843    #                 super().__init__()
4844    #                 self.param = torch.nn.Parameter(torch.randn(4, 4))
4845
4846    #             def forward(self, x):
4847    #                 self.param.untyped_storage().resize_(
4848    #                     self.param.numel() * self.param.itemsize
4849    #                 )
4850    #                 with torch.no_grad():
4851    #                     torch._foreach_copy_([self.param], [x])
4852    #                 out = torch.matmul(self.param, self.param)
4853    #                 self.param.untyped_storage().resize_(0)
4854    #                 return out
4855
4856    #         def post_accumulate_grad_hook(param):
4857    #             param.untyped_storage().resize_(0)
4858
4859    #         # Beginning of backward, resize and put data into the param
4860    #         def pre_backward_hook(module, grad) -> None:
4861    #             module.param.untyped_storage().resize_(
4862    #                 self.param.numel() * self.param.itemsize
4863    #             )
4864    #             with torch.no_grad():
4865    #                 # simulates loading data into param from allgather
4866    #                 module.param.fill_(2)
4867
4868    #         def post_forward_hook(module, args, output):
4869    #             output.register_hook(functools.partial(pre_backward_hook, module))
4870
4871    #         x = torch.randn(4, 4)
4872
4873    #         mod_ref = TestModule()
4874    #         mod_test = deepcopy(mod_ref)
4875
4876    #         # Start the param off with zero storage size to mimic fsdp
4877    #         mod_ref.param.untyped_storage().resize_(0)
4878    #         mod_test.param.untyped_storage().resize_(0)
4879
4880    #         # Resize storage at beginning of backward
4881    #         # Free storage at end of backward
4882    #         mod_ref.register_forward_hook(post_forward_hook, prepend=False)
4883    #         mod_ref.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
4884    #         mod_test.register_forward_hook(post_forward_hook, prepend=False)
4885    #         mod_test.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
4886
4887    #         mod_test = torch.compile(mod_test, backend=aot_graph_capture_backend)
4888
4889    #         out_ref = mod_ref(x)
4890    #         out_test = mod_test(x)
4891    #         self.assertExpectedInline(
4892    #             str(fw_graph[0].code.strip()),
4893    #             """\
4894    # def forward(self, primals_1, primals_2):
4895    #     _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]);  primals_1 = primals_2 = None
4896    #     getitem = _foreach_copy[0];  _foreach_copy = None
4897    #     mm = torch.ops.aten.mm.default(getitem, getitem)
4898    #     return [mm, getitem]""",
4899    #         )
4900    #         self.assertEqual(out_ref, out_test)
4901
4902    def test_super_in_staticmethod(self):
4903        class A:
4904            @staticmethod
4905            def foo():
4906                return super().__init__()
4907
4908        def fn(obj):
4909            return obj.foo()
4910
4911        obj = A()
4912
4913        try:
4914            fn(obj)
4915        except Exception as e:
4916            orig_str = str(e)
4917        self.assertIn("no arguments", orig_str)
4918
4919        try:
4920            torch.compile(backend="eager")(fn)(obj)
4921        except Exception as e:
4922            compiled_str = str(e)
4923        self.assertEqual(orig_str, compiled_str)
4924
4925    def test_super_staticmethod(self):
4926        class Parent:
4927            @staticmethod
4928            def greet():
4929                return 5
4930
4931        class Child(Parent):
4932            @staticmethod
4933            def greet(x):
4934                return x * super(Child, Child).greet()
4935
4936        child = Child()
4937
4938        def fn(x):
4939            return child.greet(x)
4940
4941        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4942        x = torch.ones(4)
4943        ref = fn(x)
4944        res = opt_fn(x)
4945        self.assertEqual(ref, res)
4946
4947    def test_super_diamond(self):
4948        class A:
4949            def __init__(self):
4950                super().__init__()
4951                self.a = 5
4952
4953        class Nothing:
4954            pass
4955
4956        class B(Nothing, A):
4957            def __init__(self):
4958                super().__init__()
4959                self.b = 10
4960
4961            def run(self, x):
4962                return self.a * self.b * x
4963
4964        def fn(x):
4965            b = B()
4966            return b.run(x)
4967
4968        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4969        x = torch.randn(4)
4970        ref = fn(x)
4971        res = opt_fn(x)
4972        self.assertEqual(ref, res)
4973
4974    def test_vc_bumped_in_inference_graph(self):
4975        @torch.compile
4976        def f(x):
4977            return x.mul_(2)
4978
4979        x = torch.randn(4)
4980        vc_before = x._version
4981        f(x)
4982        vc_after = x._version
4983        self.assertTrue(vc_after > vc_before)
4984
4985    def test_nn_module_callable(self):
4986        class M(nn.Module):
4987            def forward(self, x):
4988                return x.sin()
4989
4990        def f(m):
4991            return callable(m)
4992
4993        res = torch.compile(f, fullgraph=True)(M())
4994        self.assertTrue(res)
4995
4996    def test_stk_sdd_is_transposed(self):
4997        trigger_graph_break = False
4998
4999        def _is_transposed(x):
5000            return (
5001                not x.is_contiguous()
5002                and x.stride()[0] == 1
5003                and x.stride()[1] == x.size()[0]
5004            )
5005
5006        class SDD(torch.autograd.Function):
5007            @staticmethod
5008            def forward(ctx, lhs, rhs):
5009                ctx.save_for_backward(lhs, rhs)
5010                out = torch.full_like(lhs, 1.0, dtype=lhs.dtype, device=lhs.device)
5011                return out
5012
5013            @staticmethod
5014            def backward(ctx, dy):
5015                saved_tensors = ctx.saved_tensors
5016                lhs, rhs = saved_tensors[:2]
5017                trans_a = _is_transposed(lhs)
5018                trans_b = _is_transposed(rhs)
5019                dlhs = None
5020                if ctx.needs_input_grad[0]:
5021                    dlhs = torch.full_like(lhs, 1.0 if trans_a else 2.0)
5022                drhs = None
5023                if ctx.needs_input_grad[1]:
5024                    drhs = torch.full_like(rhs, 1.0 if trans_b else 2.0)
5025                if trigger_graph_break:
5026                    if _is_transposed(dy):
5027                        return dlhs + 1, drhs + 1, None, None
5028                return dlhs, drhs, None, None
5029
5030        x1 = torch.randn((8, 8), requires_grad=True)
5031        y1 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True)
5032        x2 = torch.randn((8, 8), requires_grad=True)
5033        y2 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True)
5034
5035        SDD.apply(x1, y1).sum().backward()
5036
5037        @torch.compile(backend="eager", fullgraph=True)
5038        def fn():
5039            return SDD.apply(x2, y2)
5040
5041        fn().sum().backward()
5042
5043        self.assertEqual(x1.grad, x2.grad)
5044        self.assertEqual(y1.grad, y2.grad)
5045
5046        trigger_graph_break = True
5047        with self.assertRaises(torch._dynamo.exc.Unsupported):
5048            fn().sum().backward()
5049
5050    def test_partially_initialized_module_property(self):
5051        class Matrix(torch.nn.Module):
5052            def __init__(self, data):
5053                super().__init__()
5054                self._data = data
5055                self.foo = 10 * self.blocking
5056
5057            @property
5058            def data(self):
5059                return self._data
5060
5061            @property
5062            def blocking(self):
5063                return self.data.shape[1]
5064
5065        @torch.compile(backend="eager", fullgraph=True)
5066        def fn():
5067            return Matrix(torch.randn(10, 20))
5068
5069        v = fn()
5070        self.assertEqual(v.foo, 200)
5071        self.assertEqual(v.data.shape, (10, 20))
5072        self.assertEqual(type(v), Matrix)
5073
5074    def test_classmethod_with_slots(self):
5075        class Mock:
5076            __slots__ = ("_a",)
5077
5078            def __init__(self):
5079                self._a = 2
5080
5081            @classmethod
5082            def _m(cls):
5083                return 3
5084
5085            def run(self, x):
5086                return torch.sin(x) * self._a * self._m()
5087
5088        def fn(x):
5089            mock = Mock()
5090            return mock.run(x)
5091
5092        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
5093        x = torch.randn(4)
5094        self.assertEqual(fn(x), opt_fn(x))
5095
5096    def test_nn_parametrize(self):
5097        class Module(nn.Module):
5098            def __init__(self) -> None:
5099                super().__init__()
5100                self.param = torch.nn.Parameter(torch.randn(10, 10))
5101
5102            def forward(self, x):
5103                return self.param @ x
5104
5105        class Parametrization(torch.nn.Module):
5106            def forward(self, x):
5107                return torch.sin(x)
5108
5109        m = Module()
5110        torch.nn.utils.parametrize.register_parametrization(
5111            m, "param", Parametrization()
5112        )
5113
5114        sin_found = False
5115
5116        def backend(gm, _):
5117            nonlocal sin_found
5118            for node in gm.graph.nodes:
5119                if node.target is torch.sin:
5120                    sin_found = True
5121            return gm
5122
5123        opt_m = torch.compile(m, backend=backend, fullgraph=True)
5124        inp = torch.randn(10, 10)
5125        self.assertEqual(m(inp), opt_m(inp))
5126        self.assertTrue(sin_found)
5127
5128        torch.nn.utils.parametrize.remove_parametrizations(m, "param")
5129        sin_found = False
5130        self.assertEqual(m(inp), opt_m(inp))
5131        self.assertFalse(sin_found)
5132
5133    def test_nn_module_property_closure(self):
5134        x = torch.randn(10, 10)
5135
5136        class Mod(torch.nn.Module):
5137            @property
5138            def y(self):
5139                return torch.ones(10, 10) + x
5140
5141            def forward(self, x):
5142                return x @ self.y
5143
5144        mod = Mod()
5145
5146        def fn(x):
5147            return mod(x)
5148
5149        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
5150
5151        inp = torch.randn(10, 10)
5152        self.assertEqual(fn(inp), opt_fn(inp))
5153
5154    def test_global_fn_mutation(self):
5155        def foo(x, y):
5156            return global_fn(x) + y
5157
5158        x = torch.ones(1)
5159        y = torch.ones(1)
5160
5161        opt = torch.compile(foo, fullgraph=True, backend="eager")
5162        self.assertEqual(opt(x, y), foo(x, y))
5163
5164        # Change global_fn
5165        global global_fn
5166
5167        def new_fn(x):
5168            return torch.cos(x)
5169
5170        global_fn = new_fn
5171        self.assertEqual(opt(x, y), foo(x, y))
5172
5173    # ref https://github.com/pytorch/pytorch/issues/123974
5174    def test_list_reverse(self):
5175        def ladder(x):
5176            trail = x.size(-1)
5177            assert trail > 2
5178            weights = []
5179            for s in [trail, trail - 1, trail - 2]:
5180                weights.append(torch.ones(s, s - 1))
5181
5182            for w in weights:
5183                x = x @ w
5184
5185            weights.reverse()
5186
5187            for w in weights:
5188                x = x @ w.t()
5189
5190            return x
5191
5192        data = torch.randn(3, 4)
5193        opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager")
5194        self.assertEqual(opt_ladder(data), ladder(data))
5195
5196    def test_trace_functional_tensor_with(self):
5197        from torch._subclasses.fake_tensor import FakeTensorMode
5198        from torch._subclasses.functional_tensor import (
5199            FunctionalTensor,
5200            FunctionalTensorMode,
5201        )
5202
5203        def f(a, tmp):
5204            a_view = a.view(-1)
5205            with torch.no_grad():
5206                a.set_(tmp)
5207                a_view.mul_(2)
5208            return a + tmp
5209
5210        fake_mode = FakeTensorMode()
5211        with FunctionalTensorMode():
5212            inp = torch.ones(3, 3, requires_grad=True)
5213            inp = fake_mode.from_tensor(inp, static_shapes=True)
5214            inp = FunctionalTensor.to_functional(inp)
5215
5216            tmp = torch.ones(3, 3, requires_grad=True)
5217            tmp = fake_mode.from_tensor(tmp, static_shapes=True)
5218            tmp = FunctionalTensor.to_functional(tmp)
5219
5220            opt_f = torch.compile(f, backend="eager")
5221            with self.assertRaisesRegex(
5222                RuntimeError, "cannot mutate tensors with frozen storage"
5223            ):
5224                opt_f(inp, tmp)
5225
5226    def test_const_dict_keyerror(self):
5227        d = {}
5228
5229        def fn(x):
5230            try:
5231                y = d[0]
5232            except KeyError:
5233                y = 1
5234            return x + y
5235
5236        opt_fn = torch.compile(fn, backend="eager")
5237        inp = torch.randn(3, 3)
5238        self.assertEqual(fn(inp), opt_fn(inp))
5239
5240    def test_dict_tag_guard(self):
5241        class Foo:
5242            def __init__(self) -> None:
5243                self.scalar = 10
5244
5245        def fn(d, x):
5246            return d["a"] * d["b"] * d["c"].scalar * x
5247
5248        foo = Foo()
5249
5250        d = {"a": 2, "b": 3, "c": foo}
5251
5252        opt_fn = torch.compile(fn, backend="eager")
5253        inp = torch.randn(3, 3)
5254        self.assertEqual(fn(d, inp), opt_fn(d, inp))
5255
5256        d["a"] = 4
5257        self.assertEqual(fn(d, inp), opt_fn(d, inp))
5258
5259        # Check that recompilation happens
5260        foo.scalar = 12
5261        self.assertEqual(fn(d, inp), opt_fn(d, inp))
5262
5263    def test_nonconst_issubclass(self):
5264        def fn(x):
5265            if issubclass(x.__class__, np.ndarray):
5266                return 1
5267            return 0
5268
5269        opt_fn = torch.compile(fn, backend="eager")
5270        opt_fn(np.ones([3, 3]))
5271
5272    def test_issue126128(self):
5273        def fn():
5274            x = torch.randn(1, 10)
5275            y = torch.randn(10, 1)
5276            return torch.mm(x, y).sum()
5277
5278        def fn2():
5279            x = torch.randn(10, 100)
5280            y = torch.randn(100, 10)
5281            return torch.mm(x, y).sum()
5282
5283        with fresh_inductor_cache():
5284            torch.compile(fn)()
5285
5286        torch.compile(fn2)()
5287
5288    def test_jit_script_defaults(self):
5289        @torch.jit.script
5290        def fast_cos(x, c: float = 2.0):
5291            return torch.cos(x) * c
5292
5293        class Mod(torch.nn.Module):
5294            def __init__(self) -> None:
5295                super().__init__()
5296                self.fast_cos = fast_cos
5297
5298            def forward(self, x):
5299                return self.fast_cos(x)
5300
5301        mod = Mod()
5302        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
5303        x = torch.randn(4)
5304        self.assertEqual(mod(x), opt_mod(x))
5305
5306    def test_enum(self):
5307        class ExplicitEnum(str, Enum):
5308            @classmethod
5309            def _missing_(cls, value):
5310                raise ValueError(
5311                    f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
5312                )
5313
5314        class PaddingStrategy(ExplicitEnum):
5315            LONGEST = "longest"
5316            MAX_LENGTH = "max_length"
5317            DO_NOT_PAD = "do_not_pad"
5318
5319        def fn(x):
5320            a = PaddingStrategy("longest")
5321            if a == PaddingStrategy.LONGEST:
5322                return torch.sin(x)
5323            return torch.cos(x)
5324
5325        x = torch.randn(3, 3)
5326        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
5327        self.assertEqual(fn(x), opt_fn(x))
5328
5329    def test_hasattr_builtin(self):
5330        class MyClass:
5331            foo: int = 1
5332
5333        def func(x, m):
5334            if getattr(type(m), "foo", 0):
5335                return x + MyClass.foo
5336            return x
5337
5338        opt_func = torch.compile(func, backend="eager", fullgraph=True)
5339        m = MyClass()
5340        x = torch.zeros(())
5341        self.assertEqual(func(x, m), opt_func(x, m))
5342        self.assertEqual(func(x, 0), opt_func(x, 0))
5343
5344    def test_grad(self):
5345        def fn(x, y):
5346            x._grad = y
5347            return x.grad.data
5348
5349        x = torch.randn(4, requires_grad=True)
5350        y = torch.randn(4)
5351        opt_fn = torch.compile(fn, backend="eager")
5352        self.assertEqual(fn(x, y), opt_fn(x, y))
5353
5354    def test_nn_module_stack_bc(self):
5355        from torch._dynamo.mutation_guard import GenerationTracker
5356
5357        def compiler(gm, *args):
5358            module_stacks = [
5359                node.meta.get("nn_module_stack", None) for node in gm.graph.nodes
5360            ]
5361            module_stacks, _ = pytree.tree_flatten(module_stacks)
5362            module_stacks = [x for x in module_stacks if isinstance(x, str)]
5363            for stack in module_stacks:
5364                self.assertTrue("_module" not in stack)
5365            return gm.forward
5366
5367        class SubMod(torch.nn.Module):
5368            def __init__(self) -> None:
5369                super().__init__()
5370                self.linear = torch.nn.Linear(2, 2)
5371
5372            def forward(self, x):
5373                return self.linear(x)
5374
5375        class Mod(torch.nn.Module):
5376            def __init__(self) -> None:
5377                super().__init__()
5378                self.submod1 = SubMod()
5379                self.submod2 = SubMod()
5380
5381            def forward(self, x):
5382                return self.submod1(x) + self.submod2(x)
5383
5384        mod = Mod()
5385        opt_mod = torch.compile(mod, backend=compiler)
5386        opt_mod(torch.randn(2, 2))
5387
5388        with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True):
5389            mod = Mod()
5390            opt_mod = torch.compile(mod, backend=compiler)
5391            opt_mod(torch.randn(2, 2))
5392
5393        # an example similar to Pippy usecase
5394        mod = Mod()
5395        GenerationTracker.tag(mod.submod1)
5396        GenerationTracker.mark_class_dynamic(type(mod.submod1))
5397        mod = Mod()
5398        opt_mod = torch.compile(mod, backend=compiler)
5399        opt_mod(torch.randn(2, 2))
5400
5401    def test_is_make_fx_tracing(self):
5402        @torch.compile(backend="eager", fullgraph=True)
5403        def fn(x):
5404            torch.nn.modules.activation._is_make_fx_tracing()
5405            return torch.sin(x)
5406
5407        fn(torch.rand(4))
5408
5409    def test_negative_floor_div_solve(self):
5410        class CompiledClass(nn.Module):
5411            def __init__(self) -> None:
5412                super().__init__()
5413                self.nums = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
5414                self.t = 5
5415
5416            def forward(self):
5417                self.num = self.nums[self.t // 12]
5418                self.t += 1
5419                return self.num
5420
5421        m = CompiledClass()
5422        m = torch.compile(m, backend="eager")
5423
5424        # the first call works
5425        m()
5426        # the second call causes a failure
5427        m()
5428
5429    # https://github.com/pytorch/pytorch/issues/121621
5430    def test_tensor_random(self):
5431        def random_op(tensor, params):
5432            res = tensor.random_(**params)
5433            return res
5434
5435        random_op = torch.compile(random_op)
5436        params = {"from": -10, "to": 10}
5437        tensor = torch.randn([2, 3])
5438        res = random_op(tensor, params)
5439
5440    # https://github.com/pytorch/pytorch/issues/131019
5441    def test_tensor_uniform(self):
5442        def uniform_op(tensor, params):
5443            res = tensor.uniform_(**params)
5444            return res
5445
5446        uniform_op = torch.compile(uniform_op)
5447        params = {"from": -10, "to": 10}
5448        tensor = torch.randn([2, 3])
5449        res = uniform_op(tensor, params)
5450
5451    def test_data_attr_mutation_after_saved_for_bw(self):
5452        def f(x):
5453            out = x.sin()
5454            x.data.mul_(2)
5455            return out
5456
5457        x = torch.randn(4, requires_grad=True)
5458        x_test = x.clone().detach().requires_grad_(True)
5459
5460        out = f(x)
5461        out_test = torch.compile(f, backend="aot_eager")(x_test)
5462        self.assertEqual(out, out_test)
5463
5464        out.sum().backward()
5465        out_test.sum().backward()
5466        self.assertEqual(x.grad, x_test.grad)
5467
5468    # https://github.com/pytorch/pytorch/issues/128072
5469    def test_map_with_multiple_args(self):
5470        def f(a, b):
5471            return a[0] * b[0] + a[1] * b[1]
5472
5473        def gen_inps(len_x, len_y):
5474            x = [torch.randn(5) for _ in range(len_x)]
5475            y = [torch.randn(5) for _ in range(len_y)]
5476            return x, y
5477
5478        def g(x, y):
5479            return map(f, x, y)
5480
5481        opt_g = torch.compile(g, fullgraph=True, backend="eager")
5482
5483        inps = gen_inps(3, 3)
5484        self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
5485        self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
5486
5487        inps = gen_inps(3, 5)
5488        self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
5489        self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
5490
5491    def test_staticmethod_allow_in_graph(self):
5492        class MyClass:
5493            i = 3
5494
5495            @staticmethod
5496            def foo_inner(x):
5497                return torch.mul(x, MyClass.i)
5498
5499            # if dynamo inlines with fullgraph, will error
5500            # verify that dynamo doesn't inline
5501            @staticmethod
5502            @torch._dynamo.allow_in_graph
5503            def foo1(x):
5504                torch._dynamo.graph_break()
5505                return MyClass.foo_inner(x)
5506
5507        @torch.compile(backend="eager", fullgraph=True)
5508        def f_bad(x):
5509            return MyClass.foo1(x)
5510
5511        f_bad(torch.ones(2, 2))
5512
5513    def test_guard_with_tuple_mutation(self):
5514        class Foo:
5515            def __init__(self) -> None:
5516                self.x = 10
5517
5518        foo = Foo()
5519        d = {
5520            "a": 2,
5521            "b": (foo,),
5522        }
5523
5524        def fn(x, d):
5525            return x * d["a"] * d["b"][0].x
5526
5527        opt_fn = torch.compile(fn, backend="eager")
5528        inp = torch.randn(3, 3)
5529        self.assertEqual(fn(inp, d), opt_fn(inp, d))
5530        d["b"][0].x = 12
5531        self.assertEqual(fn(inp, d), opt_fn(inp, d))
5532
5533    def test_compile_complex_conj(self):
5534        def f(x):
5535            return torch.mul(x, 2j)
5536
5537        x_ref = torch.randn(4, 2, requires_grad=True)
5538        x_test = x_ref.clone().detach().requires_grad_(True)
5539
5540        out_ref = f(torch.view_as_complex(x_ref))
5541        out_test = torch.compile(f, backend="aot_eager")(torch.view_as_complex(x_test))
5542        self.assertEqual(out_ref, out_test)
5543
5544        torch.view_as_real(out_ref).sum().backward()
5545        torch.view_as_real(out_test).sum().backward()
5546        self.assertEqual(x_ref.grad, x_test.grad)
5547
5548    # https://github.com/pytorch/pytorch/issues/132200
5549    def test_partitioner_cse_respects_mutation_boundaries(self):
5550        set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")
5551        if not set_available:
5552            return
5553
5554        @torch.compile(backend="aot_eager_decomp_partition")
5555        def f(x, l):
5556            # z0 and z1 can be CSEd
5557            z0 = x.sin()
5558            z1 = x.sin()
5559            y = x + 1
5560            torch.ops.fsdp.set_.default(x, y)
5561            # z3 and z3 can be CSEd with each other,
5562            # but *not* with z0/z1 (they cross a mutation boundary)
5563            z2 = x.sin()
5564            z3 = x.sin()
5565            return z0, z1, z2, z3, l**2
5566
5567        x = torch.randn(3)
5568        x_clone = x.clone()
5569        l = torch.randn(3, requires_grad=True)
5570        z0, z1, z2, z3, _ = f(x, l)
5571
5572        # the partitioner runs CSE. We expect that of the 4 sin() ops above:
5573        # - the first 2 are CSE'd
5574        # - the last 2 are CSE'd
5575        # - the set_() op in the middle is a mutation barrier, preventing CSE
5576        self.assertEqual(z0, (x_clone).sin())
5577        self.assertEqual(z1, (x_clone).sin())
5578        self.assertEqual(z2, (x_clone + 1).sin())
5579        self.assertEqual(z3, (x_clone + 1).sin())
5580
5581    # https://github.com/pytorch/pytorch/issues/132197
5582    def test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients(self):
5583        set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")
5584        if not set_available:
5585            return
5586
5587        @torch.compile(backend="aot_eager_decomp_partition")
5588        def f(x, l):
5589            z = x.sin()
5590            y = x + 1
5591            # graph input has its storage mutated
5592            torch.ops.fsdp.set_.default(x, y)
5593            z2 = x.sin()
5594            return z2, l**2
5595
5596        x = torch.randn(3)
5597        x_test = x.clone()
5598        l = torch.randn(3, requires_grad=True)
5599        result, _ = f(x, l)
5600        result_test, _ = torch.compile(f, backend="aot_eager_decomp_partition")(
5601            x_test, l
5602        )
5603
5604        self.assertEqual(result, result_test)
5605        self.assertEqual(x, x_test)
5606
5607    def test_changing_stride(self):
5608        cnt = torch._dynamo.testing.CompileCounter()
5609
5610        @torch.compile(backend=cnt)
5611        def fn(x, y):
5612            return x * y
5613
5614        for i in range(1, 4):
5615            x = torch.randn(4, i)
5616
5617            # create a view for i > 1
5618            if i == 1:
5619                x1 = x
5620            else:
5621                x1 = x[:, 0:1]
5622
5623            y = torch.randn(4, 1)
5624            print(x1.shape, y.shape)
5625            fn(x1, y)
5626
5627        self.assertTrue(cnt.frame_count <= 2)
5628
5629    @torch._dynamo.config.patch(guard_nn_modules=False)
5630    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
5631    def test_inlining_cornercase(self):
5632        """
5633        nn.Modules can be mapped to either NNModuleVariable or UnspecializedNNModuleVariable. For NNModuleVariable, the
5634        tensor attributes become part of the Dynamo graph. For unspecialized, they are lifted as inputs.
5635
5636        But there is a cornercase. Suppose you have NNModuleVariable with a submodule that is
5637        UnspecializedNNModuleVariable. Today, Dynamo will still consider the submodule as specialized (courtesy of
5638        guard.source().is_nn_module()). In retrospect, this is a mistake but there are dependencies of export and also
5639        cudagraphs which make it harder to fix the corner case right away. The long term solution is
5640        inline_inbuilt_nn_modules anyways, so we might have to live with this cornercase in the short term.
5641
5642        We are starting to annotate the source of each nn module more precisely - NNModuleVariable attribute is marked
5643        as NNModuleSource, UnspecilaizedNNModuleVariable attribute is marked as UnspecializedNNModuleSource. But this
5644        changes the behavior for the cornercase. And fails some tests which have unfortunately relied on this behavior.
5645
5646
5647        To solve this, we tag the source only when inline_inbuilt_nn_module flag is turned on.
5648
5649        In this test, we purposely turn the flag off, testing that the tagging is disabled.
5650        """
5651
5652        class SubMod(torch.nn.Module):
5653            def __init__(self):
5654                super().__init__()
5655                self.linear = torch.nn.Linear(1, 1)
5656                self.a = torch.randn(1, 1)
5657                self.counter = 0
5658                self.multipliers = [2.2, 3.3]
5659
5660            def forward(self, x):
5661                self.counter += 1
5662                return (
5663                    self.linear(x) * self.a * self.multipliers[0] * self.multipliers[1]
5664                )
5665
5666        class Mod(torch.nn.Module):
5667            def __init__(self):
5668                super().__init__()
5669                self.submod = SubMod()
5670
5671            def forward(self, x):
5672                return self.submod(x)
5673
5674        mod = Mod()
5675        opt_mod = torch.compile(mod, backend="eager")
5676
5677        x = torch.randn(1, 1)
5678        ref = mod(x)
5679        res = opt_mod(x)
5680
5681        mod.submod.multipliers = [3.3, 4.4]
5682        # Since guard_nn_modules is False, this will not recompile
5683        with torch._dynamo.config.patch(error_on_recompile=True):
5684            ref = mod(x)
5685            res = opt_mod(x)
5686
5687    def test_optimized_module_training(self):
5688        mod = torch.nn.Linear(3, 3)
5689        mod.eval()
5690
5691        opt_mod = torch.compile(mod, backend="eager")
5692        self.assertFalse(opt_mod.training)
5693
5694        opt_mod.train()
5695        self.assertTrue(opt_mod.training)
5696        self.assertTrue(mod.training)
5697
5698        mod.eval()
5699        self.assertFalse(opt_mod.training)
5700
5701    @requires_cuda
5702    def test_memleak_when_graph_input_has_tensor_attr(self):
5703        @torch.compile(backend="eager")
5704        def f(x):
5705            x.add_(1)
5706
5707        mem_before = torch.cuda.memory_allocated()
5708
5709        x = torch.ones(2, device="cuda")
5710        x.foo = torch.zeros(2, device="cuda")
5711        f(x)
5712        del x.foo
5713        del x
5714        mem_after = torch.cuda.memory_allocated()
5715        self.assertEqual(mem_before, mem_after)
5716
5717        # check when non-tensor data structure attribute contains a tensor
5718        @torch.compile(backend="eager")
5719        def f(x):
5720            x.add_(1)
5721
5722        mem_before = torch.cuda.memory_allocated()
5723        x = torch.ones(2, device="cuda")
5724        x.foo = [torch.zeros(2, device="cuda") for _ in range(5)]
5725        f(x)
5726        del x.foo
5727        del x
5728        mem_after = torch.cuda.memory_allocated()
5729        self.assertEqual(mem_before, mem_after)
5730
5731        # check with tensor refcycle
5732        @torch.compile(backend="eager")
5733        def g(x, y):
5734            return x + y
5735
5736        mem_before = torch.cuda.memory_allocated()
5737        x = torch.ones(2, device="cuda")
5738        y = torch.zeros(2, device="cuda")
5739        x.foo = [y]
5740        y.foo = [x]
5741        g(x, y)
5742        del x.foo
5743        del y.foo
5744        del x
5745        del y
5746        mem_after = torch.cuda.memory_allocated()
5747        self.assertEqual(mem_before, mem_after)
5748
5749    def test_os_fspath(self):
5750        @torch.compile(backend="eager", fullgraph=True)
5751        def fn(x):
5752            os.fspath(".")
5753            return torch.sin(x)
5754
5755        fn(torch.randn(4))
5756
5757    @requires_cuda
5758    # This test will fail as flip in combination with particular input lenghts
5759    # produces weird results.
5760    # This is under investigations in
5761    # https://github.com/pytorch/pytorch/issues/131805
5762    @unittest.skip("Skip this flip test for the moment. It is under investigation")
5763    def test_flip_bad_accuracy(self):
5764        import torch
5765        import torch._dynamo.config
5766        import torch._functorch.config
5767        import torch._inductor.config
5768        import torch._inductor.inductor_prims
5769        import torch.fx.experimental._config
5770
5771        class Repro(torch.nn.Module):
5772            def __init__(self):
5773                super().__init__()
5774
5775            def forward(self, arg0_1):
5776                rev = torch.ops.prims.rev.default(arg0_1, [0])
5777                arg0_1 = None
5778                slice_1 = torch.ops.aten.slice.Tensor(rev, 0, 0, -1, 2)
5779                slice_2 = torch.ops.aten.slice.Tensor(rev, 0, 1, 9223372036854775807, 2)
5780                add_1 = torch.ops.aten.add.Tensor(slice_1, slice_2)
5781                slice_1 = slice_2 = None
5782                slice_3 = torch.ops.aten.slice.Tensor(add_1, 0, 0, -1, 2)
5783                slice_4 = torch.ops.aten.slice.Tensor(
5784                    add_1, 0, 1, 9223372036854775807, 2
5785                )
5786                add_2 = torch.ops.aten.add.Tensor(slice_3, slice_4)
5787                slice_3 = slice_4 = None
5788                slice_5 = torch.ops.aten.slice.Tensor(add_2, 0, 0, -1, 2)
5789                slice_6 = torch.ops.aten.slice.Tensor(
5790                    add_2, 0, 1, 9223372036854775807, 2
5791                )
5792                add_3 = torch.ops.aten.add.Tensor(slice_5, slice_6)
5793                slice_5 = slice_6 = None
5794                slice_9 = torch.ops.aten.slice.Tensor(add_2, 0, 0, 1)
5795                add_2 = None
5796                unsqueeze = torch.ops.aten.unsqueeze.default(slice_9, 1)
5797                slice_9 = None
5798                unsqueeze_1 = torch.ops.aten.unsqueeze.default(add_3, 1)
5799                add_3 = None
5800                cat = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1)
5801                unsqueeze = unsqueeze_1 = None
5802                view = torch.ops.aten.view.default(cat, [2])
5803                cat = None
5804                slice_10 = torch.ops.aten.slice.Tensor(view, 0, 0, -1)
5805                slice_11 = torch.ops.aten.slice.Tensor(
5806                    add_1, 0, 2, 9223372036854775807, 2
5807                )
5808                add_5 = torch.ops.aten.add.Tensor(slice_10, slice_11)
5809                slice_10 = slice_11 = None
5810                slice_12 = torch.ops.aten.slice.Tensor(add_1, 0, 0, 1)
5811                add_1 = None
5812                cat_1 = torch.ops.aten.cat.default([slice_12, add_5])
5813                slice_12 = add_5 = None
5814                unsqueeze_2 = torch.ops.aten.unsqueeze.default(cat_1, 1)
5815                cat_1 = None
5816                unsqueeze_3 = torch.ops.aten.unsqueeze.default(view, 1)
5817                view = None
5818                cat_2 = torch.ops.aten.cat.default([unsqueeze_2, unsqueeze_3], 1)
5819                unsqueeze_2 = unsqueeze_3 = None
5820                view_1 = torch.ops.aten.view.default(cat_2, [4])
5821                cat_2 = None
5822                slice_13 = torch.ops.aten.slice.Tensor(
5823                    rev, 0, 2, 9223372036854775807, 2
5824                )
5825                add_6 = torch.ops.aten.add.Tensor(view_1, slice_13)
5826                slice_13 = None
5827                slice_14 = torch.ops.aten.slice.Tensor(rev, 0, 0, 1)
5828                rev = None
5829                cat_3 = torch.ops.aten.cat.default([slice_14, add_6])
5830                slice_14 = add_6 = None
5831                constant_pad_nd = torch.ops.aten.constant_pad_nd.default(
5832                    view_1, [0, 1], 0.0
5833                )
5834                view_1 = None
5835                unsqueeze_4 = torch.ops.aten.unsqueeze.default(cat_3, 1)
5836                cat_3 = None
5837                unsqueeze_5 = torch.ops.aten.unsqueeze.default(constant_pad_nd, 1)
5838                constant_pad_nd = None
5839                cat_4 = torch.ops.aten.cat.default([unsqueeze_4, unsqueeze_5], 1)
5840                unsqueeze_4 = unsqueeze_5 = None
5841                view_2 = torch.ops.aten.view.default(cat_4, [10])
5842                cat_4 = None
5843                slice_15 = torch.ops.aten.slice.Tensor(view_2, 0, 0, 9)
5844                view_2 = None
5845                rev_1 = torch.ops.prims.rev.default(slice_15, [0])
5846                slice_15 = None
5847                return (rev_1,)
5848
5849        mod = Repro()
5850        x = torch.arange(9, device=torch.device("cuda"))
5851
5852        @torch.compile
5853        def f(x):
5854            return mod(x)
5855
5856        out = f(x)
5857        self.assertEqual(torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]), out[0])
5858
5859    # https://github.com/pytorch/pytorch/issues/88813
5860    def test_return_value_duplication_tensor(self) -> None:
5861        def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5862            return val * 2, val * 2
5863
5864        x = torch.randn(2, requires_grad=True)
5865
5866        expect = fn(x)
5867        self.assertNotEqual(
5868            expect[0].untyped_storage().data_ptr(),
5869            expect[1].untyped_storage().data_ptr(),
5870        )
5871
5872        actual = torch.compile(fn, backend="aot_eager")(x)
5873        self.assertNotEqual(
5874            actual[0].untyped_storage().data_ptr(),
5875            actual[1].untyped_storage().data_ptr(),
5876        )
5877
5878    # https://github.com/pytorch/pytorch/issues/114344
5879    def test_return_value_duplication_mixed_grad(self) -> None:
5880        def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5881            with torch.no_grad():
5882                out0 = val + 1
5883            out1 = val + 1
5884            return out0, out1
5885
5886        x = torch.randn(2, requires_grad=True)
5887
5888        with torch.enable_grad():
5889            expect = fn(x)
5890            actual = torch.compile(fn, backend="aot_eager")(x)
5891
5892            self.assertEqual(expect[0].requires_grad, actual[0].requires_grad)
5893            self.assertEqual(expect[1].requires_grad, actual[1].requires_grad)
5894
5895    # https://github.com/pytorch/pytorch/pull/134726#discussion_r1738774371
5896    def test_return_value_duplication_scalar(self) -> None:
5897        def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5898            x, y = val * 2, val * 2
5899            return x[0], y[0]
5900
5901        x = torch.randn(2, requires_grad=True)
5902
5903        expect = fn(x)
5904        self.assertNotEqual(
5905            expect[0].untyped_storage().data_ptr(),
5906            expect[1].untyped_storage().data_ptr(),
5907        )
5908
5909        actual = torch.compile(fn, backend="aot_eager")(x)
5910        self.assertNotEqual(
5911            actual[0].untyped_storage().data_ptr(),
5912            actual[1].untyped_storage().data_ptr(),
5913        )
5914
5915
5916instantiate_parametrized_tests(ReproTests)
5917
5918
5919if __name__ == "__main__":
5920    from torch._dynamo.test_case import run_tests
5921
5922    run_tests()
5923