xref: /aosp_15_r20/external/pytorch/test/dynamo/test_repros.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker"""
2*da0073e9SAndroid Build Coastguard WorkerPYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
3*da0073e9SAndroid Build Coastguard Workerwith test_rewrite_assert_with_msg and test_rewrite_assert_without_msg)
4*da0073e9SAndroid Build Coastguard Worker"""
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
7*da0073e9SAndroid Build Coastguard Workerimport collections
8*da0073e9SAndroid Build Coastguard Workerimport contextlib
9*da0073e9SAndroid Build Coastguard Workerimport copy
10*da0073e9SAndroid Build Coastguard Workerimport dataclasses
11*da0073e9SAndroid Build Coastguard Workerimport functools
12*da0073e9SAndroid Build Coastguard Workerimport gc
13*da0073e9SAndroid Build Coastguard Workerimport inspect
14*da0073e9SAndroid Build Coastguard Workerimport itertools
15*da0073e9SAndroid Build Coastguard Workerimport os
16*da0073e9SAndroid Build Coastguard Workerimport random
17*da0073e9SAndroid Build Coastguard Workerimport unittest
18*da0073e9SAndroid Build Coastguard Workerimport warnings
19*da0073e9SAndroid Build Coastguard Workerimport weakref
20*da0073e9SAndroid Build Coastguard Workerfrom abc import ABC
21*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
22*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy
23*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum
24*da0073e9SAndroid Build Coastguard Workerfrom functools import wraps
25*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Iterator, List, Tuple
26*da0073e9SAndroid Build Coastguard Workerfrom unittest import mock
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerimport numpy as np
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerimport torch
31*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
32*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
33*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.utils
34*da0073e9SAndroid Build Coastguard Workerimport torch._functorch.config
35*da0073e9SAndroid Build Coastguard Workerimport torch.library
36*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
37*da0073e9SAndroid Build Coastguard Workerfrom torch import nn
38*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.debug_utils import same_two_models
39*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import CompileCounter, rand_strided, same
40*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor.utils import fresh_inductor_cache
41*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F
42*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
43*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
44*da0073e9SAndroid Build Coastguard Worker    disable_translation_validation_if_dynamic_shapes,
45*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
46*da0073e9SAndroid Build Coastguard Worker    parametrize,
47*da0073e9SAndroid Build Coastguard Worker    skipIfWindows,
48*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
49*da0073e9SAndroid Build Coastguard Worker)
50*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.two_tensor import TwoTensor
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker_orig_module_call = torch.nn.Module.__call__
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker# Custom operator that only supports CPU and Meta
56*da0073e9SAndroid Build Coastguard Workerlib = torch.library.Library("test_sample", "DEF")  # noqa: TOR901
57*da0073e9SAndroid Build Coastguard Workerlib.define("foo(Tensor self) -> Tensor")
58*da0073e9SAndroid Build Coastguard Workerlib.impl("foo", torch.sin, "CPU")
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workerrequires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker_GLOBAL_CPU_TENSOR = torch.randn(3)
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Workerdef exists(val):
68*da0073e9SAndroid Build Coastguard Worker    return val is not None
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerdef maybe(fn):
72*da0073e9SAndroid Build Coastguard Worker    @wraps(fn)
73*da0073e9SAndroid Build Coastguard Worker    def inner(x, *args, **kwargs):
74*da0073e9SAndroid Build Coastguard Worker        if not exists(x):
75*da0073e9SAndroid Build Coastguard Worker            return x
76*da0073e9SAndroid Build Coastguard Worker        return fn(x, *args, **kwargs)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    return inner
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Workerdef is_fx_tracing_test() -> bool:
82*da0073e9SAndroid Build Coastguard Worker    """
83*da0073e9SAndroid Build Coastguard Worker    Copied from the hpc trainer codebase
84*da0073e9SAndroid Build Coastguard Worker    """
85*da0073e9SAndroid Build Coastguard Worker    return torch.nn.Module.__call__ is not _orig_module_call
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Workerdef has_detectron2():
89*da0073e9SAndroid Build Coastguard Worker    try:
90*da0073e9SAndroid Build Coastguard Worker        from detectron2.layers.mask_ops import _paste_masks_tensor_shape
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker        return _paste_masks_tensor_shape is not None
93*da0073e9SAndroid Build Coastguard Worker    except ImportError:
94*da0073e9SAndroid Build Coastguard Worker        return False
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Workerdef _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True):
98*da0073e9SAndroid Build Coastguard Worker    # from detectron2 mask_ops.py
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    device = masks.device
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    if skip_empty and not torch.jit.is_scripting():
103*da0073e9SAndroid Build Coastguard Worker        x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
104*da0073e9SAndroid Build Coastguard Worker            dtype=torch.int32
105*da0073e9SAndroid Build Coastguard Worker        )
106*da0073e9SAndroid Build Coastguard Worker        x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(
107*da0073e9SAndroid Build Coastguard Worker            dtype=torch.int32
108*da0073e9SAndroid Build Coastguard Worker        )
109*da0073e9SAndroid Build Coastguard Worker        y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(
110*da0073e9SAndroid Build Coastguard Worker            dtype=torch.int32
111*da0073e9SAndroid Build Coastguard Worker        )
112*da0073e9SAndroid Build Coastguard Worker    else:
113*da0073e9SAndroid Build Coastguard Worker        x0_int, y0_int = 0, 0
114*da0073e9SAndroid Build Coastguard Worker        x1_int, y1_int = img_w, img_h
115*da0073e9SAndroid Build Coastguard Worker    x0, y0, x1, y1 = torch.split(boxes, 1, dim=1)  # each is Nx1
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    N = masks.shape[0]
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
120*da0073e9SAndroid Build Coastguard Worker    img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
121*da0073e9SAndroid Build Coastguard Worker    img_y = (img_y - y0) / (y1 - y0) * 2 - 1
122*da0073e9SAndroid Build Coastguard Worker    img_x = (img_x - x0) / (x1 - x0) * 2 - 1
123*da0073e9SAndroid Build Coastguard Worker    # img_x, img_y have shapes (N, w), (N, h)
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
126*da0073e9SAndroid Build Coastguard Worker    gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
127*da0073e9SAndroid Build Coastguard Worker    grid = torch.stack([gx, gy], dim=3)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    if not torch.jit.is_scripting():
130*da0073e9SAndroid Build Coastguard Worker        if not masks.dtype.is_floating_point:
131*da0073e9SAndroid Build Coastguard Worker            masks = masks.float()
132*da0073e9SAndroid Build Coastguard Worker    img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    if skip_empty and not torch.jit.is_scripting():
135*da0073e9SAndroid Build Coastguard Worker        return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
136*da0073e9SAndroid Build Coastguard Worker    else:
137*da0073e9SAndroid Build Coastguard Worker        return img_masks[:, 0], ()
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Workerdef global_fn(x):
141*da0073e9SAndroid Build Coastguard Worker    return torch.sin(x)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Workerdef cat(tensors, dim=0):
145*da0073e9SAndroid Build Coastguard Worker    # from detectron2 wrappers.py
146*da0073e9SAndroid Build Coastguard Worker    assert isinstance(tensors, (list, tuple))
147*da0073e9SAndroid Build Coastguard Worker    if len(tensors) == 1:
148*da0073e9SAndroid Build Coastguard Worker        return tensors[0]
149*da0073e9SAndroid Build Coastguard Worker    return torch.cat(tensors, dim)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Workerdef shapes_to_tensor(x, device=None):
153*da0073e9SAndroid Build Coastguard Worker    # from detectron2 wrappers.py
154*da0073e9SAndroid Build Coastguard Worker    if torch.jit.is_scripting():
155*da0073e9SAndroid Build Coastguard Worker        return torch.as_tensor(x, device=device)
156*da0073e9SAndroid Build Coastguard Worker    if torch.jit.is_tracing():
157*da0073e9SAndroid Build Coastguard Worker        assert all(
158*da0073e9SAndroid Build Coastguard Worker            isinstance(t, torch.Tensor) for t in x
159*da0073e9SAndroid Build Coastguard Worker        ), "Shape should be tensor during tracing!"
160*da0073e9SAndroid Build Coastguard Worker        # as_tensor should not be used in tracing because it records a constant
161*da0073e9SAndroid Build Coastguard Worker        ret = torch.stack(x)
162*da0073e9SAndroid Build Coastguard Worker        if ret.device != device:  # avoid recording a hard-coded device if not necessary
163*da0073e9SAndroid Build Coastguard Worker            ret = ret.to(device=device)
164*da0073e9SAndroid Build Coastguard Worker        return ret
165*da0073e9SAndroid Build Coastguard Worker    return torch.as_tensor(x, device=device)
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Workerfw_graph = [None]
169*da0073e9SAndroid Build Coastguard Workerbw_graph = [None]
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Workerdef aot_graph_capture_backend(gm, args):
173*da0073e9SAndroid Build Coastguard Worker    from functorch.compile import min_cut_rematerialization_partition
174*da0073e9SAndroid Build Coastguard Worker    from torch._functorch.aot_autograd import aot_module_simplified
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    def fw_compiler(gm, _):
177*da0073e9SAndroid Build Coastguard Worker        fw_graph[0] = gm
178*da0073e9SAndroid Build Coastguard Worker        return gm
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    def bw_compiler(gm, _):
181*da0073e9SAndroid Build Coastguard Worker        bw_graph[0] = gm
182*da0073e9SAndroid Build Coastguard Worker        return gm
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker    return aot_module_simplified(
185*da0073e9SAndroid Build Coastguard Worker        gm,
186*da0073e9SAndroid Build Coastguard Worker        args,
187*da0073e9SAndroid Build Coastguard Worker        fw_compiler,
188*da0073e9SAndroid Build Coastguard Worker        bw_compiler,
189*da0073e9SAndroid Build Coastguard Worker        partition_fn=min_cut_rematerialization_partition,
190*da0073e9SAndroid Build Coastguard Worker        keep_inference_input_mutations=True,
191*da0073e9SAndroid Build Coastguard Worker    )
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Workerclass Boxes:
195*da0073e9SAndroid Build Coastguard Worker    # from detectron2 poolers.py
196*da0073e9SAndroid Build Coastguard Worker    def __init__(self, tensor: torch.Tensor):
197*da0073e9SAndroid Build Coastguard Worker        """
198*da0073e9SAndroid Build Coastguard Worker        Args:
199*da0073e9SAndroid Build Coastguard Worker            tensor (Tensor[float]): a Nx4 matrix.  Each row is (x1, y1, x2, y2).
200*da0073e9SAndroid Build Coastguard Worker        """
201*da0073e9SAndroid Build Coastguard Worker        device = (
202*da0073e9SAndroid Build Coastguard Worker            tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
203*da0073e9SAndroid Build Coastguard Worker        )
204*da0073e9SAndroid Build Coastguard Worker        tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
205*da0073e9SAndroid Build Coastguard Worker        if tensor.numel() == 0:
206*da0073e9SAndroid Build Coastguard Worker            # Use reshape, so we don't end up creating a new tensor that does not depend on
207*da0073e9SAndroid Build Coastguard Worker            # the inputs (and consequently confuses jit)
208*da0073e9SAndroid Build Coastguard Worker            tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32, device=device)
209*da0073e9SAndroid Build Coastguard Worker        assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
210*da0073e9SAndroid Build Coastguard Worker        self.tensor = tensor
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker    def __len__(self) -> int:
213*da0073e9SAndroid Build Coastguard Worker        return self.tensor.shape[0]
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker    @property
216*da0073e9SAndroid Build Coastguard Worker    def device(self):
217*da0073e9SAndroid Build Coastguard Worker        return self.tensor.device
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Workerdef convert_boxes_to_pooler_format(box_lists):
221*da0073e9SAndroid Build Coastguard Worker    # from detectron2 structures.py
222*da0073e9SAndroid Build Coastguard Worker    boxes = torch.cat([x.tensor for x in box_lists], dim=0)
223*da0073e9SAndroid Build Coastguard Worker    # __len__ returns Tensor in tracing.
224*da0073e9SAndroid Build Coastguard Worker    sizes = shapes_to_tensor([x.__len__() for x in box_lists], device=boxes.device)
225*da0073e9SAndroid Build Coastguard Worker    indices = torch.repeat_interleave(
226*da0073e9SAndroid Build Coastguard Worker        torch.arange(len(box_lists), dtype=boxes.dtype, device=boxes.device), sizes
227*da0073e9SAndroid Build Coastguard Worker    )
228*da0073e9SAndroid Build Coastguard Worker    return cat([indices[:, None], boxes], dim=1)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard WorkerReformerBackwardOutput = namedtuple(
232*da0073e9SAndroid Build Coastguard Worker    "ReformerBackwardOutput",
233*da0073e9SAndroid Build Coastguard Worker    ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"],
234*da0073e9SAndroid Build Coastguard Worker)
235*da0073e9SAndroid Build Coastguard WorkerReformerEncoderOutput = namedtuple(
236*da0073e9SAndroid Build Coastguard Worker    "ReformerEncoderOutput",
237*da0073e9SAndroid Build Coastguard Worker    ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"],
238*da0073e9SAndroid Build Coastguard Worker)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Workerclass _ReversibleFunction(torch.autograd.Function):
242*da0073e9SAndroid Build Coastguard Worker    # taken from modeling_reformer.py in huggingface
243*da0073e9SAndroid Build Coastguard Worker    @staticmethod
244*da0073e9SAndroid Build Coastguard Worker    def forward(
245*da0073e9SAndroid Build Coastguard Worker        ctx,
246*da0073e9SAndroid Build Coastguard Worker        hidden_states,
247*da0073e9SAndroid Build Coastguard Worker        layers,
248*da0073e9SAndroid Build Coastguard Worker        attention_mask,
249*da0073e9SAndroid Build Coastguard Worker        head_mask,
250*da0073e9SAndroid Build Coastguard Worker        num_hashes,
251*da0073e9SAndroid Build Coastguard Worker        all_hidden_states,
252*da0073e9SAndroid Build Coastguard Worker        all_attentions,
253*da0073e9SAndroid Build Coastguard Worker        past_buckets_states,
254*da0073e9SAndroid Build Coastguard Worker        use_cache,
255*da0073e9SAndroid Build Coastguard Worker        orig_sequence_length,
256*da0073e9SAndroid Build Coastguard Worker        output_hidden_states,
257*da0073e9SAndroid Build Coastguard Worker        output_attentions,
258*da0073e9SAndroid Build Coastguard Worker    ):
259*da0073e9SAndroid Build Coastguard Worker        all_buckets = ()
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker        # split duplicated tensor
262*da0073e9SAndroid Build Coastguard Worker        hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1)
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker        for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)):
265*da0073e9SAndroid Build Coastguard Worker            if output_hidden_states is True:
266*da0073e9SAndroid Build Coastguard Worker                all_hidden_states.append(hidden_states)
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker            attn_output = layer(attn_output)
269*da0073e9SAndroid Build Coastguard Worker            all_buckets = all_buckets + (attn_output,)
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker        # Add last layer
272*da0073e9SAndroid Build Coastguard Worker        if output_hidden_states is True:
273*da0073e9SAndroid Build Coastguard Worker            all_hidden_states.append(hidden_states)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        # attach params to ctx for backward
276*da0073e9SAndroid Build Coastguard Worker        ctx.save_for_backward(attn_output.detach(), hidden_states.detach())
277*da0073e9SAndroid Build Coastguard Worker        ctx.layers = layers
278*da0073e9SAndroid Build Coastguard Worker        ctx.all_buckets = all_buckets
279*da0073e9SAndroid Build Coastguard Worker        ctx.head_mask = head_mask
280*da0073e9SAndroid Build Coastguard Worker        ctx.attention_mask = attention_mask
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker        # Concatenate 2 RevNet outputs
283*da0073e9SAndroid Build Coastguard Worker        return torch.cat([attn_output, hidden_states], dim=-1)
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker    @staticmethod
286*da0073e9SAndroid Build Coastguard Worker    def backward(ctx, grad_hidden_states):
287*da0073e9SAndroid Build Coastguard Worker        grad_attn_output, grad_hidden_states = torch.chunk(
288*da0073e9SAndroid Build Coastguard Worker            grad_hidden_states, 2, dim=-1
289*da0073e9SAndroid Build Coastguard Worker        )
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker        # free memory
292*da0073e9SAndroid Build Coastguard Worker        del grad_attn_output
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        # num of return vars has to match num of forward() args
295*da0073e9SAndroid Build Coastguard Worker        # return gradient for hidden_states arg and None for other args
296*da0073e9SAndroid Build Coastguard Worker        return (
297*da0073e9SAndroid Build Coastguard Worker            grad_hidden_states,
298*da0073e9SAndroid Build Coastguard Worker            None,
299*da0073e9SAndroid Build Coastguard Worker            None,
300*da0073e9SAndroid Build Coastguard Worker            None,
301*da0073e9SAndroid Build Coastguard Worker            None,
302*da0073e9SAndroid Build Coastguard Worker            None,
303*da0073e9SAndroid Build Coastguard Worker            None,
304*da0073e9SAndroid Build Coastguard Worker            None,
305*da0073e9SAndroid Build Coastguard Worker            None,
306*da0073e9SAndroid Build Coastguard Worker            None,
307*da0073e9SAndroid Build Coastguard Worker            None,
308*da0073e9SAndroid Build Coastguard Worker            None,
309*da0073e9SAndroid Build Coastguard Worker        )
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Workerclass ReformerEncoder(torch.nn.Module):
313*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
314*da0073e9SAndroid Build Coastguard Worker        super().__init__()
315*da0073e9SAndroid Build Coastguard Worker        self.dropout = 0.5
316*da0073e9SAndroid Build Coastguard Worker        self.layer_norm = torch.nn.LayerNorm(512, eps=1.0e-12)
317*da0073e9SAndroid Build Coastguard Worker        self.layers = [torch.nn.Linear(256, 256)]
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker    def forward(
320*da0073e9SAndroid Build Coastguard Worker        self,
321*da0073e9SAndroid Build Coastguard Worker        hidden_states,
322*da0073e9SAndroid Build Coastguard Worker        attention_mask=None,
323*da0073e9SAndroid Build Coastguard Worker        head_mask=[None] * 6,
324*da0073e9SAndroid Build Coastguard Worker        num_hashes=None,
325*da0073e9SAndroid Build Coastguard Worker        use_cache=False,
326*da0073e9SAndroid Build Coastguard Worker        orig_sequence_length=64,
327*da0073e9SAndroid Build Coastguard Worker        output_hidden_states=False,
328*da0073e9SAndroid Build Coastguard Worker        output_attentions=False,
329*da0073e9SAndroid Build Coastguard Worker    ):
330*da0073e9SAndroid Build Coastguard Worker        # hidden_states and attention lists to be filled if wished
331*da0073e9SAndroid Build Coastguard Worker        all_hidden_states = []
332*da0073e9SAndroid Build Coastguard Worker        all_attentions = []
333*da0073e9SAndroid Build Coastguard Worker        past_buckets_states = [((None), (None)) for i in range(len(self.layers))]
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        # concat same tensor for reversible ResNet
336*da0073e9SAndroid Build Coastguard Worker        hidden_states = torch.cat([hidden_states, hidden_states], dim=-1)
337*da0073e9SAndroid Build Coastguard Worker        hidden_states = _ReversibleFunction.apply(
338*da0073e9SAndroid Build Coastguard Worker            hidden_states,
339*da0073e9SAndroid Build Coastguard Worker            self.layers,
340*da0073e9SAndroid Build Coastguard Worker            attention_mask,
341*da0073e9SAndroid Build Coastguard Worker            head_mask,
342*da0073e9SAndroid Build Coastguard Worker            num_hashes,
343*da0073e9SAndroid Build Coastguard Worker            all_hidden_states,
344*da0073e9SAndroid Build Coastguard Worker            all_attentions,
345*da0073e9SAndroid Build Coastguard Worker            past_buckets_states,
346*da0073e9SAndroid Build Coastguard Worker            use_cache,
347*da0073e9SAndroid Build Coastguard Worker            orig_sequence_length,
348*da0073e9SAndroid Build Coastguard Worker            output_hidden_states,
349*da0073e9SAndroid Build Coastguard Worker            output_attentions,
350*da0073e9SAndroid Build Coastguard Worker        )
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        # Apply layer norm to concatenated hidden states
353*da0073e9SAndroid Build Coastguard Worker        hidden_states = self.layer_norm(hidden_states)
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker        # Apply dropout
356*da0073e9SAndroid Build Coastguard Worker        hidden_states = torch.nn.functional.dropout(
357*da0073e9SAndroid Build Coastguard Worker            hidden_states, p=self.dropout, training=self.training
358*da0073e9SAndroid Build Coastguard Worker        )
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker        return ReformerEncoderOutput(
361*da0073e9SAndroid Build Coastguard Worker            hidden_states=hidden_states,
362*da0073e9SAndroid Build Coastguard Worker            all_hidden_states=all_hidden_states,
363*da0073e9SAndroid Build Coastguard Worker            all_attentions=all_attentions,
364*da0073e9SAndroid Build Coastguard Worker            past_buckets_states=past_buckets_states,
365*da0073e9SAndroid Build Coastguard Worker        )
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Workerclass ListConfig:
369*da0073e9SAndroid Build Coastguard Worker    class ValueNode:
370*da0073e9SAndroid Build Coastguard Worker        def __init__(self, value):
371*da0073e9SAndroid Build Coastguard Worker            self.value = value
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker        def _dereference_node(self):
374*da0073e9SAndroid Build Coastguard Worker            return self
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        def _is_missing(self):
377*da0073e9SAndroid Build Coastguard Worker            return False
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        def _value(self):
380*da0073e9SAndroid Build Coastguard Worker            return self.value
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker    # Based on an example from omegaconfig.listconfig
383*da0073e9SAndroid Build Coastguard Worker    class ListIterator(Iterator[Any]):
384*da0073e9SAndroid Build Coastguard Worker        def __init__(self, lst: Any, resolve: bool) -> None:
385*da0073e9SAndroid Build Coastguard Worker            self.resolve = resolve
386*da0073e9SAndroid Build Coastguard Worker            self.iterator = iter(lst.__dict__["_content"])
387*da0073e9SAndroid Build Coastguard Worker            self.index = 0
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker        def __next__(self) -> Any:
390*da0073e9SAndroid Build Coastguard Worker            x = next(self.iterator)
391*da0073e9SAndroid Build Coastguard Worker            if self.resolve:
392*da0073e9SAndroid Build Coastguard Worker                x = x._dereference_node()
393*da0073e9SAndroid Build Coastguard Worker                if x._is_missing():
394*da0073e9SAndroid Build Coastguard Worker                    raise AssertionError
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker            self.index = self.index + 1
397*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, ListConfig.ValueNode):
398*da0073e9SAndroid Build Coastguard Worker                return x._value()
399*da0073e9SAndroid Build Coastguard Worker            raise AssertionError
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
402*da0073e9SAndroid Build Coastguard Worker        return self._iter_ex(True)
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    def _iter_ex(self, resolve: bool) -> Iterator[Any]:
405*da0073e9SAndroid Build Coastguard Worker        try:
406*da0073e9SAndroid Build Coastguard Worker            return ListConfig.ListIterator(self, resolve)
407*da0073e9SAndroid Build Coastguard Worker        except Exception:
408*da0073e9SAndroid Build Coastguard Worker            raise AssertionError from None
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
411*da0073e9SAndroid Build Coastguard Worker        self._content = [
412*da0073e9SAndroid Build Coastguard Worker            ListConfig.ValueNode(1),
413*da0073e9SAndroid Build Coastguard Worker            ListConfig.ValueNode(3),
414*da0073e9SAndroid Build Coastguard Worker            ListConfig.ValueNode(torch.tensor([7.0])),
415*da0073e9SAndroid Build Coastguard Worker        ]
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Workerdef longformer_chunk(hidden_states, window_overlap=256):
419*da0073e9SAndroid Build Coastguard Worker    """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker    # non-overlapping chunks of size = 2w
422*da0073e9SAndroid Build Coastguard Worker    hidden_states = hidden_states.view(
423*da0073e9SAndroid Build Coastguard Worker        hidden_states.size(0),
424*da0073e9SAndroid Build Coastguard Worker        hidden_states.size(1) // (window_overlap * 2),
425*da0073e9SAndroid Build Coastguard Worker        window_overlap * 2,
426*da0073e9SAndroid Build Coastguard Worker        hidden_states.size(2),
427*da0073e9SAndroid Build Coastguard Worker    )
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker    # use `as_strided` to make the chunks overlap with an overlap size = window_overlap
430*da0073e9SAndroid Build Coastguard Worker    chunk_size = list(hidden_states.size())
431*da0073e9SAndroid Build Coastguard Worker    chunk_size[1] = chunk_size[1] * 2 - 1
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    chunk_stride = list(hidden_states.stride())
434*da0073e9SAndroid Build Coastguard Worker    chunk_stride[1] = chunk_stride[1] // 2
435*da0073e9SAndroid Build Coastguard Worker    return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Workerclass PartialT5(torch.nn.Module):
439*da0073e9SAndroid Build Coastguard Worker    # Highly simplified T5Attention prefix
440*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
441*da0073e9SAndroid Build Coastguard Worker        super().__init__()
442*da0073e9SAndroid Build Coastguard Worker        self.q = torch.nn.Linear(512, 512)
443*da0073e9SAndroid Build Coastguard Worker        self.k = torch.nn.Linear(512, 512)
444*da0073e9SAndroid Build Coastguard Worker        self.v = torch.nn.Linear(512, 512)
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker    def forward(
447*da0073e9SAndroid Build Coastguard Worker        self,
448*da0073e9SAndroid Build Coastguard Worker        hidden_states,
449*da0073e9SAndroid Build Coastguard Worker        key_value_states=None,
450*da0073e9SAndroid Build Coastguard Worker        past_key_value=None,
451*da0073e9SAndroid Build Coastguard Worker        query_length=None,
452*da0073e9SAndroid Build Coastguard Worker    ):
453*da0073e9SAndroid Build Coastguard Worker        batch_size, seq_length = hidden_states.shape[:2]
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker        real_seq_length = seq_length
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker        if past_key_value is not None:
458*da0073e9SAndroid Build Coastguard Worker            assert (
459*da0073e9SAndroid Build Coastguard Worker                len(past_key_value) == 2
460*da0073e9SAndroid Build Coastguard Worker            ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
461*da0073e9SAndroid Build Coastguard Worker            real_seq_length += (
462*da0073e9SAndroid Build Coastguard Worker                past_key_value[0].shape[2] if query_length is None else query_length
463*da0073e9SAndroid Build Coastguard Worker            )
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker        def shape(states):
466*da0073e9SAndroid Build Coastguard Worker            """projection"""
467*da0073e9SAndroid Build Coastguard Worker            return states.view(batch_size, -1, 8, 64).transpose(1, 2)
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker        def project(hidden_states, proj_layer, key_value_states, past_key_value):
470*da0073e9SAndroid Build Coastguard Worker            """projects hidden states correctly to key/query states"""
471*da0073e9SAndroid Build Coastguard Worker            if key_value_states is None:
472*da0073e9SAndroid Build Coastguard Worker                # self-attn
473*da0073e9SAndroid Build Coastguard Worker                # (batch_size, n_heads, seq_length, dim_per_head)
474*da0073e9SAndroid Build Coastguard Worker                hidden_states = shape(proj_layer(hidden_states))
475*da0073e9SAndroid Build Coastguard Worker            elif past_key_value is None:
476*da0073e9SAndroid Build Coastguard Worker                # cross-attn
477*da0073e9SAndroid Build Coastguard Worker                # (batch_size, n_heads, seq_length, dim_per_head)
478*da0073e9SAndroid Build Coastguard Worker                hidden_states = shape(proj_layer(key_value_states))
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker            if past_key_value is not None:
481*da0073e9SAndroid Build Coastguard Worker                if key_value_states is None:
482*da0073e9SAndroid Build Coastguard Worker                    # self-attn
483*da0073e9SAndroid Build Coastguard Worker                    # (batch_size, n_heads, key_length, dim_per_head)
484*da0073e9SAndroid Build Coastguard Worker                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
485*da0073e9SAndroid Build Coastguard Worker                else:
486*da0073e9SAndroid Build Coastguard Worker                    # cross-attn
487*da0073e9SAndroid Build Coastguard Worker                    hidden_states = past_key_value
488*da0073e9SAndroid Build Coastguard Worker            return hidden_states
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker        # get query states
491*da0073e9SAndroid Build Coastguard Worker        query_states = shape(
492*da0073e9SAndroid Build Coastguard Worker            self.q(hidden_states)
493*da0073e9SAndroid Build Coastguard Worker        )  # (batch_size, n_heads, seq_length, dim_per_head)
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker        # get key/value states
496*da0073e9SAndroid Build Coastguard Worker        key_states = project(
497*da0073e9SAndroid Build Coastguard Worker            hidden_states,
498*da0073e9SAndroid Build Coastguard Worker            self.k,
499*da0073e9SAndroid Build Coastguard Worker            key_value_states,
500*da0073e9SAndroid Build Coastguard Worker            past_key_value[0] if past_key_value is not None else None,
501*da0073e9SAndroid Build Coastguard Worker        )
502*da0073e9SAndroid Build Coastguard Worker        value_states = project(
503*da0073e9SAndroid Build Coastguard Worker            hidden_states,
504*da0073e9SAndroid Build Coastguard Worker            self.v,
505*da0073e9SAndroid Build Coastguard Worker            key_value_states,
506*da0073e9SAndroid Build Coastguard Worker            past_key_value[1] if past_key_value is not None else None,
507*da0073e9SAndroid Build Coastguard Worker        )
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker        # compute scores
510*da0073e9SAndroid Build Coastguard Worker        scores = torch.matmul(query_states, key_states.transpose(3, 2))
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker        # (truncated here )
513*da0073e9SAndroid Build Coastguard Worker        return scores, value_states
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Workerclass ChunkReformerFeedForward(torch.nn.Module):
517*da0073e9SAndroid Build Coastguard Worker    # simplified from HF modeling_reformer.py
518*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
519*da0073e9SAndroid Build Coastguard Worker        super().__init__()
520*da0073e9SAndroid Build Coastguard Worker        self.layer_norm = torch.nn.LayerNorm(256, eps=1e-12)
521*da0073e9SAndroid Build Coastguard Worker        self.dense = torch.nn.Linear(256, 256)
522*da0073e9SAndroid Build Coastguard Worker        self.output = torch.nn.Linear(256, 256)
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker    def forward(self, attention_output):
525*da0073e9SAndroid Build Coastguard Worker        return apply_chunking_to_forward(
526*da0073e9SAndroid Build Coastguard Worker            self.forward_chunk,
527*da0073e9SAndroid Build Coastguard Worker            attention_output + 1,
528*da0073e9SAndroid Build Coastguard Worker        )
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker    def forward_chunk(self, hidden_states):
531*da0073e9SAndroid Build Coastguard Worker        hidden_states = self.layer_norm(hidden_states)
532*da0073e9SAndroid Build Coastguard Worker        hidden_states = self.dense(hidden_states)
533*da0073e9SAndroid Build Coastguard Worker        return self.output(hidden_states)
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Workerdef apply_chunking_to_forward(forward_fn, *input_tensors):
537*da0073e9SAndroid Build Coastguard Worker    # simplified from HF model_utils.py
538*da0073e9SAndroid Build Coastguard Worker    assert len(input_tensors) > 0
539*da0073e9SAndroid Build Coastguard Worker    tensor_shape = input_tensors[0].shape[1]
540*da0073e9SAndroid Build Coastguard Worker    assert all(input_tensor.shape[1] == tensor_shape for input_tensor in input_tensors)
541*da0073e9SAndroid Build Coastguard Worker    num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
542*da0073e9SAndroid Build Coastguard Worker    if num_args_in_forward_chunk_fn != len(input_tensors):
543*da0073e9SAndroid Build Coastguard Worker        raise ValueError
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker    return forward_fn(*input_tensors)
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Workerdef _validate_model_kwargs(fn, model_kwargs):
549*da0073e9SAndroid Build Coastguard Worker    # simplified from transformers.generation.utils._validate_model_kwargs
550*da0073e9SAndroid Build Coastguard Worker    unused_model_args = []
551*da0073e9SAndroid Build Coastguard Worker    model_args = set(inspect.signature(fn).parameters)
552*da0073e9SAndroid Build Coastguard Worker    for key, value in model_kwargs.items():
553*da0073e9SAndroid Build Coastguard Worker        if value is not None and key not in model_args:
554*da0073e9SAndroid Build Coastguard Worker            unused_model_args.append(key)
555*da0073e9SAndroid Build Coastguard Worker    if unused_model_args:
556*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
557*da0073e9SAndroid Build Coastguard Worker            f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
558*da0073e9SAndroid Build Coastguard Worker            " generate arguments will also show up in this list)"
559*da0073e9SAndroid Build Coastguard Worker        )
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Workerclass FakeMamlInner(torch.nn.Module):
563*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
564*da0073e9SAndroid Build Coastguard Worker        super().__init__()
565*da0073e9SAndroid Build Coastguard Worker        self.linear = torch.nn.Linear(784, 5)
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker    def forward(self, x, ignored=None, bn_training=False):
568*da0073e9SAndroid Build Coastguard Worker        return self.linear(x.view(x.shape[0], -1))
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Workerclass PartialMaml(torch.nn.Module):
572*da0073e9SAndroid Build Coastguard Worker    # Highly simplified version of maml.meta.Meta.finetuning
573*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
574*da0073e9SAndroid Build Coastguard Worker        super().__init__()
575*da0073e9SAndroid Build Coastguard Worker        self.net = FakeMamlInner()
576*da0073e9SAndroid Build Coastguard Worker        self.update_step_test = 10
577*da0073e9SAndroid Build Coastguard Worker        self.update_lr = 0.4
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker    def forward(self, x_spt, y_spt, x_qry, y_qry):
580*da0073e9SAndroid Build Coastguard Worker        querysz = x_qry.size(0)
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker        corrects = [0 for _ in range(self.update_step_test + 1)]
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker        # in order to not ruin the state of running_mean/variance and bn_weight/bias
585*da0073e9SAndroid Build Coastguard Worker        # we finetuning on the copied model instead of self.net
586*da0073e9SAndroid Build Coastguard Worker        net = deepcopy(self.net)
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker        # 1. run the i-th task and compute loss for k=0
589*da0073e9SAndroid Build Coastguard Worker        logits = net(x_spt)
590*da0073e9SAndroid Build Coastguard Worker        loss = F.cross_entropy(logits, y_spt)
591*da0073e9SAndroid Build Coastguard Worker        grad = torch.autograd.grad(loss, net.parameters())
592*da0073e9SAndroid Build Coastguard Worker        fast_weights = [
593*da0073e9SAndroid Build Coastguard Worker            p[1] - self.update_lr * p[0] for p in zip(grad, net.parameters())
594*da0073e9SAndroid Build Coastguard Worker        ]
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker        # this is the loss and accuracy before first update
597*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
598*da0073e9SAndroid Build Coastguard Worker            # [setsz, nway]
599*da0073e9SAndroid Build Coastguard Worker            logits_q = net(x_qry, net.parameters(), bn_training=True)
600*da0073e9SAndroid Build Coastguard Worker            # [setsz]
601*da0073e9SAndroid Build Coastguard Worker            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
602*da0073e9SAndroid Build Coastguard Worker            # scalar
603*da0073e9SAndroid Build Coastguard Worker            correct = torch.eq(pred_q, y_qry).sum().item()
604*da0073e9SAndroid Build Coastguard Worker            corrects[0] = corrects[0] + correct
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker        # this is the loss and accuracy after the first update
607*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
608*da0073e9SAndroid Build Coastguard Worker            # [setsz, nway]
609*da0073e9SAndroid Build Coastguard Worker            logits_q = net(x_qry, fast_weights, bn_training=True)
610*da0073e9SAndroid Build Coastguard Worker            # [setsz]
611*da0073e9SAndroid Build Coastguard Worker            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
612*da0073e9SAndroid Build Coastguard Worker            # scalar
613*da0073e9SAndroid Build Coastguard Worker            correct = torch.eq(pred_q, y_qry).sum().item()
614*da0073e9SAndroid Build Coastguard Worker            corrects[1] = corrects[1] + correct
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker        del net
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker        accs = torch.tensor(corrects) / querysz
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker        return accs
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Workerdef softmax_backward_data(parent, grad_output, output, dim, self):
624*da0073e9SAndroid Build Coastguard Worker    from torch import _softmax_backward_data
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker    return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Workerclass XSoftmax(torch.autograd.Function):
630*da0073e9SAndroid Build Coastguard Worker    # transformers.models.deberta.modeling_deberta.XSoftmax
631*da0073e9SAndroid Build Coastguard Worker    @staticmethod
632*da0073e9SAndroid Build Coastguard Worker    def forward(self, input, mask, dim):
633*da0073e9SAndroid Build Coastguard Worker        self.dim = dim
634*da0073e9SAndroid Build Coastguard Worker        rmask = ~(mask.to(torch.bool))
635*da0073e9SAndroid Build Coastguard Worker        output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
636*da0073e9SAndroid Build Coastguard Worker        output = torch.softmax(output, self.dim)
637*da0073e9SAndroid Build Coastguard Worker        output.masked_fill_(rmask, 0)
638*da0073e9SAndroid Build Coastguard Worker        self.save_for_backward(output, rmask)
639*da0073e9SAndroid Build Coastguard Worker        return output
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker    @staticmethod
642*da0073e9SAndroid Build Coastguard Worker    def backward(self, grad_output):
643*da0073e9SAndroid Build Coastguard Worker        (output, rmask) = self.saved_tensors
644*da0073e9SAndroid Build Coastguard Worker        inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
645*da0073e9SAndroid Build Coastguard Worker        return inputGrad, None, None
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Workerclass ModelOutput(collections.OrderedDict):
649*da0073e9SAndroid Build Coastguard Worker    """based on file_utils.py in HuggingFace"""
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, k):
652*da0073e9SAndroid Build Coastguard Worker        if isinstance(k, str):
653*da0073e9SAndroid Build Coastguard Worker            inner_dict = dict(self.items())
654*da0073e9SAndroid Build Coastguard Worker            return inner_dict[k]
655*da0073e9SAndroid Build Coastguard Worker        else:
656*da0073e9SAndroid Build Coastguard Worker            return self.to_tuple()[k]
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker    def __setattr__(self, name, value):
659*da0073e9SAndroid Build Coastguard Worker        if name in self.keys() and value is not None:
660*da0073e9SAndroid Build Coastguard Worker            # Don't call self.__setitem__ to avoid recursion errors
661*da0073e9SAndroid Build Coastguard Worker            super().__setitem__(name, value)
662*da0073e9SAndroid Build Coastguard Worker        super().__setattr__(name, value)
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker    def __setitem__(self, key, value):
665*da0073e9SAndroid Build Coastguard Worker        # Will raise a KeyException if needed
666*da0073e9SAndroid Build Coastguard Worker        super().__setitem__(key, value)
667*da0073e9SAndroid Build Coastguard Worker        # Don't call self.__setattr__ to avoid recursion errors
668*da0073e9SAndroid Build Coastguard Worker        super().__setattr__(key, value)
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker    def to_tuple(self):
671*da0073e9SAndroid Build Coastguard Worker        return tuple(self[k] for k in self.keys())
672*da0073e9SAndroid Build Coastguard Worker
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Workerdef create_rand_mask_from_inputs(
675*da0073e9SAndroid Build Coastguard Worker    from_blocked_mask,
676*da0073e9SAndroid Build Coastguard Worker    to_blocked_mask,
677*da0073e9SAndroid Build Coastguard Worker    rand_attn,
678*da0073e9SAndroid Build Coastguard Worker    num_attention_heads,
679*da0073e9SAndroid Build Coastguard Worker    num_rand_blocks,
680*da0073e9SAndroid Build Coastguard Worker    batch_size,
681*da0073e9SAndroid Build Coastguard Worker    from_seq_length,
682*da0073e9SAndroid Build Coastguard Worker    from_block_size,
683*da0073e9SAndroid Build Coastguard Worker):
684*da0073e9SAndroid Build Coastguard Worker    """taken from HF modeling_big_bird.py"""
685*da0073e9SAndroid Build Coastguard Worker    num_windows = from_seq_length // from_block_size - 2
686*da0073e9SAndroid Build Coastguard Worker    rand_mask = torch.stack(
687*da0073e9SAndroid Build Coastguard Worker        [p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)]
688*da0073e9SAndroid Build Coastguard Worker    )
689*da0073e9SAndroid Build Coastguard Worker    rand_mask = rand_mask.view(
690*da0073e9SAndroid Build Coastguard Worker        batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size
691*da0073e9SAndroid Build Coastguard Worker    )
692*da0073e9SAndroid Build Coastguard Worker    rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask)
693*da0073e9SAndroid Build Coastguard Worker    return rand_mask
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Workerclass SequentialAppendList(torch.nn.Sequential):
697*da0073e9SAndroid Build Coastguard Worker    """from timm/models/vovnet.py"""
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker    def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
700*da0073e9SAndroid Build Coastguard Worker        for i, module in enumerate(self):
701*da0073e9SAndroid Build Coastguard Worker            if i == 0:
702*da0073e9SAndroid Build Coastguard Worker                concat_list.append(module(x))
703*da0073e9SAndroid Build Coastguard Worker            else:
704*da0073e9SAndroid Build Coastguard Worker                concat_list.append(module(concat_list[-1]))
705*da0073e9SAndroid Build Coastguard Worker        x = torch.cat(concat_list, dim=1)
706*da0073e9SAndroid Build Coastguard Worker        return x, concat_list
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker
709*da0073e9SAndroid Build Coastguard Workerclass BatchNormAct2d(torch.nn.BatchNorm2d):
710*da0073e9SAndroid Build Coastguard Worker    """Taken from timm"""
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker    def __init__(
713*da0073e9SAndroid Build Coastguard Worker        self,
714*da0073e9SAndroid Build Coastguard Worker        num_features,
715*da0073e9SAndroid Build Coastguard Worker        eps=1e-5,
716*da0073e9SAndroid Build Coastguard Worker        momentum=0.1,
717*da0073e9SAndroid Build Coastguard Worker        affine=True,
718*da0073e9SAndroid Build Coastguard Worker        track_running_stats=True,
719*da0073e9SAndroid Build Coastguard Worker        act_layer=torch.nn.ReLU,
720*da0073e9SAndroid Build Coastguard Worker        inplace=True,
721*da0073e9SAndroid Build Coastguard Worker    ):
722*da0073e9SAndroid Build Coastguard Worker        super().__init__(
723*da0073e9SAndroid Build Coastguard Worker            num_features,
724*da0073e9SAndroid Build Coastguard Worker            eps=eps,
725*da0073e9SAndroid Build Coastguard Worker            momentum=momentum,
726*da0073e9SAndroid Build Coastguard Worker            affine=affine,
727*da0073e9SAndroid Build Coastguard Worker            track_running_stats=track_running_stats,
728*da0073e9SAndroid Build Coastguard Worker        )
729*da0073e9SAndroid Build Coastguard Worker        self.act = act_layer(inplace=inplace)
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker    @torch.jit.ignore
732*da0073e9SAndroid Build Coastguard Worker    def _forward_python(self, x):
733*da0073e9SAndroid Build Coastguard Worker        return super().forward(x)
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
736*da0073e9SAndroid Build Coastguard Worker        if torch.jit.is_scripting():
737*da0073e9SAndroid Build Coastguard Worker            x = self._forward_jit(x)
738*da0073e9SAndroid Build Coastguard Worker        else:
739*da0073e9SAndroid Build Coastguard Worker            x = self._forward_python(x)
740*da0073e9SAndroid Build Coastguard Worker        x = self.act(x)
741*da0073e9SAndroid Build Coastguard Worker        return x
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Workerdef get_parameter_dtype(parameter):
745*da0073e9SAndroid Build Coastguard Worker    """from huggingface model_utils.py"""
746*da0073e9SAndroid Build Coastguard Worker    try:
747*da0073e9SAndroid Build Coastguard Worker        return next(parameter.parameters()).dtype
748*da0073e9SAndroid Build Coastguard Worker    except StopIteration:
749*da0073e9SAndroid Build Coastguard Worker        # For nn.DataParallel compatibility in PyTorch 1.5
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker        def find_tensor_attributes(module):
752*da0073e9SAndroid Build Coastguard Worker            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
753*da0073e9SAndroid Build Coastguard Worker            return tuples
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
756*da0073e9SAndroid Build Coastguard Worker        first_tuple = next(gen)
757*da0073e9SAndroid Build Coastguard Worker        return first_tuple[1].dtype
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Workerclass DummyConfig:
761*da0073e9SAndroid Build Coastguard Worker    attn_layers = ["local", "lsh", "local", "lsh", "local", "lsh"]
762*da0073e9SAndroid Build Coastguard Worker    lsh_attn_chunk_length = 64
763*da0073e9SAndroid Build Coastguard Worker    local_attn_chunk_length = 64
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Workerdef _get_min_chunk_len(config):
767*da0073e9SAndroid Build Coastguard Worker    """from hf_Reformer"""
768*da0073e9SAndroid Build Coastguard Worker    attn_types = config.attn_layers
769*da0073e9SAndroid Build Coastguard Worker    attn_types_set = set(attn_types)
770*da0073e9SAndroid Build Coastguard Worker    if len(attn_types_set) == 1 and attn_types[0] == "lsh":
771*da0073e9SAndroid Build Coastguard Worker        return config.lsh_attn_chunk_length
772*da0073e9SAndroid Build Coastguard Worker    elif len(attn_types_set) == 1 and attn_types[0] == "local":
773*da0073e9SAndroid Build Coastguard Worker        return config.local_attn_chunk_length
774*da0073e9SAndroid Build Coastguard Worker    elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}:
775*da0073e9SAndroid Build Coastguard Worker        return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
776*da0073e9SAndroid Build Coastguard Worker    else:
777*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
778*da0073e9SAndroid Build Coastguard Worker            f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select "
779*da0073e9SAndroid Build Coastguard Worker            "attn layer types from ['lsh', 'local'] only."
780*da0073e9SAndroid Build Coastguard Worker        )
781*da0073e9SAndroid Build Coastguard Worker
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Workerdef _stable_argsort(vector, dim):
784*da0073e9SAndroid Build Coastguard Worker    """from hf_Reformer"""
785*da0073e9SAndroid Build Coastguard Worker    # this function scales the vector so that torch.argsort is stable.
786*da0073e9SAndroid Build Coastguard Worker    # torch.argsort is not stable on its own
787*da0073e9SAndroid Build Coastguard Worker    scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1)
788*da0073e9SAndroid Build Coastguard Worker    scale_offset = scale_offset.expand(vector.shape)
789*da0073e9SAndroid Build Coastguard Worker    scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim])
790*da0073e9SAndroid Build Coastguard Worker    return torch.argsort(scaled_vector, dim=dim)
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Workerdef _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(buckets):
794*da0073e9SAndroid Build Coastguard Worker    """from hf_Reformer"""
795*da0073e9SAndroid Build Coastguard Worker    # no gradients are needed
796*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
797*da0073e9SAndroid Build Coastguard Worker        # hash-based sort
798*da0073e9SAndroid Build Coastguard Worker        sorted_bucket_idx = _stable_argsort(buckets, dim=-1)
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker        # create simple indices to scatter to, to have undo sort
801*da0073e9SAndroid Build Coastguard Worker        indices = (
802*da0073e9SAndroid Build Coastguard Worker            torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device)
803*da0073e9SAndroid Build Coastguard Worker            .view(1, 1, -1)
804*da0073e9SAndroid Build Coastguard Worker            .expand(sorted_bucket_idx.shape)
805*da0073e9SAndroid Build Coastguard Worker        )
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker        # get undo sort
808*da0073e9SAndroid Build Coastguard Worker        undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size())
809*da0073e9SAndroid Build Coastguard Worker        undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices)
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker    return sorted_bucket_idx, undo_sorted_bucket_idx
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Workerclass CustomList1(list):
815*da0073e9SAndroid Build Coastguard Worker    def __call__(self, x):
816*da0073e9SAndroid Build Coastguard Worker        for processor in self:
817*da0073e9SAndroid Build Coastguard Worker            x = processor(x)
818*da0073e9SAndroid Build Coastguard Worker        return x
819*da0073e9SAndroid Build Coastguard Worker
820*da0073e9SAndroid Build Coastguard Worker    def clear(self):
821*da0073e9SAndroid Build Coastguard Worker        pass  # this prevents RestrictedListSubclassVariable from kicking in
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Workerclass CustomList2(list):
825*da0073e9SAndroid Build Coastguard Worker    def __call__(self, x):
826*da0073e9SAndroid Build Coastguard Worker        for processor in self:
827*da0073e9SAndroid Build Coastguard Worker            x = processor(x)
828*da0073e9SAndroid Build Coastguard Worker        return x
829*da0073e9SAndroid Build Coastguard Worker
830*da0073e9SAndroid Build Coastguard Worker    def length_times_10(self):
831*da0073e9SAndroid Build Coastguard Worker        return len(self) * 10
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker    def append_twice(self, x):
834*da0073e9SAndroid Build Coastguard Worker        self.extend([x, x])
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Workerdef _merge_criteria_processor_list(default_list, custom_list):
838*da0073e9SAndroid Build Coastguard Worker    # simplified transformers/generation/utils.py
839*da0073e9SAndroid Build Coastguard Worker    if len(custom_list) == 0:
840*da0073e9SAndroid Build Coastguard Worker        return default_list
841*da0073e9SAndroid Build Coastguard Worker    for default in default_list:
842*da0073e9SAndroid Build Coastguard Worker        for custom in custom_list:
843*da0073e9SAndroid Build Coastguard Worker            if type(custom) is type(default):
844*da0073e9SAndroid Build Coastguard Worker                raise ValueError
845*da0073e9SAndroid Build Coastguard Worker    default_list.extend(custom_list)
846*da0073e9SAndroid Build Coastguard Worker    return default_list
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Workerclass FeedForwardLayer(nn.Module):
850*da0073e9SAndroid Build Coastguard Worker    def __init__(self, d_model, dim_feedforward, activation, dropout) -> None:
851*da0073e9SAndroid Build Coastguard Worker        super().__init__()
852*da0073e9SAndroid Build Coastguard Worker        self.linear1 = nn.Linear(d_model, dim_feedforward)
853*da0073e9SAndroid Build Coastguard Worker        self.activation = activation
854*da0073e9SAndroid Build Coastguard Worker        self.dropout1 = nn.Dropout(dropout)
855*da0073e9SAndroid Build Coastguard Worker        self.linear2 = nn.Linear(dim_feedforward, d_model)
856*da0073e9SAndroid Build Coastguard Worker        self.dropout2 = nn.Dropout(dropout)
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
859*da0073e9SAndroid Build Coastguard Worker        return self.dropout2(
860*da0073e9SAndroid Build Coastguard Worker            self.linear2(self.dropout1(self.activation(self.linear1(x))))
861*da0073e9SAndroid Build Coastguard Worker        )
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker
864*da0073e9SAndroid Build Coastguard Workerclass TransformerEncoderLayer(nn.Module):
865*da0073e9SAndroid Build Coastguard Worker    def __init__(
866*da0073e9SAndroid Build Coastguard Worker        self,
867*da0073e9SAndroid Build Coastguard Worker        d_model,
868*da0073e9SAndroid Build Coastguard Worker        nhead,
869*da0073e9SAndroid Build Coastguard Worker        dim_feedforward=2048,
870*da0073e9SAndroid Build Coastguard Worker        dropout=0.1,
871*da0073e9SAndroid Build Coastguard Worker        activation=nn.ReLU(),
872*da0073e9SAndroid Build Coastguard Worker        layer_norm_eps=1e-5,
873*da0073e9SAndroid Build Coastguard Worker    ):
874*da0073e9SAndroid Build Coastguard Worker        super().__init__()
875*da0073e9SAndroid Build Coastguard Worker        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
876*da0073e9SAndroid Build Coastguard Worker        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
877*da0073e9SAndroid Build Coastguard Worker        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
878*da0073e9SAndroid Build Coastguard Worker        self.dropout = nn.Dropout(dropout)
879*da0073e9SAndroid Build Coastguard Worker        self.ff_block = FeedForwardLayer(d_model, dim_feedforward, activation, dropout)
880*da0073e9SAndroid Build Coastguard Worker
881*da0073e9SAndroid Build Coastguard Worker    def forward(self, src, src_mask=None, src_key_padding_mask=None):
882*da0073e9SAndroid Build Coastguard Worker        x = src
883*da0073e9SAndroid Build Coastguard Worker        x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
884*da0073e9SAndroid Build Coastguard Worker        x = self.norm2(x + self._ff_block(x))
885*da0073e9SAndroid Build Coastguard Worker        return x
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker    # self-attention block
888*da0073e9SAndroid Build Coastguard Worker    def _sa_block(self, x, attn_mask, key_padding_mask):
889*da0073e9SAndroid Build Coastguard Worker        x = self.self_attn(
890*da0073e9SAndroid Build Coastguard Worker            x,
891*da0073e9SAndroid Build Coastguard Worker            x,
892*da0073e9SAndroid Build Coastguard Worker            x,
893*da0073e9SAndroid Build Coastguard Worker            attn_mask=attn_mask,
894*da0073e9SAndroid Build Coastguard Worker            key_padding_mask=key_padding_mask,
895*da0073e9SAndroid Build Coastguard Worker            need_weights=False,
896*da0073e9SAndroid Build Coastguard Worker        )[0]
897*da0073e9SAndroid Build Coastguard Worker        return self.dropout(x)
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker    # feed forward block
900*da0073e9SAndroid Build Coastguard Worker    def _ff_block(self, x):
901*da0073e9SAndroid Build Coastguard Worker        return self.ff_block(x)
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker
904*da0073e9SAndroid Build Coastguard Workerclass MockModule(torch.nn.Module):
905*da0073e9SAndroid Build Coastguard Worker    def inner_fn(self, left, right):
906*da0073e9SAndroid Build Coastguard Worker        return tuple(left) == tuple(right)
907*da0073e9SAndroid Build Coastguard Worker
908*da0073e9SAndroid Build Coastguard Worker    def fn(self, tensor):
909*da0073e9SAndroid Build Coastguard Worker        if type(tensor) is int:
910*da0073e9SAndroid Build Coastguard Worker            return False
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker        torch.add(tensor, tensor)
913*da0073e9SAndroid Build Coastguard Worker        return self.inner_fn(tensor.shape, (1, 2, 3))
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Workerclass IncByOne:
917*da0073e9SAndroid Build Coastguard Worker    def __init__(self, x):
918*da0073e9SAndroid Build Coastguard Worker        self.x = x + 1
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Workerclass IncByTwo:
922*da0073e9SAndroid Build Coastguard Worker    def __init__(self, x):
923*da0073e9SAndroid Build Coastguard Worker        self.x = x + 2
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Workerclass ReproTests(torch._dynamo.test_case.TestCase):
927*da0073e9SAndroid Build Coastguard Worker    def test_do_paste_mask(self):
928*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.utils.counters.clear()
929*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
930*da0073e9SAndroid Build Coastguard Worker        opt__do_paste_mask = torch.compile(_do_paste_mask, backend=cnt)
931*da0073e9SAndroid Build Coastguard Worker        opt__do_paste_mask(
932*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 1, 28, 28),
933*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0.0, 1, 2, 4]]) * 1,
934*da0073e9SAndroid Build Coastguard Worker            427,
935*da0073e9SAndroid Build Coastguard Worker            640,
936*da0073e9SAndroid Build Coastguard Worker            True,
937*da0073e9SAndroid Build Coastguard Worker        )
938*da0073e9SAndroid Build Coastguard Worker        opt__do_paste_mask(
939*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 1, 28, 28),
940*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0.0, 1, 2, 4]]) * 2,
941*da0073e9SAndroid Build Coastguard Worker            427,
942*da0073e9SAndroid Build Coastguard Worker            640,
943*da0073e9SAndroid Build Coastguard Worker            True,
944*da0073e9SAndroid Build Coastguard Worker        )
945*da0073e9SAndroid Build Coastguard Worker        opt__do_paste_mask(
946*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 1, 28, 28),
947*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0.0, 1, 2, 4]]) * 3,
948*da0073e9SAndroid Build Coastguard Worker            612,
949*da0073e9SAndroid Build Coastguard Worker            612,
950*da0073e9SAndroid Build Coastguard Worker            True,
951*da0073e9SAndroid Build Coastguard Worker        )
952*da0073e9SAndroid Build Coastguard Worker        opt__do_paste_mask(
953*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 1, 28, 28),
954*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0.0, 1, 2, 4]]) * 4,
955*da0073e9SAndroid Build Coastguard Worker            612,
956*da0073e9SAndroid Build Coastguard Worker            612,
957*da0073e9SAndroid Build Coastguard Worker            True,
958*da0073e9SAndroid Build Coastguard Worker        )
959*da0073e9SAndroid Build Coastguard Worker        opt__do_paste_mask(
960*da0073e9SAndroid Build Coastguard Worker            torch.randn(1, 1, 28, 28),
961*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[0.0, 1, 2, 4]]) * 2,
962*da0073e9SAndroid Build Coastguard Worker            427,
963*da0073e9SAndroid Build Coastguard Worker            640,
964*da0073e9SAndroid Build Coastguard Worker            False,
965*da0073e9SAndroid Build Coastguard Worker        )
966*da0073e9SAndroid Build Coastguard Worker        # (dynamic shapes, static shapes)
967*da0073e9SAndroid Build Coastguard Worker        self.assertIn(cnt.frame_count, (5, 7))
968*da0073e9SAndroid Build Coastguard Worker        self.assertIn(cnt.op_count, (92, 106, 119))
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Worker    def test_convert_boxes_to_pooler_format(self):
971*da0073e9SAndroid Build Coastguard Worker        boxes1 = [
972*da0073e9SAndroid Build Coastguard Worker            Boxes(torch.arange(0, 8).reshape((2, 4))),
973*da0073e9SAndroid Build Coastguard Worker            Boxes(torch.arange(8, 16).reshape((2, 4))),
974*da0073e9SAndroid Build Coastguard Worker        ]
975*da0073e9SAndroid Build Coastguard Worker        boxes2 = [
976*da0073e9SAndroid Build Coastguard Worker            Boxes(torch.arange(16, 20).reshape((1, 4))),
977*da0073e9SAndroid Build Coastguard Worker            Boxes(torch.arange(20, 24).reshape((1, 4))),
978*da0073e9SAndroid Build Coastguard Worker        ]
979*da0073e9SAndroid Build Coastguard Worker        correct1 = convert_boxes_to_pooler_format(boxes1)
980*da0073e9SAndroid Build Coastguard Worker        correct2 = convert_boxes_to_pooler_format(boxes2)
981*da0073e9SAndroid Build Coastguard Worker        fn = convert_boxes_to_pooler_format
982*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
983*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
984*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(boxes1), correct1))
985*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(boxes2), correct2))
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker        # repeat_interleave is a dynamic shape operator we do not execute/
988*da0073e9SAndroid Build Coastguard Worker        # In the future, we could reduce the frame_count down to 1
989*da0073e9SAndroid Build Coastguard Worker        # by guarding on the exact values of `Tensor repeats` arg
990*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
991*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """4""")
992*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.op_count, """10""")
993*da0073e9SAndroid Build Coastguard Worker        else:
994*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """4""")
995*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.op_count, """14""")
996*da0073e9SAndroid Build Coastguard Worker
997*da0073e9SAndroid Build Coastguard Worker    def test_boxes_len(self):
998*da0073e9SAndroid Build Coastguard Worker        def fn(boxes):
999*da0073e9SAndroid Build Coastguard Worker            return len(boxes) + boxes.__len__() + boxes.tensor
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker        boxes1 = Boxes(torch.arange(0, 8).reshape((2, 4)))
1002*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1003*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1004*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1007*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """1""")
1008*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.op_count, """1""")
1009*da0073e9SAndroid Build Coastguard Worker        else:
1010*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """1""")
1011*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.op_count, """2""")
1012*da0073e9SAndroid Build Coastguard Worker
1013*da0073e9SAndroid Build Coastguard Worker    def _reformer(self, nopython):
1014*da0073e9SAndroid Build Coastguard Worker        input = torch.randn([1, 64, 256])
1015*da0073e9SAndroid Build Coastguard Worker        model = ReformerEncoder()
1016*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1337)
1017*da0073e9SAndroid Build Coastguard Worker        correct = copy.deepcopy(model)(input)
1018*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1019*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1337)
1020*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize(cnt, nopython=nopython)(model)
1021*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_model(input), correct))
1022*da0073e9SAndroid Build Coastguard Worker        return cnt
1023*da0073e9SAndroid Build Coastguard Worker
1024*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
1025*da0073e9SAndroid Build Coastguard Worker    def test_sub_alpha_scalar_repro(self):
1026*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="aot_eager")
1027*da0073e9SAndroid Build Coastguard Worker        def f(x):
1028*da0073e9SAndroid Build Coastguard Worker            return x.sub(1, alpha=2)
1029*da0073e9SAndroid Build Coastguard Worker
1030*da0073e9SAndroid Build Coastguard Worker        f(torch.ones(2, device="cuda", dtype=torch.float64))
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/113010
1033*da0073e9SAndroid Build Coastguard Worker    def test_out_overload_non_contiguous(self):
1034*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
1035*da0073e9SAndroid Build Coastguard Worker            return torch.abs(x, out=y.T)
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker        f_compiled = torch.compile(f, backend="aot_eager")
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker        x_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2)
1040*da0073e9SAndroid Build Coastguard Worker        y_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2)
1041*da0073e9SAndroid Build Coastguard Worker        x_test = torch.arange(4, dtype=torch.float32).reshape(2, 2)
1042*da0073e9SAndroid Build Coastguard Worker        y_test = torch.arange(4, dtype=torch.float32).reshape(2, 2)
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x_ref, y_ref)
1045*da0073e9SAndroid Build Coastguard Worker        out_test = f_compiled(x_test, y_test)
1046*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
1047*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_ref, y_test)
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/109053
1050*da0073e9SAndroid Build Coastguard Worker    def test_view_dtype_overload(self):
1051*da0073e9SAndroid Build Coastguard Worker        def f(x):
1052*da0073e9SAndroid Build Coastguard Worker            return x.view(torch.int32)
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker        f_compiled = torch.compile(f, backend="aot_eager")
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker        x1 = torch.ones(4, requires_grad=True)
1057*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x1)
1058*da0073e9SAndroid Build Coastguard Worker        out_test = f_compiled(x1)
1059*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        x2 = torch.ones(4, requires_grad=False)
1062*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x2)
1063*da0073e9SAndroid Build Coastguard Worker        out_test = f_compiled(x2)
1064*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/90552
1067*da0073e9SAndroid Build Coastguard Worker    def test_intermediate_leaf_requires_grad(self):
1068*da0073e9SAndroid Build Coastguard Worker        def f(x):
1069*da0073e9SAndroid Build Coastguard Worker            leaf = torch.ones(2, requires_grad=True)
1070*da0073e9SAndroid Build Coastguard Worker            return leaf, leaf * 2
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker        f_compiled = torch.compile(f, backend="aot_eager")
1073*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(4, dtype=torch.float32).reshape(2, 2)
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker        leaf, out = f(x)
1076*da0073e9SAndroid Build Coastguard Worker        leaf_test, out_test = f_compiled(x)
1077*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
1078*da0073e9SAndroid Build Coastguard Worker        out_test.sum().backward()
1079*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(leaf.grad, leaf_test.grad)
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/113263
1082*da0073e9SAndroid Build Coastguard Worker    def test_unpack_hooks_dont_run_during_tracing(self):
1083*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
1084*da0073e9SAndroid Build Coastguard Worker            return x * y
1085*da0073e9SAndroid Build Coastguard Worker
1086*da0073e9SAndroid Build Coastguard Worker        f_compiled = torch.compile(f, backend="aot_eager")
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker        pack_count = 0
1089*da0073e9SAndroid Build Coastguard Worker        unpack_count = 0
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Worker        def pack_hook(x):
1092*da0073e9SAndroid Build Coastguard Worker            nonlocal pack_count
1093*da0073e9SAndroid Build Coastguard Worker            pack_count += 1
1094*da0073e9SAndroid Build Coastguard Worker            return x
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker        # unpack hook shouldn't run during compilation, while we trace the forward
1097*da0073e9SAndroid Build Coastguard Worker        def unpack_hook(x):
1098*da0073e9SAndroid Build Coastguard Worker            nonlocal unpack_count
1099*da0073e9SAndroid Build Coastguard Worker            unpack_count += 1
1100*da0073e9SAndroid Build Coastguard Worker            return x
1101*da0073e9SAndroid Build Coastguard Worker
1102*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(4, requires_grad=True)
1103*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(4, requires_grad=False)
1104*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
1105*da0073e9SAndroid Build Coastguard Worker            out_test = f_compiled(x, y)
1106*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(pack_count, 1)
1107*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unpack_count, 0)
1108*da0073e9SAndroid Build Coastguard Worker            out_test.sum().backward()
1109*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(pack_count, 1)
1110*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unpack_count, 1)
1111*da0073e9SAndroid Build Coastguard Worker
1112*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/113263
1113*da0073e9SAndroid Build Coastguard Worker    def test_unpack_hooks_can_be_disabled(self):
1114*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
1115*da0073e9SAndroid Build Coastguard Worker            return x * y
1116*da0073e9SAndroid Build Coastguard Worker
1117*da0073e9SAndroid Build Coastguard Worker        f_compiled = torch.compile(f, backend="aot_eager")
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(4, requires_grad=True)
1120*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(4, requires_grad=False)
1121*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"):
1122*da0073e9SAndroid Build Coastguard Worker            out_test = f_compiled(x, y)
1123*da0073e9SAndroid Build Coastguard Worker            out_test.sum().backward()
1124*da0073e9SAndroid Build Coastguard Worker
1125*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/113263
1126*da0073e9SAndroid Build Coastguard Worker    def test_disabling_unpack_hooks_within_compiled_region(self):
1127*da0073e9SAndroid Build Coastguard Worker        def g(z):
1128*da0073e9SAndroid Build Coastguard Worker            with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"):
1129*da0073e9SAndroid Build Coastguard Worker                return z + 5
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
1132*da0073e9SAndroid Build Coastguard Worker            z = x * y
1133*da0073e9SAndroid Build Coastguard Worker            return g(z)
1134*da0073e9SAndroid Build Coastguard Worker
1135*da0073e9SAndroid Build Coastguard Worker        f_compiled = torch.compile(f, backend="aot_eager")
1136*da0073e9SAndroid Build Coastguard Worker
1137*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(4, requires_grad=True)
1138*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(4, requires_grad=False)
1139*da0073e9SAndroid Build Coastguard Worker        out_test = f_compiled(x, y)
1140*da0073e9SAndroid Build Coastguard Worker        out_test.sum().backward()
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/97745
1143*da0073e9SAndroid Build Coastguard Worker    def test_gan_repro_trying_to_backward_through_the_graph_a_second_time(self):
1144*da0073e9SAndroid Build Coastguard Worker        def f(a, b):
1145*da0073e9SAndroid Build Coastguard Worker            c = torch.ones(2, 2)
1146*da0073e9SAndroid Build Coastguard Worker            d = torch.ones(2, 2)
1147*da0073e9SAndroid Build Coastguard Worker            e = torch.matmul(a, c)
1148*da0073e9SAndroid Build Coastguard Worker            g_loss = torch.abs(e - d).mean()
1149*da0073e9SAndroid Build Coastguard Worker            g_loss.backward()
1150*da0073e9SAndroid Build Coastguard Worker            fake_d_pred = torch.matmul(b, e.detach())
1151*da0073e9SAndroid Build Coastguard Worker            d_loss = fake_d_pred.mean()
1152*da0073e9SAndroid Build Coastguard Worker            d_loss.backward()
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker        a_ref = torch.randn(2, 2, requires_grad=True)
1155*da0073e9SAndroid Build Coastguard Worker        b_ref = torch.randn(2, 2, requires_grad=True)
1156*da0073e9SAndroid Build Coastguard Worker        out_ref = f(a_ref, b_ref)
1157*da0073e9SAndroid Build Coastguard Worker
1158*da0073e9SAndroid Build Coastguard Worker        a_test = a_ref.clone().detach().requires_grad_(True)
1159*da0073e9SAndroid Build Coastguard Worker        b_test = b_ref.clone().detach().requires_grad_(True)
1160*da0073e9SAndroid Build Coastguard Worker        out_test = torch.compile(f, backend="aot_eager")(a_test, b_test)
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
1163*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_ref.grad, a_test.grad)
1164*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b_ref.grad, b_test.grad)
1165*da0073e9SAndroid Build Coastguard Worker
1166*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/111603
1167*da0073e9SAndroid Build Coastguard Worker    def test_tuple_enum_as_key_dict(self):
1168*da0073e9SAndroid Build Coastguard Worker        class MyEnum(Enum):
1169*da0073e9SAndroid Build Coastguard Worker            A = "a"
1170*da0073e9SAndroid Build Coastguard Worker
1171*da0073e9SAndroid Build Coastguard Worker        class SomeModel(torch.nn.Module):
1172*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1173*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1174*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(1, 1)
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker            def forward(self, x) -> torch.Tensor:
1177*da0073e9SAndroid Build Coastguard Worker                return self.linear(x[MyEnum.A])
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker        x = {MyEnum.A: torch.rand(8, 1)}
1180*da0073e9SAndroid Build Coastguard Worker        model_pytorch = SomeModel()
1181*da0073e9SAndroid Build Coastguard Worker        model = torch.compile(model_pytorch)
1182*da0073e9SAndroid Build Coastguard Worker        # Executing twice works
1183*da0073e9SAndroid Build Coastguard Worker        model(x)
1184*da0073e9SAndroid Build Coastguard Worker        y = model(x)
1185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, model_pytorch(x))
1186*da0073e9SAndroid Build Coastguard Worker
1187*da0073e9SAndroid Build Coastguard Worker    def test_embedding_backward_broadcasting_decomp(self):
1188*da0073e9SAndroid Build Coastguard Worker        def f(grad_output, indices):
1189*da0073e9SAndroid Build Coastguard Worker            num_weights = 10
1190*da0073e9SAndroid Build Coastguard Worker            padding_idx = 1
1191*da0073e9SAndroid Build Coastguard Worker            scale_grad_by_freq = True
1192*da0073e9SAndroid Build Coastguard Worker            return torch.ops.aten.embedding_dense_backward(
1193*da0073e9SAndroid Build Coastguard Worker                grad_output, indices, num_weights, padding_idx, scale_grad_by_freq
1194*da0073e9SAndroid Build Coastguard Worker            )
1195*da0073e9SAndroid Build Coastguard Worker
1196*da0073e9SAndroid Build Coastguard Worker        f_compiled = torch.compile(f, backend="aot_eager")
1197*da0073e9SAndroid Build Coastguard Worker
1198*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.ones(2, 4, 3, dtype=torch.float16)
1199*da0073e9SAndroid Build Coastguard Worker        indices = torch.ones(2, 4, dtype=torch.int64)
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker        out_ref = f(grad_output, indices)
1202*da0073e9SAndroid Build Coastguard Worker        out_test = f_compiled(grad_output, indices)
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Worker    def test_reformer_eval(self):
1207*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1208*da0073e9SAndroid Build Coastguard Worker            cnt = self._reformer(nopython=True)
1209*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1210*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 11)
1211*da0073e9SAndroid Build Coastguard Worker
1212*da0073e9SAndroid Build Coastguard Worker    def test_reformer_train(self):
1213*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
1214*da0073e9SAndroid Build Coastguard Worker            cnt = self._reformer(nopython=False)
1215*da0073e9SAndroid Build Coastguard Worker        expected_op_count = (
1216*da0073e9SAndroid Build Coastguard Worker            """11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5"""
1217*da0073e9SAndroid Build Coastguard Worker        )
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnt.frame_count, """1""")
1220*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnt.op_count, expected_op_count)
1221*da0073e9SAndroid Build Coastguard Worker
1222*da0073e9SAndroid Build Coastguard Worker    @disable_translation_validation_if_dynamic_shapes
1223*da0073e9SAndroid Build Coastguard Worker    def test_longformer_chunk(self):
1224*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn([1, 4096, 1])
1225*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn([12, 4096, 64])
1226*da0073e9SAndroid Build Coastguard Worker        correct1 = longformer_chunk(input1)
1227*da0073e9SAndroid Build Coastguard Worker        correct2 = longformer_chunk(input2)
1228*da0073e9SAndroid Build Coastguard Worker        fn = longformer_chunk
1229*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1230*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1231*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(input1), correct1))
1232*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(input2), correct2))
1233*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(input1), correct1))
1234*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(input2), correct2))
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1237*da0073e9SAndroid Build Coastguard Worker            if torch._dynamo.config.automatic_dynamic_shapes:
1238*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(cnt.frame_count, """2""")
1239*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(cnt.op_count, """8""")
1240*da0073e9SAndroid Build Coastguard Worker            else:
1241*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(cnt.frame_count, """2""")
1242*da0073e9SAndroid Build Coastguard Worker                self.assertExpectedInline(cnt.op_count, """4""")
1243*da0073e9SAndroid Build Coastguard Worker        else:
1244*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """2""")
1245*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.op_count, """19""")
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker    def test_hf_t5_forward(self):
1248*da0073e9SAndroid Build Coastguard Worker        input = torch.randn([1, 2048, 512])
1249*da0073e9SAndroid Build Coastguard Worker        model = PartialT5()
1250*da0073e9SAndroid Build Coastguard Worker        correct = model(input)
1251*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1252*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize_assert(cnt)(model)
1253*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_model(input), correct))
1254*da0073e9SAndroid Build Coastguard Worker
1255*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1256*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """1""")
1257*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.op_count, """11""")
1258*da0073e9SAndroid Build Coastguard Worker        else:
1259*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """1""")
1260*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.op_count, """11""")
1261*da0073e9SAndroid Build Coastguard Worker
1262*da0073e9SAndroid Build Coastguard Worker    def test_module_in_skipfiles(self):
1263*da0073e9SAndroid Build Coastguard Worker        model = nn.Linear(10, 10)
1264*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1265*da0073e9SAndroid Build Coastguard Worker        torch.compile(model, backend=cnt, fullgraph=True)(torch.randn([5, 10]))
1266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1267*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 1)
1268*da0073e9SAndroid Build Coastguard Worker
1269*da0073e9SAndroid Build Coastguard Worker    def test_function_in_skipfiles(self):
1270*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1271*da0073e9SAndroid Build Coastguard Worker        torch.compile(torch.sin, backend=cnt, fullgraph=True)(torch.randn([5, 10]))
1272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1273*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 1)
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker    def test_slicing_dynamic_shape(self):
1276*da0073e9SAndroid Build Coastguard Worker        def fn(y):
1277*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(8)
1278*da0073e9SAndroid Build Coastguard Worker            idx = y[0]
1279*da0073e9SAndroid Build Coastguard Worker            out = x[idx:]
1280*da0073e9SAndroid Build Coastguard Worker            return (out + 3) * 5
1281*da0073e9SAndroid Build Coastguard Worker
1282*da0073e9SAndroid Build Coastguard Worker        counter = torch._dynamo.testing.CompileCounter()
1283*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(counter)(fn)
1284*da0073e9SAndroid Build Coastguard Worker        out = opt_fn(torch.ones(10, dtype=torch.long))
1285*da0073e9SAndroid Build Coastguard Worker        # idx should be 1 -> slicing off [1:] of 8 elem tensor
1286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(out.shape), [7])
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 2)
1289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(opt_fn(torch.tensor([4])).shape), [4])
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker    def test_slicing_dynamic_shape_setitem(self):
1294*da0073e9SAndroid Build Coastguard Worker        def fn(input_lengths: torch.Tensor, new_ones_1):
1295*da0073e9SAndroid Build Coastguard Worker            getitem_13 = input_lengths[3]
1296*da0073e9SAndroid Build Coastguard Worker            new_ones_1[(3, slice(getitem_13, None, None))] = 0
1297*da0073e9SAndroid Build Coastguard Worker            setitem_13 = new_ones_1
1298*da0073e9SAndroid Build Coastguard Worker            return (setitem_13,)
1299*da0073e9SAndroid Build Coastguard Worker
1300*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10).to(dtype=torch.int64)
1301*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10, 204)
1302*da0073e9SAndroid Build Coastguard Worker        ref = fn(x, y)
1303*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
1304*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x, y)
1305*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
1306*da0073e9SAndroid Build Coastguard Worker
1307*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(error_on_recompile=True)
1308*da0073e9SAndroid Build Coastguard Worker    @torch.fx.experimental._config.patch(use_duck_shape=False)
1309*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_shape_disable_duck_size(self):
1310*da0073e9SAndroid Build Coastguard Worker        class TestModel(nn.Module):
1311*da0073e9SAndroid Build Coastguard Worker            def __init__(
1312*da0073e9SAndroid Build Coastguard Worker                self,
1313*da0073e9SAndroid Build Coastguard Worker            ):
1314*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1315*da0073e9SAndroid Build Coastguard Worker
1316*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor, val: int) -> torch.Tensor:
1317*da0073e9SAndroid Build Coastguard Worker                return x + val
1318*da0073e9SAndroid Build Coastguard Worker
1319*da0073e9SAndroid Build Coastguard Worker        main_model = TestModel().to(memory_format=torch.channels_last)
1320*da0073e9SAndroid Build Coastguard Worker        opt_model = torch.compile(main_model, backend="eager", dynamic=True)
1321*da0073e9SAndroid Build Coastguard Worker
1322*da0073e9SAndroid Build Coastguard Worker        x1 = torch.rand(2, 5, 10, 10).to(memory_format=torch.channels_last)
1323*da0073e9SAndroid Build Coastguard Worker        x2 = torch.rand(2, 5, 4, 8).to(memory_format=torch.channels_last)
1324*da0073e9SAndroid Build Coastguard Worker
1325*da0073e9SAndroid Build Coastguard Worker        o1_ref = main_model(x1, 4)
1326*da0073e9SAndroid Build Coastguard Worker        o1 = opt_model(x1, 4)
1327*da0073e9SAndroid Build Coastguard Worker
1328*da0073e9SAndroid Build Coastguard Worker        o2_ref = main_model(x2, 20)
1329*da0073e9SAndroid Build Coastguard Worker        o2 = opt_model(x2, 20)
1330*da0073e9SAndroid Build Coastguard Worker
1331*da0073e9SAndroid Build Coastguard Worker    def test_chunk_reformer_ff(self):
1332*da0073e9SAndroid Build Coastguard Worker        input = torch.randn([1, 4096, 256])
1333*da0073e9SAndroid Build Coastguard Worker        model = ChunkReformerFeedForward()
1334*da0073e9SAndroid Build Coastguard Worker        correct = model(input)
1335*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1336*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize_assert(cnt)(model)
1337*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_model(input), correct))
1338*da0073e9SAndroid Build Coastguard Worker
1339*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1340*da0073e9SAndroid Build Coastguard Worker        self.assertLessEqual(cnt.op_count, 10)
1341*da0073e9SAndroid Build Coastguard Worker
1342*da0073e9SAndroid Build Coastguard Worker    # see: https://github.com/pytorch/pytorch/issues/80067
1343*da0073e9SAndroid Build Coastguard Worker    # NB: When you remove the expectedFailure, don't forget to
1344*da0073e9SAndroid Build Coastguard Worker    # uncomment/adjust the assertEqual below
1345*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
1346*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(
1347*da0073e9SAndroid Build Coastguard Worker        fake_tensor_propagation=True, capture_scalar_outputs=True
1348*da0073e9SAndroid Build Coastguard Worker    )
1349*da0073e9SAndroid Build Coastguard Worker    def test_maml_item_capture(self):
1350*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 1, 28, 28)
1351*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(5, dtype=torch.int64)
1352*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(75, 1, 28, 28)
1353*da0073e9SAndroid Build Coastguard Worker        d = torch.zeros(75, dtype=torch.int64)
1354*da0073e9SAndroid Build Coastguard Worker        model = PartialMaml()
1355*da0073e9SAndroid Build Coastguard Worker        correct = model(a, b, c, d)
1356*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1357*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize(cnt)(model)
1358*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
1359*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(opt_model(a, b, c, d), correct))
1360*da0073e9SAndroid Build Coastguard Worker
1361*da0073e9SAndroid Build Coastguard Worker        # if torch._dynamo.config.assume_static_by_default:
1362*da0073e9SAndroid Build Coastguard Worker        #     self.assertExpectedInline(cnt.frame_count, """2""")
1363*da0073e9SAndroid Build Coastguard Worker        # else:
1364*da0073e9SAndroid Build Coastguard Worker        #     self.assertExpectedInline(cnt.frame_count, """3""")
1365*da0073e9SAndroid Build Coastguard Worker        # TODO(jansel): figure out why op count depends on imports
1366*da0073e9SAndroid Build Coastguard Worker        self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27))
1367*da0073e9SAndroid Build Coastguard Worker
1368*da0073e9SAndroid Build Coastguard Worker    # see: https://github.com/pytorch/pytorch/issues/80067
1369*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=False)
1370*da0073e9SAndroid Build Coastguard Worker    def test_maml_no_item_capture(self):
1371*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 1, 28, 28)
1372*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(5, dtype=torch.int64)
1373*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(75, 1, 28, 28)
1374*da0073e9SAndroid Build Coastguard Worker        d = torch.zeros(75, dtype=torch.int64)
1375*da0073e9SAndroid Build Coastguard Worker        model = PartialMaml()
1376*da0073e9SAndroid Build Coastguard Worker        correct = model(a, b, c, d)
1377*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1378*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize(cnt)(model)
1379*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
1380*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(opt_model(a, b, c, d), correct))
1381*da0073e9SAndroid Build Coastguard Worker
1382*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1383*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """4""")
1384*da0073e9SAndroid Build Coastguard Worker        else:
1385*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """5""")
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker    def test_hf_model_output(self):
1388*da0073e9SAndroid Build Coastguard Worker        ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10))
1389*da0073e9SAndroid Build Coastguard Worker
1390*da0073e9SAndroid Build Coastguard Worker        def fn1(x):
1391*da0073e9SAndroid Build Coastguard Worker            return x["a"] + 1
1392*da0073e9SAndroid Build Coastguard Worker
1393*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
1394*da0073e9SAndroid Build Coastguard Worker            return x.a + 1
1395*da0073e9SAndroid Build Coastguard Worker
1396*da0073e9SAndroid Build Coastguard Worker        def fn3(x):
1397*da0073e9SAndroid Build Coastguard Worker            return x.to_tuple()[0] + 1
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Worker        def fn4(x):
1400*da0073e9SAndroid Build Coastguard Worker            return x[0] + 1
1401*da0073e9SAndroid Build Coastguard Worker
1402*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1403*da0073e9SAndroid Build Coastguard Worker        for fn in (fn1, fn2, fn3, fn4):
1404*da0073e9SAndroid Build Coastguard Worker            cnt.clear()
1405*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1406*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(opt_fn(ex), ex.a + 1))
1407*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, 1)
1408*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.op_count, 1)
1409*da0073e9SAndroid Build Coastguard Worker
1410*da0073e9SAndroid Build Coastguard Worker    @disable_translation_validation_if_dynamic_shapes
1411*da0073e9SAndroid Build Coastguard Worker    def test_create_rand_mask_from_inputs(self):
1412*da0073e9SAndroid Build Coastguard Worker        args = [
1413*da0073e9SAndroid Build Coastguard Worker            torch.randn([1, 64, 64]),
1414*da0073e9SAndroid Build Coastguard Worker            torch.randn([1, 64, 64]),
1415*da0073e9SAndroid Build Coastguard Worker            torch.zeros([1, 12, 62, 3], dtype=torch.int64),
1416*da0073e9SAndroid Build Coastguard Worker            12,
1417*da0073e9SAndroid Build Coastguard Worker            3,
1418*da0073e9SAndroid Build Coastguard Worker            1,
1419*da0073e9SAndroid Build Coastguard Worker            4096,
1420*da0073e9SAndroid Build Coastguard Worker            64,
1421*da0073e9SAndroid Build Coastguard Worker        ]
1422*da0073e9SAndroid Build Coastguard Worker        correct = create_rand_mask_from_inputs(*args)
1423*da0073e9SAndroid Build Coastguard Worker        fn = create_rand_mask_from_inputs
1424*da0073e9SAndroid Build Coastguard Worker
1425*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1426*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1427*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(*args), correct))
1428*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1429*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """1""")
1430*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.op_count, """8""")
1431*da0073e9SAndroid Build Coastguard Worker        else:
1432*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """1""")
1433*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.op_count, """11""")
1434*da0073e9SAndroid Build Coastguard Worker
1435*da0073e9SAndroid Build Coastguard Worker    def test_rng_state(self):
1436*da0073e9SAndroid Build Coastguard Worker        def fn():
1437*da0073e9SAndroid Build Coastguard Worker            state = torch.get_rng_state()
1438*da0073e9SAndroid Build Coastguard Worker            before = torch.rand(1000)
1439*da0073e9SAndroid Build Coastguard Worker            torch.set_rng_state(state)
1440*da0073e9SAndroid Build Coastguard Worker            after = torch.rand(1000)
1441*da0073e9SAndroid Build Coastguard Worker            return before, after
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1444*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker        before, after = opt_fn()
1447*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(before, after))
1448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
1449*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 2)  # rand, rand
1450*da0073e9SAndroid Build Coastguard Worker        try:
1451*da0073e9SAndroid Build Coastguard Worker            graph, _ = torch._dynamo.export(fn)()
1452*da0073e9SAndroid Build Coastguard Worker            # See https://github.com/pytorch/pytorch/pull/87490
1453*da0073e9SAndroid Build Coastguard Worker            self.fail("unexpected export success")
1454*da0073e9SAndroid Build Coastguard Worker        except torch._dynamo.exc.Unsupported:
1455*da0073e9SAndroid Build Coastguard Worker            pass
1456*da0073e9SAndroid Build Coastguard Worker
1457*da0073e9SAndroid Build Coastguard Worker    def test_threading_local(self):
1458*da0073e9SAndroid Build Coastguard Worker        import threading
1459*da0073e9SAndroid Build Coastguard Worker
1460*da0073e9SAndroid Build Coastguard Worker        foo = threading.local()
1461*da0073e9SAndroid Build Coastguard Worker        foo.x = torch.rand(1)
1462*da0073e9SAndroid Build Coastguard Worker
1463*da0073e9SAndroid Build Coastguard Worker        def f(x):
1464*da0073e9SAndroid Build Coastguard Worker            return torch.cat([x, foo.x])
1465*da0073e9SAndroid Build Coastguard Worker
1466*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1467*da0073e9SAndroid Build Coastguard Worker        opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(1)
1470*da0073e9SAndroid Build Coastguard Worker        out = f(inp)
1471*da0073e9SAndroid Build Coastguard Worker        opt_out = opt_f(inp)
1472*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_out, out)
1473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker    def test_seq_append_list(self):
1476*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 10)
1477*da0073e9SAndroid Build Coastguard Worker        model = SequentialAppendList(
1478*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
1479*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
1480*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
1481*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
1482*da0073e9SAndroid Build Coastguard Worker        )
1483*da0073e9SAndroid Build Coastguard Worker        # this one is tricky because it mutates the list provided as an input
1484*da0073e9SAndroid Build Coastguard Worker        l1 = [x]
1485*da0073e9SAndroid Build Coastguard Worker        l2 = [x]
1486*da0073e9SAndroid Build Coastguard Worker        correct, _ = model(x, l1)
1487*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1488*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize_assert(cnt)(model)
1489*da0073e9SAndroid Build Coastguard Worker        result, l3 = opt_model(x, l2)
1490*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(result, correct))
1491*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(l1, l2))
1492*da0073e9SAndroid Build Coastguard Worker        self.assertIs(l2, l3)
1493*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1494*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 5)
1495*da0073e9SAndroid Build Coastguard Worker
1496*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_act(self):
1497*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 1, 28, 28)
1498*da0073e9SAndroid Build Coastguard Worker        model = BatchNormAct2d(1).eval()
1499*da0073e9SAndroid Build Coastguard Worker        correct = model(a)
1500*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1501*da0073e9SAndroid Build Coastguard Worker        if not torch._dynamo.config.specialize_int:
1502*da0073e9SAndroid Build Coastguard Worker            # _local_scalar_dense causes graph break w 0-dim tensor
1503*da0073e9SAndroid Build Coastguard Worker            opt_model = torch._dynamo.optimize(cnt)(model)
1504*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(opt_model(a), correct))
1505*da0073e9SAndroid Build Coastguard Worker            return
1506*da0073e9SAndroid Build Coastguard Worker
1507*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize_assert(cnt)(model)
1508*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_model(a), correct))
1509*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 2)
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Worker    def test_get_parameter_dtype(self):
1513*da0073e9SAndroid Build Coastguard Worker        model = SequentialAppendList(
1514*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(10, 10),
1515*da0073e9SAndroid Build Coastguard Worker            torch.nn.ReLU(),
1516*da0073e9SAndroid Build Coastguard Worker        )
1517*da0073e9SAndroid Build Coastguard Worker
1518*da0073e9SAndroid Build Coastguard Worker        def fn(model, x):
1519*da0073e9SAndroid Build Coastguard Worker            return x + torch.randn(10, dtype=get_parameter_dtype(model))
1520*da0073e9SAndroid Build Coastguard Worker
1521*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1522*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(model, torch.randn(10)).dtype, torch.float32)
1524*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1525*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 2)
1526*da0073e9SAndroid Build Coastguard Worker
1527*da0073e9SAndroid Build Coastguard Worker    def test_nn_parameter(self):
1528*da0073e9SAndroid Build Coastguard Worker        def test_fn():
1529*da0073e9SAndroid Build Coastguard Worker            a = torch.nn.Parameter(torch.randn(5, 5))
1530*da0073e9SAndroid Build Coastguard Worker            # Checks that TensorVariable stores the type information correctly
1531*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(a, torch.nn.Parameter))
1532*da0073e9SAndroid Build Coastguard Worker            return a
1533*da0073e9SAndroid Build Coastguard Worker
1534*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1535*da0073e9SAndroid Build Coastguard Worker        opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
1536*da0073e9SAndroid Build Coastguard Worker        out = opt_test_fn()
1537*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(out, torch.nn.Parameter))
1538*da0073e9SAndroid Build Coastguard Worker
1539*da0073e9SAndroid Build Coastguard Worker    def test_Size(self):
1540*da0073e9SAndroid Build Coastguard Worker        def test_fn():
1541*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(4)
1542*da0073e9SAndroid Build Coastguard Worker            x = torch.Size([1, 2, 3])
1543*da0073e9SAndroid Build Coastguard Worker            # Checks that SizeVariable return torch.Size object
1544*da0073e9SAndroid Build Coastguard Worker            assert isinstance(x, torch.Size)
1545*da0073e9SAndroid Build Coastguard Worker            # Causes graph breaks and checks reconstruction of SizeVariable
1546*da0073e9SAndroid Build Coastguard Worker            # object
1547*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(x, torch.Size)
1548*da0073e9SAndroid Build Coastguard Worker            return a
1549*da0073e9SAndroid Build Coastguard Worker
1550*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1551*da0073e9SAndroid Build Coastguard Worker        opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
1552*da0073e9SAndroid Build Coastguard Worker        opt_test_fn()
1553*da0073e9SAndroid Build Coastguard Worker
1554*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/100067
1555*da0073e9SAndroid Build Coastguard Worker    def test_copy_weird_strides(self):
1556*da0073e9SAndroid Build Coastguard Worker        # This test requires inductor's copy() decomp to preserve strides properly.
1557*da0073e9SAndroid Build Coastguard Worker        def test_fn(a):
1558*da0073e9SAndroid Build Coastguard Worker            b = torch.zeros(48, 4, 256, 513)
1559*da0073e9SAndroid Build Coastguard Worker            b[:, 0, 1:256, 1:256] = a
1560*da0073e9SAndroid Build Coastguard Worker            c = b.view(4, 12, 1024, 513)
1561*da0073e9SAndroid Build Coastguard Worker            d = c.transpose(2, 1)
1562*da0073e9SAndroid Build Coastguard Worker            d.add_(1)
1563*da0073e9SAndroid Build Coastguard Worker            return d
1564*da0073e9SAndroid Build Coastguard Worker
1565*da0073e9SAndroid Build Coastguard Worker        sh, st, dt, dev, rg = (
1566*da0073e9SAndroid Build Coastguard Worker            (48, 255, 255),
1567*da0073e9SAndroid Build Coastguard Worker            (787968, 513, 1),
1568*da0073e9SAndroid Build Coastguard Worker            torch.float16,
1569*da0073e9SAndroid Build Coastguard Worker            "cpu",
1570*da0073e9SAndroid Build Coastguard Worker            True,
1571*da0073e9SAndroid Build Coastguard Worker        )
1572*da0073e9SAndroid Build Coastguard Worker        a = rand_strided(sh, st, dt, dev).requires_grad_(rg)
1573*da0073e9SAndroid Build Coastguard Worker        compiled_f = torch.compile(test_fn, backend="aot_eager_decomp_partition")
1574*da0073e9SAndroid Build Coastguard Worker        out1 = test_fn(a)
1575*da0073e9SAndroid Build Coastguard Worker        out2 = compiled_f(a)
1576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
1577*da0073e9SAndroid Build Coastguard Worker
1578*da0073e9SAndroid Build Coastguard Worker    def test_indexing_with_list(self):
1579*da0073e9SAndroid Build Coastguard Worker        def test_fn():
1580*da0073e9SAndroid Build Coastguard Worker            def run_test(tensor, *idx):
1581*da0073e9SAndroid Build Coastguard Worker                npt = tensor.numpy()
1582*da0073e9SAndroid Build Coastguard Worker                assert npt[idx].shape == tensor[idx].shape
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker            x = torch.arange(0, 10)
1585*da0073e9SAndroid Build Coastguard Worker            cases = [
1586*da0073e9SAndroid Build Coastguard Worker                [None, None],
1587*da0073e9SAndroid Build Coastguard Worker                [1, None],
1588*da0073e9SAndroid Build Coastguard Worker            ]
1589*da0073e9SAndroid Build Coastguard Worker
1590*da0073e9SAndroid Build Coastguard Worker            for case in cases:
1591*da0073e9SAndroid Build Coastguard Worker                run_test(x, *case)
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker            return torch.randn(4)
1594*da0073e9SAndroid Build Coastguard Worker
1595*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1596*da0073e9SAndroid Build Coastguard Worker        opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
1597*da0073e9SAndroid Build Coastguard Worker        opt_test_fn()
1598*da0073e9SAndroid Build Coastguard Worker
1599*da0073e9SAndroid Build Coastguard Worker    def test_reformer_min_chunk_len(self):
1600*da0073e9SAndroid Build Coastguard Worker        def fn(cfg):
1601*da0073e9SAndroid Build Coastguard Worker            t = torch.empty(10)
1602*da0073e9SAndroid Build Coastguard Worker            t.fill_(_get_min_chunk_len(cfg))
1603*da0073e9SAndroid Build Coastguard Worker            return t[0]
1604*da0073e9SAndroid Build Coastguard Worker
1605*da0073e9SAndroid Build Coastguard Worker        cfg = DummyConfig()
1606*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1607*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1608*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(cfg), 64)
1609*da0073e9SAndroid Build Coastguard Worker        # With unspec int, maximum computation is preserved
1610*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnt.frame_count, """1""")
1611*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnt.op_count, """3""")
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker    def test_reformer_sorting(self):
1614*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros([1, 12, 4096], dtype=torch.int64)
1615*da0073e9SAndroid Build Coastguard Worker        correct = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(x)
1616*da0073e9SAndroid Build Coastguard Worker        fn = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx
1617*da0073e9SAndroid Build Coastguard Worker
1618*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1619*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
1620*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(opt_fn(x), correct))
1621*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
1622*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """1""")
1623*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.op_count, """14""")
1624*da0073e9SAndroid Build Coastguard Worker        else:
1625*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """1""")
1626*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.op_count, """16""")
1627*da0073e9SAndroid Build Coastguard Worker
1628*da0073e9SAndroid Build Coastguard Worker    def test_recursive_map(self):
1629*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/torchdynamo/issues/132
1630*da0073e9SAndroid Build Coastguard Worker        def _recursive_map(struct, batch_dim=0):
1631*da0073e9SAndroid Build Coastguard Worker            for k, v in struct.items():
1632*da0073e9SAndroid Build Coastguard Worker                if v is not None:
1633*da0073e9SAndroid Build Coastguard Worker                    if isinstance(v, dict):
1634*da0073e9SAndroid Build Coastguard Worker                        _recursive_map(v)
1635*da0073e9SAndroid Build Coastguard Worker                    else:
1636*da0073e9SAndroid Build Coastguard Worker                        struct[k] = v
1637*da0073e9SAndroid Build Coastguard Worker
1638*da0073e9SAndroid Build Coastguard Worker        def toy_example(a, b, v):
1639*da0073e9SAndroid Build Coastguard Worker            x = a / (torch.abs(a) + 1)
1640*da0073e9SAndroid Build Coastguard Worker            if v is not None:
1641*da0073e9SAndroid Build Coastguard Worker                _recursive_map(v)
1642*da0073e9SAndroid Build Coastguard Worker            return x * b
1643*da0073e9SAndroid Build Coastguard Worker
1644*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1645*da0073e9SAndroid Build Coastguard Worker        opt_toy_example = torch._dynamo.optimize(cnt)(toy_example)
1646*da0073e9SAndroid Build Coastguard Worker        opt_toy_example(
1647*da0073e9SAndroid Build Coastguard Worker            torch.randn(10),
1648*da0073e9SAndroid Build Coastguard Worker            torch.randn(10),
1649*da0073e9SAndroid Build Coastguard Worker            {"layer0": {"memory_keys": torch.randn(10)}},
1650*da0073e9SAndroid Build Coastguard Worker        )
1651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1652*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 4)
1653*da0073e9SAndroid Build Coastguard Worker
1654*da0073e9SAndroid Build Coastguard Worker    def test_issue114171(self):
1655*da0073e9SAndroid Build Coastguard Worker        device = torch.device("cpu")
1656*da0073e9SAndroid Build Coastguard Worker
1657*da0073e9SAndroid Build Coastguard Worker        def fcnn(in_dim, out_dim, hidden_dim, activation=torch.nn.GELU):
1658*da0073e9SAndroid Build Coastguard Worker            layers = [
1659*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(in_dim, hidden_dim, device=device),
1660*da0073e9SAndroid Build Coastguard Worker                activation(),
1661*da0073e9SAndroid Build Coastguard Worker                torch.nn.Linear(hidden_dim, out_dim, device=device),
1662*da0073e9SAndroid Build Coastguard Worker            ]
1663*da0073e9SAndroid Build Coastguard Worker            return torch.nn.Sequential(*layers)
1664*da0073e9SAndroid Build Coastguard Worker
1665*da0073e9SAndroid Build Coastguard Worker        class testmodel(torch.nn.Module):
1666*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1667*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1668*da0073e9SAndroid Build Coastguard Worker                self.interaction_networks = torch.nn.ModuleList(
1669*da0073e9SAndroid Build Coastguard Worker                    [fcnn(262, 1174, 400) for _ in range(4)]
1670*da0073e9SAndroid Build Coastguard Worker                )
1671*da0073e9SAndroid Build Coastguard Worker
1672*da0073e9SAndroid Build Coastguard Worker            def interact(self, x, cycle):
1673*da0073e9SAndroid Build Coastguard Worker                return self.interaction_networks[cycle](x)
1674*da0073e9SAndroid Build Coastguard Worker
1675*da0073e9SAndroid Build Coastguard Worker        model = testmodel()
1676*da0073e9SAndroid Build Coastguard Worker        forward_aot = torch.compile(
1677*da0073e9SAndroid Build Coastguard Worker            model.interact, fullgraph=True, dynamic=True, backend="eager"
1678*da0073e9SAndroid Build Coastguard Worker        )
1679*da0073e9SAndroid Build Coastguard Worker
1680*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([111, 262], device=device)
1681*da0073e9SAndroid Build Coastguard Worker        y2 = forward_aot(x, 2)  # previously failed
1682*da0073e9SAndroid Build Coastguard Worker
1683*da0073e9SAndroid Build Coastguard Worker    def test_issue175(self):
1684*da0073e9SAndroid Build Coastguard Worker        n_heads = 2
1685*da0073e9SAndroid Build Coastguard Worker        d_model = 64
1686*da0073e9SAndroid Build Coastguard Worker        model = TransformerEncoderLayer(d_model, n_heads)
1687*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(1, d_model)
1688*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1689*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize(cnt, nopython=True)(model)
1690*da0073e9SAndroid Build Coastguard Worker        opt_model(inp)
1691*da0073e9SAndroid Build Coastguard Worker        opt_model(inp)
1692*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
1693*da0073e9SAndroid Build Coastguard Worker
1694*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1695*da0073e9SAndroid Build Coastguard Worker            15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count
1696*da0073e9SAndroid Build Coastguard Worker        )
1697*da0073e9SAndroid Build Coastguard Worker
1698*da0073e9SAndroid Build Coastguard Worker    def test_exec_import(self):
1699*da0073e9SAndroid Build Coastguard Worker        def fn1():
1700*da0073e9SAndroid Build Coastguard Worker            exec("import math")
1701*da0073e9SAndroid Build Coastguard Worker
1702*da0073e9SAndroid Build Coastguard Worker        def fn2():
1703*da0073e9SAndroid Build Coastguard Worker            try:
1704*da0073e9SAndroid Build Coastguard Worker                math.sqrt(4)
1705*da0073e9SAndroid Build Coastguard Worker                return False
1706*da0073e9SAndroid Build Coastguard Worker            except NameError:
1707*da0073e9SAndroid Build Coastguard Worker                return True
1708*da0073e9SAndroid Build Coastguard Worker
1709*da0073e9SAndroid Build Coastguard Worker        def fn3():
1710*da0073e9SAndroid Build Coastguard Worker            fn1()
1711*da0073e9SAndroid Build Coastguard Worker            return fn2()
1712*da0073e9SAndroid Build Coastguard Worker
1713*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fn3())
1714*da0073e9SAndroid Build Coastguard Worker        opt_fn3 = torch._dynamo.optimize("eager")(fn3)
1715*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(opt_fn3())
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker    def test_exec_wildcard_import(self):
1718*da0073e9SAndroid Build Coastguard Worker        # Test that globals are not carried over from frame to frame
1719*da0073e9SAndroid Build Coastguard Worker        def fn1():
1720*da0073e9SAndroid Build Coastguard Worker            exec("from torch import *")
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker        def fn2():
1723*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(4)
1724*da0073e9SAndroid Build Coastguard Worker            for i in range(5):
1725*da0073e9SAndroid Build Coastguard Worker                x = x + i
1726*da0073e9SAndroid Build Coastguard Worker            return x
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Worker        def fn3():
1729*da0073e9SAndroid Build Coastguard Worker            fn1()
1730*da0073e9SAndroid Build Coastguard Worker            return fn2()
1731*da0073e9SAndroid Build Coastguard Worker
1732*da0073e9SAndroid Build Coastguard Worker        ref = fn3()
1733*da0073e9SAndroid Build Coastguard Worker        opt_fn3 = torch._dynamo.optimize("eager")(fn3)
1734*da0073e9SAndroid Build Coastguard Worker        res = opt_fn3()
1735*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
1736*da0073e9SAndroid Build Coastguard Worker
1737*da0073e9SAndroid Build Coastguard Worker    def test_with_on_graph_break_inst(self):
1738*da0073e9SAndroid Build Coastguard Worker        def reversible(x):
1739*da0073e9SAndroid Build Coastguard Worker            print("Hello world")  # Cause graph break so inline fails
1740*da0073e9SAndroid Build Coastguard Worker            return torch.sin(torch.cos(x))
1741*da0073e9SAndroid Build Coastguard Worker
1742*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1743*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
1744*da0073e9SAndroid Build Coastguard Worker                a = torch.sin(x)
1745*da0073e9SAndroid Build Coastguard Worker                b = reversible(a)
1746*da0073e9SAndroid Build Coastguard Worker                c = torch.sigmoid(b)
1747*da0073e9SAndroid Build Coastguard Worker                c.sum().backward()
1748*da0073e9SAndroid Build Coastguard Worker                return x.grad
1749*da0073e9SAndroid Build Coastguard Worker
1750*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
1751*da0073e9SAndroid Build Coastguard Worker        x.grad = None
1752*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1753*da0073e9SAndroid Build Coastguard Worker            ref = fn(x)
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker        x.grad = None
1756*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
1757*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1758*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x)
1759*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker    def test_with_on_graph_break_nested(self):
1762*da0073e9SAndroid Build Coastguard Worker        def reversible(x):
1763*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()  # Cause graph break so inline fails
1764*da0073e9SAndroid Build Coastguard Worker            return torch.sin(torch.cos(x))
1765*da0073e9SAndroid Build Coastguard Worker
1766*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1767*da0073e9SAndroid Build Coastguard Worker            # nested context manager failed previously
1768*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
1769*da0073e9SAndroid Build Coastguard Worker                with torch.enable_grad():
1770*da0073e9SAndroid Build Coastguard Worker                    a = torch.sin(x)
1771*da0073e9SAndroid Build Coastguard Worker                    b = reversible(a)
1772*da0073e9SAndroid Build Coastguard Worker                    c = torch.sigmoid(b)
1773*da0073e9SAndroid Build Coastguard Worker                    c.sum().backward()
1774*da0073e9SAndroid Build Coastguard Worker                    return x.grad
1775*da0073e9SAndroid Build Coastguard Worker
1776*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
1777*da0073e9SAndroid Build Coastguard Worker        x.grad = None
1778*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1779*da0073e9SAndroid Build Coastguard Worker            ref = fn(x)
1780*da0073e9SAndroid Build Coastguard Worker
1781*da0073e9SAndroid Build Coastguard Worker        x.grad = None
1782*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
1783*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1784*da0073e9SAndroid Build Coastguard Worker            res = opt_fn(x)
1785*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
1786*da0073e9SAndroid Build Coastguard Worker
1787*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/torchdynamo/issues/1446
1788*da0073e9SAndroid Build Coastguard Worker    def test_grad_mode_carrying_correct_state_after_graph_break(self):
1789*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1790*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
1791*da0073e9SAndroid Build Coastguard Worker                y = x * 3
1792*da0073e9SAndroid Build Coastguard Worker                print("Break")
1793*da0073e9SAndroid Build Coastguard Worker                z = x + 2
1794*da0073e9SAndroid Build Coastguard Worker            return y, z
1795*da0073e9SAndroid Build Coastguard Worker
1796*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
1797*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
1798*da0073e9SAndroid Build Coastguard Worker        y, z = opt_fn(x)
1799*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(y.requires_grad)
1800*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(z.requires_grad)
1801*da0073e9SAndroid Build Coastguard Worker
1802*da0073e9SAndroid Build Coastguard Worker    def test_abc_setattr(self):
1803*da0073e9SAndroid Build Coastguard Worker        # tests that we correctly bail out of __setattr__ calls
1804*da0073e9SAndroid Build Coastguard Worker
1805*da0073e9SAndroid Build Coastguard Worker        # TODO: does not ensure ABC classes are correctly inferred as ClassVariables
1806*da0073e9SAndroid Build Coastguard Worker        # (doesn't test the fix for 'super()')
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker        class BaseModule(torch.nn.Module, ABC):
1809*da0073e9SAndroid Build Coastguard Worker            def blah(self, x):
1810*da0073e9SAndroid Build Coastguard Worker                return x + 1
1811*da0073e9SAndroid Build Coastguard Worker
1812*da0073e9SAndroid Build Coastguard Worker        class Derived(BaseModule):
1813*da0073e9SAndroid Build Coastguard Worker            def __setattr__(self, name, value) -> None:
1814*da0073e9SAndroid Build Coastguard Worker                super().__setattr__(name, value)
1815*da0073e9SAndroid Build Coastguard Worker
1816*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1817*da0073e9SAndroid Build Coastguard Worker                # expect a graph break on __setattr__
1818*da0073e9SAndroid Build Coastguard Worker                self.foo = 0
1819*da0073e9SAndroid Build Coastguard Worker                return self.blah(x)
1820*da0073e9SAndroid Build Coastguard Worker
1821*da0073e9SAndroid Build Coastguard Worker            def blah(self, x):
1822*da0073e9SAndroid Build Coastguard Worker                return super().blah(x)
1823*da0073e9SAndroid Build Coastguard Worker
1824*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
1825*da0073e9SAndroid Build Coastguard Worker        mod = Derived()
1826*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("eager")(mod)
1827*da0073e9SAndroid Build Coastguard Worker        opt_mod(x)
1828*da0073e9SAndroid Build Coastguard Worker
1829*da0073e9SAndroid Build Coastguard Worker        # Not sure what this test is testing. It was earlier graph breaking on
1830*da0073e9SAndroid Build Coastguard Worker        # __dict__, so the counter >= 2. With __dict__ support, there is no
1831*da0073e9SAndroid Build Coastguard Worker        # graph break.
1832*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
1833*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["total"], 1)
1834*da0073e9SAndroid Build Coastguard Worker
1835*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("suppress_errors", True)
1836*da0073e9SAndroid Build Coastguard Worker    def test_guard_fail_tensor_bool(self):
1837*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.disable(recursive=False)
1838*da0073e9SAndroid Build Coastguard Worker        def fn():
1839*da0073e9SAndroid Build Coastguard Worker            condition_shape = (5, 5)
1840*da0073e9SAndroid Build Coastguard Worker            dtypes = (torch.bool,)
1841*da0073e9SAndroid Build Coastguard Worker            shapes = (
1842*da0073e9SAndroid Build Coastguard Worker                (),
1843*da0073e9SAndroid Build Coastguard Worker                (5,),
1844*da0073e9SAndroid Build Coastguard Worker                (1, 5),
1845*da0073e9SAndroid Build Coastguard Worker            )
1846*da0073e9SAndroid Build Coastguard Worker
1847*da0073e9SAndroid Build Coastguard Worker            tensors = [
1848*da0073e9SAndroid Build Coastguard Worker                torch.empty(shape, dtype=dtype).fill_(17)
1849*da0073e9SAndroid Build Coastguard Worker                for shape, dtype in itertools.product(shapes, dtypes)
1850*da0073e9SAndroid Build Coastguard Worker            ]
1851*da0073e9SAndroid Build Coastguard Worker
1852*da0073e9SAndroid Build Coastguard Worker            x_vals = (5.0, *tensors)
1853*da0073e9SAndroid Build Coastguard Worker            y_vals = (6.0, *tensors)
1854*da0073e9SAndroid Build Coastguard Worker
1855*da0073e9SAndroid Build Coastguard Worker            @torch._dynamo.disable
1856*da0073e9SAndroid Build Coastguard Worker            def get_expected(condition, x, y):
1857*da0073e9SAndroid Build Coastguard Worker                x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x
1858*da0073e9SAndroid Build Coastguard Worker                y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y
1859*da0073e9SAndroid Build Coastguard Worker                return torch.from_numpy(
1860*da0073e9SAndroid Build Coastguard Worker                    np.where(condition.cpu().numpy(), x_np, y_np)
1861*da0073e9SAndroid Build Coastguard Worker                ).to(common_dtype)
1862*da0073e9SAndroid Build Coastguard Worker
1863*da0073e9SAndroid Build Coastguard Worker            for x, y in zip(x_vals, y_vals):
1864*da0073e9SAndroid Build Coastguard Worker                condition = torch.empty(*condition_shape, dtype=torch.bool).bernoulli_()
1865*da0073e9SAndroid Build Coastguard Worker                common_dtype = torch.result_type(x, y)
1866*da0073e9SAndroid Build Coastguard Worker
1867*da0073e9SAndroid Build Coastguard Worker                def check_equal(condition, x, y):
1868*da0073e9SAndroid Build Coastguard Worker                    # NumPy aggressively promotes to double, hence cast to output to correct dtype
1869*da0073e9SAndroid Build Coastguard Worker                    expected = get_expected(condition, x, y)
1870*da0073e9SAndroid Build Coastguard Worker                    result = torch.where(condition, x, y)
1871*da0073e9SAndroid Build Coastguard Worker                    assert torch.allclose(expected, result)
1872*da0073e9SAndroid Build Coastguard Worker
1873*da0073e9SAndroid Build Coastguard Worker                check_equal(condition, x, y)
1874*da0073e9SAndroid Build Coastguard Worker                check_equal(condition, y, x)
1875*da0073e9SAndroid Build Coastguard Worker
1876*da0073e9SAndroid Build Coastguard Worker        fn()
1877*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
1878*da0073e9SAndroid Build Coastguard Worker        opt_fn()
1879*da0073e9SAndroid Build Coastguard Worker
1880*da0073e9SAndroid Build Coastguard Worker    def test_guard_fail_nested_tuple(self):
1881*da0073e9SAndroid Build Coastguard Worker        def fn(args):
1882*da0073e9SAndroid Build Coastguard Worker            return torch.ones(()), args[0] * 2
1883*da0073e9SAndroid Build Coastguard Worker
1884*da0073e9SAndroid Build Coastguard Worker        # This adds a tensor check on args[1][0] and args[1][1]
1885*da0073e9SAndroid Build Coastguard Worker        args1 = (torch.ones(1), (torch.ones(1), torch.ones(1)))
1886*da0073e9SAndroid Build Coastguard Worker        args2 = (torch.ones(1), torch.ones(1))
1887*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
1888*da0073e9SAndroid Build Coastguard Worker        ref = opt_fn(args1)
1889*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(args2)
1890*da0073e9SAndroid Build Coastguard Worker
1891*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
1892*da0073e9SAndroid Build Coastguard Worker
1893*da0073e9SAndroid Build Coastguard Worker    def test_nullcontext1(self):
1894*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True, backend="eager")
1895*da0073e9SAndroid Build Coastguard Worker        def fn(x, ctx):
1896*da0073e9SAndroid Build Coastguard Worker            x = x.sin()
1897*da0073e9SAndroid Build Coastguard Worker            with ctx:
1898*da0073e9SAndroid Build Coastguard Worker                x = x.cos()
1899*da0073e9SAndroid Build Coastguard Worker            x = x.sin()
1900*da0073e9SAndroid Build Coastguard Worker            return x
1901*da0073e9SAndroid Build Coastguard Worker
1902*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
1903*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(y, contextlib.nullcontext()), y.sin().cos().sin()))
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker    def test_nullcontext2(self):
1906*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True, backend="eager")
1907*da0073e9SAndroid Build Coastguard Worker        def fn(x, ctx):
1908*da0073e9SAndroid Build Coastguard Worker            x = x.sin()
1909*da0073e9SAndroid Build Coastguard Worker            with ctx():
1910*da0073e9SAndroid Build Coastguard Worker                x = x.cos()
1911*da0073e9SAndroid Build Coastguard Worker            x = x.sin()
1912*da0073e9SAndroid Build Coastguard Worker            return x
1913*da0073e9SAndroid Build Coastguard Worker
1914*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
1915*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(y, contextlib.nullcontext), y.sin().cos().sin()))
1916*da0073e9SAndroid Build Coastguard Worker
1917*da0073e9SAndroid Build Coastguard Worker    def test_no_grad_inline(self):
1918*da0073e9SAndroid Build Coastguard Worker        @torch.no_grad()
1919*da0073e9SAndroid Build Coastguard Worker        def a(x):
1920*da0073e9SAndroid Build Coastguard Worker            return x.sin()
1921*da0073e9SAndroid Build Coastguard Worker
1922*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
1923*da0073e9SAndroid Build Coastguard Worker        def b(x):
1924*da0073e9SAndroid Build Coastguard Worker            return a(x).cos()
1925*da0073e9SAndroid Build Coastguard Worker
1926*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
1927*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(b(y), y.sin().cos()))
1928*da0073e9SAndroid Build Coastguard Worker
1929*da0073e9SAndroid Build Coastguard Worker    @skipIfWindows(
1930*da0073e9SAndroid Build Coastguard Worker        msg="torch._dynamo.exc.TorchRuntimeError: Failed running call_function <class 'torch.LongTensor'>(*(FakeTensor(..., size=(10,), dtype=torch.int32),), **{}):"  # noqa: B950
1931*da0073e9SAndroid Build Coastguard Worker    )
1932*da0073e9SAndroid Build Coastguard Worker    def test_longtensor_list(self):
1933*da0073e9SAndroid Build Coastguard Worker        for partition in [0, 5, 10]:
1934*da0073e9SAndroid Build Coastguard Worker
1935*da0073e9SAndroid Build Coastguard Worker            @torch._dynamo.disable
1936*da0073e9SAndroid Build Coastguard Worker            def rand_gen():
1937*da0073e9SAndroid Build Coastguard Worker                rand_vals = [random.randint(5, 10) for _ in range(10)]
1938*da0073e9SAndroid Build Coastguard Worker                # List of tensors mixed with np.arrays
1939*da0073e9SAndroid Build Coastguard Worker                return list(np.array(rand_vals[:partition])) + [
1940*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(val) for val in rand_vals[partition:]
1941*da0073e9SAndroid Build Coastguard Worker                ]
1942*da0073e9SAndroid Build Coastguard Worker
1943*da0073e9SAndroid Build Coastguard Worker            def fn(x):
1944*da0073e9SAndroid Build Coastguard Worker                random_list = rand_gen()
1945*da0073e9SAndroid Build Coastguard Worker                z = torch.LongTensor(random_list)
1946*da0073e9SAndroid Build Coastguard Worker                return x * z
1947*da0073e9SAndroid Build Coastguard Worker
1948*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(10) * 2
1949*da0073e9SAndroid Build Coastguard Worker
1950*da0073e9SAndroid Build Coastguard Worker            random.seed(0)
1951*da0073e9SAndroid Build Coastguard Worker            ref0 = fn(x)
1952*da0073e9SAndroid Build Coastguard Worker            ref1 = fn(x)
1953*da0073e9SAndroid Build Coastguard Worker
1954*da0073e9SAndroid Build Coastguard Worker            random.seed(0)
1955*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize("eager")(fn)
1956*da0073e9SAndroid Build Coastguard Worker            res0 = opt_fn(x)
1957*da0073e9SAndroid Build Coastguard Worker            res1 = opt_fn(x)
1958*da0073e9SAndroid Build Coastguard Worker
1959*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(ref0, res0))
1960*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(ref1, res1))
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker    def test_primtorch(self):
1963*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager")
1964*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1965*da0073e9SAndroid Build Coastguard Worker            torch._refs.abs(x)
1966*da0073e9SAndroid Build Coastguard Worker
1967*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(3))
1968*da0073e9SAndroid Build Coastguard Worker
1969*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
1970*da0073e9SAndroid Build Coastguard Worker    # inline_call [('inline in skipfiles: bind ...python3.10/inspect.py', 1)]
1971*da0073e9SAndroid Build Coastguard Worker    def test_primtorch_no_graph_break(self):
1972*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
1973*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1974*da0073e9SAndroid Build Coastguard Worker            torch._refs.abs(x)
1975*da0073e9SAndroid Build Coastguard Worker
1976*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(3))
1977*da0073e9SAndroid Build Coastguard Worker
1978*da0073e9SAndroid Build Coastguard Worker    def test_torch_tensor_ops_no_graph_break(self):
1979*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
1980*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1981*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.abs_(x)
1982*da0073e9SAndroid Build Coastguard Worker
1983*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(3))
1984*da0073e9SAndroid Build Coastguard Worker
1985*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1986*da0073e9SAndroid Build Coastguard Worker        not isinstance(torch.ops.aten.abs, torch._ops.OpOverloadPacket),
1987*da0073e9SAndroid Build Coastguard Worker        "old pt doesn't work",
1988*da0073e9SAndroid Build Coastguard Worker    )
1989*da0073e9SAndroid Build Coastguard Worker    def test_torch_ops_aten(self):
1990*da0073e9SAndroid Build Coastguard Worker        # Picked an op that doesn't show up in the default list
1991*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
1992*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1993*da0073e9SAndroid Build Coastguard Worker            return torch.ops.aten.absolute(x)
1994*da0073e9SAndroid Build Coastguard Worker
1995*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(3))
1996*da0073e9SAndroid Build Coastguard Worker
1997*da0073e9SAndroid Build Coastguard Worker    def test_hf_gelu_inline(self):
1998*da0073e9SAndroid Build Coastguard Worker        class GELUActivation(nn.Module):
1999*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2000*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2001*da0073e9SAndroid Build Coastguard Worker                self.act = nn.functional.gelu
2002*da0073e9SAndroid Build Coastguard Worker
2003*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
2004*da0073e9SAndroid Build Coastguard Worker                return self.act(input)
2005*da0073e9SAndroid Build Coastguard Worker
2006*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
2007*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2008*da0073e9SAndroid Build Coastguard Worker            return GELUActivation()(x)
2009*da0073e9SAndroid Build Coastguard Worker
2010*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
2011*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(y), nn.functional.gelu(y)))
2012*da0073e9SAndroid Build Coastguard Worker
2013*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
2014*da0073e9SAndroid Build Coastguard Worker        def fn_returns(x):
2015*da0073e9SAndroid Build Coastguard Worker            return GELUActivation(), x + 1
2016*da0073e9SAndroid Build Coastguard Worker
2017*da0073e9SAndroid Build Coastguard Worker        act, _ = fn_returns(y)
2018*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(act, GELUActivation)
2019*da0073e9SAndroid Build Coastguard Worker        self.assertIs(act.act, nn.functional.gelu)
2020*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(act, "_buffers"))  # check that __init__ got called
2021*da0073e9SAndroid Build Coastguard Worker
2022*da0073e9SAndroid Build Coastguard Worker    def test_dropout_inline(self):
2023*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager")
2024*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2025*da0073e9SAndroid Build Coastguard Worker            return torch.nn.Dropout(0.1)(x)
2026*da0073e9SAndroid Build Coastguard Worker
2027*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
2028*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1337)
2029*da0073e9SAndroid Build Coastguard Worker        ref = nn.functional.dropout(y, 0.1)
2030*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1337)
2031*da0073e9SAndroid Build Coastguard Worker        res = fn(y)
2032*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
2033*da0073e9SAndroid Build Coastguard Worker
2034*da0073e9SAndroid Build Coastguard Worker    def test_setitem_boolean_mask_diff(self):
2035*da0073e9SAndroid Build Coastguard Worker        def fn(x, b, y):
2036*da0073e9SAndroid Build Coastguard Worker            x = x.clone()
2037*da0073e9SAndroid Build Coastguard Worker            x[b] = y
2038*da0073e9SAndroid Build Coastguard Worker            return x
2039*da0073e9SAndroid Build Coastguard Worker
2040*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
2041*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, requires_grad=True)
2042*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([True, False, True, False])
2043*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, requires_grad=True)
2044*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, b, y)
2045*da0073e9SAndroid Build Coastguard Worker
2046*da0073e9SAndroid Build Coastguard Worker    def test_setitem_tuple_boolean_mask_diff(self):
2047*da0073e9SAndroid Build Coastguard Worker        def fn(x, b, y):
2048*da0073e9SAndroid Build Coastguard Worker            x = x.clone()
2049*da0073e9SAndroid Build Coastguard Worker            x[:, b] = y
2050*da0073e9SAndroid Build Coastguard Worker            return x
2051*da0073e9SAndroid Build Coastguard Worker
2052*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
2053*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(8, 4, requires_grad=True)
2054*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([True, False, True, False])
2055*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, requires_grad=True)
2056*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, b, y)
2057*da0073e9SAndroid Build Coastguard Worker
2058*da0073e9SAndroid Build Coastguard Worker    def test_torch_tensor_ops(self):
2059*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2060*da0073e9SAndroid Build Coastguard Worker            return torch.Tensor.abs_(x)
2061*da0073e9SAndroid Build Coastguard Worker
2062*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
2063*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
2064*da0073e9SAndroid Build Coastguard Worker        y = fn(x)
2065*da0073e9SAndroid Build Coastguard Worker        y_ = opt_fn(x)
2066*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(y, y_))
2067*da0073e9SAndroid Build Coastguard Worker
2068*da0073e9SAndroid Build Coastguard Worker    def test_guard_ordering_shape_fail(self):
2069*da0073e9SAndroid Build Coastguard Worker        # If a function which takes a tensor has an inner function which
2070*da0073e9SAndroid Build Coastguard Worker        # is compiled and generates a guard on its shape,
2071*da0073e9SAndroid Build Coastguard Worker        # they are evaluated in the wrong order. So if on a subsequent call
2072*da0073e9SAndroid Build Coastguard Worker        # an int is passed instead of a tensor, guard evaluation will crash
2073*da0073e9SAndroid Build Coastguard Worker        # with a "no attribute: shape" error
2074*da0073e9SAndroid Build Coastguard Worker        m = MockModule()
2075*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize("eager")(m)
2076*da0073e9SAndroid Build Coastguard Worker        opt_m.fn(torch.ones((5, 5)))
2077*da0073e9SAndroid Build Coastguard Worker        opt_m.fn(-3)
2078*da0073e9SAndroid Build Coastguard Worker
2079*da0073e9SAndroid Build Coastguard Worker    def test_tensor_isinstance_tuple(self):
2080*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager")
2081*da0073e9SAndroid Build Coastguard Worker        def fn():
2082*da0073e9SAndroid Build Coastguard Worker            t = torch.ones(5, 5)
2083*da0073e9SAndroid Build Coastguard Worker            if not isinstance(t, (int, torch.Tensor)):
2084*da0073e9SAndroid Build Coastguard Worker                msg = str.format(
2085*da0073e9SAndroid Build Coastguard Worker                    "{0} is not an instance of {1}",
2086*da0073e9SAndroid Build Coastguard Worker                    type(t),
2087*da0073e9SAndroid Build Coastguard Worker                    (int, torch.Tensor),
2088*da0073e9SAndroid Build Coastguard Worker                )
2089*da0073e9SAndroid Build Coastguard Worker                raise ValueError(msg)
2090*da0073e9SAndroid Build Coastguard Worker            return True
2091*da0073e9SAndroid Build Coastguard Worker
2092*da0073e9SAndroid Build Coastguard Worker        fn()
2093*da0073e9SAndroid Build Coastguard Worker
2094*da0073e9SAndroid Build Coastguard Worker    def test_isinstance_dtype(self):
2095*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
2096*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2097*da0073e9SAndroid Build Coastguard Worker            isinstance(torch.bfloat16, torch.dtype)
2098*da0073e9SAndroid Build Coastguard Worker            return x
2099*da0073e9SAndroid Build Coastguard Worker
2100*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(3))
2101*da0073e9SAndroid Build Coastguard Worker
2102*da0073e9SAndroid Build Coastguard Worker    def test_isinstance_storage(self):
2103*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager")
2104*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2105*da0073e9SAndroid Build Coastguard Worker            f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40])
2106*da0073e9SAndroid Build Coastguard Worker            bools = torch.BoolStorage.from_buffer(f, "big")
2107*da0073e9SAndroid Build Coastguard Worker            assert isinstance(bools, torch.BoolStorage)
2108*da0073e9SAndroid Build Coastguard Worker            return x
2109*da0073e9SAndroid Build Coastguard Worker
2110*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(3))
2111*da0073e9SAndroid Build Coastguard Worker
2112*da0073e9SAndroid Build Coastguard Worker    def test_issue111522(self):
2113*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
2114*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
2115*da0073e9SAndroid Build Coastguard Worker            return x + y.a
2116*da0073e9SAndroid Build Coastguard Worker
2117*da0073e9SAndroid Build Coastguard Worker        class A:
2118*da0073e9SAndroid Build Coastguard Worker            a = 2
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(torch.zeros(2), A()), torch.full([2], 2.0))
2121*da0073e9SAndroid Build Coastguard Worker
2122*da0073e9SAndroid Build Coastguard Worker        del A.a
2123*da0073e9SAndroid Build Coastguard Worker
2124*da0073e9SAndroid Build Coastguard Worker        # graph break on missing attr
2125*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(torch._dynamo.exc.Unsupported):
2126*da0073e9SAndroid Build Coastguard Worker            f(torch.zeros(2), A())
2127*da0073e9SAndroid Build Coastguard Worker
2128*da0073e9SAndroid Build Coastguard Worker    def test_dict_list_values(self):
2129*da0073e9SAndroid Build Coastguard Worker        def inner_fn(args):
2130*da0073e9SAndroid Build Coastguard Worker            return [x[1].shape for x in args]
2131*da0073e9SAndroid Build Coastguard Worker
2132*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager")
2133*da0073e9SAndroid Build Coastguard Worker        def fn(tensors):
2134*da0073e9SAndroid Build Coastguard Worker            return inner_fn(zip(itertools.count(), tensors["args"]))
2135*da0073e9SAndroid Build Coastguard Worker
2136*da0073e9SAndroid Build Coastguard Worker        fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]})
2137*da0073e9SAndroid Build Coastguard Worker        fn({"args": [torch.ones(5, 5)]})
2138*da0073e9SAndroid Build Coastguard Worker
2139*da0073e9SAndroid Build Coastguard Worker    def test_dict_iter(self):
2140*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
2141*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2142*da0073e9SAndroid Build Coastguard Worker                z = {"my": 1, "const": 2, "dict": 3, "variable": 4}
2143*da0073e9SAndroid Build Coastguard Worker                tot = 0
2144*da0073e9SAndroid Build Coastguard Worker                for key in z:
2145*da0073e9SAndroid Build Coastguard Worker                    tot += z[key]
2146*da0073e9SAndroid Build Coastguard Worker
2147*da0073e9SAndroid Build Coastguard Worker                return tot
2148*da0073e9SAndroid Build Coastguard Worker
2149*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0])
2150*da0073e9SAndroid Build Coastguard Worker        model = MyMod()
2151*da0073e9SAndroid Build Coastguard Worker        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
2152*da0073e9SAndroid Build Coastguard Worker        y = opt_model(x)
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, 10)
2155*da0073e9SAndroid Build Coastguard Worker
2156*da0073e9SAndroid Build Coastguard Worker    def test_sort_out(self):
2157*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
2158*da0073e9SAndroid Build Coastguard Worker        device = "cpu"
2159*da0073e9SAndroid Build Coastguard Worker
2160*da0073e9SAndroid Build Coastguard Worker        def fn():
2161*da0073e9SAndroid Build Coastguard Worker            tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0]
2162*da0073e9SAndroid Build Coastguard Worker            values1 = torch.tensor(0, dtype=dtype, device=device)
2163*da0073e9SAndroid Build Coastguard Worker            indices1 = torch.tensor(0, dtype=torch.long, device=device)
2164*da0073e9SAndroid Build Coastguard Worker            torch.sort(tensor, out=(values1, indices1))
2165*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values1.stride(), (1,))
2166*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(indices1.stride(), (1,))
2167*da0073e9SAndroid Build Coastguard Worker
2168*da0073e9SAndroid Build Coastguard Worker        fn()
2169*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
2170*da0073e9SAndroid Build Coastguard Worker        opt_fn()
2171*da0073e9SAndroid Build Coastguard Worker
2172*da0073e9SAndroid Build Coastguard Worker    def test_sort_out2(self):
2173*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2174*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2175*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2176*da0073e9SAndroid Build Coastguard Worker                self.sorted = torch.nn.Buffer(torch.ones(4, 4))
2177*da0073e9SAndroid Build Coastguard Worker                self.indices = torch.nn.Buffer(torch.ones(4, 4, dtype=torch.long))
2178*da0073e9SAndroid Build Coastguard Worker
2179*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2180*da0073e9SAndroid Build Coastguard Worker                torch.sort(x, out=(self.sorted, self.indices))
2181*da0073e9SAndroid Build Coastguard Worker                return (x + 1, self.sorted, self.indices)
2182*da0073e9SAndroid Build Coastguard Worker
2183*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4)
2184*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
2185*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
2186*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize("eager")(m)
2187*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
2188*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
2189*da0073e9SAndroid Build Coastguard Worker
2190*da0073e9SAndroid Build Coastguard Worker    def test_sigmoid_out(self):
2191*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
2192*da0073e9SAndroid Build Coastguard Worker        device = "cpu"
2193*da0073e9SAndroid Build Coastguard Worker
2194*da0073e9SAndroid Build Coastguard Worker        def fn():
2195*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn((3, 5), dtype=dtype, device=device)
2196*da0073e9SAndroid Build Coastguard Worker            out1 = torch.tensor(0, dtype=dtype, device=device)
2197*da0073e9SAndroid Build Coastguard Worker            torch.sigmoid(inp, out=out1)
2198*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out1.numel(), 15)
2199*da0073e9SAndroid Build Coastguard Worker
2200*da0073e9SAndroid Build Coastguard Worker        fn()
2201*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
2202*da0073e9SAndroid Build Coastguard Worker        opt_fn()
2203*da0073e9SAndroid Build Coastguard Worker
2204*da0073e9SAndroid Build Coastguard Worker    def test_sigmoid_out2(self):
2205*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
2206*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2207*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2208*da0073e9SAndroid Build Coastguard Worker                self.base = torch.nn.Buffer(torch.ones(4, 4))
2209*da0073e9SAndroid Build Coastguard Worker
2210*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2211*da0073e9SAndroid Build Coastguard Worker                torch.sigmoid(x, out=self.base)
2212*da0073e9SAndroid Build Coastguard Worker                return x + self.base
2213*da0073e9SAndroid Build Coastguard Worker
2214*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4)
2215*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
2216*da0073e9SAndroid Build Coastguard Worker        ref = m(x)
2217*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize("eager")(m)
2218*da0073e9SAndroid Build Coastguard Worker        res = opt_m(x)
2219*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
2220*da0073e9SAndroid Build Coastguard Worker
2221*da0073e9SAndroid Build Coastguard Worker    def test_slice_into_list_mutable(self):
2222*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
2223*da0073e9SAndroid Build Coastguard Worker            def forward(self, listy):
2224*da0073e9SAndroid Build Coastguard Worker                x = listy[3:5]
2225*da0073e9SAndroid Build Coastguard Worker                for i in range(10):
2226*da0073e9SAndroid Build Coastguard Worker                    z = torch.abs(torch.randn(10)) + 1
2227*da0073e9SAndroid Build Coastguard Worker                    x[0] = z
2228*da0073e9SAndroid Build Coastguard Worker                return x
2229*da0073e9SAndroid Build Coastguard Worker
2230*da0073e9SAndroid Build Coastguard Worker        m = Mod()
2231*da0073e9SAndroid Build Coastguard Worker        listy = [torch.randn(10)] * 10
2232*da0073e9SAndroid Build Coastguard Worker
2233*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2234*da0073e9SAndroid Build Coastguard Worker        opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
2235*da0073e9SAndroid Build Coastguard Worker        opt_m.forward(listy)
2236*da0073e9SAndroid Build Coastguard Worker
2237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2238*da0073e9SAndroid Build Coastguard Worker
2239*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
2240*da0073e9SAndroid Build Coastguard Worker    def test_issue111918(self):
2241*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
2242*da0073e9SAndroid Build Coastguard Worker
2243*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, dynamic=True)
2244*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2245*da0073e9SAndroid Build Coastguard Worker            x = x + 1
2246*da0073e9SAndroid Build Coastguard Worker            y = x.item()
2247*da0073e9SAndroid Build Coastguard Worker            if y > 2:
2248*da0073e9SAndroid Build Coastguard Worker                return x * 2
2249*da0073e9SAndroid Build Coastguard Worker            else:
2250*da0073e9SAndroid Build Coastguard Worker                return x * 3
2251*da0073e9SAndroid Build Coastguard Worker
2252*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([3.0])
2253*da0073e9SAndroid Build Coastguard Worker        fn(x)
2254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
2255*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 4)
2256*da0073e9SAndroid Build Coastguard Worker
2257*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
2258*da0073e9SAndroid Build Coastguard Worker        fn = torch.compile(fn, fullgraph=True, backend="eager")
2259*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(torch._dynamo.exc.UserError):
2260*da0073e9SAndroid Build Coastguard Worker            fn(x)
2261*da0073e9SAndroid Build Coastguard Worker
2262*da0073e9SAndroid Build Coastguard Worker    def test_vdd_duplicate_error(self):
2263*da0073e9SAndroid Build Coastguard Worker        def fn(a, dt):
2264*da0073e9SAndroid Build Coastguard Worker            keys = list(dt._jt_dict.keys())
2265*da0073e9SAndroid Build Coastguard Worker            p = torch.cos(dt._jt_dict[keys[0]]._value)
2266*da0073e9SAndroid Build Coastguard Worker            q = torch.sin(a)
2267*da0073e9SAndroid Build Coastguard Worker            r = torch.sigmoid(dt._jt_dict[keys[0]]._value)
2268*da0073e9SAndroid Build Coastguard Worker            return p + q + r
2269*da0073e9SAndroid Build Coastguard Worker
2270*da0073e9SAndroid Build Coastguard Worker        class Value:
2271*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2272*da0073e9SAndroid Build Coastguard Worker                self._value = torch.randn(4)
2273*da0073e9SAndroid Build Coastguard Worker
2274*da0073e9SAndroid Build Coastguard Worker        class Sample:
2275*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2276*da0073e9SAndroid Build Coastguard Worker                self._jt_dict = {}
2277*da0073e9SAndroid Build Coastguard Worker                self._jt_dict["POSITION_ID"] = Value()
2278*da0073e9SAndroid Build Coastguard Worker
2279*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4)
2280*da0073e9SAndroid Build Coastguard Worker        sample = Sample()
2281*da0073e9SAndroid Build Coastguard Worker
2282*da0073e9SAndroid Build Coastguard Worker        ref = fn(a, sample)
2283*da0073e9SAndroid Build Coastguard Worker
2284*da0073e9SAndroid Build Coastguard Worker        optimized_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
2285*da0073e9SAndroid Build Coastguard Worker        res = optimized_fn(a, sample)
2286*da0073e9SAndroid Build Coastguard Worker
2287*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
2288*da0073e9SAndroid Build Coastguard Worker
2289*da0073e9SAndroid Build Coastguard Worker    def test_specialized_stride(self):
2290*da0073e9SAndroid Build Coastguard Worker        def f():
2291*da0073e9SAndroid Build Coastguard Worker            e = torch.empty(4)
2292*da0073e9SAndroid Build Coastguard Worker            x = e[::2]
2293*da0073e9SAndroid Build Coastguard Worker            return x.stride()
2294*da0073e9SAndroid Build Coastguard Worker
2295*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(), torch._dynamo.optimize("eager")(f)())
2296*da0073e9SAndroid Build Coastguard Worker
2297*da0073e9SAndroid Build Coastguard Worker    def test_out_none(self):
2298*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/92814
2299*da0073e9SAndroid Build Coastguard Worker        def fn(input):
2300*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.normalize(input, dim=0, out=None)
2301*da0073e9SAndroid Build Coastguard Worker
2302*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([1])
2303*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), torch._dynamo.optimize("eager")(fn)(x))
2304*da0073e9SAndroid Build Coastguard Worker
2305*da0073e9SAndroid Build Coastguard Worker    def test_multi_import(self):
2306*da0073e9SAndroid Build Coastguard Worker        if not has_detectron2():
2307*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("requires detectron2")
2308*da0073e9SAndroid Build Coastguard Worker
2309*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager", nopython=True)
2310*da0073e9SAndroid Build Coastguard Worker        def to_bitmasks(boxes):
2311*da0073e9SAndroid Build Coastguard Worker            from detectron2.layers.mask_ops import (
2312*da0073e9SAndroid Build Coastguard Worker                _paste_masks_tensor_shape,
2313*da0073e9SAndroid Build Coastguard Worker                paste_masks_in_image,
2314*da0073e9SAndroid Build Coastguard Worker            )
2315*da0073e9SAndroid Build Coastguard Worker
2316*da0073e9SAndroid Build Coastguard Worker            if (
2317*da0073e9SAndroid Build Coastguard Worker                paste_masks_in_image is not None
2318*da0073e9SAndroid Build Coastguard Worker                and _paste_masks_tensor_shape is not None
2319*da0073e9SAndroid Build Coastguard Worker            ):
2320*da0073e9SAndroid Build Coastguard Worker                return boxes + 1
2321*da0073e9SAndroid Build Coastguard Worker
2322*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((to_bitmasks(torch.zeros(10)) == torch.ones(10)).all())
2323*da0073e9SAndroid Build Coastguard Worker
2324*da0073e9SAndroid Build Coastguard Worker    def test_multi_dot_import(self):
2325*da0073e9SAndroid Build Coastguard Worker        def fn1(x):
2326*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
2327*da0073e9SAndroid Build Coastguard Worker
2328*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2329*da0073e9SAndroid Build Coastguard Worker            import torch.fx
2330*da0073e9SAndroid Build Coastguard Worker
2331*da0073e9SAndroid Build Coastguard Worker            _ = torch.fx.symbolic_trace(fn1)
2332*da0073e9SAndroid Build Coastguard Worker            return x * 2
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
2335*da0073e9SAndroid Build Coastguard Worker        fn(x)
2336*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2337*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
2338*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
2339*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2340*da0073e9SAndroid Build Coastguard Worker
2341*da0073e9SAndroid Build Coastguard Worker    def test_relative_import(self):
2342*da0073e9SAndroid Build Coastguard Worker        try:
2343*da0073e9SAndroid Build Coastguard Worker            from . import utils as _  # noqa: F401
2344*da0073e9SAndroid Build Coastguard Worker
2345*da0073e9SAndroid Build Coastguard Worker            def fn(x):
2346*da0073e9SAndroid Build Coastguard Worker                from .utils import tensor_for_import_testing
2347*da0073e9SAndroid Build Coastguard Worker
2348*da0073e9SAndroid Build Coastguard Worker                return x * 2 * tensor_for_import_testing
2349*da0073e9SAndroid Build Coastguard Worker
2350*da0073e9SAndroid Build Coastguard Worker        except ImportError:
2351*da0073e9SAndroid Build Coastguard Worker
2352*da0073e9SAndroid Build Coastguard Worker            def fn(x):
2353*da0073e9SAndroid Build Coastguard Worker                from utils import tensor_for_import_testing
2354*da0073e9SAndroid Build Coastguard Worker
2355*da0073e9SAndroid Build Coastguard Worker                return x * 2 * tensor_for_import_testing
2356*da0073e9SAndroid Build Coastguard Worker
2357*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
2358*da0073e9SAndroid Build Coastguard Worker        fn(x)
2359*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2360*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
2361*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
2362*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2363*da0073e9SAndroid Build Coastguard Worker
2364*da0073e9SAndroid Build Coastguard Worker    def test_relative_import_no_modulename(self):
2365*da0073e9SAndroid Build Coastguard Worker        try:
2366*da0073e9SAndroid Build Coastguard Worker            from . import utils as _  # noqa: F401
2367*da0073e9SAndroid Build Coastguard Worker
2368*da0073e9SAndroid Build Coastguard Worker            def fn(x):
2369*da0073e9SAndroid Build Coastguard Worker                from . import utils
2370*da0073e9SAndroid Build Coastguard Worker
2371*da0073e9SAndroid Build Coastguard Worker                return x * 2 * utils.tensor_for_import_testing
2372*da0073e9SAndroid Build Coastguard Worker
2373*da0073e9SAndroid Build Coastguard Worker        except ImportError:
2374*da0073e9SAndroid Build Coastguard Worker
2375*da0073e9SAndroid Build Coastguard Worker            def fn(x):
2376*da0073e9SAndroid Build Coastguard Worker                import utils
2377*da0073e9SAndroid Build Coastguard Worker
2378*da0073e9SAndroid Build Coastguard Worker                return x * 2 * utils.tensor_for_import_testing
2379*da0073e9SAndroid Build Coastguard Worker
2380*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
2381*da0073e9SAndroid Build Coastguard Worker        fn(x)
2382*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2383*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
2384*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
2385*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2386*da0073e9SAndroid Build Coastguard Worker
2387*da0073e9SAndroid Build Coastguard Worker    def test_bigbird_unsqueeze_inplace(self):
2388*da0073e9SAndroid Build Coastguard Worker        def fn(reshape_2):
2389*da0073e9SAndroid Build Coastguard Worker            view_2 = reshape_2.clone()
2390*da0073e9SAndroid Build Coastguard Worker            view_2.unsqueeze_(2)
2391*da0073e9SAndroid Build Coastguard Worker            cat_11 = torch.cat([view_2], dim=2)
2392*da0073e9SAndroid Build Coastguard Worker            view_13 = cat_11.view((2, 12, 64, -1))
2393*da0073e9SAndroid Build Coastguard Worker            return (view_13,)
2394*da0073e9SAndroid Build Coastguard Worker
2395*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 12, 64, 64, requires_grad=True)
2396*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
2397*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
2398*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
2399*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
2400*da0073e9SAndroid Build Coastguard Worker
2401*da0073e9SAndroid Build Coastguard Worker    def test_issue1466_size_aot_autograd(self):
2402*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2403*da0073e9SAndroid Build Coastguard Worker            # do a tensor op and a size compute
2404*da0073e9SAndroid Build Coastguard Worker            y = x * 2
2405*da0073e9SAndroid Build Coastguard Worker            x_size = x.size()
2406*da0073e9SAndroid Build Coastguard Worker            # trigger a graph break
2407*da0073e9SAndroid Build Coastguard Worker            print("arf")
2408*da0073e9SAndroid Build Coastguard Worker            # use the tensor op and size compute
2409*da0073e9SAndroid Build Coastguard Worker            z = y.view(x_size) + 1
2410*da0073e9SAndroid Build Coastguard Worker            return z
2411*da0073e9SAndroid Build Coastguard Worker
2412*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, requires_grad=True)
2413*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
2414*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
2415*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
2416*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
2417*da0073e9SAndroid Build Coastguard Worker
2418*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis(self):
2419*da0073e9SAndroid Build Coastguard Worker        class Repro(torch.nn.Module):
2420*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2421*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2422*da0073e9SAndroid Build Coastguard Worker                self.lnorm = torch.nn.LayerNorm(
2423*da0073e9SAndroid Build Coastguard Worker                    (256,), eps=1e-06, elementwise_affine=True
2424*da0073e9SAndroid Build Coastguard Worker                )
2425*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(
2426*da0073e9SAndroid Build Coastguard Worker                    in_features=256, out_features=256, bias=True
2427*da0073e9SAndroid Build Coastguard Worker                )
2428*da0073e9SAndroid Build Coastguard Worker
2429*da0073e9SAndroid Build Coastguard Worker            def forward(self, cat_10):
2430*da0073e9SAndroid Build Coastguard Worker                lnorm = self.lnorm(cat_10)
2431*da0073e9SAndroid Build Coastguard Worker                getitem_64 = lnorm[
2432*da0073e9SAndroid Build Coastguard Worker                    (slice(None, None, None), slice(0, 1, None), Ellipsis)
2433*da0073e9SAndroid Build Coastguard Worker                ]
2434*da0073e9SAndroid Build Coastguard Worker                linear = self.linear(getitem_64)
2435*da0073e9SAndroid Build Coastguard Worker                return (linear,)
2436*da0073e9SAndroid Build Coastguard Worker
2437*da0073e9SAndroid Build Coastguard Worker        args = [torch.randn(2, 197, 256)]
2438*da0073e9SAndroid Build Coastguard Worker
2439*da0073e9SAndroid Build Coastguard Worker        mod = Repro()
2440*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
2441*da0073e9SAndroid Build Coastguard Worker
2442*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(mod(*args), opt_mod(*args)))
2443*da0073e9SAndroid Build Coastguard Worker
2444*da0073e9SAndroid Build Coastguard Worker    def test_reinplacing(self):
2445*da0073e9SAndroid Build Coastguard Worker        class MockModule(torch.nn.Module):
2446*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2447*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2448*da0073e9SAndroid Build Coastguard Worker                self.self_layoutlm_embeddings_x_position_embeddings = (
2449*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Embedding(1024, 768)
2450*da0073e9SAndroid Build Coastguard Worker                )
2451*da0073e9SAndroid Build Coastguard Worker                self.self_layoutlm_embeddings_y_position_embeddings = (
2452*da0073e9SAndroid Build Coastguard Worker                    torch.nn.Embedding(1024, 768)
2453*da0073e9SAndroid Build Coastguard Worker                )
2454*da0073e9SAndroid Build Coastguard Worker
2455*da0073e9SAndroid Build Coastguard Worker            def forward(self, getitem_1, getitem_2, add):
2456*da0073e9SAndroid Build Coastguard Worker                self_layoutlm_embeddings_x_position_embeddings = (
2457*da0073e9SAndroid Build Coastguard Worker                    self.self_layoutlm_embeddings_x_position_embeddings(getitem_1)
2458*da0073e9SAndroid Build Coastguard Worker                )
2459*da0073e9SAndroid Build Coastguard Worker                self_layoutlm_embeddings_y_position_embeddings = (
2460*da0073e9SAndroid Build Coastguard Worker                    self.self_layoutlm_embeddings_y_position_embeddings(getitem_2)
2461*da0073e9SAndroid Build Coastguard Worker                )
2462*da0073e9SAndroid Build Coastguard Worker                add_1 = add + self_layoutlm_embeddings_x_position_embeddings
2463*da0073e9SAndroid Build Coastguard Worker                add_2 = add_1 + self_layoutlm_embeddings_y_position_embeddings
2464*da0073e9SAndroid Build Coastguard Worker                return (add_2,)
2465*da0073e9SAndroid Build Coastguard Worker
2466*da0073e9SAndroid Build Coastguard Worker        mod = MockModule()
2467*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("aot_eager_decomp_partition")(mod)
2468*da0073e9SAndroid Build Coastguard Worker
2469*da0073e9SAndroid Build Coastguard Worker        args = [
2470*da0073e9SAndroid Build Coastguard Worker            ((2, 512), (2048, 4), torch.int64, "cpu", False),
2471*da0073e9SAndroid Build Coastguard Worker            ((2, 512), (2048, 4), torch.int64, "cpu", False),
2472*da0073e9SAndroid Build Coastguard Worker            ((2, 512, 768), (393216, 768, 1), torch.float32, "cpu", True),
2473*da0073e9SAndroid Build Coastguard Worker        ]
2474*da0073e9SAndroid Build Coastguard Worker        args = [
2475*da0073e9SAndroid Build Coastguard Worker            rand_strided(sh, st, dt, dev).requires_grad_(rg)
2476*da0073e9SAndroid Build Coastguard Worker            for (sh, st, dt, dev, rg) in args
2477*da0073e9SAndroid Build Coastguard Worker        ]
2478*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same_two_models(mod, opt_mod, args))
2479*da0073e9SAndroid Build Coastguard Worker
2480*da0073e9SAndroid Build Coastguard Worker    def test_optimized_deepcopy(self):
2481*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/pull/88629
2482*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
2483*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2484*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2485*da0073e9SAndroid Build Coastguard Worker                self.fc = torch.nn.Linear(in_features=2, out_features=3, bias=True)
2486*da0073e9SAndroid Build Coastguard Worker
2487*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2488*da0073e9SAndroid Build Coastguard Worker                return self.fc(x)
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker        mod = Foo()
2491*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("eager")(mod)
2492*da0073e9SAndroid Build Coastguard Worker        args = [torch.randn(1, 2)]
2493*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same_two_models(mod, opt_mod, args))
2494*da0073e9SAndroid Build Coastguard Worker
2495*da0073e9SAndroid Build Coastguard Worker    def test_class_member(self):
2496*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
2497*da0073e9SAndroid Build Coastguard Worker            a = 4
2498*da0073e9SAndroid Build Coastguard Worker            b = torch.ones(3, 4)
2499*da0073e9SAndroid Build Coastguard Worker
2500*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2501*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2502*da0073e9SAndroid Build Coastguard Worker                self.c = 4
2503*da0073e9SAndroid Build Coastguard Worker
2504*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2505*da0073e9SAndroid Build Coastguard Worker                return x.cos() + self.a + self.b + self.c
2506*da0073e9SAndroid Build Coastguard Worker
2507*da0073e9SAndroid Build Coastguard Worker        mod = Foo()
2508*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
2509*da0073e9SAndroid Build Coastguard Worker        args = (torch.randn(3, 4),)
2510*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(mod(*args), opt_mod(*args)))
2511*da0073e9SAndroid Build Coastguard Worker
2512*da0073e9SAndroid Build Coastguard Worker    def test_named_buffers(self):
2513*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
2514*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2515*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2516*da0073e9SAndroid Build Coastguard Worker                self.x = torch.nn.Buffer(torch.ones(3))
2517*da0073e9SAndroid Build Coastguard Worker                self.y = torch.nn.Buffer(torch.ones(3))
2518*da0073e9SAndroid Build Coastguard Worker
2519*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
2520*da0073e9SAndroid Build Coastguard Worker                res = 0
2521*da0073e9SAndroid Build Coastguard Worker                for name, buffer in self.named_buffers():
2522*da0073e9SAndroid Build Coastguard Worker                    res += buffer.sum()
2523*da0073e9SAndroid Build Coastguard Worker
2524*da0073e9SAndroid Build Coastguard Worker                return inp.cos() + res
2525*da0073e9SAndroid Build Coastguard Worker
2526*da0073e9SAndroid Build Coastguard Worker        mod = Foo()
2527*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
2528*da0073e9SAndroid Build Coastguard Worker        args = (torch.randn(3, 4),)
2529*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(mod(*args), opt_mod(*args)))
2530*da0073e9SAndroid Build Coastguard Worker
2531*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_guards_with_grad_mode1(self):
2532*da0073e9SAndroid Build Coastguard Worker        def f(x):
2533*da0073e9SAndroid Build Coastguard Worker            if x.requires_grad:
2534*da0073e9SAndroid Build Coastguard Worker                return x + 1
2535*da0073e9SAndroid Build Coastguard Worker            else:
2536*da0073e9SAndroid Build Coastguard Worker                return x + 2
2537*da0073e9SAndroid Build Coastguard Worker
2538*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, requires_grad=True)
2539*da0073e9SAndroid Build Coastguard Worker
2540*da0073e9SAndroid Build Coastguard Worker        f_compiled = torch.compile(f)
2541*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2542*da0073e9SAndroid Build Coastguard Worker            # compile an inference graph
2543*da0073e9SAndroid Build Coastguard Worker            f_compiled(x)
2544*da0073e9SAndroid Build Coastguard Worker
2545*da0073e9SAndroid Build Coastguard Worker        # Test: we should fail guards and recompile (even though it's still an inference graph)
2546*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x.detach())
2547*da0073e9SAndroid Build Coastguard Worker        out = f_compiled(x.detach())
2548*da0073e9SAndroid Build Coastguard Worker
2549*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out)
2550*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref.requires_grad, out.requires_grad)
2551*da0073e9SAndroid Build Coastguard Worker
2552*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_guards_with_grad_mode2(self):
2553*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, requires_grad=True)
2554*da0073e9SAndroid Build Coastguard Worker        x_ref = x.clone().detach().requires_grad_(True)
2555*da0073e9SAndroid Build Coastguard Worker
2556*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.Linear(2, 2)
2557*da0073e9SAndroid Build Coastguard Worker        m_compiled = torch.compile(m)
2558*da0073e9SAndroid Build Coastguard Worker
2559*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2560*da0073e9SAndroid Build Coastguard Worker            # compile an inference graph
2561*da0073e9SAndroid Build Coastguard Worker            m_compiled(x)
2562*da0073e9SAndroid Build Coastguard Worker
2563*da0073e9SAndroid Build Coastguard Worker        # Test: we should fail guards and recompile a training graph
2564*da0073e9SAndroid Build Coastguard Worker        out_ref = m(x_ref)
2565*da0073e9SAndroid Build Coastguard Worker        out = m_compiled(x)
2566*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out)
2567*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref.requires_grad, out.requires_grad)
2568*da0073e9SAndroid Build Coastguard Worker
2569*da0073e9SAndroid Build Coastguard Worker    def test_is_symbolic_tracing(self):
2570*da0073e9SAndroid Build Coastguard Worker        # Ensure no graph break here
2571*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2572*da0073e9SAndroid Build Coastguard Worker            if is_fx_tracing_test():
2573*da0073e9SAndroid Build Coastguard Worker                return x * 2
2574*da0073e9SAndroid Build Coastguard Worker            return x * 4
2575*da0073e9SAndroid Build Coastguard Worker
2576*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4)
2577*da0073e9SAndroid Build Coastguard Worker        ref = fn(a)
2578*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
2579*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(a)
2580*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
2581*da0073e9SAndroid Build Coastguard Worker
2582*da0073e9SAndroid Build Coastguard Worker    def test_tokenization(self):
2583*da0073e9SAndroid Build Coastguard Worker        from collections import UserDict
2584*da0073e9SAndroid Build Coastguard Worker
2585*da0073e9SAndroid Build Coastguard Worker        class BatchEncoding(UserDict):
2586*da0073e9SAndroid Build Coastguard Worker            """
2587*da0073e9SAndroid Build Coastguard Worker            Copied from tokenization
2588*da0073e9SAndroid Build Coastguard Worker            """
2589*da0073e9SAndroid Build Coastguard Worker
2590*da0073e9SAndroid Build Coastguard Worker            def __init__(
2591*da0073e9SAndroid Build Coastguard Worker                self,
2592*da0073e9SAndroid Build Coastguard Worker                data,
2593*da0073e9SAndroid Build Coastguard Worker            ):
2594*da0073e9SAndroid Build Coastguard Worker                super().__init__(data)
2595*da0073e9SAndroid Build Coastguard Worker
2596*da0073e9SAndroid Build Coastguard Worker            def __getattr__(self, item: str):
2597*da0073e9SAndroid Build Coastguard Worker                try:
2598*da0073e9SAndroid Build Coastguard Worker                    return self.data[item]
2599*da0073e9SAndroid Build Coastguard Worker                except KeyError as e:
2600*da0073e9SAndroid Build Coastguard Worker                    raise AttributeError from e
2601*da0073e9SAndroid Build Coastguard Worker
2602*da0073e9SAndroid Build Coastguard Worker        def tokenization(x):
2603*da0073e9SAndroid Build Coastguard Worker            encoding = BatchEncoding({"key": x})
2604*da0073e9SAndroid Build Coastguard Worker            return encoding["key"]
2605*da0073e9SAndroid Build Coastguard Worker
2606*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(tokenization)
2607*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1, 4))
2608*da0073e9SAndroid Build Coastguard Worker        ref = tokenization(x)
2609*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
2610*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
2611*da0073e9SAndroid Build Coastguard Worker
2612*da0073e9SAndroid Build Coastguard Worker    def test_modules(self):
2613*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
2614*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2615*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2616*da0073e9SAndroid Build Coastguard Worker                self.fc = torch.nn.Linear(4, 3)
2617*da0073e9SAndroid Build Coastguard Worker
2618*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
2619*da0073e9SAndroid Build Coastguard Worker                res = torch.zeros(3, 3)
2620*da0073e9SAndroid Build Coastguard Worker                for mod in self.modules():
2621*da0073e9SAndroid Build Coastguard Worker                    res += self.fc(inp)
2622*da0073e9SAndroid Build Coastguard Worker                return res
2623*da0073e9SAndroid Build Coastguard Worker
2624*da0073e9SAndroid Build Coastguard Worker        mod = Foo()
2625*da0073e9SAndroid Build Coastguard Worker        args = (torch.ones(3, 4),)
2626*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2627*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize(cnt, nopython=True)(mod)
2628*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(mod(*args), opt_mod(*args)))
2629*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 5)
2630*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2631*da0073e9SAndroid Build Coastguard Worker
2632*da0073e9SAndroid Build Coastguard Worker    def test_omegaconf_listconfig_iter(self):
2633*da0073e9SAndroid Build Coastguard Worker        obj = ListConfig()
2634*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(2)
2635*da0073e9SAndroid Build Coastguard Worker
2636*da0073e9SAndroid Build Coastguard Worker        def fn():
2637*da0073e9SAndroid Build Coastguard Worker            y = x
2638*da0073e9SAndroid Build Coastguard Worker            for i in obj:
2639*da0073e9SAndroid Build Coastguard Worker                y += i
2640*da0073e9SAndroid Build Coastguard Worker            return y
2641*da0073e9SAndroid Build Coastguard Worker
2642*da0073e9SAndroid Build Coastguard Worker        expected = fn()
2643*da0073e9SAndroid Build Coastguard Worker        actual = torch.compile(fn, fullgraph=True, backend="eager")()
2644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected)
2645*da0073e9SAndroid Build Coastguard Worker
2646*da0073e9SAndroid Build Coastguard Worker    def test_user_defined_iter(self):
2647*da0073e9SAndroid Build Coastguard Worker        class MyIter:
2648*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2649*da0073e9SAndroid Build Coastguard Worker                self.i = 0
2650*da0073e9SAndroid Build Coastguard Worker
2651*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
2652*da0073e9SAndroid Build Coastguard Worker                return self
2653*da0073e9SAndroid Build Coastguard Worker
2654*da0073e9SAndroid Build Coastguard Worker            def __next__(self):
2655*da0073e9SAndroid Build Coastguard Worker                if self.i < 3:
2656*da0073e9SAndroid Build Coastguard Worker                    self.i += 1
2657*da0073e9SAndroid Build Coastguard Worker                    return self.i
2658*da0073e9SAndroid Build Coastguard Worker                raise StopIteration
2659*da0073e9SAndroid Build Coastguard Worker
2660*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
2661*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2662*da0073e9SAndroid Build Coastguard Worker            for i in MyIter():
2663*da0073e9SAndroid Build Coastguard Worker                x += i
2664*da0073e9SAndroid Build Coastguard Worker            return x
2665*da0073e9SAndroid Build Coastguard Worker
2666*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(torch.zeros(1)), torch.full([1], 6.0))
2667*da0073e9SAndroid Build Coastguard Worker
2668*da0073e9SAndroid Build Coastguard Worker    def test_stop_iteration_reconstruct(self):
2669*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
2670*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2671*da0073e9SAndroid Build Coastguard Worker            return x.sin(), StopIteration(1, 2, 3)
2672*da0073e9SAndroid Build Coastguard Worker
2673*da0073e9SAndroid Build Coastguard Worker        _, res = fn(torch.ones(1))
2674*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(res), str(StopIteration(1, 2, 3)))
2675*da0073e9SAndroid Build Coastguard Worker
2676*da0073e9SAndroid Build Coastguard Worker    def test_tensor_data_kwarg(self):
2677*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/96278
2678*da0073e9SAndroid Build Coastguard Worker        def f():
2679*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(data=[[1.0, -1.0]])
2680*da0073e9SAndroid Build Coastguard Worker
2681*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2682*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f)
2683*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(f(), opt_fn()))
2684*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2685*da0073e9SAndroid Build Coastguard Worker
2686*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
2687*da0073e9SAndroid Build Coastguard Worker    def test_norm_dtype(self):
2688*da0073e9SAndroid Build Coastguard Worker        def foo(_stack0):
2689*da0073e9SAndroid Build Coastguard Worker            getitem = _stack0[(slice(None, None, None), -1)]
2690*da0073e9SAndroid Build Coastguard Worker            _stack0 = None
2691*da0073e9SAndroid Build Coastguard Worker            normalize = torch.nn.functional.normalize(getitem, p=2, dim=1)
2692*da0073e9SAndroid Build Coastguard Worker            getitem = None
2693*da0073e9SAndroid Build Coastguard Worker            return (normalize,)
2694*da0073e9SAndroid Build Coastguard Worker
2695*da0073e9SAndroid Build Coastguard Worker        args = [((2, 50, 256), (1, 256, 1), torch.float16, "cuda", False)]
2696*da0073e9SAndroid Build Coastguard Worker        args = [
2697*da0073e9SAndroid Build Coastguard Worker            rand_strided(sh, st, dt, dev).requires_grad_(rg)
2698*da0073e9SAndroid Build Coastguard Worker            for (sh, st, dt, dev, rg) in args
2699*da0073e9SAndroid Build Coastguard Worker        ]
2700*da0073e9SAndroid Build Coastguard Worker
2701*da0073e9SAndroid Build Coastguard Worker        opt_foo = torch._dynamo.optimize("aot_eager_decomp_partition")(foo)
2702*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.amp.autocast(enabled=True):
2703*da0073e9SAndroid Build Coastguard Worker            ref = foo(*args)[0]
2704*da0073e9SAndroid Build Coastguard Worker            res = foo(*args)[0]
2705*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref.dtype, res.dtype)
2706*da0073e9SAndroid Build Coastguard Worker
2707*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(res, ref))
2708*da0073e9SAndroid Build Coastguard Worker
2709*da0073e9SAndroid Build Coastguard Worker    def test_for_loop_graph_break(self):
2710*da0073e9SAndroid Build Coastguard Worker        def inner(x):
2711*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
2712*da0073e9SAndroid Build Coastguard Worker
2713*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2714*da0073e9SAndroid Build Coastguard Worker            for _ in range(100):
2715*da0073e9SAndroid Build Coastguard Worker                inner(x)
2716*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
2717*da0073e9SAndroid Build Coastguard Worker            return x
2718*da0073e9SAndroid Build Coastguard Worker
2719*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2720*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
2721*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
2722*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
2723*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2724*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 1)
2725*da0073e9SAndroid Build Coastguard Worker
2726*da0073e9SAndroid Build Coastguard Worker    def test_for_loop_graph_break_before(self):
2727*da0073e9SAndroid Build Coastguard Worker        # Checks that the backedge is calculated correctly
2728*da0073e9SAndroid Build Coastguard Worker        def inner(x):
2729*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
2730*da0073e9SAndroid Build Coastguard Worker
2731*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2732*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
2733*da0073e9SAndroid Build Coastguard Worker            for _ in range(100):
2734*da0073e9SAndroid Build Coastguard Worker                inner(x)
2735*da0073e9SAndroid Build Coastguard Worker            return x
2736*da0073e9SAndroid Build Coastguard Worker
2737*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2738*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
2739*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
2740*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
2741*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2742*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 100)
2743*da0073e9SAndroid Build Coastguard Worker
2744*da0073e9SAndroid Build Coastguard Worker    def test_avoid_dupe_specialization(self):
2745*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
2746*da0073e9SAndroid Build Coastguard Worker            return (x + y) * 1
2747*da0073e9SAndroid Build Coastguard Worker
2748*da0073e9SAndroid Build Coastguard Worker        opt_f = torch._dynamo.optimize("aot_eager")(f)
2749*da0073e9SAndroid Build Coastguard Worker
2750*da0073e9SAndroid Build Coastguard Worker        for b in [True, False]:
2751*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(4, requires_grad=b)
2752*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(4, requires_grad=b)
2753*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(x, x), opt_f(x, x))
2754*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(x, y), opt_f(x, y))
2755*da0073e9SAndroid Build Coastguard Worker
2756*da0073e9SAndroid Build Coastguard Worker    def test_validate_model_kwargs(self):
2757*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
2758*da0073e9SAndroid Build Coastguard Worker
2759*da0073e9SAndroid Build Coastguard Worker        def f1(a, b):
2760*da0073e9SAndroid Build Coastguard Worker            return torch.sin(a) + torch.cos(b)
2761*da0073e9SAndroid Build Coastguard Worker
2762*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
2763*da0073e9SAndroid Build Coastguard Worker        def f2(**kwargs):
2764*da0073e9SAndroid Build Coastguard Worker            _validate_model_kwargs(f1, kwargs)
2765*da0073e9SAndroid Build Coastguard Worker            return f1(**kwargs)
2766*da0073e9SAndroid Build Coastguard Worker
2767*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
2768*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
2769*da0073e9SAndroid Build Coastguard Worker
2770*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f2(a=x, b=y), f1(x, y))
2771*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2772*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 3)
2773*da0073e9SAndroid Build Coastguard Worker
2774*da0073e9SAndroid Build Coastguard Worker    def test_swin_base_tensor_attr(self):
2775*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
2776*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
2777*da0073e9SAndroid Build Coastguard Worker                super().__init__()
2778*da0073e9SAndroid Build Coastguard Worker                # NB: not a parameter or buffer
2779*da0073e9SAndroid Build Coastguard Worker                self.t = torch.randn(3)
2780*da0073e9SAndroid Build Coastguard Worker
2781*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2782*da0073e9SAndroid Build Coastguard Worker                return x + torch.cat((self.t, self.t))
2783*da0073e9SAndroid Build Coastguard Worker
2784*da0073e9SAndroid Build Coastguard Worker        mod = Foo()
2785*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch._dynamo.optimize("eager")(mod)
2786*da0073e9SAndroid Build Coastguard Worker        args = [torch.randn(6)]
2787*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same_two_models(mod, opt_mod, args))
2788*da0073e9SAndroid Build Coastguard Worker        opt_mod(*args)
2789*da0073e9SAndroid Build Coastguard Worker
2790*da0073e9SAndroid Build Coastguard Worker    def test_pointless_graph_removal(self):
2791*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2792*da0073e9SAndroid Build Coastguard Worker
2793*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt)
2794*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2795*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
2796*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
2797*da0073e9SAndroid Build Coastguard Worker                return x + 1
2798*da0073e9SAndroid Build Coastguard Worker
2799*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(4))
2800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 3)
2802*da0073e9SAndroid Build Coastguard Worker
2803*da0073e9SAndroid Build Coastguard Worker    def test_output_aliases_intermediate(self):
2804*da0073e9SAndroid Build Coastguard Worker        def f(x):
2805*da0073e9SAndroid Build Coastguard Worker            intermediate = x.mul(2)
2806*da0073e9SAndroid Build Coastguard Worker            return intermediate.view(-1), intermediate
2807*da0073e9SAndroid Build Coastguard Worker
2808*da0073e9SAndroid Build Coastguard Worker        opt_f = torch._dynamo.optimize("aot_eager")(f)
2809*da0073e9SAndroid Build Coastguard Worker
2810*da0073e9SAndroid Build Coastguard Worker        for b in [True, False]:
2811*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(4, requires_grad=b)
2812*da0073e9SAndroid Build Coastguard Worker            out = f(x)
2813*da0073e9SAndroid Build Coastguard Worker            out_test = opt_f(x)
2814*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out[0], out_test[0])
2815*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out[1], out_test[1])
2816*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out[0].requires_grad, out_test[0].requires_grad)
2817*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out[1].requires_grad, out_test[1].requires_grad)
2818*da0073e9SAndroid Build Coastguard Worker            # test that the aliasing relationship of outputs is preserved
2819*da0073e9SAndroid Build Coastguard Worker            out[0].mul_(2)
2820*da0073e9SAndroid Build Coastguard Worker            out_test[0].mul_(2)
2821*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out[0], out_test[0])
2822*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out[1], out_test[1])
2823*da0073e9SAndroid Build Coastguard Worker
2824*da0073e9SAndroid Build Coastguard Worker    def test_while_loop_graph_break(self):
2825*da0073e9SAndroid Build Coastguard Worker        # Repro of tacotron2 cache_size_recompilation
2826*da0073e9SAndroid Build Coastguard Worker        def inner(x):
2827*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
2828*da0073e9SAndroid Build Coastguard Worker
2829*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2830*da0073e9SAndroid Build Coastguard Worker            i = 20
2831*da0073e9SAndroid Build Coastguard Worker            while i > 10:
2832*da0073e9SAndroid Build Coastguard Worker                x = inner(x)
2833*da0073e9SAndroid Build Coastguard Worker                i -= 1
2834*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
2835*da0073e9SAndroid Build Coastguard Worker            return x
2836*da0073e9SAndroid Build Coastguard Worker
2837*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2838*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
2839*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
2840*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
2841*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2842*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 1)
2843*da0073e9SAndroid Build Coastguard Worker
2844*da0073e9SAndroid Build Coastguard Worker    def test_nested_while_loop_graph_break(self):
2845*da0073e9SAndroid Build Coastguard Worker        def inner_loop(x):
2846*da0073e9SAndroid Build Coastguard Worker            i = 3
2847*da0073e9SAndroid Build Coastguard Worker            while i > 0:
2848*da0073e9SAndroid Build Coastguard Worker                i -= 1
2849*da0073e9SAndroid Build Coastguard Worker                x += 1
2850*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
2851*da0073e9SAndroid Build Coastguard Worker            return x
2852*da0073e9SAndroid Build Coastguard Worker
2853*da0073e9SAndroid Build Coastguard Worker        def inner(x):
2854*da0073e9SAndroid Build Coastguard Worker            inner_loop(x)
2855*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
2856*da0073e9SAndroid Build Coastguard Worker
2857*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2858*da0073e9SAndroid Build Coastguard Worker            i = 20
2859*da0073e9SAndroid Build Coastguard Worker            while i > 10:
2860*da0073e9SAndroid Build Coastguard Worker                x = inner(x)
2861*da0073e9SAndroid Build Coastguard Worker                i -= 1
2862*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
2863*da0073e9SAndroid Build Coastguard Worker            return x
2864*da0073e9SAndroid Build Coastguard Worker
2865*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2866*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
2867*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
2868*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
2869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2870*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 1)
2871*da0073e9SAndroid Build Coastguard Worker
2872*da0073e9SAndroid Build Coastguard Worker    def test_while_loop_graph_break_inside_call_function(self):
2873*da0073e9SAndroid Build Coastguard Worker        # Repro of huggingface graph break inside loop in `get_parameter_dtype`.
2874*da0073e9SAndroid Build Coastguard Worker        # Skip only the inner frame that has loop that contains graph break.
2875*da0073e9SAndroid Build Coastguard Worker        def inner(x):
2876*da0073e9SAndroid Build Coastguard Worker            for i in range(3):
2877*da0073e9SAndroid Build Coastguard Worker                x += 1
2878*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
2879*da0073e9SAndroid Build Coastguard Worker            return x
2880*da0073e9SAndroid Build Coastguard Worker
2881*da0073e9SAndroid Build Coastguard Worker        def fn(x):
2882*da0073e9SAndroid Build Coastguard Worker            x += 2
2883*da0073e9SAndroid Build Coastguard Worker            inner(x)
2884*da0073e9SAndroid Build Coastguard Worker            x += 3
2885*da0073e9SAndroid Build Coastguard Worker            return x
2886*da0073e9SAndroid Build Coastguard Worker
2887*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2888*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
2889*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
2890*da0073e9SAndroid Build Coastguard Worker        opt_fn(x)
2891*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
2892*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 2)
2893*da0073e9SAndroid Build Coastguard Worker
2894*da0073e9SAndroid Build Coastguard Worker    def test_exception_in_dynamo_handling(self):
2895*da0073e9SAndroid Build Coastguard Worker        hit_handler = False
2896*da0073e9SAndroid Build Coastguard Worker
2897*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/pull/96488
2898*da0073e9SAndroid Build Coastguard Worker        @contextlib.contextmanager
2899*da0073e9SAndroid Build Coastguard Worker        def ctx():
2900*da0073e9SAndroid Build Coastguard Worker            try:
2901*da0073e9SAndroid Build Coastguard Worker                yield
2902*da0073e9SAndroid Build Coastguard Worker            except RuntimeError:
2903*da0073e9SAndroid Build Coastguard Worker                nonlocal hit_handler
2904*da0073e9SAndroid Build Coastguard Worker                hit_handler = True
2905*da0073e9SAndroid Build Coastguard Worker
2906*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager")
2907*da0073e9SAndroid Build Coastguard Worker        def f():
2908*da0073e9SAndroid Build Coastguard Worker            with ctx():
2909*da0073e9SAndroid Build Coastguard Worker                h()
2910*da0073e9SAndroid Build Coastguard Worker
2911*da0073e9SAndroid Build Coastguard Worker        def h():
2912*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("boof")
2913*da0073e9SAndroid Build Coastguard Worker
2914*da0073e9SAndroid Build Coastguard Worker        # Should not error
2915*da0073e9SAndroid Build Coastguard Worker        f()
2916*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hit_handler)
2917*da0073e9SAndroid Build Coastguard Worker
2918*da0073e9SAndroid Build Coastguard Worker    def test_generator_dealloc(self):
2919*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/pull/96488
2920*da0073e9SAndroid Build Coastguard Worker        #
2921*da0073e9SAndroid Build Coastguard Worker        # NB: yes, [(...)] is intentional, this is a list containing a
2922*da0073e9SAndroid Build Coastguard Worker        # generator
2923*da0073e9SAndroid Build Coastguard Worker        generator_box = [(x for x in [1, 2, 3])]
2924*da0073e9SAndroid Build Coastguard Worker
2925*da0073e9SAndroid Build Coastguard Worker        counter = torch._dynamo.testing.CompileCounter()
2926*da0073e9SAndroid Build Coastguard Worker
2927*da0073e9SAndroid Build Coastguard Worker        def g(x):
2928*da0073e9SAndroid Build Coastguard Worker            return x + 2
2929*da0073e9SAndroid Build Coastguard Worker
2930*da0073e9SAndroid Build Coastguard Worker        # TODO: This test is pretty delicate.  To test if it's actually doing
2931*da0073e9SAndroid Build Coastguard Worker        # anything, rebuild eval_frame.c with '#define TORCHDYNAMO_DEBUG 1'
2932*da0073e9SAndroid Build Coastguard Worker        # and then look at the logs for:
2933*da0073e9SAndroid Build Coastguard Worker        #
2934*da0073e9SAndroid Build Coastguard Worker        # TRACE[_custom_eval_frame:650] begin <genexpr> test_repros.py 2276 -1 0 0
2935*da0073e9SAndroid Build Coastguard Worker        # TRACE[_custom_eval_frame:664] throw <genexpr>
2936*da0073e9SAndroid Build Coastguard Worker        #
2937*da0073e9SAndroid Build Coastguard Worker        # This means we're actually hitting the relevant codepath
2938*da0073e9SAndroid Build Coastguard Worker
2939*da0073e9SAndroid Build Coastguard Worker        # NB: Make sure we don't actually Dynamo this frame; if we do Dynamo
2940*da0073e9SAndroid Build Coastguard Worker        # this frame, Dynamo actually DOES understand list.clear and will
2941*da0073e9SAndroid Build Coastguard Worker        # arrange for the generator deallocation to happen when the eval frame
2942*da0073e9SAndroid Build Coastguard Worker        # handler is disabled, which will prevent the bug from happening (we
2943*da0073e9SAndroid Build Coastguard Worker        # specifically want to trigger the generator deallocation WHILE the
2944*da0073e9SAndroid Build Coastguard Worker        # dynamo eval frame handler is active), as that will cause the
2945*da0073e9SAndroid Build Coastguard Worker        # generator to become exhausted and trigger the throw_flag == TRUE
2946*da0073e9SAndroid Build Coastguard Worker        # case.
2947*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.disable(recursive=False)
2948*da0073e9SAndroid Build Coastguard Worker        def f(x):
2949*da0073e9SAndroid Build Coastguard Worker            generator_box.clear()
2950*da0073e9SAndroid Build Coastguard Worker            return g(x)
2951*da0073e9SAndroid Build Coastguard Worker
2952*da0073e9SAndroid Build Coastguard Worker        self.assertNoUnraisable(
2953*da0073e9SAndroid Build Coastguard Worker            lambda: torch._dynamo.optimize(counter)(f)(torch.randn(3))
2954*da0073e9SAndroid Build Coastguard Worker        )
2955*da0073e9SAndroid Build Coastguard Worker
2956*da0073e9SAndroid Build Coastguard Worker        # Make sure the x + 2 is captured (a previous incorrect implementation
2957*da0073e9SAndroid Build Coastguard Worker        # of this fix would have disabled the eval frame callback, which means
2958*da0073e9SAndroid Build Coastguard Worker        # g wouldn't get traced
2959*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 1)
2960*da0073e9SAndroid Build Coastguard Worker
2961*da0073e9SAndroid Build Coastguard Worker    def test_error_return_without_exception_set(self):
2962*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/93781
2963*da0073e9SAndroid Build Coastguard Worker        @torch.compile
2964*da0073e9SAndroid Build Coastguard Worker        def f():
2965*da0073e9SAndroid Build Coastguard Worker            _generator_type = type(_ for _ in ())
2966*da0073e9SAndroid Build Coastguard Worker
2967*da0073e9SAndroid Build Coastguard Worker        self.assertNoUnraisable(f)
2968*da0073e9SAndroid Build Coastguard Worker
2969*da0073e9SAndroid Build Coastguard Worker    def common_merge_criteria_processor_list(self, list_cls, fullgraph):
2970*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
2971*da0073e9SAndroid Build Coastguard Worker
2972*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=fullgraph)
2973*da0073e9SAndroid Build Coastguard Worker        def f(x, left, right):
2974*da0073e9SAndroid Build Coastguard Worker            combined = _merge_criteria_processor_list(left, right)
2975*da0073e9SAndroid Build Coastguard Worker            return combined(x)
2976*da0073e9SAndroid Build Coastguard Worker
2977*da0073e9SAndroid Build Coastguard Worker        l1 = list_cls([torch.nn.ReLU(), torch.nn.Sigmoid()])
2978*da0073e9SAndroid Build Coastguard Worker        l2 = list_cls([])
2979*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(16)
2980*da0073e9SAndroid Build Coastguard Worker        result = f(input, l1, l2)
2981*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, l1(input))
2982*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2983*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 2)
2984*da0073e9SAndroid Build Coastguard Worker
2985*da0073e9SAndroid Build Coastguard Worker        cnt.clear()
2986*da0073e9SAndroid Build Coastguard Worker        l3 = list_cls([torch.nn.SiLU()])
2987*da0073e9SAndroid Build Coastguard Worker        expected = l3(l1(input))
2988*da0073e9SAndroid Build Coastguard Worker        result = f(input, l1, l3)
2989*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(l1), 3)
2990*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
2991*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
2992*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 3)
2993*da0073e9SAndroid Build Coastguard Worker
2994*da0073e9SAndroid Build Coastguard Worker    def test_merge_criteria_processor_list1(self):
2995*da0073e9SAndroid Build Coastguard Worker        self.common_merge_criteria_processor_list(CustomList1, False)
2996*da0073e9SAndroid Build Coastguard Worker
2997*da0073e9SAndroid Build Coastguard Worker    def test_merge_criteria_processor_list2(self):
2998*da0073e9SAndroid Build Coastguard Worker        self.common_merge_criteria_processor_list(CustomList2, True)
2999*da0073e9SAndroid Build Coastguard Worker
3000*da0073e9SAndroid Build Coastguard Worker    def test_restricted_list_subclass1(self):
3001*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
3002*da0073e9SAndroid Build Coastguard Worker
3003*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
3004*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
3005*da0073e9SAndroid Build Coastguard Worker            l = CustomList2()
3006*da0073e9SAndroid Build Coastguard Worker            l.extend([True])
3007*da0073e9SAndroid Build Coastguard Worker            l.append(a)
3008*da0073e9SAndroid Build Coastguard Worker            l.extend([b])
3009*da0073e9SAndroid Build Coastguard Worker            l.pop(0)
3010*da0073e9SAndroid Build Coastguard Worker            l.append(l.length_times_10())
3011*da0073e9SAndroid Build Coastguard Worker            return sum(l)
3012*da0073e9SAndroid Build Coastguard Worker
3013*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
3014*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
3015*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x, y), x + y + 20)
3016*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 3)
3017*da0073e9SAndroid Build Coastguard Worker
3018*da0073e9SAndroid Build Coastguard Worker    def test_restricted_list_subclass2(self):
3019*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
3020*da0073e9SAndroid Build Coastguard Worker
3021*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
3022*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
3023*da0073e9SAndroid Build Coastguard Worker            l1 = CustomList2([a + 1])
3024*da0073e9SAndroid Build Coastguard Worker            l2 = CustomList2([b + 2])
3025*da0073e9SAndroid Build Coastguard Worker            l1.extend(l2)
3026*da0073e9SAndroid Build Coastguard Worker            return l1
3027*da0073e9SAndroid Build Coastguard Worker
3028*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
3029*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
3030*da0073e9SAndroid Build Coastguard Worker        z = fn(x, y)
3031*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(z), CustomList2)
3032*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(z), 2)
3033*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.length_times_10(), 20)
3034*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(z), [x + 1, y + 2])
3035*da0073e9SAndroid Build Coastguard Worker
3036*da0073e9SAndroid Build Coastguard Worker    def test_restricted_list_subclass3(self):
3037*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
3038*da0073e9SAndroid Build Coastguard Worker
3039*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
3040*da0073e9SAndroid Build Coastguard Worker        def fn(a: CustomList2, b: CustomList2):
3041*da0073e9SAndroid Build Coastguard Worker            a.extend(b)
3042*da0073e9SAndroid Build Coastguard Worker            a.append_twice(b[2] + 1)
3043*da0073e9SAndroid Build Coastguard Worker            a.append(b[3] + 2)
3044*da0073e9SAndroid Build Coastguard Worker            return b
3045*da0073e9SAndroid Build Coastguard Worker
3046*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
3047*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
3048*da0073e9SAndroid Build Coastguard Worker        l = CustomList2([x, y])
3049*da0073e9SAndroid Build Coastguard Worker        self.assertIs(fn(l, l), l)
3050*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(l), 7)
3051*da0073e9SAndroid Build Coastguard Worker        self.assertIs(l[0], x)
3052*da0073e9SAndroid Build Coastguard Worker        self.assertIs(l[1], y)
3053*da0073e9SAndroid Build Coastguard Worker        self.assertIs(l[2], x)
3054*da0073e9SAndroid Build Coastguard Worker        self.assertIs(l[3], y)
3055*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(l[4], x + 1)
3056*da0073e9SAndroid Build Coastguard Worker        self.assertIs(l[5], l[4])
3057*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(l[6], y + 2)
3058*da0073e9SAndroid Build Coastguard Worker
3059*da0073e9SAndroid Build Coastguard Worker    def test_rewrite_assert_with_msg(self):
3060*da0073e9SAndroid Build Coastguard Worker        def f(x):
3061*da0073e9SAndroid Build Coastguard Worker            b = x.sin()
3062*da0073e9SAndroid Build Coastguard Worker            assert x[0] == 3, "First dim need to be 3"
3063*da0073e9SAndroid Build Coastguard Worker            return x.cos() + b
3064*da0073e9SAndroid Build Coastguard Worker
3065*da0073e9SAndroid Build Coastguard Worker        args = (torch.Tensor([3, 4, 5]),)
3066*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
3067*da0073e9SAndroid Build Coastguard Worker
3068*da0073e9SAndroid Build Coastguard Worker        opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
3069*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(f(*args), opt_f(*args)))
3070*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 6)
3071*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
3072*da0073e9SAndroid Build Coastguard Worker
3073*da0073e9SAndroid Build Coastguard Worker        exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
3074*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(exported(*args), f(*args)))
3075*da0073e9SAndroid Build Coastguard Worker
3076*da0073e9SAndroid Build Coastguard Worker    def test_list_aliasing(self):
3077*da0073e9SAndroid Build Coastguard Worker        cnt = CompileCounter()
3078*da0073e9SAndroid Build Coastguard Worker
3079*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
3080*da0073e9SAndroid Build Coastguard Worker        def fn(a):
3081*da0073e9SAndroid Build Coastguard Worker            a.append(torch.sin(a[0]))
3082*da0073e9SAndroid Build Coastguard Worker            return a
3083*da0073e9SAndroid Build Coastguard Worker
3084*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
3085*da0073e9SAndroid Build Coastguard Worker        l = [x]
3086*da0073e9SAndroid Build Coastguard Worker        self.assertIs(fn(l), l)
3087*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(l), 2)
3088*da0073e9SAndroid Build Coastguard Worker        self.assertIs(l[0], x)
3089*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(l[1], torch.sin(x))
3090*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
3091*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 1)
3092*da0073e9SAndroid Build Coastguard Worker
3093*da0073e9SAndroid Build Coastguard Worker    def test_not_rewrite_assert_for_other_errors(self):
3094*da0073e9SAndroid Build Coastguard Worker        def f(x):
3095*da0073e9SAndroid Build Coastguard Worker            b = x.sin()
3096*da0073e9SAndroid Build Coastguard Worker            if not x.sum() <= 3:
3097*da0073e9SAndroid Build Coastguard Worker                raise ValueError("input sum needs to be 3")
3098*da0073e9SAndroid Build Coastguard Worker            return x.cos() + b
3099*da0073e9SAndroid Build Coastguard Worker
3100*da0073e9SAndroid Build Coastguard Worker        args = (torch.Tensor([3, 4, 5]),)
3101*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(f)
3102*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "input sum needs to be 3"):
3103*da0073e9SAndroid Build Coastguard Worker            opt_fn(*args)
3104*da0073e9SAndroid Build Coastguard Worker
3105*da0073e9SAndroid Build Coastguard Worker    def test_rewrite_assert_dont_change_bytecode(self):
3106*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3107*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
3108*da0073e9SAndroid Build Coastguard Worker                assert x.max() < 5, f"invalid max {x.max()}"
3109*da0073e9SAndroid Build Coastguard Worker                x = torch.sin(x)
3110*da0073e9SAndroid Build Coastguard Worker            return x
3111*da0073e9SAndroid Build Coastguard Worker
3112*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(4)
3113*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager")(fn)
3114*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(x), opt_fn(x)))
3115*da0073e9SAndroid Build Coastguard Worker
3116*da0073e9SAndroid Build Coastguard Worker    def test_rewrite_assert_without_msg(self):
3117*da0073e9SAndroid Build Coastguard Worker        def f(x):
3118*da0073e9SAndroid Build Coastguard Worker            b = x.sin()
3119*da0073e9SAndroid Build Coastguard Worker            assert x[0] == 3
3120*da0073e9SAndroid Build Coastguard Worker            return x.cos() + b
3121*da0073e9SAndroid Build Coastguard Worker
3122*da0073e9SAndroid Build Coastguard Worker        args = (torch.Tensor([3, 4, 5]),)
3123*da0073e9SAndroid Build Coastguard Worker        exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
3124*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(exported(*args), f(*args)))
3125*da0073e9SAndroid Build Coastguard Worker
3126*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "assertion error"):
3127*da0073e9SAndroid Build Coastguard Worker            exported(torch.Tensor([5, 6, 7]))
3128*da0073e9SAndroid Build Coastguard Worker
3129*da0073e9SAndroid Build Coastguard Worker    def test_rewrite_assert_with_non_string_msg(self):
3130*da0073e9SAndroid Build Coastguard Worker        def f(x):
3131*da0073e9SAndroid Build Coastguard Worker            b = x.sin()
3132*da0073e9SAndroid Build Coastguard Worker            assert x[0] == 2, x.size()
3133*da0073e9SAndroid Build Coastguard Worker            return x.cos() + b
3134*da0073e9SAndroid Build Coastguard Worker
3135*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.utils.counters.clear()
3136*da0073e9SAndroid Build Coastguard Worker        args = torch.Tensor([3, 4, 5])
3137*da0073e9SAndroid Build Coastguard Worker        opt_f = torch._dynamo.optimize("eager")(f)
3138*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "torch.Size"):
3139*da0073e9SAndroid Build Coastguard Worker            opt_f(args)
3140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3141*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.utils.counters["graph_break"][
3142*da0073e9SAndroid Build Coastguard Worker                "assert with non-string message"
3143*da0073e9SAndroid Build Coastguard Worker            ],
3144*da0073e9SAndroid Build Coastguard Worker            1,
3145*da0073e9SAndroid Build Coastguard Worker        )
3146*da0073e9SAndroid Build Coastguard Worker
3147*da0073e9SAndroid Build Coastguard Worker    def test_rewrite_assert_noop(self):
3148*da0073e9SAndroid Build Coastguard Worker        def f(x):
3149*da0073e9SAndroid Build Coastguard Worker            b = x.sin()
3150*da0073e9SAndroid Build Coastguard Worker            assert True
3151*da0073e9SAndroid Build Coastguard Worker            assert x.dtype == torch.float32
3152*da0073e9SAndroid Build Coastguard Worker            return x.cos() + b
3153*da0073e9SAndroid Build Coastguard Worker
3154*da0073e9SAndroid Build Coastguard Worker        args = (torch.Tensor([3, 4, 5]),)
3155*da0073e9SAndroid Build Coastguard Worker        exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
3156*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(exported(*args), f(*args)))
3157*da0073e9SAndroid Build Coastguard Worker
3158*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
3159*da0073e9SAndroid Build Coastguard Worker        opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
3160*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(f(*args), opt_f(*args)))
3161*da0073e9SAndroid Build Coastguard Worker        # torch._assert shouldn't be in the graph
3162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 3)
3163*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
3164*da0073e9SAndroid Build Coastguard Worker
3165*da0073e9SAndroid Build Coastguard Worker        exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5]))
3166*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(exported(*args), f(*args)))
3167*da0073e9SAndroid Build Coastguard Worker
3168*da0073e9SAndroid Build Coastguard Worker    def test_size_typematch(self):
3169*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
3170*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, torch.Size):
3171*da0073e9SAndroid Build Coastguard Worker                return y + 1
3172*da0073e9SAndroid Build Coastguard Worker            else:
3173*da0073e9SAndroid Build Coastguard Worker                return y + 2
3174*da0073e9SAndroid Build Coastguard Worker
3175*da0073e9SAndroid Build Coastguard Worker        y = torch.zeros(1)
3176*da0073e9SAndroid Build Coastguard Worker        x1 = torch.Size((3,))
3177*da0073e9SAndroid Build Coastguard Worker        x2 = (3,)
3178*da0073e9SAndroid Build Coastguard Worker
3179*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
3180*da0073e9SAndroid Build Coastguard Worker        opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
3181*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(f(x1, y), opt_f(x1, y)))
3182*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(f(x2, y), opt_f(x2, y)))
3183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
3184*da0073e9SAndroid Build Coastguard Worker
3185*da0073e9SAndroid Build Coastguard Worker    def test_dict_subclass_contains(self):
3186*da0073e9SAndroid Build Coastguard Worker        # pattern from huggingface
3187*da0073e9SAndroid Build Coastguard Worker        class ClassInstantier(collections.OrderedDict):
3188*da0073e9SAndroid Build Coastguard Worker            pass
3189*da0073e9SAndroid Build Coastguard Worker
3190*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True, backend="eager")
3191*da0073e9SAndroid Build Coastguard Worker        def f(x, d):
3192*da0073e9SAndroid Build Coastguard Worker            if "key1" in d:
3193*da0073e9SAndroid Build Coastguard Worker                x = x + 2
3194*da0073e9SAndroid Build Coastguard Worker            if "key2" in d:
3195*da0073e9SAndroid Build Coastguard Worker                x = x + 4
3196*da0073e9SAndroid Build Coastguard Worker            x = x + 8
3197*da0073e9SAndroid Build Coastguard Worker            return x
3198*da0073e9SAndroid Build Coastguard Worker
3199*da0073e9SAndroid Build Coastguard Worker        result = f(torch.ones(8), ClassInstantier({"key1": torch.ones(8)}))
3200*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(result, torch.full([8], 11.0)))
3201*da0073e9SAndroid Build Coastguard Worker
3202*da0073e9SAndroid Build Coastguard Worker        result = f(torch.ones(8), ClassInstantier({"key2": torch.ones(8)}))
3203*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(result, torch.full([8], 13.0)))
3204*da0073e9SAndroid Build Coastguard Worker
3205*da0073e9SAndroid Build Coastguard Worker    def test_hf_classinstantier(self):
3206*da0073e9SAndroid Build Coastguard Worker        # hf activations.py
3207*da0073e9SAndroid Build Coastguard Worker        class ClassInstantier(collections.OrderedDict):
3208*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, key):
3209*da0073e9SAndroid Build Coastguard Worker                content = super().__getitem__(key)
3210*da0073e9SAndroid Build Coastguard Worker                cls, kwargs = content if isinstance(content, tuple) else (content, {})
3211*da0073e9SAndroid Build Coastguard Worker                return cls(**kwargs)
3212*da0073e9SAndroid Build Coastguard Worker
3213*da0073e9SAndroid Build Coastguard Worker        ACT2CLS = ClassInstantier(
3214*da0073e9SAndroid Build Coastguard Worker            {
3215*da0073e9SAndroid Build Coastguard Worker                "relu": (nn.ReLU, {"inplace": False}),
3216*da0073e9SAndroid Build Coastguard Worker                "tanh": nn.Tanh,
3217*da0073e9SAndroid Build Coastguard Worker            }
3218*da0073e9SAndroid Build Coastguard Worker        )
3219*da0073e9SAndroid Build Coastguard Worker
3220*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True, backend="eager")
3221*da0073e9SAndroid Build Coastguard Worker        def f(x, act):
3222*da0073e9SAndroid Build Coastguard Worker            return ACT2CLS[act](x)
3223*da0073e9SAndroid Build Coastguard Worker
3224*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
3225*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(f(y, "tanh"), torch.tanh(y)))
3226*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(f(y, "relu"), torch.relu(y)))
3227*da0073e9SAndroid Build Coastguard Worker
3228*da0073e9SAndroid Build Coastguard Worker    def test_ephemeral_module(self):
3229*da0073e9SAndroid Build Coastguard Worker        # hf activations.py
3230*da0073e9SAndroid Build Coastguard Worker        class ReLUSquaredActivation(nn.Module):
3231*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
3232*da0073e9SAndroid Build Coastguard Worker                relu_applied = torch.nn.functional.relu(input)
3233*da0073e9SAndroid Build Coastguard Worker                squared = torch.square(relu_applied)
3234*da0073e9SAndroid Build Coastguard Worker                return squared
3235*da0073e9SAndroid Build Coastguard Worker
3236*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True, backend="eager")
3237*da0073e9SAndroid Build Coastguard Worker        def f(x):
3238*da0073e9SAndroid Build Coastguard Worker            x = x + 0.2
3239*da0073e9SAndroid Build Coastguard Worker            x = ReLUSquaredActivation()(x)
3240*da0073e9SAndroid Build Coastguard Worker            x = x + 1
3241*da0073e9SAndroid Build Coastguard Worker            return x
3242*da0073e9SAndroid Build Coastguard Worker
3243*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10)
3244*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(f(y), ReLUSquaredActivation()(y + 0.2) + 1))
3245*da0073e9SAndroid Build Coastguard Worker
3246*da0073e9SAndroid Build Coastguard Worker    def test_inplace_unsqueeze_input(self):
3247*da0073e9SAndroid Build Coastguard Worker        def backend(gm, example_inputs):
3248*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(example_inputs[-1].size(), torch.Size([1, 3, 4]))
3249*da0073e9SAndroid Build Coastguard Worker            return gm
3250*da0073e9SAndroid Build Coastguard Worker
3251*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=backend)
3252*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3253*da0073e9SAndroid Build Coastguard Worker            x.unsqueeze_(0)
3254*da0073e9SAndroid Build Coastguard Worker            return x + 1
3255*da0073e9SAndroid Build Coastguard Worker
3256*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.randn(3, 4)]
3257*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(*inputs).size(), torch.Size([1, 3, 4]))
3258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inputs[0].size(), torch.Size([1, 3, 4]))
3259*da0073e9SAndroid Build Coastguard Worker
3260*da0073e9SAndroid Build Coastguard Worker    def test_batchnorm_e2e(self):
3261*da0073e9SAndroid Build Coastguard Worker        class Repro(torch.nn.Module):
3262*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3263*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3264*da0073e9SAndroid Build Coastguard Worker                self.bn = torch.nn.BatchNorm2d(
3265*da0073e9SAndroid Build Coastguard Worker                    64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
3266*da0073e9SAndroid Build Coastguard Worker                )
3267*da0073e9SAndroid Build Coastguard Worker                self.conv1 = torch.nn.Conv2d(
3268*da0073e9SAndroid Build Coastguard Worker                    64,
3269*da0073e9SAndroid Build Coastguard Worker                    64,
3270*da0073e9SAndroid Build Coastguard Worker                    kernel_size=(3, 3),
3271*da0073e9SAndroid Build Coastguard Worker                    stride=(1, 1),
3272*da0073e9SAndroid Build Coastguard Worker                    padding=(1, 1),
3273*da0073e9SAndroid Build Coastguard Worker                    bias=False,
3274*da0073e9SAndroid Build Coastguard Worker                )
3275*da0073e9SAndroid Build Coastguard Worker
3276*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3277*da0073e9SAndroid Build Coastguard Worker                x1 = self.bn(x)
3278*da0073e9SAndroid Build Coastguard Worker                x2 = self.conv1(x1)
3279*da0073e9SAndroid Build Coastguard Worker                out = torch.nn.functional.relu(x2)
3280*da0073e9SAndroid Build Coastguard Worker                return (out,)
3281*da0073e9SAndroid Build Coastguard Worker
3282*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1337)
3283*da0073e9SAndroid Build Coastguard Worker
3284*da0073e9SAndroid Build Coastguard Worker        m_ref = Repro()
3285*da0073e9SAndroid Build Coastguard Worker        m_test = deepcopy(m_ref)
3286*da0073e9SAndroid Build Coastguard Worker
3287*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("aot_eager_decomp_partition")
3288*da0073e9SAndroid Build Coastguard Worker        def compiled_fn(x):
3289*da0073e9SAndroid Build Coastguard Worker            return m_test(x)
3290*da0073e9SAndroid Build Coastguard Worker
3291*da0073e9SAndroid Build Coastguard Worker        x_ref = torch.randn(2, 64, 32, 32, requires_grad=True)
3292*da0073e9SAndroid Build Coastguard Worker        x_test = x_ref.clone()
3293*da0073e9SAndroid Build Coastguard Worker
3294*da0073e9SAndroid Build Coastguard Worker        # Loop multiple times: each iteration the running_mean/var on batchnorm will update,
3295*da0073e9SAndroid Build Coastguard Worker        # which changes the output of the next iteration
3296*da0073e9SAndroid Build Coastguard Worker        for _ in range(3):
3297*da0073e9SAndroid Build Coastguard Worker            ref = m_ref(x_ref)
3298*da0073e9SAndroid Build Coastguard Worker            res = compiled_fn(x_test)
3299*da0073e9SAndroid Build Coastguard Worker
3300*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(ref, res))
3301*da0073e9SAndroid Build Coastguard Worker
3302*da0073e9SAndroid Build Coastguard Worker            for r in ref:
3303*da0073e9SAndroid Build Coastguard Worker                if r.requires_grad:
3304*da0073e9SAndroid Build Coastguard Worker                    r.sum().backward()
3305*da0073e9SAndroid Build Coastguard Worker            for r in res:
3306*da0073e9SAndroid Build Coastguard Worker                if r.requires_grad:
3307*da0073e9SAndroid Build Coastguard Worker                    r.sum().backward()
3308*da0073e9SAndroid Build Coastguard Worker
3309*da0073e9SAndroid Build Coastguard Worker            for param_ref, param_test in zip(m_ref.parameters(), m_test.parameters()):
3310*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(same(param_ref, param_test))
3311*da0073e9SAndroid Build Coastguard Worker            # Assert running_mean/var
3312*da0073e9SAndroid Build Coastguard Worker            for buffer_ref, buffer_test in zip(m_ref.buffers(), m_test.buffers()):
3313*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(same(buffer_ref, buffer_test))
3314*da0073e9SAndroid Build Coastguard Worker
3315*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("assume_static_by_default", False)
3316*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_shapes_right_side(self):
3317*da0073e9SAndroid Build Coastguard Worker        def f(x):
3318*da0073e9SAndroid Build Coastguard Worker            return torch.ones(5 * x.shape[0])
3319*da0073e9SAndroid Build Coastguard Worker
3320*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(6, 5)
3321*da0073e9SAndroid Build Coastguard Worker
3322*da0073e9SAndroid Build Coastguard Worker        gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
3323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gm(inp).shape, f(inp).shape)
3324*da0073e9SAndroid Build Coastguard Worker
3325*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("specialize_int", False)
3326*da0073e9SAndroid Build Coastguard Worker    def test_maybe_multiply_symint(self):
3327*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/97346
3328*da0073e9SAndroid Build Coastguard Worker        from torch._functorch.aot_autograd import aot_module_simplified
3329*da0073e9SAndroid Build Coastguard Worker
3330*da0073e9SAndroid Build Coastguard Worker        def my_aot_compiler(gm, example_inputs):
3331*da0073e9SAndroid Build Coastguard Worker            def my_compiler(gm, example_inputs):
3332*da0073e9SAndroid Build Coastguard Worker                return gm.forward
3333*da0073e9SAndroid Build Coastguard Worker
3334*da0073e9SAndroid Build Coastguard Worker            # Invoke AOTAutograd
3335*da0073e9SAndroid Build Coastguard Worker            return aot_module_simplified(gm, example_inputs, fw_compiler=my_compiler)
3336*da0073e9SAndroid Build Coastguard Worker
3337*da0073e9SAndroid Build Coastguard Worker        def my_example(t1, t2, d):
3338*da0073e9SAndroid Build Coastguard Worker            out = torch.add(t1, t2, alpha=d)
3339*da0073e9SAndroid Build Coastguard Worker            return out
3340*da0073e9SAndroid Build Coastguard Worker
3341*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch.compile(backend=my_aot_compiler, dynamic=True)(my_example)
3342*da0073e9SAndroid Build Coastguard Worker
3343*da0073e9SAndroid Build Coastguard Worker        t1 = torch.arange(3, dtype=torch.float32).requires_grad_(True)
3344*da0073e9SAndroid Build Coastguard Worker        t2 = torch.arange(3, dtype=torch.float32).requires_grad_(True)
3345*da0073e9SAndroid Build Coastguard Worker
3346*da0073e9SAndroid Build Coastguard Worker        ra = compiled_fn(t1, t2, 5)
3347*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ra, torch.tensor([0.0, 6.0, 12.0]))
3348*da0073e9SAndroid Build Coastguard Worker
3349*da0073e9SAndroid Build Coastguard Worker        ra = compiled_fn(t1, t2, 6)
3350*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ra, torch.tensor([0.0, 7.0, 14.0]))
3351*da0073e9SAndroid Build Coastguard Worker
3352*da0073e9SAndroid Build Coastguard Worker    def test_build_map_unpack_with_call(self):
3353*da0073e9SAndroid Build Coastguard Worker        def forward_with_cond_scale(x, t, cond_scale, self_cond, other1, other2):
3354*da0073e9SAndroid Build Coastguard Worker            return x.sin() + t + cond_scale + self_cond + other1 + other2
3355*da0073e9SAndroid Build Coastguard Worker
3356*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
3357*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3358*da0073e9SAndroid Build Coastguard Worker            d1 = dict(other1=5)
3359*da0073e9SAndroid Build Coastguard Worker            d2 = dict(other2=4)
3360*da0073e9SAndroid Build Coastguard Worker            text_cond = {**d1, **d2}
3361*da0073e9SAndroid Build Coastguard Worker            return forward_with_cond_scale(x, 1, cond_scale=2, self_cond=3, **text_cond)
3362*da0073e9SAndroid Build Coastguard Worker
3363*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(torch.ones(4)), torch.ones(4).sin() + 15))
3364*da0073e9SAndroid Build Coastguard Worker
3365*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(verbose=True)
3366*da0073e9SAndroid Build Coastguard Worker    def test_graph_break_unsupported_fake(self):
3367*da0073e9SAndroid Build Coastguard Worker        counter = torch._dynamo.testing.CompileCounter()
3368*da0073e9SAndroid Build Coastguard Worker
3369*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(counter)
3370*da0073e9SAndroid Build Coastguard Worker        def f(x):
3371*da0073e9SAndroid Build Coastguard Worker            return torch.ops.test_sample.foo(x + 1) + 1
3372*da0073e9SAndroid Build Coastguard Worker
3373*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(3))
3374*da0073e9SAndroid Build Coastguard Worker
3375*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.op_count, 2)
3376*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
3377*da0073e9SAndroid Build Coastguard Worker
3378*da0073e9SAndroid Build Coastguard Worker    def test_delattr(self):
3379*da0073e9SAndroid Build Coastguard Worker        class MyObj:
3380*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a, b):
3381*da0073e9SAndroid Build Coastguard Worker                self.a = a
3382*da0073e9SAndroid Build Coastguard Worker                self.b = b
3383*da0073e9SAndroid Build Coastguard Worker
3384*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
3385*da0073e9SAndroid Build Coastguard Worker        def fn(x, obj):
3386*da0073e9SAndroid Build Coastguard Worker            del obj.a
3387*da0073e9SAndroid Build Coastguard Worker            obj.c = x + 1
3388*da0073e9SAndroid Build Coastguard Worker            del obj.c
3389*da0073e9SAndroid Build Coastguard Worker            tmp = MyObj(x + 2, x + 3)
3390*da0073e9SAndroid Build Coastguard Worker            del tmp.b
3391*da0073e9SAndroid Build Coastguard Worker            if hasattr(obj, "a"):
3392*da0073e9SAndroid Build Coastguard Worker                return x + 1
3393*da0073e9SAndroid Build Coastguard Worker            return tmp
3394*da0073e9SAndroid Build Coastguard Worker
3395*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros([])
3396*da0073e9SAndroid Build Coastguard Worker        obj1 = MyObj(x, x)
3397*da0073e9SAndroid Build Coastguard Worker        obj2 = fn(x, obj1)
3398*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(obj1, "a"))
3399*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(obj1, "c"))
3400*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(obj2, "b"))
3401*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(obj1.b.item(), 0)
3402*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(obj2.a.item(), 2)
3403*da0073e9SAndroid Build Coastguard Worker
3404*da0073e9SAndroid Build Coastguard Worker    def test_delattr_raises(self):
3405*da0073e9SAndroid Build Coastguard Worker        class MyObj:
3406*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a, b):
3407*da0073e9SAndroid Build Coastguard Worker                self.a = a
3408*da0073e9SAndroid Build Coastguard Worker                self.b = b
3409*da0073e9SAndroid Build Coastguard Worker
3410*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
3411*da0073e9SAndroid Build Coastguard Worker        def fn(x, obj):
3412*da0073e9SAndroid Build Coastguard Worker            del obj.a
3413*da0073e9SAndroid Build Coastguard Worker            x = x + 1
3414*da0073e9SAndroid Build Coastguard Worker            obj.a  # will raise
3415*da0073e9SAndroid Build Coastguard Worker            return x
3416*da0073e9SAndroid Build Coastguard Worker
3417*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros([])
3418*da0073e9SAndroid Build Coastguard Worker        obj1 = MyObj(x, x)
3419*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(AttributeError, lambda: fn(x, obj1))
3420*da0073e9SAndroid Build Coastguard Worker
3421*da0073e9SAndroid Build Coastguard Worker    def test_delsubscr(self):
3422*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
3423*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3424*da0073e9SAndroid Build Coastguard Worker            del x["a"]
3425*da0073e9SAndroid Build Coastguard Worker            y = x["b"] + 1
3426*da0073e9SAndroid Build Coastguard Worker            return y
3427*da0073e9SAndroid Build Coastguard Worker
3428*da0073e9SAndroid Build Coastguard Worker        x = {"a": torch.tensor([1]), "b": torch.tensor([1])}
3429*da0073e9SAndroid Build Coastguard Worker        result = fn(x)
3430*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hasattr(x, "a"))
3431*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.item(), 2)
3432*da0073e9SAndroid Build Coastguard Worker
3433*da0073e9SAndroid Build Coastguard Worker    def test_delsubscr_raises(self):
3434*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
3435*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3436*da0073e9SAndroid Build Coastguard Worker            del x["a"]
3437*da0073e9SAndroid Build Coastguard Worker            y = x["a"] + 1  # should raise KeyError
3438*da0073e9SAndroid Build Coastguard Worker            return y
3439*da0073e9SAndroid Build Coastguard Worker
3440*da0073e9SAndroid Build Coastguard Worker        x = {"a": torch.tensor([1]), "b": torch.tensor([1])}
3441*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(KeyError, lambda: fn(x))
3442*da0073e9SAndroid Build Coastguard Worker
3443*da0073e9SAndroid Build Coastguard Worker    def test_attached_attribute_in_dir(self):
3444*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
3445*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3446*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3447*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(16, 16)
3448*da0073e9SAndroid Build Coastguard Worker                self.relu = torch.nn.ReLU()
3449*da0073e9SAndroid Build Coastguard Worker
3450*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3451*da0073e9SAndroid Build Coastguard Worker                return self.relu(self.linear(x))
3452*da0073e9SAndroid Build Coastguard Worker
3453*da0073e9SAndroid Build Coastguard Worker        mod = torch.compile(MyModule(), backend="eager")
3454*da0073e9SAndroid Build Coastguard Worker        mod.is_compiled = True
3455*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("is_compiled" in dir(mod))
3456*da0073e9SAndroid Build Coastguard Worker
3457*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("automatic_dynamic_shapes", False)
3458*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_shapes_implicit_guard(self):
3459*da0073e9SAndroid Build Coastguard Worker        def f(x):
3460*da0073e9SAndroid Build Coastguard Worker            y = x * x.size(x.shape[0])
3461*da0073e9SAndroid Build Coastguard Worker            torch.sum(y, [y.shape[0]])
3462*da0073e9SAndroid Build Coastguard Worker            return y
3463*da0073e9SAndroid Build Coastguard Worker
3464*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
3465*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f)
3466*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn(3, 1, 1, 1, 1))
3467*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
3468*da0073e9SAndroid Build Coastguard Worker
3469*da0073e9SAndroid Build Coastguard Worker    def test_dalle2_maybe(self):
3470*da0073e9SAndroid Build Coastguard Worker        def normalize(x):
3471*da0073e9SAndroid Build Coastguard Worker            return x.cos()
3472*da0073e9SAndroid Build Coastguard Worker
3473*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
3474*da0073e9SAndroid Build Coastguard Worker        def fn(x, normalize_img):
3475*da0073e9SAndroid Build Coastguard Worker            lowres_cond_img = x.sin()
3476*da0073e9SAndroid Build Coastguard Worker            lowres_cond_img = maybe(normalize_img)(lowres_cond_img)
3477*da0073e9SAndroid Build Coastguard Worker            return lowres_cond_img
3478*da0073e9SAndroid Build Coastguard Worker
3479*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(torch.ones([]), normalize), torch.ones([]).sin().cos())
3480*da0073e9SAndroid Build Coastguard Worker
3481*da0073e9SAndroid Build Coastguard Worker    def test_functools_wraps(self):
3482*da0073e9SAndroid Build Coastguard Worker        def cool_name(x):
3483*da0073e9SAndroid Build Coastguard Worker            return x.sin()
3484*da0073e9SAndroid Build Coastguard Worker
3485*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
3486*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3487*da0073e9SAndroid Build Coastguard Worker            y = x.cos()
3488*da0073e9SAndroid Build Coastguard Worker
3489*da0073e9SAndroid Build Coastguard Worker            @functools.wraps(cool_name)
3490*da0073e9SAndroid Build Coastguard Worker            def uncool_name():
3491*da0073e9SAndroid Build Coastguard Worker                return cool_name(y)
3492*da0073e9SAndroid Build Coastguard Worker
3493*da0073e9SAndroid Build Coastguard Worker            return uncool_name
3494*da0073e9SAndroid Build Coastguard Worker
3495*da0073e9SAndroid Build Coastguard Worker        result = fn(torch.ones([]))
3496*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.__name__, "cool_name")
3497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result(), torch.ones([]).cos().sin())
3498*da0073e9SAndroid Build Coastguard Worker
3499*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_shapes_float_guard(self):
3500*da0073e9SAndroid Build Coastguard Worker        def f(x):
3501*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.dropout(x, x.shape[0] / 6)
3502*da0073e9SAndroid Build Coastguard Worker
3503*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
3504*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f)
3505*da0073e9SAndroid Build Coastguard Worker        opt_fn(torch.randn(3))
3506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
3507*da0073e9SAndroid Build Coastguard Worker
3508*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
3509*da0073e9SAndroid Build Coastguard Worker    def test_tensor_item(self):
3510*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
3511*da0073e9SAndroid Build Coastguard Worker            val = y.item()
3512*da0073e9SAndroid Build Coastguard Worker            return x.sum() + val
3513*da0073e9SAndroid Build Coastguard Worker
3514*da0073e9SAndroid Build Coastguard Worker        gm, _ = torch._dynamo.export(
3515*da0073e9SAndroid Build Coastguard Worker            f,
3516*da0073e9SAndroid Build Coastguard Worker            aten_graph=True,
3517*da0073e9SAndroid Build Coastguard Worker        )(
3518*da0073e9SAndroid Build Coastguard Worker            torch.zeros(6, 4),
3519*da0073e9SAndroid Build Coastguard Worker            torch.tensor(1),
3520*da0073e9SAndroid Build Coastguard Worker        )
3521*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3522*da0073e9SAndroid Build Coastguard Worker            f(torch.zeros(6, 4), torch.tensor(1)),
3523*da0073e9SAndroid Build Coastguard Worker            gm(torch.zeros(6, 4), torch.tensor(1)),
3524*da0073e9SAndroid Build Coastguard Worker        )
3525*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3526*da0073e9SAndroid Build Coastguard Worker            f(torch.zeros(6, 4), torch.tensor(2)),
3527*da0073e9SAndroid Build Coastguard Worker            gm(torch.zeros(6, 4), torch.tensor(2)),
3528*da0073e9SAndroid Build Coastguard Worker        )
3529*da0073e9SAndroid Build Coastguard Worker
3530*da0073e9SAndroid Build Coastguard Worker    def test_dataclass_init_with_default_factory_with_inputs(self):
3531*da0073e9SAndroid Build Coastguard Worker        @dataclasses.dataclass
3532*da0073e9SAndroid Build Coastguard Worker        class DClass:
3533*da0073e9SAndroid Build Coastguard Worker            sharding_contexts: Any = dataclasses.field(default_factory=list)
3534*da0073e9SAndroid Build Coastguard Worker            a: int = 1
3535*da0073e9SAndroid Build Coastguard Worker
3536*da0073e9SAndroid Build Coastguard Worker        def fn(x, inp_list):
3537*da0073e9SAndroid Build Coastguard Worker            d = DClass(inp_list)
3538*da0073e9SAndroid Build Coastguard Worker            d.sharding_contexts.append(x.sin() + d.a)
3539*da0073e9SAndroid Build Coastguard Worker            return d
3540*da0073e9SAndroid Build Coastguard Worker
3541*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
3542*da0073e9SAndroid Build Coastguard Worker        inp_list1 = [1, 2, 3]
3543*da0073e9SAndroid Build Coastguard Worker        inp_list2 = [2, 3, 4]
3544*da0073e9SAndroid Build Coastguard Worker        inp_list3 = [1, 2]
3545*da0073e9SAndroid Build Coastguard Worker        ref1 = fn(x, inp_list1)
3546*da0073e9SAndroid Build Coastguard Worker        ref2 = fn(x, inp_list2)
3547*da0073e9SAndroid Build Coastguard Worker        ref3 = fn(x, inp_list3)
3548*da0073e9SAndroid Build Coastguard Worker
3549*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
3550*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, fullgraph=True)
3551*da0073e9SAndroid Build Coastguard Worker
3552*da0073e9SAndroid Build Coastguard Worker        opt_ret1 = opt_fn(x, inp_list1)
3553*da0073e9SAndroid Build Coastguard Worker        opt_ret2 = opt_fn(x, inp_list2)
3554*da0073e9SAndroid Build Coastguard Worker        opt_ret3 = opt_fn(x, inp_list3)
3555*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref1.sharding_contexts, opt_ret1.sharding_contexts)
3556*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref2.sharding_contexts, opt_ret2.sharding_contexts)
3557*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref3.sharding_contexts, opt_ret3.sharding_contexts)
3558*da0073e9SAndroid Build Coastguard Worker
3559*da0073e9SAndroid Build Coastguard Worker    def test_list_index(self):
3560*da0073e9SAndroid Build Coastguard Worker        for i, list_type in enumerate(
3561*da0073e9SAndroid Build Coastguard Worker            (
3562*da0073e9SAndroid Build Coastguard Worker                list,
3563*da0073e9SAndroid Build Coastguard Worker                tuple,
3564*da0073e9SAndroid Build Coastguard Worker                torch.Size,
3565*da0073e9SAndroid Build Coastguard Worker                collections.deque,
3566*da0073e9SAndroid Build Coastguard Worker                namedtuple("FourElems", "one two three four", defaults=[0, 0, 0, 0]),
3567*da0073e9SAndroid Build Coastguard Worker            )
3568*da0073e9SAndroid Build Coastguard Worker        ):
3569*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
3570*da0073e9SAndroid Build Coastguard Worker            for index in ([], [2], [0, 3]):
3571*da0073e9SAndroid Build Coastguard Worker
3572*da0073e9SAndroid Build Coastguard Worker                def f(t):
3573*da0073e9SAndroid Build Coastguard Worker                    if i == 4:  # namedtuple
3574*da0073e9SAndroid Build Coastguard Worker                        xs = list_type(1, 2, 3, 4)
3575*da0073e9SAndroid Build Coastguard Worker                    else:
3576*da0073e9SAndroid Build Coastguard Worker                        xs = list_type([1, 2, 3, 4])
3577*da0073e9SAndroid Build Coastguard Worker                    res = xs.index(3, *index)
3578*da0073e9SAndroid Build Coastguard Worker                    return t + res
3579*da0073e9SAndroid Build Coastguard Worker
3580*da0073e9SAndroid Build Coastguard Worker                res = torch._dynamo.optimize(backend="eager", nopython=True)(f)(
3581*da0073e9SAndroid Build Coastguard Worker                    torch.zeros(1)
3582*da0073e9SAndroid Build Coastguard Worker                )
3583*da0073e9SAndroid Build Coastguard Worker
3584*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, torch.tensor([2.0]))
3585*da0073e9SAndroid Build Coastguard Worker
3586*da0073e9SAndroid Build Coastguard Worker    def test_list_index_not_found(self):
3587*da0073e9SAndroid Build Coastguard Worker        def f(t):
3588*da0073e9SAndroid Build Coastguard Worker            xs = ["bar", "foo", "baz", "buzz"]
3589*da0073e9SAndroid Build Coastguard Worker            res = xs.index("non-existent")
3590*da0073e9SAndroid Build Coastguard Worker            return t + res
3591*da0073e9SAndroid Build Coastguard Worker
3592*da0073e9SAndroid Build Coastguard Worker        # Raising ValueError from item not found is unsupported
3593*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(
3594*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.exc.Unsupported,
3595*da0073e9SAndroid Build Coastguard Worker        ):
3596*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.optimize(backend="eager", nopython=True)(f)(torch.zeros(1))
3597*da0073e9SAndroid Build Coastguard Worker
3598*da0073e9SAndroid Build Coastguard Worker    def test_list_index_tensor_unsupported(self):
3599*da0073e9SAndroid Build Coastguard Worker        for index in ([], [2], [0, 3]):
3600*da0073e9SAndroid Build Coastguard Worker
3601*da0073e9SAndroid Build Coastguard Worker            def f(t):
3602*da0073e9SAndroid Build Coastguard Worker                xs = [torch.tensor([i]) for i in range(4)]
3603*da0073e9SAndroid Build Coastguard Worker                res = xs.index(torch.tensor([2]), *index)
3604*da0073e9SAndroid Build Coastguard Worker                return t + res
3605*da0073e9SAndroid Build Coastguard Worker
3606*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
3607*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.exc.UserError, "Dynamic control flow is not supported"
3608*da0073e9SAndroid Build Coastguard Worker            ):
3609*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.optimize(backend="eager", nopython=True)(f)(
3610*da0073e9SAndroid Build Coastguard Worker                    torch.zeros(1)
3611*da0073e9SAndroid Build Coastguard Worker                )
3612*da0073e9SAndroid Build Coastguard Worker
3613*da0073e9SAndroid Build Coastguard Worker    def test_hf_xsoftmax_inference(self):
3614*da0073e9SAndroid Build Coastguard Worker        def fn(input, mask):
3615*da0073e9SAndroid Build Coastguard Worker            return XSoftmax.apply(input + 1, mask, 1) + 2
3616*da0073e9SAndroid Build Coastguard Worker
3617*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
3618*da0073e9SAndroid Build Coastguard Worker
3619*da0073e9SAndroid Build Coastguard Worker        inputs = [
3620*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 10),
3621*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 10) < 0,
3622*da0073e9SAndroid Build Coastguard Worker        ]
3623*da0073e9SAndroid Build Coastguard Worker        expected = fn(*inputs)
3624*da0073e9SAndroid Build Coastguard Worker        actual = fn_opt(*inputs)
3625*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(actual, expected))
3626*da0073e9SAndroid Build Coastguard Worker
3627*da0073e9SAndroid Build Coastguard Worker    @mock.patch("torch._dynamo.config.guard_nn_modules", True)
3628*da0073e9SAndroid Build Coastguard Worker    def test_hf_xsoftmax_training(self):
3629*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.utils import counters
3630*da0073e9SAndroid Build Coastguard Worker
3631*da0073e9SAndroid Build Coastguard Worker        counters.clear()
3632*da0073e9SAndroid Build Coastguard Worker
3633*da0073e9SAndroid Build Coastguard Worker        def fn(input, mask):
3634*da0073e9SAndroid Build Coastguard Worker            return XSoftmax.apply(input, mask, 1)
3635*da0073e9SAndroid Build Coastguard Worker
3636*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
3637*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(fn, backend=cnt, fullgraph=False)
3638*da0073e9SAndroid Build Coastguard Worker
3639*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1234)
3640*da0073e9SAndroid Build Coastguard Worker        inputs1 = [
3641*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 10, requires_grad=True),
3642*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 10) < 0,
3643*da0073e9SAndroid Build Coastguard Worker        ]
3644*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1234)
3645*da0073e9SAndroid Build Coastguard Worker        inputs2 = [
3646*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 10, requires_grad=True),
3647*da0073e9SAndroid Build Coastguard Worker            torch.randn(4, 10) < 0,
3648*da0073e9SAndroid Build Coastguard Worker        ]
3649*da0073e9SAndroid Build Coastguard Worker
3650*da0073e9SAndroid Build Coastguard Worker        expected = fn(*inputs1)
3651*da0073e9SAndroid Build Coastguard Worker        actual = fn_opt(*inputs2)
3652*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(actual, expected))
3653*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dict(counters["frames"]), {"total": 1, "ok": 1})
3654*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 2)
3655*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
3656*da0073e9SAndroid Build Coastguard Worker        cnt.clear()
3657*da0073e9SAndroid Build Coastguard Worker        counters.clear()
3658*da0073e9SAndroid Build Coastguard Worker
3659*da0073e9SAndroid Build Coastguard Worker        expected.sum().backward()
3660*da0073e9SAndroid Build Coastguard Worker        actual.sum().backward()
3661*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(inputs1[0].grad, inputs2[0].grad))
3662*da0073e9SAndroid Build Coastguard Worker
3663*da0073e9SAndroid Build Coastguard Worker        # currently we don't capture the backwards frame
3664*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 0)
3665*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.op_count, 0)
3666*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dict(counters["frames"]), {})
3667*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dict(counters["graph_break"]), {})
3668*da0073e9SAndroid Build Coastguard Worker
3669*da0073e9SAndroid Build Coastguard Worker    def test_autograd_function_graph_break(self):
3670*da0073e9SAndroid Build Coastguard Worker        class MySin(torch.autograd.Function):
3671*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3672*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
3673*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
3674*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
3675*da0073e9SAndroid Build Coastguard Worker                return x.sin()
3676*da0073e9SAndroid Build Coastguard Worker
3677*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3678*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gx):
3679*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
3680*da0073e9SAndroid Build Coastguard Worker                return gx * x.cos()
3681*da0073e9SAndroid Build Coastguard Worker
3682*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
3683*da0073e9SAndroid Build Coastguard Worker
3684*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
3685*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3686*da0073e9SAndroid Build Coastguard Worker            return MySin.apply(x)
3687*da0073e9SAndroid Build Coastguard Worker
3688*da0073e9SAndroid Build Coastguard Worker        y = fn(x)
3689*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.sin())
3690*da0073e9SAndroid Build Coastguard Worker
3691*da0073e9SAndroid Build Coastguard Worker        (gx,) = torch.autograd.grad(y, x)
3692*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx, x.cos())
3693*da0073e9SAndroid Build Coastguard Worker
3694*da0073e9SAndroid Build Coastguard Worker    def test_jit_trace_errors(self):
3695*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", dynamic=True)
3696*da0073e9SAndroid Build Coastguard Worker        def f(x):
3697*da0073e9SAndroid Build Coastguard Worker            return x + 1
3698*da0073e9SAndroid Build Coastguard Worker
3699*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3700*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(f, torch.randn(3))
3701*da0073e9SAndroid Build Coastguard Worker
3702*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(error_on_nested_jit_trace=False):
3703*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(f, torch.randn(3))
3704*da0073e9SAndroid Build Coastguard Worker
3705*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("assume_static_by_default", False)
3706*da0073e9SAndroid Build Coastguard Worker    def test_tensor_split(self):
3707*da0073e9SAndroid Build Coastguard Worker        def f(x):
3708*da0073e9SAndroid Build Coastguard Worker            return torch.split(x, x.shape[0] // 2, dim=0)[0]
3709*da0073e9SAndroid Build Coastguard Worker
3710*da0073e9SAndroid Build Coastguard Worker        gm, _ = torch._dynamo.export(
3711*da0073e9SAndroid Build Coastguard Worker            f,
3712*da0073e9SAndroid Build Coastguard Worker            aten_graph=True,
3713*da0073e9SAndroid Build Coastguard Worker        )(
3714*da0073e9SAndroid Build Coastguard Worker            torch.zeros(6, 4),
3715*da0073e9SAndroid Build Coastguard Worker        )
3716*da0073e9SAndroid Build Coastguard Worker
3717*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4)))
3718*da0073e9SAndroid Build Coastguard Worker
3719*da0073e9SAndroid Build Coastguard Worker    def test_optim_state_references_cleared(self):
3720*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Linear(2048, 2048, bias=False)
3721*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2048)
3722*da0073e9SAndroid Build Coastguard Worker        state_ref = 0
3723*da0073e9SAndroid Build Coastguard Worker
3724*da0073e9SAndroid Build Coastguard Worker        optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)
3725*da0073e9SAndroid Build Coastguard Worker
3726*da0073e9SAndroid Build Coastguard Worker        def opt_step():
3727*da0073e9SAndroid Build Coastguard Worker            optimizer.step()
3728*da0073e9SAndroid Build Coastguard Worker
3729*da0073e9SAndroid Build Coastguard Worker        compiled_opt_step = torch._dynamo.optimize("eager")(opt_step)
3730*da0073e9SAndroid Build Coastguard Worker
3731*da0073e9SAndroid Build Coastguard Worker        def compiled_model_step(x):
3732*da0073e9SAndroid Build Coastguard Worker            optimizer.zero_grad()
3733*da0073e9SAndroid Build Coastguard Worker            y = model(x)
3734*da0073e9SAndroid Build Coastguard Worker            torch.sum(y).backward()
3735*da0073e9SAndroid Build Coastguard Worker            compiled_opt_step()
3736*da0073e9SAndroid Build Coastguard Worker
3737*da0073e9SAndroid Build Coastguard Worker        compiled_model_step(x)
3738*da0073e9SAndroid Build Coastguard Worker
3739*da0073e9SAndroid Build Coastguard Worker        # Picked "square_avg" arbitrarily to check that
3740*da0073e9SAndroid Build Coastguard Worker        # optimizer state tensors are deallocated
3741*da0073e9SAndroid Build Coastguard Worker        state_ref = weakref.ref(
3742*da0073e9SAndroid Build Coastguard Worker            optimizer.state[optimizer.param_groups[0]["params"][0]]["square_avg"]
3743*da0073e9SAndroid Build Coastguard Worker        )
3744*da0073e9SAndroid Build Coastguard Worker        optimizer = None
3745*da0073e9SAndroid Build Coastguard Worker
3746*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(state_ref())
3747*da0073e9SAndroid Build Coastguard Worker
3748*da0073e9SAndroid Build Coastguard Worker    def test_grad_references_cleared(self):
3749*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Linear(2048, 2048, bias=False)
3750*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2048)
3751*da0073e9SAndroid Build Coastguard Worker        optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)
3752*da0073e9SAndroid Build Coastguard Worker
3753*da0073e9SAndroid Build Coastguard Worker        def opt_step():
3754*da0073e9SAndroid Build Coastguard Worker            optimizer.step()
3755*da0073e9SAndroid Build Coastguard Worker
3756*da0073e9SAndroid Build Coastguard Worker        compiled_opt_step = torch._dynamo.optimize("eager")(opt_step)
3757*da0073e9SAndroid Build Coastguard Worker
3758*da0073e9SAndroid Build Coastguard Worker        def compiled_model_step(x):
3759*da0073e9SAndroid Build Coastguard Worker            optimizer.zero_grad(True)
3760*da0073e9SAndroid Build Coastguard Worker            y = model(x)
3761*da0073e9SAndroid Build Coastguard Worker            torch.sum(y).backward()
3762*da0073e9SAndroid Build Coastguard Worker            compiled_opt_step()
3763*da0073e9SAndroid Build Coastguard Worker
3764*da0073e9SAndroid Build Coastguard Worker        compiled_model_step(x)
3765*da0073e9SAndroid Build Coastguard Worker        param_grad_ref = weakref.ref(next(iter(model.parameters())).grad)
3766*da0073e9SAndroid Build Coastguard Worker        optimizer.zero_grad(True)
3767*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(param_grad_ref())
3768*da0073e9SAndroid Build Coastguard Worker
3769*da0073e9SAndroid Build Coastguard Worker    def test_batch_encoding_clone_inputs(self):
3770*da0073e9SAndroid Build Coastguard Worker        class BatchEncoding(dict):
3771*da0073e9SAndroid Build Coastguard Worker            """
3772*da0073e9SAndroid Build Coastguard Worker            Copied from test_tokenization
3773*da0073e9SAndroid Build Coastguard Worker            """
3774*da0073e9SAndroid Build Coastguard Worker
3775*da0073e9SAndroid Build Coastguard Worker            def __init__(
3776*da0073e9SAndroid Build Coastguard Worker                self,
3777*da0073e9SAndroid Build Coastguard Worker                data,
3778*da0073e9SAndroid Build Coastguard Worker            ):
3779*da0073e9SAndroid Build Coastguard Worker                super().__init__(data)
3780*da0073e9SAndroid Build Coastguard Worker
3781*da0073e9SAndroid Build Coastguard Worker            def __getattr__(self, item: str):
3782*da0073e9SAndroid Build Coastguard Worker                try:
3783*da0073e9SAndroid Build Coastguard Worker                    return self.data[item]
3784*da0073e9SAndroid Build Coastguard Worker                except KeyError as e:
3785*da0073e9SAndroid Build Coastguard Worker                    raise AttributeError from e
3786*da0073e9SAndroid Build Coastguard Worker
3787*da0073e9SAndroid Build Coastguard Worker        encoding = BatchEncoding({"key": torch.rand((1, 4))})
3788*da0073e9SAndroid Build Coastguard Worker        cloned_encoding = torch._dynamo.utils.clone_inputs(encoding)
3789*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(type(cloned_encoding) is not dict)
3790*da0073e9SAndroid Build Coastguard Worker
3791*da0073e9SAndroid Build Coastguard Worker    def test_iadd_graph_break(self):
3792*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3793*da0073e9SAndroid Build Coastguard Worker            a = ()
3794*da0073e9SAndroid Build Coastguard Worker            x = torch.sin(x)
3795*da0073e9SAndroid Build Coastguard Worker            a += (x,)
3796*da0073e9SAndroid Build Coastguard Worker            return a
3797*da0073e9SAndroid Build Coastguard Worker
3798*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
3799*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
3800*da0073e9SAndroid Build Coastguard Worker
3801*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
3802*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
3803*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref, res))
3804*da0073e9SAndroid Build Coastguard Worker
3805*da0073e9SAndroid Build Coastguard Worker    def test_odict_get_item_index_name(self):
3806*da0073e9SAndroid Build Coastguard Worker        d = {float: torch.float32, np.float16: torch.float16}
3807*da0073e9SAndroid Build Coastguard Worker
3808*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
3809*da0073e9SAndroid Build Coastguard Worker        def f(x, y1, y2):
3810*da0073e9SAndroid Build Coastguard Worker            return torch.zeros(5, dtype=d[y1]), torch.zeros(5, dtype=d[y2])
3811*da0073e9SAndroid Build Coastguard Worker
3812*da0073e9SAndroid Build Coastguard Worker        f(torch.zeros(4), float, np.float16)
3813*da0073e9SAndroid Build Coastguard Worker
3814*da0073e9SAndroid Build Coastguard Worker    def test_dedup_global(self):
3815*da0073e9SAndroid Build Coastguard Worker        @torch.compile()
3816*da0073e9SAndroid Build Coastguard Worker        def f():
3817*da0073e9SAndroid Build Coastguard Worker            return _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR
3818*da0073e9SAndroid Build Coastguard Worker
3819*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(), _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR)
3820*da0073e9SAndroid Build Coastguard Worker
3821*da0073e9SAndroid Build Coastguard Worker    def test_randint_out_dynamic(self):
3822*da0073e9SAndroid Build Coastguard Worker        def randint_fn(high, size, out):
3823*da0073e9SAndroid Build Coastguard Worker            return torch.randint(high, size, out=out)
3824*da0073e9SAndroid Build Coastguard Worker
3825*da0073e9SAndroid Build Coastguard Worker        opt_model = torch.compile(randint_fn)
3826*da0073e9SAndroid Build Coastguard Worker
3827*da0073e9SAndroid Build Coastguard Worker        out1 = torch.empty(10, dtype=torch.int32)
3828*da0073e9SAndroid Build Coastguard Worker        opt_model(17, (10,), out1)
3829*da0073e9SAndroid Build Coastguard Worker
3830*da0073e9SAndroid Build Coastguard Worker        out2 = torch.empty(12, dtype=torch.int32)
3831*da0073e9SAndroid Build Coastguard Worker        opt_model(17, (12,), out2)
3832*da0073e9SAndroid Build Coastguard Worker
3833*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
3834*da0073e9SAndroid Build Coastguard Worker    def test_guard_default_device(self):
3835*da0073e9SAndroid Build Coastguard Worker        try:
3836*da0073e9SAndroid Build Coastguard Worker            torch.set_default_device("cuda")
3837*da0073e9SAndroid Build Coastguard Worker
3838*da0073e9SAndroid Build Coastguard Worker            counter = torch._dynamo.testing.CompileCounter()
3839*da0073e9SAndroid Build Coastguard Worker
3840*da0073e9SAndroid Build Coastguard Worker            @torch._dynamo.optimize(counter)
3841*da0073e9SAndroid Build Coastguard Worker            def f():
3842*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3)
3843*da0073e9SAndroid Build Coastguard Worker                return x * 2
3844*da0073e9SAndroid Build Coastguard Worker
3845*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f().device.type, "cuda")
3846*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter.frame_count, 1)
3847*da0073e9SAndroid Build Coastguard Worker
3848*da0073e9SAndroid Build Coastguard Worker            torch.set_default_device("cpu")
3849*da0073e9SAndroid Build Coastguard Worker
3850*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f().device.type, "cpu")
3851*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter.frame_count, 2)
3852*da0073e9SAndroid Build Coastguard Worker
3853*da0073e9SAndroid Build Coastguard Worker        finally:
3854*da0073e9SAndroid Build Coastguard Worker            torch.set_default_device(None)
3855*da0073e9SAndroid Build Coastguard Worker
3856*da0073e9SAndroid Build Coastguard Worker    def test_list_self_reference(self):
3857*da0073e9SAndroid Build Coastguard Worker        # Issue - https://github.com/pytorch/pytorch/issues/100150
3858*da0073e9SAndroid Build Coastguard Worker        root = []
3859*da0073e9SAndroid Build Coastguard Worker        root[:] = [root, root, None, None]
3860*da0073e9SAndroid Build Coastguard Worker
3861*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("eager")
3862*da0073e9SAndroid Build Coastguard Worker        def test_bug():
3863*da0073e9SAndroid Build Coastguard Worker            return root
3864*da0073e9SAndroid Build Coastguard Worker
3865*da0073e9SAndroid Build Coastguard Worker        test_bug()
3866*da0073e9SAndroid Build Coastguard Worker
3867*da0073e9SAndroid Build Coastguard Worker    def test_hf_bigbird_unsqueeze(self):
3868*da0073e9SAndroid Build Coastguard Worker        def torch_bmm_nd(inp_1, inp_2, ndim=None):
3869*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
3870*da0073e9SAndroid Build Coastguard Worker            return torch.bmm(inp1, inp2)
3871*da0073e9SAndroid Build Coastguard Worker
3872*da0073e9SAndroid Build Coastguard Worker        def fn(inp1, inp2, inp3, inp4, c):
3873*da0073e9SAndroid Build Coastguard Worker            a = torch_bmm_nd(inp1, inp2, 4)
3874*da0073e9SAndroid Build Coastguard Worker            a.unsqueeze_(2)
3875*da0073e9SAndroid Build Coastguard Worker            a = a * 2
3876*da0073e9SAndroid Build Coastguard Worker
3877*da0073e9SAndroid Build Coastguard Worker            b = torch_bmm_nd(inp3, inp4, 4)
3878*da0073e9SAndroid Build Coastguard Worker            b.unsqueeze_(2)
3879*da0073e9SAndroid Build Coastguard Worker            l = a + b
3880*da0073e9SAndroid Build Coastguard Worker
3881*da0073e9SAndroid Build Coastguard Worker            out = torch.cat([a, b, c], dim=2)
3882*da0073e9SAndroid Build Coastguard Worker            return out, l
3883*da0073e9SAndroid Build Coastguard Worker
3884*da0073e9SAndroid Build Coastguard Worker        inp1 = torch.rand(1, 64, 448)
3885*da0073e9SAndroid Build Coastguard Worker        inp2 = torch.rand(1, 448, 64)
3886*da0073e9SAndroid Build Coastguard Worker        inp3 = torch.rand(1, 64, 448)
3887*da0073e9SAndroid Build Coastguard Worker        inp4 = torch.rand(1, 448, 64)
3888*da0073e9SAndroid Build Coastguard Worker        c = torch.rand(1, 64, 1, 64)
3889*da0073e9SAndroid Build Coastguard Worker
3890*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
3891*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
3892*da0073e9SAndroid Build Coastguard Worker        opt_fn(inp1, inp2, inp3, inp4, c)
3893*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 3)
3894*da0073e9SAndroid Build Coastguard Worker
3895*da0073e9SAndroid Build Coastguard Worker    def test_torch_variable_type(self):
3896*da0073e9SAndroid Build Coastguard Worker        # from torchvision
3897*da0073e9SAndroid Build Coastguard Worker        def check_type(obj, types_or_checks):
3898*da0073e9SAndroid Build Coastguard Worker            for type_or_check in types_or_checks:
3899*da0073e9SAndroid Build Coastguard Worker                if (
3900*da0073e9SAndroid Build Coastguard Worker                    isinstance(obj, type_or_check)
3901*da0073e9SAndroid Build Coastguard Worker                    if isinstance(type_or_check, type)
3902*da0073e9SAndroid Build Coastguard Worker                    else type_or_check(obj)
3903*da0073e9SAndroid Build Coastguard Worker                ):
3904*da0073e9SAndroid Build Coastguard Worker                    return True
3905*da0073e9SAndroid Build Coastguard Worker            return False
3906*da0073e9SAndroid Build Coastguard Worker
3907*da0073e9SAndroid Build Coastguard Worker        opt_check_type = torch._dynamo.optimize("eager")(check_type)
3908*da0073e9SAndroid Build Coastguard Worker        ref = check_type(torch.randn(4), [torch.Tensor])
3909*da0073e9SAndroid Build Coastguard Worker        res = opt_check_type(torch.randn(4), [torch.Tensor])
3910*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
3911*da0073e9SAndroid Build Coastguard Worker
3912*da0073e9SAndroid Build Coastguard Worker    # Test for https://github.com/pytorch/pytorch/issues/103132
3913*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("assume_static_by_default", False)
3914*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_dynamic_shapes(self):
3915*da0073e9SAndroid Build Coastguard Worker        class Repro(torch.nn.Module):
3916*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3917*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3918*da0073e9SAndroid Build Coastguard Worker
3919*da0073e9SAndroid Build Coastguard Worker            def forward(self, param):
3920*da0073e9SAndroid Build Coastguard Worker                z = torch.matmul(param, param)
3921*da0073e9SAndroid Build Coastguard Worker                return z
3922*da0073e9SAndroid Build Coastguard Worker
3923*da0073e9SAndroid Build Coastguard Worker        model = Repro()
3924*da0073e9SAndroid Build Coastguard Worker        # Need a 3d tensor to actually cause the error:
3925*da0073e9SAndroid Build Coastguard Worker        # we go down a path of the C++ matmul decomp that calls sizes().
3926*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(4, 4, 4, requires_grad=True)
3927*da0073e9SAndroid Build Coastguard Worker        model = torch.compile(model, backend="aot_eager", dynamic=True)
3928*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
3929*da0073e9SAndroid Build Coastguard Worker            model(inp)
3930*da0073e9SAndroid Build Coastguard Worker
3931*da0073e9SAndroid Build Coastguard Worker    def test_kwargs_out_list_variable(self):
3932*da0073e9SAndroid Build Coastguard Worker        class Repro(torch.nn.Module):
3933*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
3934*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3935*da0073e9SAndroid Build Coastguard Worker
3936*da0073e9SAndroid Build Coastguard Worker            def forward(self, param):
3937*da0073e9SAndroid Build Coastguard Worker                z = torch.frexp(**param)
3938*da0073e9SAndroid Build Coastguard Worker                return z
3939*da0073e9SAndroid Build Coastguard Worker
3940*da0073e9SAndroid Build Coastguard Worker        model = Repro()
3941*da0073e9SAndroid Build Coastguard Worker        params = {"input": torch.tensor([[0.0, 1, 2, 4]])}
3942*da0073e9SAndroid Build Coastguard Worker        params["out"] = [
3943*da0073e9SAndroid Build Coastguard Worker            torch.empty(0, dtype=torch.float32),  # mantissa
3944*da0073e9SAndroid Build Coastguard Worker            torch.empty(0, dtype=torch.int32),  # exponent
3945*da0073e9SAndroid Build Coastguard Worker        ]
3946*da0073e9SAndroid Build Coastguard Worker
3947*da0073e9SAndroid Build Coastguard Worker        model = torch.compile(model, backend="eager")
3948*da0073e9SAndroid Build Coastguard Worker        mantissa, exponent = model(params)
3949*da0073e9SAndroid Build Coastguard Worker        ref_mantissa = torch.tensor([[0.0000, 0.5000, 0.5000, 0.5000]])
3950*da0073e9SAndroid Build Coastguard Worker        ref_exponent = torch.tensor([[0, 1, 2, 3]], dtype=torch.int32)
3951*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_mantissa, mantissa)
3952*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_exponent, exponent)
3953*da0073e9SAndroid Build Coastguard Worker
3954*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
3955*da0073e9SAndroid Build Coastguard Worker    def test_split_with_sizes_aot_autograd(self):
3956*da0073e9SAndroid Build Coastguard Worker        def fn(result, split_sizes):
3957*da0073e9SAndroid Build Coastguard Worker            rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist())
3958*da0073e9SAndroid Build Coastguard Worker            return rs
3959*da0073e9SAndroid Build Coastguard Worker
3960*da0073e9SAndroid Build Coastguard Worker        example_inputs = (
3961*da0073e9SAndroid Build Coastguard Worker            torch.randn(32, requires_grad=True),
3962*da0073e9SAndroid Build Coastguard Worker            torch.tensor((7, 16, 9)),
3963*da0073e9SAndroid Build Coastguard Worker        )
3964*da0073e9SAndroid Build Coastguard Worker        actual = torch.compile(fn, fullgraph=True, backend="aot_eager")(*example_inputs)
3965*da0073e9SAndroid Build Coastguard Worker        expected = fn(*example_inputs)
3966*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected)
3967*da0073e9SAndroid Build Coastguard Worker
3968*da0073e9SAndroid Build Coastguard Worker    def test_unspecialized_nn_module_with_torch_variable_attribute(self):
3969*da0073e9SAndroid Build Coastguard Worker        """
3970*da0073e9SAndroid Build Coastguard Worker        In this case self.fn = something that should be a TorchVariable.
3971*da0073e9SAndroid Build Coastguard Worker        When it's not a TorchVariable, dynamo tries to trace through and fails.
3972*da0073e9SAndroid Build Coastguard Worker        This makes sure that the self.fn is handled as a TorchVariable.
3973*da0073e9SAndroid Build Coastguard Worker        """
3974*da0073e9SAndroid Build Coastguard Worker
3975*da0073e9SAndroid Build Coastguard Worker        class UserModule(torch.nn.Module):
3976*da0073e9SAndroid Build Coastguard Worker            torchdynamo_force_dynamic = True  # forced to be a UnspecializedNNModule
3977*da0073e9SAndroid Build Coastguard Worker
3978*da0073e9SAndroid Build Coastguard Worker            def __init__(self, fn):
3979*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3980*da0073e9SAndroid Build Coastguard Worker                self.fn = fn
3981*da0073e9SAndroid Build Coastguard Worker
3982*da0073e9SAndroid Build Coastguard Worker            def forward(self, **inp):
3983*da0073e9SAndroid Build Coastguard Worker                return self.fn(**inp)
3984*da0073e9SAndroid Build Coastguard Worker
3985*da0073e9SAndroid Build Coastguard Worker        inputs = {
3986*da0073e9SAndroid Build Coastguard Worker            "input": torch.randn([2, 9]).uniform_(0, 1),
3987*da0073e9SAndroid Build Coastguard Worker            "target": torch.randn([2, 9]).uniform_(0, 1),
3988*da0073e9SAndroid Build Coastguard Worker            "reduction": "mean",
3989*da0073e9SAndroid Build Coastguard Worker        }
3990*da0073e9SAndroid Build Coastguard Worker
3991*da0073e9SAndroid Build Coastguard Worker        mod = UserModule(torch.nn.functional.binary_cross_entropy)
3992*da0073e9SAndroid Build Coastguard Worker        ref = mod(**inputs)
3993*da0073e9SAndroid Build Coastguard Worker        res = torch._dynamo.optimize("eager", nopython=True)(mod)(**inputs)
3994*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
3995*da0073e9SAndroid Build Coastguard Worker
3996*da0073e9SAndroid Build Coastguard Worker    def test_call_finally_python_3_8(self):
3997*da0073e9SAndroid Build Coastguard Worker        # Issue - https://github.com/pytorch/pytorch/issues/97811
3998*da0073e9SAndroid Build Coastguard Worker        def make_fn(g):
3999*da0073e9SAndroid Build Coastguard Worker            def fn():
4000*da0073e9SAndroid Build Coastguard Worker                while True:
4001*da0073e9SAndroid Build Coastguard Worker                    try:
4002*da0073e9SAndroid Build Coastguard Worker                        print(g)
4003*da0073e9SAndroid Build Coastguard Worker                        break
4004*da0073e9SAndroid Build Coastguard Worker                    except Exception as _:
4005*da0073e9SAndroid Build Coastguard Worker                        break
4006*da0073e9SAndroid Build Coastguard Worker
4007*da0073e9SAndroid Build Coastguard Worker            return torch.compile(fn, backend="eager")
4008*da0073e9SAndroid Build Coastguard Worker
4009*da0073e9SAndroid Build Coastguard Worker        make_fn(None)()
4010*da0073e9SAndroid Build Coastguard Worker
4011*da0073e9SAndroid Build Coastguard Worker    def test_call_finally_python_3_8_2(self):
4012*da0073e9SAndroid Build Coastguard Worker        def f(x):
4013*da0073e9SAndroid Build Coastguard Worker            while x:
4014*da0073e9SAndroid Build Coastguard Worker                try:
4015*da0073e9SAndroid Build Coastguard Worker                    pass
4016*da0073e9SAndroid Build Coastguard Worker                except Exception as _:
4017*da0073e9SAndroid Build Coastguard Worker                    continue
4018*da0073e9SAndroid Build Coastguard Worker
4019*da0073e9SAndroid Build Coastguard Worker        torch.compile(f, backend="eager")(0)
4020*da0073e9SAndroid Build Coastguard Worker
4021*da0073e9SAndroid Build Coastguard Worker    def test_call_finally_opcode_python_3_8(self):
4022*da0073e9SAndroid Build Coastguard Worker        def fn():
4023*da0073e9SAndroid Build Coastguard Worker            try:
4024*da0073e9SAndroid Build Coastguard Worker                return torch.zeros(4)
4025*da0073e9SAndroid Build Coastguard Worker            finally:
4026*da0073e9SAndroid Build Coastguard Worker                return torch.ones(4)  # noqa: SIM107, B012
4027*da0073e9SAndroid Build Coastguard Worker
4028*da0073e9SAndroid Build Coastguard Worker        result = torch.compile(fn, backend="aot_eager")()
4029*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, torch.ones(4))
4030*da0073e9SAndroid Build Coastguard Worker
4031*da0073e9SAndroid Build Coastguard Worker    def test_string_format(self):
4032*da0073e9SAndroid Build Coastguard Worker        s = "temp{i}"
4033*da0073e9SAndroid Build Coastguard Worker
4034*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
4035*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4036*da0073e9SAndroid Build Coastguard Worker            if s.format(i=4) == "temp4":
4037*da0073e9SAndroid Build Coastguard Worker                return torch.sin(x)
4038*da0073e9SAndroid Build Coastguard Worker            return torch.cos(x)
4039*da0073e9SAndroid Build Coastguard Worker
4040*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
4041*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), torch.sin(x))
4042*da0073e9SAndroid Build Coastguard Worker
4043*da0073e9SAndroid Build Coastguard Worker    # Repro of torch._dynamo.exc.InternalTorchDynamoError: 'NoneType' object has no attribute 'guards'
4044*da0073e9SAndroid Build Coastguard Worker    # due to bad empty list handling
4045*da0073e9SAndroid Build Coastguard Worker    def test_empty_list_contains_with_jump(self):
4046*da0073e9SAndroid Build Coastguard Worker        def fn(x, l):
4047*da0073e9SAndroid Build Coastguard Worker            if x in l:
4048*da0073e9SAndroid Build Coastguard Worker                return x.cos()
4049*da0073e9SAndroid Build Coastguard Worker            return x.sin()
4050*da0073e9SAndroid Build Coastguard Worker
4051*da0073e9SAndroid Build Coastguard Worker        counter = CompileCounter()
4052*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch._dynamo.optimize(counter)(fn)(torch.randn([2, 2]), [])
4053*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
4054*da0073e9SAndroid Build Coastguard Worker
4055*da0073e9SAndroid Build Coastguard Worker    def test_graph_break_on_jit_isinstance(self):
4056*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
4057*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4058*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(x, List[str]):
4059*da0073e9SAndroid Build Coastguard Worker                return x * 2
4060*da0073e9SAndroid Build Coastguard Worker            return x
4061*da0073e9SAndroid Build Coastguard Worker
4062*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
4063*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
4064*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(fn(x), opt_fn(x)))
4065*da0073e9SAndroid Build Coastguard Worker
4066*da0073e9SAndroid Build Coastguard Worker    def test_add_sub_alpha_out(self):
4067*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, 3, 4)
4068*da0073e9SAndroid Build Coastguard Worker        other = 1
4069*da0073e9SAndroid Build Coastguard Worker        alpha = 2
4070*da0073e9SAndroid Build Coastguard Worker        for op in [torch.add, torch.sub]:
4071*da0073e9SAndroid Build Coastguard Worker            out = torch.zeros(2, 3, 4)
4072*da0073e9SAndroid Build Coastguard Worker            compile_out = torch.zeros(2, 3, 4)
4073*da0073e9SAndroid Build Coastguard Worker            op(inp, other, alpha=alpha, out=out)
4074*da0073e9SAndroid Build Coastguard Worker            compiled_fn = torch.compile(op, dynamic=True)
4075*da0073e9SAndroid Build Coastguard Worker            compiled_fn(inp, other, alpha=alpha, out=compile_out)
4076*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(out, compile_out))
4077*da0073e9SAndroid Build Coastguard Worker
4078*da0073e9SAndroid Build Coastguard Worker    def test_negative_shape_guard(self):
4079*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4080*da0073e9SAndroid Build Coastguard Worker            if x.size() != (5, 1, 2, 3):
4081*da0073e9SAndroid Build Coastguard Worker                return x.cos()
4082*da0073e9SAndroid Build Coastguard Worker            return x.sin()
4083*da0073e9SAndroid Build Coastguard Worker
4084*da0073e9SAndroid Build Coastguard Worker        counter = torch._dynamo.testing.CompileCounter()
4085*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend=counter, dynamic=True)
4086*da0073e9SAndroid Build Coastguard Worker
4087*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 1, 3, 4)
4088*da0073e9SAndroid Build Coastguard Worker        x2 = torch.ones(5, 1, 2, 3)
4089*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
4090*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x2), opt_fn(x2))
4091*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
4092*da0073e9SAndroid Build Coastguard Worker
4093*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
4094*da0073e9SAndroid Build Coastguard Worker    def test_deferred_runtime_asserts(self):
4095*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True)
4096*da0073e9SAndroid Build Coastguard Worker        def f(x):
4097*da0073e9SAndroid Build Coastguard Worker            y = x.item()
4098*da0073e9SAndroid Build Coastguard Worker            torch._check_is_size(y)
4099*da0073e9SAndroid Build Coastguard Worker            if y >= 0:
4100*da0073e9SAndroid Build Coastguard Worker                return x * 2
4101*da0073e9SAndroid Build Coastguard Worker            else:
4102*da0073e9SAndroid Build Coastguard Worker                return x * 3
4103*da0073e9SAndroid Build Coastguard Worker
4104*da0073e9SAndroid Build Coastguard Worker        f(torch.tensor([3]))
4105*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: f(torch.tensor([-2])))
4106*da0073e9SAndroid Build Coastguard Worker
4107*da0073e9SAndroid Build Coastguard Worker    def test_addr_alpha_beta_out(self):
4108*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(2, 3)
4109*da0073e9SAndroid Build Coastguard Worker        vec1 = torch.randn(2)
4110*da0073e9SAndroid Build Coastguard Worker        vec2 = torch.randn(3)
4111*da0073e9SAndroid Build Coastguard Worker        alpha = 2
4112*da0073e9SAndroid Build Coastguard Worker        beta = 5
4113*da0073e9SAndroid Build Coastguard Worker
4114*da0073e9SAndroid Build Coastguard Worker        out = torch.zeros(2, 3)
4115*da0073e9SAndroid Build Coastguard Worker        compile_out = torch.zeros(2, 3)
4116*da0073e9SAndroid Build Coastguard Worker
4117*da0073e9SAndroid Build Coastguard Worker        torch.addr(inp, vec1, vec2, alpha=alpha, beta=beta, out=out)
4118*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch.compile(torch.addr, dynamic=True)
4119*da0073e9SAndroid Build Coastguard Worker        compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out)
4120*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(out, compile_out))
4121*da0073e9SAndroid Build Coastguard Worker
4122*da0073e9SAndroid Build Coastguard Worker    def test_setattr_requires_grad_graph_breaks(self):
4123*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4124*da0073e9SAndroid Build Coastguard Worker            z = x + 4
4125*da0073e9SAndroid Build Coastguard Worker            x.requires_grad = True
4126*da0073e9SAndroid Build Coastguard Worker            y = x * z
4127*da0073e9SAndroid Build Coastguard Worker            return y
4128*da0073e9SAndroid Build Coastguard Worker
4129*da0073e9SAndroid Build Coastguard Worker        for backend in ["count", "eager", "aot_eager"]:
4130*da0073e9SAndroid Build Coastguard Worker            if backend == "count":
4131*da0073e9SAndroid Build Coastguard Worker                backend = CompileCounter()
4132*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch.compile(fn, backend=backend)
4133*da0073e9SAndroid Build Coastguard Worker
4134*da0073e9SAndroid Build Coastguard Worker            eager = torch.zeros(5)
4135*da0073e9SAndroid Build Coastguard Worker            compiled = eager.clone()
4136*da0073e9SAndroid Build Coastguard Worker
4137*da0073e9SAndroid Build Coastguard Worker            out_eager = fn(eager)
4138*da0073e9SAndroid Build Coastguard Worker            out_opt = opt_fn(compiled)
4139*da0073e9SAndroid Build Coastguard Worker
4140*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_eager, out_opt)
4141*da0073e9SAndroid Build Coastguard Worker
4142*da0073e9SAndroid Build Coastguard Worker            out_eager.sum().backward()
4143*da0073e9SAndroid Build Coastguard Worker            out_opt.sum().backward()
4144*da0073e9SAndroid Build Coastguard Worker
4145*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(eager, compiled)
4146*da0073e9SAndroid Build Coastguard Worker            if isinstance(backend, CompileCounter):
4147*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(backend.frame_count, 2)  # graph breaks
4148*da0073e9SAndroid Build Coastguard Worker
4149*da0073e9SAndroid Build Coastguard Worker    def test_dynamic_shapes_double_not_equal(self):
4150*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/113393
4151*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4152*da0073e9SAndroid Build Coastguard Worker            if x.size() != (5, 1, 2, 3):
4153*da0073e9SAndroid Build Coastguard Worker                return x.cos()
4154*da0073e9SAndroid Build Coastguard Worker            return x.sin()
4155*da0073e9SAndroid Build Coastguard Worker
4156*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
4157*da0073e9SAndroid Build Coastguard Worker
4158*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 1, 2, 3)
4159*da0073e9SAndroid Build Coastguard Worker        x2 = torch.ones(5, 1, 3, 4)
4160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
4161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x2), opt_fn(x2))
4162*da0073e9SAndroid Build Coastguard Worker
4163*da0073e9SAndroid Build Coastguard Worker    def test_inductor_no_recursionerror_on_for_loops(self):
4164*da0073e9SAndroid Build Coastguard Worker        def forward(x):
4165*da0073e9SAndroid Build Coastguard Worker            for _ in range(1000):
4166*da0073e9SAndroid Build Coastguard Worker                x = 1.0 * x
4167*da0073e9SAndroid Build Coastguard Worker            return x
4168*da0073e9SAndroid Build Coastguard Worker
4169*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
4170*da0073e9SAndroid Build Coastguard Worker            same(torch.compile(forward)(torch.tensor([1.0])), torch.tensor([1.0]))
4171*da0073e9SAndroid Build Coastguard Worker        )
4172*da0073e9SAndroid Build Coastguard Worker
4173*da0073e9SAndroid Build Coastguard Worker    def test_user_defined_object_callable(self):
4174*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/114019
4175*da0073e9SAndroid Build Coastguard Worker        class MyCallable:
4176*da0073e9SAndroid Build Coastguard Worker            def __call__(self, x):
4177*da0073e9SAndroid Build Coastguard Worker                return x + 1
4178*da0073e9SAndroid Build Coastguard Worker
4179*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4180*da0073e9SAndroid Build Coastguard Worker            # Create in graph - will not have source
4181*da0073e9SAndroid Build Coastguard Worker            return MyCallable()(x)
4182*da0073e9SAndroid Build Coastguard Worker
4183*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
4184*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn_opt(torch.zeros(1)), fn(torch.zeros(1)))
4185*da0073e9SAndroid Build Coastguard Worker
4186*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(log_compilation_metrics=True)
4187*da0073e9SAndroid Build Coastguard Worker    def test_many_views_with_mutation(self):
4188*da0073e9SAndroid Build Coastguard Worker        # When symbolic storage offsets were added in #113734, tensors_definitely_do_not_overlap
4189*da0073e9SAndroid Build Coastguard Worker        # began adding shape guards - a quadratic amount relative to the number of inputs.
4190*da0073e9SAndroid Build Coastguard Worker        # Test this configuration, and test that a reasonable number of guards are added.
4191*da0073e9SAndroid Build Coastguard Worker        # Note, when dynamic shapes are turned on, this test fails and we still get quadratic guards.
4192*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4193*da0073e9SAndroid Build Coastguard Worker            x[0].relu_()
4194*da0073e9SAndroid Build Coastguard Worker            return torch.cat(x).sum()
4195*da0073e9SAndroid Build Coastguard Worker
4196*da0073e9SAndroid Build Coastguard Worker        AMT = 32
4197*da0073e9SAndroid Build Coastguard Worker        src = torch.rand(16 * (AMT + 1))
4198*da0073e9SAndroid Build Coastguard Worker
4199*da0073e9SAndroid Build Coastguard Worker        x = [src.as_strided((4, 4), (4, 1), 3 + 16 * i) for i in range(AMT)]
4200*da0073e9SAndroid Build Coastguard Worker
4201*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
4202*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.utils.clear_compilation_metrics()
4203*da0073e9SAndroid Build Coastguard Worker
4204*da0073e9SAndroid Build Coastguard Worker        res = torch.compile(fn, backend="aot_eager")(x)
4205*da0073e9SAndroid Build Coastguard Worker
4206*da0073e9SAndroid Build Coastguard Worker        all_metrics = torch._dynamo.utils.get_compilation_metrics()
4207*da0073e9SAndroid Build Coastguard Worker
4208*da0073e9SAndroid Build Coastguard Worker        total_guards = sum(metric.guard_count for metric in all_metrics)
4209*da0073e9SAndroid Build Coastguard Worker        self.assertLess(total_guards, AMT * 8)
4210*da0073e9SAndroid Build Coastguard Worker
4211*da0073e9SAndroid Build Coastguard Worker        total_shape_env_guards = sum(
4212*da0073e9SAndroid Build Coastguard Worker            metric.shape_env_guard_count for metric in all_metrics
4213*da0073e9SAndroid Build Coastguard Worker        )
4214*da0073e9SAndroid Build Coastguard Worker        self.assertLess(total_shape_env_guards, AMT * 8)
4215*da0073e9SAndroid Build Coastguard Worker
4216*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/118799
4217*da0073e9SAndroid Build Coastguard Worker    def test_subclass_graph_output_repro(self):
4218*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.allow_in_graph
4219*da0073e9SAndroid Build Coastguard Worker        def to_subclass(x):
4220*da0073e9SAndroid Build Coastguard Worker            return TwoTensor(x.clone(), x.clone())
4221*da0073e9SAndroid Build Coastguard Worker
4222*da0073e9SAndroid Build Coastguard Worker        def f(x):
4223*da0073e9SAndroid Build Coastguard Worker            tmp_subclass = to_subclass(x)
4224*da0073e9SAndroid Build Coastguard Worker            return tmp_subclass.view(-1)
4225*da0073e9SAndroid Build Coastguard Worker
4226*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2)
4227*da0073e9SAndroid Build Coastguard Worker        out_ref = f(x)
4228*da0073e9SAndroid Build Coastguard Worker        out_test = torch.compile(f, backend="aot_eager")(x)
4229*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
4230*da0073e9SAndroid Build Coastguard Worker
4231*da0073e9SAndroid Build Coastguard Worker    def test_numpy_tobytes_no_error(self):
4232*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4233*da0073e9SAndroid Build Coastguard Worker            x += 1
4234*da0073e9SAndroid Build Coastguard Worker            z = x.tobytes()
4235*da0073e9SAndroid Build Coastguard Worker            x += 1
4236*da0073e9SAndroid Build Coastguard Worker            return z
4237*da0073e9SAndroid Build Coastguard Worker
4238*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
4239*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
4240*da0073e9SAndroid Build Coastguard Worker        opt_arg, arg = np.array([1, 2]), np.array([1, 2])
4241*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(opt_arg), fn(arg))
4242*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
4243*da0073e9SAndroid Build Coastguard Worker
4244*da0073e9SAndroid Build Coastguard Worker    def test_numpy_not_ndarray_recompiles(self):
4245*da0073e9SAndroid Build Coastguard Worker        import torch
4246*da0073e9SAndroid Build Coastguard Worker
4247*da0073e9SAndroid Build Coastguard Worker        def fn(x=None):
4248*da0073e9SAndroid Build Coastguard Worker            if x is None:
4249*da0073e9SAndroid Build Coastguard Worker                x = np.ones(3)
4250*da0073e9SAndroid Build Coastguard Worker            elif isinstance(x, int):
4251*da0073e9SAndroid Build Coastguard Worker                x = np.ones(6)
4252*da0073e9SAndroid Build Coastguard Worker            elif isinstance(x, str):
4253*da0073e9SAndroid Build Coastguard Worker                x = np.ones(9)
4254*da0073e9SAndroid Build Coastguard Worker            return x**2
4255*da0073e9SAndroid Build Coastguard Worker
4256*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
4257*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
4258*da0073e9SAndroid Build Coastguard Worker
4259*da0073e9SAndroid Build Coastguard Worker        x = np.zeros((2, 2))
4260*da0073e9SAndroid Build Coastguard Worker
4261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x), fn(x))
4262*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
4263*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(), fn())
4264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
4265*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(10), fn(10))
4266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 3)
4267*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn("10"), fn("10"))
4268*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 4)
4269*da0073e9SAndroid Build Coastguard Worker
4270*da0073e9SAndroid Build Coastguard Worker    @parametrize(
4271*da0073e9SAndroid Build Coastguard Worker        "backend",
4272*da0073e9SAndroid Build Coastguard Worker        ["eager", "aot_eager", "inductor"],
4273*da0073e9SAndroid Build Coastguard Worker    )
4274*da0073e9SAndroid Build Coastguard Worker    @parametrize(
4275*da0073e9SAndroid Build Coastguard Worker        "func_name",
4276*da0073e9SAndroid Build Coastguard Worker        ["func1", "func2", "func3"],
4277*da0073e9SAndroid Build Coastguard Worker    )
4278*da0073e9SAndroid Build Coastguard Worker    def test_tensor_set_data(self, backend, func_name):
4279*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/113030
4280*da0073e9SAndroid Build Coastguard Worker        def func1(x, y):
4281*da0073e9SAndroid Build Coastguard Worker            x.data = y
4282*da0073e9SAndroid Build Coastguard Worker            x.add_(1)
4283*da0073e9SAndroid Build Coastguard Worker            return x
4284*da0073e9SAndroid Build Coastguard Worker
4285*da0073e9SAndroid Build Coastguard Worker        def func2(x, y):
4286*da0073e9SAndroid Build Coastguard Worker            x.data = y
4287*da0073e9SAndroid Build Coastguard Worker            y.data = torch.zeros([0])
4288*da0073e9SAndroid Build Coastguard Worker            return x
4289*da0073e9SAndroid Build Coastguard Worker
4290*da0073e9SAndroid Build Coastguard Worker        def func3(x, y):
4291*da0073e9SAndroid Build Coastguard Worker            z = x
4292*da0073e9SAndroid Build Coastguard Worker            x.data = y
4293*da0073e9SAndroid Build Coastguard Worker            y.data = torch.zeros([0])
4294*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(x is z)
4295*da0073e9SAndroid Build Coastguard Worker
4296*da0073e9SAndroid Build Coastguard Worker        funcs = {"func1": func1, "func2": func2, "func3": func3}
4297*da0073e9SAndroid Build Coastguard Worker        func = funcs[func_name]
4298*da0073e9SAndroid Build Coastguard Worker
4299*da0073e9SAndroid Build Coastguard Worker        if backend != "eager" and func is func1:
4300*da0073e9SAndroid Build Coastguard Worker            # add_ not working w/ aot_autograd?
4301*da0073e9SAndroid Build Coastguard Worker            return
4302*da0073e9SAndroid Build Coastguard Worker
4303*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
4304*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
4305*da0073e9SAndroid Build Coastguard Worker
4306*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch.compile(func, backend=cnt, fullgraph=True)
4307*da0073e9SAndroid Build Coastguard Worker        requires_grad = func is not func1
4308*da0073e9SAndroid Build Coastguard Worker        for i in range(0, 5):
4309*da0073e9SAndroid Build Coastguard Worker            # Inputs
4310*da0073e9SAndroid Build Coastguard Worker            eager_a = torch.ones([6], requires_grad=requires_grad)
4311*da0073e9SAndroid Build Coastguard Worker            compiled_a = torch.ones([6], requires_grad=requires_grad)
4312*da0073e9SAndroid Build Coastguard Worker
4313*da0073e9SAndroid Build Coastguard Worker            eager_b = torch.ones([6], requires_grad=requires_grad)
4314*da0073e9SAndroid Build Coastguard Worker            compiled_b = torch.ones([6], requires_grad=requires_grad)
4315*da0073e9SAndroid Build Coastguard Worker
4316*da0073e9SAndroid Build Coastguard Worker            # Eager
4317*da0073e9SAndroid Build Coastguard Worker            out_eager = func(eager_a, eager_b)
4318*da0073e9SAndroid Build Coastguard Worker            # Compiled
4319*da0073e9SAndroid Build Coastguard Worker            out_compiled = compiled_fn(compiled_a, compiled_b)
4320*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(eager_a, compiled_a)
4321*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(eager_b, compiled_b)
4322*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.equal(out_eager, out_compiled))
4323*da0073e9SAndroid Build Coastguard Worker
4324*da0073e9SAndroid Build Coastguard Worker            # func1 hits a leaf Variable that requires grad is being used in an in-place operation
4325*da0073e9SAndroid Build Coastguard Worker            if requires_grad:
4326*da0073e9SAndroid Build Coastguard Worker                bwd_inp_eager = torch.randn([6])
4327*da0073e9SAndroid Build Coastguard Worker                bwd_inp_compiled = torch.clone(bwd_inp_eager)
4328*da0073e9SAndroid Build Coastguard Worker                eager_a.backward(bwd_inp_eager)
4329*da0073e9SAndroid Build Coastguard Worker                compiled_a.backward(bwd_inp_compiled)
4330*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(eager_a.grad, compiled_a.grad)
4331*da0073e9SAndroid Build Coastguard Worker
4332*da0073e9SAndroid Build Coastguard Worker        # Prove guarding works - we run the compiled_fn 5 times
4333*da0073e9SAndroid Build Coastguard Worker        # frame_count should stay at 1.
4334*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
4335*da0073e9SAndroid Build Coastguard Worker
4336*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
4337*da0073e9SAndroid Build Coastguard Worker        TEST_WITH_ROCM or not PLATFORM_SUPPORTS_FLASH_ATTENTION,
4338*da0073e9SAndroid Build Coastguard Worker        "flash attention not supported",
4339*da0073e9SAndroid Build Coastguard Worker    )
4340*da0073e9SAndroid Build Coastguard Worker    def test_flash_attn_backward_mixed_strides(self):
4341*da0073e9SAndroid Build Coastguard Worker        # in this repro, "grad_out" and "value" are transposed tensors,
4342*da0073e9SAndroid Build Coastguard Worker        # but "key" and "value" are contiguous
4343*da0073e9SAndroid Build Coastguard Worker        def gen_inputs(device):
4344*da0073e9SAndroid Build Coastguard Worker            return (
4345*da0073e9SAndroid Build Coastguard Worker                torch.randn(
4346*da0073e9SAndroid Build Coastguard Worker                    2, 513, 16, 64, dtype=torch.float16, device=device
4347*da0073e9SAndroid Build Coastguard Worker                ).transpose(1, 2),
4348*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
4349*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
4350*da0073e9SAndroid Build Coastguard Worker                torch.randn(
4351*da0073e9SAndroid Build Coastguard Worker                    2, 513, 16, 64, dtype=torch.float16, device=device
4352*da0073e9SAndroid Build Coastguard Worker                ).transpose(1, 2),
4353*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
4354*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 16, 513, device=device),
4355*da0073e9SAndroid Build Coastguard Worker                None,
4356*da0073e9SAndroid Build Coastguard Worker                None,
4357*da0073e9SAndroid Build Coastguard Worker                513,
4358*da0073e9SAndroid Build Coastguard Worker                513,
4359*da0073e9SAndroid Build Coastguard Worker                0.0,
4360*da0073e9SAndroid Build Coastguard Worker                False,
4361*da0073e9SAndroid Build Coastguard Worker                torch.tensor(1, dtype=torch.int64),
4362*da0073e9SAndroid Build Coastguard Worker                torch.tensor(1, dtype=torch.int64),
4363*da0073e9SAndroid Build Coastguard Worker            )
4364*da0073e9SAndroid Build Coastguard Worker
4365*da0073e9SAndroid Build Coastguard Worker        inps_cuda = gen_inputs("cuda")
4366*da0073e9SAndroid Build Coastguard Worker        inps_meta = gen_inputs("meta")
4367*da0073e9SAndroid Build Coastguard Worker        (
4368*da0073e9SAndroid Build Coastguard Worker            out1_ref,
4369*da0073e9SAndroid Build Coastguard Worker            out2_ref,
4370*da0073e9SAndroid Build Coastguard Worker            out3_ref,
4371*da0073e9SAndroid Build Coastguard Worker        ) = torch.ops.aten._scaled_dot_product_flash_attention_backward(
4372*da0073e9SAndroid Build Coastguard Worker            *inps_cuda, scale=0.125
4373*da0073e9SAndroid Build Coastguard Worker        )
4374*da0073e9SAndroid Build Coastguard Worker        from torch._meta_registrations import meta__scaled_dot_product_flash_backward
4375*da0073e9SAndroid Build Coastguard Worker
4376*da0073e9SAndroid Build Coastguard Worker        out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward(
4377*da0073e9SAndroid Build Coastguard Worker            *inps_meta, scale=0.125
4378*da0073e9SAndroid Build Coastguard Worker        )
4379*da0073e9SAndroid Build Coastguard Worker
4380*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1_ref.shape, out1_test.shape)
4381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1_ref.stride(), out1_test.stride())
4382*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out2_ref.shape, out2_test.shape)
4383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out2_ref.stride(), out2_test.stride())
4384*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out3_ref.shape, out3_test.shape)
4385*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out3_ref.stride(), out3_test.stride())
4386*da0073e9SAndroid Build Coastguard Worker
4387*da0073e9SAndroid Build Coastguard Worker    def test_user_ctor_ctx_manager(self):
4388*da0073e9SAndroid Build Coastguard Worker        class UserCtxManager:
4389*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
4390*da0073e9SAndroid Build Coastguard Worker                return 1
4391*da0073e9SAndroid Build Coastguard Worker
4392*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, exc_type, exc_val, exc_tb):
4393*da0073e9SAndroid Build Coastguard Worker                pass
4394*da0073e9SAndroid Build Coastguard Worker
4395*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
4396*da0073e9SAndroid Build Coastguard Worker            ucm = UserCtxManager()
4397*da0073e9SAndroid Build Coastguard Worker            return x * x
4398*da0073e9SAndroid Build Coastguard Worker
4399*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
4400*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
4401*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([2, 2])
4402*da0073e9SAndroid Build Coastguard Worker        opt_fn(x, x)
4403*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnt.frame_count, """1""")
4404*da0073e9SAndroid Build Coastguard Worker
4405*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(capture_scalar_outputs=True)
4406*da0073e9SAndroid Build Coastguard Worker    def test_unbacked_arange_in_bounds(self):
4407*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/issues/113002
4408*da0073e9SAndroid Build Coastguard Worker        class PaddingNet(nn.Module):
4409*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
4410*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4411*da0073e9SAndroid Build Coastguard Worker
4412*da0073e9SAndroid Build Coastguard Worker            def forward(self, lengths):
4413*da0073e9SAndroid Build Coastguard Worker                max_seq_len = lengths.max().item()
4414*da0073e9SAndroid Build Coastguard Worker                row_vector = torch.arange(0, max_seq_len, 1)
4415*da0073e9SAndroid Build Coastguard Worker                matrix = torch.unsqueeze(lengths, dim=-1)
4416*da0073e9SAndroid Build Coastguard Worker                mask = row_vector < matrix
4417*da0073e9SAndroid Build Coastguard Worker                mask = mask.type(torch.float32)
4418*da0073e9SAndroid Build Coastguard Worker                mask_3d_btd = mask[:, :, None]
4419*da0073e9SAndroid Build Coastguard Worker                return mask_3d_btd
4420*da0073e9SAndroid Build Coastguard Worker
4421*da0073e9SAndroid Build Coastguard Worker        model = PaddingNet()
4422*da0073e9SAndroid Build Coastguard Worker        lengths = torch.tensor([5, 4, 4, 4], dtype=torch.int32)
4423*da0073e9SAndroid Build Coastguard Worker
4424*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
4425*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(model)
4426*da0073e9SAndroid Build Coastguard Worker        opt_fn(lengths)
4427*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
4428*da0073e9SAndroid Build Coastguard Worker
4429*da0073e9SAndroid Build Coastguard Worker    def test_overlapping_inputs_with_dynamic_shapes_error(self):
4430*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="aot_eager")
4431*da0073e9SAndroid Build Coastguard Worker        def fn(a, b, c, d, e, f):
4432*da0073e9SAndroid Build Coastguard Worker            a.mul_(2)
4433*da0073e9SAndroid Build Coastguard Worker            b.mul_(2)
4434*da0073e9SAndroid Build Coastguard Worker            c.mul_(2)
4435*da0073e9SAndroid Build Coastguard Worker            d.mul_(2)
4436*da0073e9SAndroid Build Coastguard Worker            e.mul_(2)
4437*da0073e9SAndroid Build Coastguard Worker            f.mul_(2)
4438*da0073e9SAndroid Build Coastguard Worker
4439*da0073e9SAndroid Build Coastguard Worker            base = torch.ones(2, 20)
4440*da0073e9SAndroid Build Coastguard Worker            a = base[:, 0:2]
4441*da0073e9SAndroid Build Coastguard Worker            b = base[:, 2:4]
4442*da0073e9SAndroid Build Coastguard Worker            c = base[:, 4:6]
4443*da0073e9SAndroid Build Coastguard Worker            d = base[:, 6:8]
4444*da0073e9SAndroid Build Coastguard Worker            e = base[:, 8:10]
4445*da0073e9SAndroid Build Coastguard Worker            f = base[:, 10:12]
4446*da0073e9SAndroid Build Coastguard Worker            f2 = base[:, 10:14]
4447*da0073e9SAndroid Build Coastguard Worker            out = fn(a, b, c, d, e, f)
4448*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
4449*da0073e9SAndroid Build Coastguard Worker                AssertionError, "is being compiled with dynamic shapes"
4450*da0073e9SAndroid Build Coastguard Worker            ):
4451*da0073e9SAndroid Build Coastguard Worker                out2 = fn(a, b, c, d, e, f2)
4452*da0073e9SAndroid Build Coastguard Worker
4453*da0073e9SAndroid Build Coastguard Worker    def test_user_ctor_ctx_manager_custom_init(self):
4454*da0073e9SAndroid Build Coastguard Worker        class UserCtxManager:
4455*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
4456*da0073e9SAndroid Build Coastguard Worker                x[0] = 10
4457*da0073e9SAndroid Build Coastguard Worker
4458*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
4459*da0073e9SAndroid Build Coastguard Worker                return 1
4460*da0073e9SAndroid Build Coastguard Worker
4461*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, exc_type, exc_val, exc_tb):
4462*da0073e9SAndroid Build Coastguard Worker                pass
4463*da0073e9SAndroid Build Coastguard Worker
4464*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
4465*da0073e9SAndroid Build Coastguard Worker            ucm = UserCtxManager(y)
4466*da0073e9SAndroid Build Coastguard Worker            return x * y[0]
4467*da0073e9SAndroid Build Coastguard Worker
4468*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
4469*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
4470*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([2, 2])
4471*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, [5]), fn(x, [5]))
4472*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(cnt.frame_count, """1""")
4473*da0073e9SAndroid Build Coastguard Worker
4474*da0073e9SAndroid Build Coastguard Worker    def test_user_ctor_ctx_manager_custom_init_graph_break(self):
4475*da0073e9SAndroid Build Coastguard Worker        counter = [0]
4476*da0073e9SAndroid Build Coastguard Worker
4477*da0073e9SAndroid Build Coastguard Worker        class UserCtxManager:
4478*da0073e9SAndroid Build Coastguard Worker            def __init__(self, k):
4479*da0073e9SAndroid Build Coastguard Worker                k[0] += 1
4480*da0073e9SAndroid Build Coastguard Worker
4481*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
4482*da0073e9SAndroid Build Coastguard Worker                return 1
4483*da0073e9SAndroid Build Coastguard Worker
4484*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, exc_type, exc_val, exc_tb):
4485*da0073e9SAndroid Build Coastguard Worker                pass
4486*da0073e9SAndroid Build Coastguard Worker
4487*da0073e9SAndroid Build Coastguard Worker        def fn(x, counter):
4488*da0073e9SAndroid Build Coastguard Worker            x = x * x
4489*da0073e9SAndroid Build Coastguard Worker            ucm = UserCtxManager(counter)
4490*da0073e9SAndroid Build Coastguard Worker            return x * x
4491*da0073e9SAndroid Build Coastguard Worker
4492*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
4493*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnt)(fn)
4494*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([2, 2])
4495*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, counter), fn(x, counter))
4496*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 2)
4497*da0073e9SAndroid Build Coastguard Worker        for i in range(0, 10):
4498*da0073e9SAndroid Build Coastguard Worker            opt_fn(x, counter)
4499*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 12)
4500*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
4501*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """2""")
4502*da0073e9SAndroid Build Coastguard Worker        else:
4503*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(cnt.frame_count, """1""")
4504*da0073e9SAndroid Build Coastguard Worker
4505*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
4506*da0073e9SAndroid Build Coastguard Worker    def test_many_overlapping_inputs_does_not_explode_guards(self):
4507*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.backends.common import aot_autograd
4508*da0073e9SAndroid Build Coastguard Worker
4509*da0073e9SAndroid Build Coastguard Worker        # Before, this was (9702, 0)
4510*da0073e9SAndroid Build Coastguard Worker        num_shape_guards = None
4511*da0073e9SAndroid Build Coastguard Worker        num_aot_guards = None
4512*da0073e9SAndroid Build Coastguard Worker        num_compiles = 0
4513*da0073e9SAndroid Build Coastguard Worker
4514*da0073e9SAndroid Build Coastguard Worker        def guard_count_backend(gm, *args):
4515*da0073e9SAndroid Build Coastguard Worker            nonlocal num_shape_guards
4516*da0073e9SAndroid Build Coastguard Worker            nonlocal num_aot_guards
4517*da0073e9SAndroid Build Coastguard Worker            nonlocal num_compiles
4518*da0073e9SAndroid Build Coastguard Worker            num_shape_guards = len(
4519*da0073e9SAndroid Build Coastguard Worker                torch._guards.TracingContext.try_get().fake_mode.shape_env.guards
4520*da0073e9SAndroid Build Coastguard Worker            )
4521*da0073e9SAndroid Build Coastguard Worker            num_aot_guards = len(
4522*da0073e9SAndroid Build Coastguard Worker                torch._guards.TracingContext.try_get().guards_context.aotautograd_guards
4523*da0073e9SAndroid Build Coastguard Worker            )
4524*da0073e9SAndroid Build Coastguard Worker            num_compiles += 1
4525*da0073e9SAndroid Build Coastguard Worker            return gm
4526*da0073e9SAndroid Build Coastguard Worker
4527*da0073e9SAndroid Build Coastguard Worker        aot_guard_counter = aot_autograd(fw_compiler=guard_count_backend)
4528*da0073e9SAndroid Build Coastguard Worker
4529*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=aot_guard_counter, dynamic=True)
4530*da0073e9SAndroid Build Coastguard Worker        def f(*args):
4531*da0073e9SAndroid Build Coastguard Worker            for a in args:
4532*da0073e9SAndroid Build Coastguard Worker                a.add_(1)
4533*da0073e9SAndroid Build Coastguard Worker
4534*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1000, requires_grad=True)
4535*da0073e9SAndroid Build Coastguard Worker        args = x.split(10)
4536*da0073e9SAndroid Build Coastguard Worker
4537*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
4538*da0073e9SAndroid Build Coastguard Worker            f(*args)
4539*da0073e9SAndroid Build Coastguard Worker        # In this example, there were 4950 guards (roughly (# tensors) ^ 2 // 2),
4540*da0073e9SAndroid Build Coastguard Worker        # because every pair of aliased inputs needs a guard.
4541*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(num_aot_guards < 5000)
4542*da0073e9SAndroid Build Coastguard Worker        # But there are no dynamic shape guards.
4543*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_shape_guards, 0)
4544*da0073e9SAndroid Build Coastguard Worker        # don't recompile
4545*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
4546*da0073e9SAndroid Build Coastguard Worker            f(*args)
4547*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(num_compiles, 1)
4548*da0073e9SAndroid Build Coastguard Worker
4549*da0073e9SAndroid Build Coastguard Worker    def test_invalid_seq_unpack(self):
4550*da0073e9SAndroid Build Coastguard Worker        def myfn(arg):
4551*da0073e9SAndroid Build Coastguard Worker            (a, b) = arg
4552*da0073e9SAndroid Build Coastguard Worker
4553*da0073e9SAndroid Build Coastguard Worker        def fn():
4554*da0073e9SAndroid Build Coastguard Worker            return myfn((1, 2, 3))
4555*da0073e9SAndroid Build Coastguard Worker
4556*da0073e9SAndroid Build Coastguard Worker        try:
4557*da0073e9SAndroid Build Coastguard Worker            torch.compile(fn)()
4558*da0073e9SAndroid Build Coastguard Worker        except ValueError:
4559*da0073e9SAndroid Build Coastguard Worker            pass
4560*da0073e9SAndroid Build Coastguard Worker        else:
4561*da0073e9SAndroid Build Coastguard Worker            self.fail("expected exception")
4562*da0073e9SAndroid Build Coastguard Worker
4563*da0073e9SAndroid Build Coastguard Worker    def test_megablocks_moe(self):
4564*da0073e9SAndroid Build Coastguard Worker        try:
4565*da0073e9SAndroid Build Coastguard Worker            from megablocks.layers import moe
4566*da0073e9SAndroid Build Coastguard Worker            from megablocks.layers.arguments import Arguments
4567*da0073e9SAndroid Build Coastguard Worker        except ImportError as e:
4568*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("requires megablocks") from e
4569*da0073e9SAndroid Build Coastguard Worker        bs, sl, hs, num_experts, top_k = (16, 1024, 512, 1, 1)
4570*da0073e9SAndroid Build Coastguard Worker        args = Arguments(
4571*da0073e9SAndroid Build Coastguard Worker            hidden_size=hs,
4572*da0073e9SAndroid Build Coastguard Worker            ffn_hidden_size=hs * 2,
4573*da0073e9SAndroid Build Coastguard Worker            moe_num_experts=num_experts,
4574*da0073e9SAndroid Build Coastguard Worker            moe_capacity_factor=1,
4575*da0073e9SAndroid Build Coastguard Worker            moe_top_k=top_k,
4576*da0073e9SAndroid Build Coastguard Worker        )
4577*da0073e9SAndroid Build Coastguard Worker        moe_mlp = moe.MoE(args)
4578*da0073e9SAndroid Build Coastguard Worker        moe_mlp.cuda(torch.cuda.current_device()).half()
4579*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(sl, bs, hs).cuda().half()
4580*da0073e9SAndroid Build Coastguard Worker        out1, _ = moe_mlp(x)
4581*da0073e9SAndroid Build Coastguard Worker        out2, _ = torch.compile(moe_mlp, backend="eager")(x)
4582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
4583*da0073e9SAndroid Build Coastguard Worker
4584*da0073e9SAndroid Build Coastguard Worker    def test_udf_classes_reconstruction(self):
4585*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4586*da0073e9SAndroid Build Coastguard Worker            o = T(5)
4587*da0073e9SAndroid Build Coastguard Worker            return o.x + x
4588*da0073e9SAndroid Build Coastguard Worker
4589*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
4590*da0073e9SAndroid Build Coastguard Worker        T = IncByOne
4591*da0073e9SAndroid Build Coastguard Worker
4592*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
4593*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
4594*da0073e9SAndroid Build Coastguard Worker
4595*da0073e9SAndroid Build Coastguard Worker        # This should recompile
4596*da0073e9SAndroid Build Coastguard Worker        T = IncByTwo
4597*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
4598*da0073e9SAndroid Build Coastguard Worker
4599*da0073e9SAndroid Build Coastguard Worker    def test_contains_range_constprop(self):
4600*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4601*da0073e9SAndroid Build Coastguard Worker            # dynamo should const prop to False
4602*da0073e9SAndroid Build Coastguard Worker            if 3 in range(0, 10):
4603*da0073e9SAndroid Build Coastguard Worker                return x + 1
4604*da0073e9SAndroid Build Coastguard Worker            else:
4605*da0073e9SAndroid Build Coastguard Worker                return x + 2
4606*da0073e9SAndroid Build Coastguard Worker
4607*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
4608*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(4)
4609*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
4610*da0073e9SAndroid Build Coastguard Worker
4611*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/104505
4612*da0073e9SAndroid Build Coastguard Worker    def test_as_strided_on_base_with_mutation_works(self):
4613*da0073e9SAndroid Build Coastguard Worker        def foo(a):
4614*da0073e9SAndroid Build Coastguard Worker            f = a.as_strided((2,), (1,), 0)
4615*da0073e9SAndroid Build Coastguard Worker            f.add_(1.0)
4616*da0073e9SAndroid Build Coastguard Worker            return a
4617*da0073e9SAndroid Build Coastguard Worker
4618*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 4)
4619*da0073e9SAndroid Build Coastguard Worker        a_ref = a.clone()
4620*da0073e9SAndroid Build Coastguard Worker        out_ref = foo(a_ref)
4621*da0073e9SAndroid Build Coastguard Worker        f_compiled = torch.compile(foo, backend="aot_eager")
4622*da0073e9SAndroid Build Coastguard Worker        out = f_compiled(a)
4623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out)
4624*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_ref, a)
4625*da0073e9SAndroid Build Coastguard Worker
4626*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/104505
4627*da0073e9SAndroid Build Coastguard Worker    def test_as_strided_on_existing_view_banned(self):
4628*da0073e9SAndroid Build Coastguard Worker        def foo(a):
4629*da0073e9SAndroid Build Coastguard Worker            e = a.diagonal()
4630*da0073e9SAndroid Build Coastguard Worker            f = e.as_strided((2,), (1,), 0)
4631*da0073e9SAndroid Build Coastguard Worker            f.add_(1.0)
4632*da0073e9SAndroid Build Coastguard Worker            return a
4633*da0073e9SAndroid Build Coastguard Worker
4634*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 4)
4635*da0073e9SAndroid Build Coastguard Worker        a_ref = a.clone()
4636*da0073e9SAndroid Build Coastguard Worker        out_ref = foo(a_ref)
4637*da0073e9SAndroid Build Coastguard Worker        f_compiled = torch.compile(foo, backend="aot_eager")
4638*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4639*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
4640*da0073e9SAndroid Build Coastguard Worker            "encountered a mutation on a view chain of length 2, where view 1 was an as_strided",
4641*da0073e9SAndroid Build Coastguard Worker        ):
4642*da0073e9SAndroid Build Coastguard Worker            out = f_compiled(a)
4643*da0073e9SAndroid Build Coastguard Worker
4644*da0073e9SAndroid Build Coastguard Worker    def test_dont_aggressively_write_assert(self):
4645*da0073e9SAndroid Build Coastguard Worker        record_graph = torch._dynamo.testing.EagerAndRecordGraphs()
4646*da0073e9SAndroid Build Coastguard Worker
4647*da0073e9SAndroid Build Coastguard Worker        @torch.compile(dynamic=True, backend=record_graph)
4648*da0073e9SAndroid Build Coastguard Worker        def f(x):
4649*da0073e9SAndroid Build Coastguard Worker            assert x.shape[0] > 3
4650*da0073e9SAndroid Build Coastguard Worker            assert x[0].sum() > 0
4651*da0073e9SAndroid Build Coastguard Worker            assert 1 % (x.shape[0] // 2) != 0
4652*da0073e9SAndroid Build Coastguard Worker            assert 32 * (x.shape[0] // 2) ** 2 - 16 * (x.shape[0] // 2) != 0
4653*da0073e9SAndroid Build Coastguard Worker            return x.cos()
4654*da0073e9SAndroid Build Coastguard Worker
4655*da0073e9SAndroid Build Coastguard Worker        f(torch.ones(6, 4))
4656*da0073e9SAndroid Build Coastguard Worker        graph = record_graph.graphs[0]
4657*da0073e9SAndroid Build Coastguard Worker        # It is bit annoying that we generate useless statements for
4658*da0073e9SAndroid Build Coastguard Worker        # shape guards, but DCE should be able to remove them since t
4659*da0073e9SAndroid Build Coastguard Worker        # there is no backed assert on them. The reason this is ok is
4660*da0073e9SAndroid Build Coastguard Worker        # because dynamo will only skip the assert statement, but not
4661*da0073e9SAndroid Build Coastguard Worker        # the instructions before it.
4662*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
4663*da0073e9SAndroid Build Coastguard Worker            str(graph.code).strip(),
4664*da0073e9SAndroid Build Coastguard Worker            """\
4665*da0073e9SAndroid Build Coastguard Workerdef forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
4666*da0073e9SAndroid Build Coastguard Worker    l_x_ = L_x_
4667*da0073e9SAndroid Build Coastguard Worker    getitem_2 = l_x_[0]
4668*da0073e9SAndroid Build Coastguard Worker    sum_1 = getitem_2.sum();  getitem_2 = None
4669*da0073e9SAndroid Build Coastguard Worker    gt_1 = sum_1 > 0;  sum_1 = None
4670*da0073e9SAndroid Build Coastguard Worker    _assert_async = torch._assert_async(gt_1, 'assertion error');  gt_1 = _assert_async = None
4671*da0073e9SAndroid Build Coastguard Worker    cos = l_x_.cos();  l_x_ = None
4672*da0073e9SAndroid Build Coastguard Worker    return (cos,)""",
4673*da0073e9SAndroid Build Coastguard Worker        )
4674*da0073e9SAndroid Build Coastguard Worker        for node in graph.graph.nodes:
4675*da0073e9SAndroid Build Coastguard Worker            if "example_value" in node.meta and isinstance(
4676*da0073e9SAndroid Build Coastguard Worker                node.meta["example_value"], torch._subclasses.fake_tensor.FakeTensor
4677*da0073e9SAndroid Build Coastguard Worker            ):
4678*da0073e9SAndroid Build Coastguard Worker                shape_env = node.meta["example_value"].fake_mode.shape_env
4679*da0073e9SAndroid Build Coastguard Worker                lower_ranges = [val.lower for val in shape_env.var_to_range.values()]
4680*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(lower_ranges == [4, 2])
4681*da0073e9SAndroid Build Coastguard Worker
4682*da0073e9SAndroid Build Coastguard Worker        @torch.compile(dynamic=True, backend=record_graph)
4683*da0073e9SAndroid Build Coastguard Worker        def f_fail(x):
4684*da0073e9SAndroid Build Coastguard Worker            assert x.shape[0] < 3
4685*da0073e9SAndroid Build Coastguard Worker
4686*da0073e9SAndroid Build Coastguard Worker        # We graph-break here, so the failure should be eager
4687*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, ""):
4688*da0073e9SAndroid Build Coastguard Worker            f_fail(torch.ones(6, 4))
4689*da0073e9SAndroid Build Coastguard Worker
4690*da0073e9SAndroid Build Coastguard Worker    def test_detectron2_instances_cat(self):
4691*da0073e9SAndroid Build Coastguard Worker        class Instances:
4692*da0073e9SAndroid Build Coastguard Worker            def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
4693*da0073e9SAndroid Build Coastguard Worker                self._image_size = image_size
4694*da0073e9SAndroid Build Coastguard Worker                self._fields: Dict[str, Any] = {}
4695*da0073e9SAndroid Build Coastguard Worker                for k, v in kwargs.items():
4696*da0073e9SAndroid Build Coastguard Worker                    self.set(k, v)
4697*da0073e9SAndroid Build Coastguard Worker
4698*da0073e9SAndroid Build Coastguard Worker            @property
4699*da0073e9SAndroid Build Coastguard Worker            def image_size(self) -> Tuple[int, int]:
4700*da0073e9SAndroid Build Coastguard Worker                return self._image_size
4701*da0073e9SAndroid Build Coastguard Worker
4702*da0073e9SAndroid Build Coastguard Worker            def __setattr__(self, name: str, val: Any) -> None:
4703*da0073e9SAndroid Build Coastguard Worker                if name.startswith("_"):
4704*da0073e9SAndroid Build Coastguard Worker                    super().__setattr__(name, val)
4705*da0073e9SAndroid Build Coastguard Worker                else:
4706*da0073e9SAndroid Build Coastguard Worker                    self.set(name, val)
4707*da0073e9SAndroid Build Coastguard Worker
4708*da0073e9SAndroid Build Coastguard Worker            def __getattr__(self, name: str) -> Any:
4709*da0073e9SAndroid Build Coastguard Worker                if name == "_fields" or name not in self._fields:
4710*da0073e9SAndroid Build Coastguard Worker                    raise AttributeError(
4711*da0073e9SAndroid Build Coastguard Worker                        f"Cannot find field '{name}' in the given Instances!"
4712*da0073e9SAndroid Build Coastguard Worker                    )
4713*da0073e9SAndroid Build Coastguard Worker                return self._fields[name]
4714*da0073e9SAndroid Build Coastguard Worker
4715*da0073e9SAndroid Build Coastguard Worker            def __len__(self) -> int:
4716*da0073e9SAndroid Build Coastguard Worker                for v in self._fields.values():
4717*da0073e9SAndroid Build Coastguard Worker                    # use __len__ because len() has to be int and is not friendly to tracing
4718*da0073e9SAndroid Build Coastguard Worker                    return v.__len__()
4719*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError("Empty Instances does not support __len__!")
4720*da0073e9SAndroid Build Coastguard Worker
4721*da0073e9SAndroid Build Coastguard Worker            def set(self, name: str, value: Any) -> None:
4722*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True):
4723*da0073e9SAndroid Build Coastguard Worker                    data_len = len(value)
4724*da0073e9SAndroid Build Coastguard Worker                if len(self._fields):
4725*da0073e9SAndroid Build Coastguard Worker                    assert (
4726*da0073e9SAndroid Build Coastguard Worker                        len(self) == data_len
4727*da0073e9SAndroid Build Coastguard Worker                    ), f"Adding a field of length {data_len} to a Instances of length {len(self)}"
4728*da0073e9SAndroid Build Coastguard Worker                self._fields[name] = value
4729*da0073e9SAndroid Build Coastguard Worker
4730*da0073e9SAndroid Build Coastguard Worker            def get(self, name: str) -> Any:
4731*da0073e9SAndroid Build Coastguard Worker                return self._fields[name]
4732*da0073e9SAndroid Build Coastguard Worker
4733*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4734*da0073e9SAndroid Build Coastguard Worker            def cat(instance_lists: List["Instances"]) -> "Instances":
4735*da0073e9SAndroid Build Coastguard Worker                assert all(isinstance(i, Instances) for i in instance_lists)
4736*da0073e9SAndroid Build Coastguard Worker                assert len(instance_lists) > 0
4737*da0073e9SAndroid Build Coastguard Worker                if len(instance_lists) == 1:
4738*da0073e9SAndroid Build Coastguard Worker                    return instance_lists[0]
4739*da0073e9SAndroid Build Coastguard Worker
4740*da0073e9SAndroid Build Coastguard Worker                image_size = instance_lists[0].image_size
4741*da0073e9SAndroid Build Coastguard Worker                if not isinstance(
4742*da0073e9SAndroid Build Coastguard Worker                    image_size, torch.Tensor
4743*da0073e9SAndroid Build Coastguard Worker                ):  # could be a tensor in tracing
4744*da0073e9SAndroid Build Coastguard Worker                    for i in instance_lists[1:]:
4745*da0073e9SAndroid Build Coastguard Worker                        assert i.image_size == image_size
4746*da0073e9SAndroid Build Coastguard Worker                ret = Instances(image_size)
4747*da0073e9SAndroid Build Coastguard Worker                for k in instance_lists[0]._fields.keys():
4748*da0073e9SAndroid Build Coastguard Worker                    values = [i.get(k) for i in instance_lists]
4749*da0073e9SAndroid Build Coastguard Worker                    v0 = values[0]
4750*da0073e9SAndroid Build Coastguard Worker                    if isinstance(v0, torch.Tensor):
4751*da0073e9SAndroid Build Coastguard Worker                        values = torch.cat(values, dim=0)
4752*da0073e9SAndroid Build Coastguard Worker                    elif isinstance(v0, list):
4753*da0073e9SAndroid Build Coastguard Worker                        values = list(itertools.chain(*values))
4754*da0073e9SAndroid Build Coastguard Worker                    elif hasattr(type(v0), "cat"):
4755*da0073e9SAndroid Build Coastguard Worker                        values = type(v0).cat(values)
4756*da0073e9SAndroid Build Coastguard Worker                    else:
4757*da0073e9SAndroid Build Coastguard Worker                        raise ValueError(
4758*da0073e9SAndroid Build Coastguard Worker                            f"Unsupported type {type(v0)} for concatenation"
4759*da0073e9SAndroid Build Coastguard Worker                        )
4760*da0073e9SAndroid Build Coastguard Worker                    ret.set(k, values)
4761*da0073e9SAndroid Build Coastguard Worker                return ret
4762*da0073e9SAndroid Build Coastguard Worker
4763*da0073e9SAndroid Build Coastguard Worker        instances = [
4764*da0073e9SAndroid Build Coastguard Worker            Instances((16, 16), a=torch.randn(16, 16), b=torch.randn(16, 16))
4765*da0073e9SAndroid Build Coastguard Worker            for _ in range(3)
4766*da0073e9SAndroid Build Coastguard Worker        ]
4767*da0073e9SAndroid Build Coastguard Worker
4768*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
4769*da0073e9SAndroid Build Coastguard Worker        def fn(instances):
4770*da0073e9SAndroid Build Coastguard Worker            return instances[0].cat(instances)
4771*da0073e9SAndroid Build Coastguard Worker
4772*da0073e9SAndroid Build Coastguard Worker        actual = fn(instances)
4773*da0073e9SAndroid Build Coastguard Worker        expected = instances[0].cat(instances)
4774*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(actual), type(expected))
4775*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual.__dict__, expected.__dict__)
4776*da0073e9SAndroid Build Coastguard Worker
4777*da0073e9SAndroid Build Coastguard Worker    def test_weakref(self):
4778*da0073e9SAndroid Build Coastguard Worker        def fn(x_weak, weight, y):
4779*da0073e9SAndroid Build Coastguard Worker            if x_weak is not None and x_weak() is not weight:
4780*da0073e9SAndroid Build Coastguard Worker                return torch.sin(y)
4781*da0073e9SAndroid Build Coastguard Worker            return torch.cos(y)
4782*da0073e9SAndroid Build Coastguard Worker
4783*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(4)
4784*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4)
4785*da0073e9SAndroid Build Coastguard Worker        x_weak = weakref.ref(weight)
4786*da0073e9SAndroid Build Coastguard Worker
4787*da0073e9SAndroid Build Coastguard Worker        ref = fn(x_weak, weight, y)
4788*da0073e9SAndroid Build Coastguard Worker
4789*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4790*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x_weak, weight, y)
4791*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
4792*da0073e9SAndroid Build Coastguard Worker
4793*da0073e9SAndroid Build Coastguard Worker    def test_weakref_reconstruct(self):
4794*da0073e9SAndroid Build Coastguard Worker        def fn(x_weak, weight, y):
4795*da0073e9SAndroid Build Coastguard Worker            y = torch.sin(y)
4796*da0073e9SAndroid Build Coastguard Worker            referent = x_weak()
4797*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
4798*da0073e9SAndroid Build Coastguard Worker            if referent is not weight:
4799*da0073e9SAndroid Build Coastguard Worker                return torch.sin(y)
4800*da0073e9SAndroid Build Coastguard Worker            return torch.cos(y)
4801*da0073e9SAndroid Build Coastguard Worker
4802*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(4)
4803*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4)
4804*da0073e9SAndroid Build Coastguard Worker        x_weak = weakref.ref(weight)
4805*da0073e9SAndroid Build Coastguard Worker
4806*da0073e9SAndroid Build Coastguard Worker        ref = fn(x_weak, weight, y)
4807*da0073e9SAndroid Build Coastguard Worker
4808*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
4809*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend=cnt)
4810*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x_weak, weight, y)
4811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
4812*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
4813*da0073e9SAndroid Build Coastguard Worker
4814*da0073e9SAndroid Build Coastguard Worker    def test_weakref_del(self):
4815*da0073e9SAndroid Build Coastguard Worker        def fn(x_weak, y):
4816*da0073e9SAndroid Build Coastguard Worker            x = x_weak()
4817*da0073e9SAndroid Build Coastguard Worker            if x is not None:
4818*da0073e9SAndroid Build Coastguard Worker                return torch.sin(y)
4819*da0073e9SAndroid Build Coastguard Worker            return torch.cos(y)
4820*da0073e9SAndroid Build Coastguard Worker
4821*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(4)
4822*da0073e9SAndroid Build Coastguard Worker        x_weak = weakref.ref(weight)
4823*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4)
4824*da0073e9SAndroid Build Coastguard Worker
4825*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4826*da0073e9SAndroid Build Coastguard Worker
4827*da0073e9SAndroid Build Coastguard Worker        ref = fn(x_weak, y)
4828*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x_weak, y)
4829*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
4830*da0073e9SAndroid Build Coastguard Worker
4831*da0073e9SAndroid Build Coastguard Worker        del weight
4832*da0073e9SAndroid Build Coastguard Worker        gc.collect()
4833*da0073e9SAndroid Build Coastguard Worker        ref = fn(x_weak, y)
4834*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x_weak, y)
4835*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
4836*da0073e9SAndroid Build Coastguard Worker
4837*da0073e9SAndroid Build Coastguard Worker    #     @torch._functorch.config.patch(
4838*da0073e9SAndroid Build Coastguard Worker    #         recompute_views=True,
4839*da0073e9SAndroid Build Coastguard Worker    #     )
4840*da0073e9SAndroid Build Coastguard Worker    #     def test_storage_resize_forward_full_graph(self):
4841*da0073e9SAndroid Build Coastguard Worker    #         class TestModule(torch.nn.Module):
4842*da0073e9SAndroid Build Coastguard Worker    #             def __init__(self) -> None:
4843*da0073e9SAndroid Build Coastguard Worker    #                 super().__init__()
4844*da0073e9SAndroid Build Coastguard Worker    #                 self.param = torch.nn.Parameter(torch.randn(4, 4))
4845*da0073e9SAndroid Build Coastguard Worker
4846*da0073e9SAndroid Build Coastguard Worker    #             def forward(self, x):
4847*da0073e9SAndroid Build Coastguard Worker    #                 self.param.untyped_storage().resize_(
4848*da0073e9SAndroid Build Coastguard Worker    #                     self.param.numel() * self.param.itemsize
4849*da0073e9SAndroid Build Coastguard Worker    #                 )
4850*da0073e9SAndroid Build Coastguard Worker    #                 with torch.no_grad():
4851*da0073e9SAndroid Build Coastguard Worker    #                     torch._foreach_copy_([self.param], [x])
4852*da0073e9SAndroid Build Coastguard Worker    #                 out = torch.matmul(self.param, self.param)
4853*da0073e9SAndroid Build Coastguard Worker    #                 self.param.untyped_storage().resize_(0)
4854*da0073e9SAndroid Build Coastguard Worker    #                 return out
4855*da0073e9SAndroid Build Coastguard Worker
4856*da0073e9SAndroid Build Coastguard Worker    #         def post_accumulate_grad_hook(param):
4857*da0073e9SAndroid Build Coastguard Worker    #             param.untyped_storage().resize_(0)
4858*da0073e9SAndroid Build Coastguard Worker
4859*da0073e9SAndroid Build Coastguard Worker    #         # Beginning of backward, resize and put data into the param
4860*da0073e9SAndroid Build Coastguard Worker    #         def pre_backward_hook(module, grad) -> None:
4861*da0073e9SAndroid Build Coastguard Worker    #             module.param.untyped_storage().resize_(
4862*da0073e9SAndroid Build Coastguard Worker    #                 self.param.numel() * self.param.itemsize
4863*da0073e9SAndroid Build Coastguard Worker    #             )
4864*da0073e9SAndroid Build Coastguard Worker    #             with torch.no_grad():
4865*da0073e9SAndroid Build Coastguard Worker    #                 # simulates loading data into param from allgather
4866*da0073e9SAndroid Build Coastguard Worker    #                 module.param.fill_(2)
4867*da0073e9SAndroid Build Coastguard Worker
4868*da0073e9SAndroid Build Coastguard Worker    #         def post_forward_hook(module, args, output):
4869*da0073e9SAndroid Build Coastguard Worker    #             output.register_hook(functools.partial(pre_backward_hook, module))
4870*da0073e9SAndroid Build Coastguard Worker
4871*da0073e9SAndroid Build Coastguard Worker    #         x = torch.randn(4, 4)
4872*da0073e9SAndroid Build Coastguard Worker
4873*da0073e9SAndroid Build Coastguard Worker    #         mod_ref = TestModule()
4874*da0073e9SAndroid Build Coastguard Worker    #         mod_test = deepcopy(mod_ref)
4875*da0073e9SAndroid Build Coastguard Worker
4876*da0073e9SAndroid Build Coastguard Worker    #         # Start the param off with zero storage size to mimic fsdp
4877*da0073e9SAndroid Build Coastguard Worker    #         mod_ref.param.untyped_storage().resize_(0)
4878*da0073e9SAndroid Build Coastguard Worker    #         mod_test.param.untyped_storage().resize_(0)
4879*da0073e9SAndroid Build Coastguard Worker
4880*da0073e9SAndroid Build Coastguard Worker    #         # Resize storage at beginning of backward
4881*da0073e9SAndroid Build Coastguard Worker    #         # Free storage at end of backward
4882*da0073e9SAndroid Build Coastguard Worker    #         mod_ref.register_forward_hook(post_forward_hook, prepend=False)
4883*da0073e9SAndroid Build Coastguard Worker    #         mod_ref.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
4884*da0073e9SAndroid Build Coastguard Worker    #         mod_test.register_forward_hook(post_forward_hook, prepend=False)
4885*da0073e9SAndroid Build Coastguard Worker    #         mod_test.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
4886*da0073e9SAndroid Build Coastguard Worker
4887*da0073e9SAndroid Build Coastguard Worker    #         mod_test = torch.compile(mod_test, backend=aot_graph_capture_backend)
4888*da0073e9SAndroid Build Coastguard Worker
4889*da0073e9SAndroid Build Coastguard Worker    #         out_ref = mod_ref(x)
4890*da0073e9SAndroid Build Coastguard Worker    #         out_test = mod_test(x)
4891*da0073e9SAndroid Build Coastguard Worker    #         self.assertExpectedInline(
4892*da0073e9SAndroid Build Coastguard Worker    #             str(fw_graph[0].code.strip()),
4893*da0073e9SAndroid Build Coastguard Worker    #             """\
4894*da0073e9SAndroid Build Coastguard Worker    # def forward(self, primals_1, primals_2):
4895*da0073e9SAndroid Build Coastguard Worker    #     _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]);  primals_1 = primals_2 = None
4896*da0073e9SAndroid Build Coastguard Worker    #     getitem = _foreach_copy[0];  _foreach_copy = None
4897*da0073e9SAndroid Build Coastguard Worker    #     mm = torch.ops.aten.mm.default(getitem, getitem)
4898*da0073e9SAndroid Build Coastguard Worker    #     return [mm, getitem]""",
4899*da0073e9SAndroid Build Coastguard Worker    #         )
4900*da0073e9SAndroid Build Coastguard Worker    #         self.assertEqual(out_ref, out_test)
4901*da0073e9SAndroid Build Coastguard Worker
4902*da0073e9SAndroid Build Coastguard Worker    def test_super_in_staticmethod(self):
4903*da0073e9SAndroid Build Coastguard Worker        class A:
4904*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4905*da0073e9SAndroid Build Coastguard Worker            def foo():
4906*da0073e9SAndroid Build Coastguard Worker                return super().__init__()
4907*da0073e9SAndroid Build Coastguard Worker
4908*da0073e9SAndroid Build Coastguard Worker        def fn(obj):
4909*da0073e9SAndroid Build Coastguard Worker            return obj.foo()
4910*da0073e9SAndroid Build Coastguard Worker
4911*da0073e9SAndroid Build Coastguard Worker        obj = A()
4912*da0073e9SAndroid Build Coastguard Worker
4913*da0073e9SAndroid Build Coastguard Worker        try:
4914*da0073e9SAndroid Build Coastguard Worker            fn(obj)
4915*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
4916*da0073e9SAndroid Build Coastguard Worker            orig_str = str(e)
4917*da0073e9SAndroid Build Coastguard Worker        self.assertIn("no arguments", orig_str)
4918*da0073e9SAndroid Build Coastguard Worker
4919*da0073e9SAndroid Build Coastguard Worker        try:
4920*da0073e9SAndroid Build Coastguard Worker            torch.compile(backend="eager")(fn)(obj)
4921*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
4922*da0073e9SAndroid Build Coastguard Worker            compiled_str = str(e)
4923*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(orig_str, compiled_str)
4924*da0073e9SAndroid Build Coastguard Worker
4925*da0073e9SAndroid Build Coastguard Worker    def test_super_staticmethod(self):
4926*da0073e9SAndroid Build Coastguard Worker        class Parent:
4927*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4928*da0073e9SAndroid Build Coastguard Worker            def greet():
4929*da0073e9SAndroid Build Coastguard Worker                return 5
4930*da0073e9SAndroid Build Coastguard Worker
4931*da0073e9SAndroid Build Coastguard Worker        class Child(Parent):
4932*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4933*da0073e9SAndroid Build Coastguard Worker            def greet(x):
4934*da0073e9SAndroid Build Coastguard Worker                return x * super(Child, Child).greet()
4935*da0073e9SAndroid Build Coastguard Worker
4936*da0073e9SAndroid Build Coastguard Worker        child = Child()
4937*da0073e9SAndroid Build Coastguard Worker
4938*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4939*da0073e9SAndroid Build Coastguard Worker            return child.greet(x)
4940*da0073e9SAndroid Build Coastguard Worker
4941*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4942*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(4)
4943*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
4944*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
4945*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
4946*da0073e9SAndroid Build Coastguard Worker
4947*da0073e9SAndroid Build Coastguard Worker    def test_super_diamond(self):
4948*da0073e9SAndroid Build Coastguard Worker        class A:
4949*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
4950*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4951*da0073e9SAndroid Build Coastguard Worker                self.a = 5
4952*da0073e9SAndroid Build Coastguard Worker
4953*da0073e9SAndroid Build Coastguard Worker        class Nothing:
4954*da0073e9SAndroid Build Coastguard Worker            pass
4955*da0073e9SAndroid Build Coastguard Worker
4956*da0073e9SAndroid Build Coastguard Worker        class B(Nothing, A):
4957*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
4958*da0073e9SAndroid Build Coastguard Worker                super().__init__()
4959*da0073e9SAndroid Build Coastguard Worker                self.b = 10
4960*da0073e9SAndroid Build Coastguard Worker
4961*da0073e9SAndroid Build Coastguard Worker            def run(self, x):
4962*da0073e9SAndroid Build Coastguard Worker                return self.a * self.b * x
4963*da0073e9SAndroid Build Coastguard Worker
4964*da0073e9SAndroid Build Coastguard Worker        def fn(x):
4965*da0073e9SAndroid Build Coastguard Worker            b = B()
4966*da0073e9SAndroid Build Coastguard Worker            return b.run(x)
4967*da0073e9SAndroid Build Coastguard Worker
4968*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
4969*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
4970*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
4971*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
4972*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
4973*da0073e9SAndroid Build Coastguard Worker
4974*da0073e9SAndroid Build Coastguard Worker    def test_vc_bumped_in_inference_graph(self):
4975*da0073e9SAndroid Build Coastguard Worker        @torch.compile
4976*da0073e9SAndroid Build Coastguard Worker        def f(x):
4977*da0073e9SAndroid Build Coastguard Worker            return x.mul_(2)
4978*da0073e9SAndroid Build Coastguard Worker
4979*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
4980*da0073e9SAndroid Build Coastguard Worker        vc_before = x._version
4981*da0073e9SAndroid Build Coastguard Worker        f(x)
4982*da0073e9SAndroid Build Coastguard Worker        vc_after = x._version
4983*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(vc_after > vc_before)
4984*da0073e9SAndroid Build Coastguard Worker
4985*da0073e9SAndroid Build Coastguard Worker    def test_nn_module_callable(self):
4986*da0073e9SAndroid Build Coastguard Worker        class M(nn.Module):
4987*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4988*da0073e9SAndroid Build Coastguard Worker                return x.sin()
4989*da0073e9SAndroid Build Coastguard Worker
4990*da0073e9SAndroid Build Coastguard Worker        def f(m):
4991*da0073e9SAndroid Build Coastguard Worker            return callable(m)
4992*da0073e9SAndroid Build Coastguard Worker
4993*da0073e9SAndroid Build Coastguard Worker        res = torch.compile(f, fullgraph=True)(M())
4994*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(res)
4995*da0073e9SAndroid Build Coastguard Worker
4996*da0073e9SAndroid Build Coastguard Worker    def test_stk_sdd_is_transposed(self):
4997*da0073e9SAndroid Build Coastguard Worker        trigger_graph_break = False
4998*da0073e9SAndroid Build Coastguard Worker
4999*da0073e9SAndroid Build Coastguard Worker        def _is_transposed(x):
5000*da0073e9SAndroid Build Coastguard Worker            return (
5001*da0073e9SAndroid Build Coastguard Worker                not x.is_contiguous()
5002*da0073e9SAndroid Build Coastguard Worker                and x.stride()[0] == 1
5003*da0073e9SAndroid Build Coastguard Worker                and x.stride()[1] == x.size()[0]
5004*da0073e9SAndroid Build Coastguard Worker            )
5005*da0073e9SAndroid Build Coastguard Worker
5006*da0073e9SAndroid Build Coastguard Worker        class SDD(torch.autograd.Function):
5007*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5008*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, lhs, rhs):
5009*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(lhs, rhs)
5010*da0073e9SAndroid Build Coastguard Worker                out = torch.full_like(lhs, 1.0, dtype=lhs.dtype, device=lhs.device)
5011*da0073e9SAndroid Build Coastguard Worker                return out
5012*da0073e9SAndroid Build Coastguard Worker
5013*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5014*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, dy):
5015*da0073e9SAndroid Build Coastguard Worker                saved_tensors = ctx.saved_tensors
5016*da0073e9SAndroid Build Coastguard Worker                lhs, rhs = saved_tensors[:2]
5017*da0073e9SAndroid Build Coastguard Worker                trans_a = _is_transposed(lhs)
5018*da0073e9SAndroid Build Coastguard Worker                trans_b = _is_transposed(rhs)
5019*da0073e9SAndroid Build Coastguard Worker                dlhs = None
5020*da0073e9SAndroid Build Coastguard Worker                if ctx.needs_input_grad[0]:
5021*da0073e9SAndroid Build Coastguard Worker                    dlhs = torch.full_like(lhs, 1.0 if trans_a else 2.0)
5022*da0073e9SAndroid Build Coastguard Worker                drhs = None
5023*da0073e9SAndroid Build Coastguard Worker                if ctx.needs_input_grad[1]:
5024*da0073e9SAndroid Build Coastguard Worker                    drhs = torch.full_like(rhs, 1.0 if trans_b else 2.0)
5025*da0073e9SAndroid Build Coastguard Worker                if trigger_graph_break:
5026*da0073e9SAndroid Build Coastguard Worker                    if _is_transposed(dy):
5027*da0073e9SAndroid Build Coastguard Worker                        return dlhs + 1, drhs + 1, None, None
5028*da0073e9SAndroid Build Coastguard Worker                return dlhs, drhs, None, None
5029*da0073e9SAndroid Build Coastguard Worker
5030*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn((8, 8), requires_grad=True)
5031*da0073e9SAndroid Build Coastguard Worker        y1 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True)
5032*da0073e9SAndroid Build Coastguard Worker        x2 = torch.randn((8, 8), requires_grad=True)
5033*da0073e9SAndroid Build Coastguard Worker        y2 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True)
5034*da0073e9SAndroid Build Coastguard Worker
5035*da0073e9SAndroid Build Coastguard Worker        SDD.apply(x1, y1).sum().backward()
5036*da0073e9SAndroid Build Coastguard Worker
5037*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
5038*da0073e9SAndroid Build Coastguard Worker        def fn():
5039*da0073e9SAndroid Build Coastguard Worker            return SDD.apply(x2, y2)
5040*da0073e9SAndroid Build Coastguard Worker
5041*da0073e9SAndroid Build Coastguard Worker        fn().sum().backward()
5042*da0073e9SAndroid Build Coastguard Worker
5043*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.grad, x2.grad)
5044*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1.grad, y2.grad)
5045*da0073e9SAndroid Build Coastguard Worker
5046*da0073e9SAndroid Build Coastguard Worker        trigger_graph_break = True
5047*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(torch._dynamo.exc.Unsupported):
5048*da0073e9SAndroid Build Coastguard Worker            fn().sum().backward()
5049*da0073e9SAndroid Build Coastguard Worker
5050*da0073e9SAndroid Build Coastguard Worker    def test_partially_initialized_module_property(self):
5051*da0073e9SAndroid Build Coastguard Worker        class Matrix(torch.nn.Module):
5052*da0073e9SAndroid Build Coastguard Worker            def __init__(self, data):
5053*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5054*da0073e9SAndroid Build Coastguard Worker                self._data = data
5055*da0073e9SAndroid Build Coastguard Worker                self.foo = 10 * self.blocking
5056*da0073e9SAndroid Build Coastguard Worker
5057*da0073e9SAndroid Build Coastguard Worker            @property
5058*da0073e9SAndroid Build Coastguard Worker            def data(self):
5059*da0073e9SAndroid Build Coastguard Worker                return self._data
5060*da0073e9SAndroid Build Coastguard Worker
5061*da0073e9SAndroid Build Coastguard Worker            @property
5062*da0073e9SAndroid Build Coastguard Worker            def blocking(self):
5063*da0073e9SAndroid Build Coastguard Worker                return self.data.shape[1]
5064*da0073e9SAndroid Build Coastguard Worker
5065*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
5066*da0073e9SAndroid Build Coastguard Worker        def fn():
5067*da0073e9SAndroid Build Coastguard Worker            return Matrix(torch.randn(10, 20))
5068*da0073e9SAndroid Build Coastguard Worker
5069*da0073e9SAndroid Build Coastguard Worker        v = fn()
5070*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.foo, 200)
5071*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.data.shape, (10, 20))
5072*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(v), Matrix)
5073*da0073e9SAndroid Build Coastguard Worker
5074*da0073e9SAndroid Build Coastguard Worker    def test_classmethod_with_slots(self):
5075*da0073e9SAndroid Build Coastguard Worker        class Mock:
5076*da0073e9SAndroid Build Coastguard Worker            __slots__ = ("_a",)
5077*da0073e9SAndroid Build Coastguard Worker
5078*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
5079*da0073e9SAndroid Build Coastguard Worker                self._a = 2
5080*da0073e9SAndroid Build Coastguard Worker
5081*da0073e9SAndroid Build Coastguard Worker            @classmethod
5082*da0073e9SAndroid Build Coastguard Worker            def _m(cls):
5083*da0073e9SAndroid Build Coastguard Worker                return 3
5084*da0073e9SAndroid Build Coastguard Worker
5085*da0073e9SAndroid Build Coastguard Worker            def run(self, x):
5086*da0073e9SAndroid Build Coastguard Worker                return torch.sin(x) * self._a * self._m()
5087*da0073e9SAndroid Build Coastguard Worker
5088*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5089*da0073e9SAndroid Build Coastguard Worker            mock = Mock()
5090*da0073e9SAndroid Build Coastguard Worker            return mock.run(x)
5091*da0073e9SAndroid Build Coastguard Worker
5092*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
5093*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
5094*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
5095*da0073e9SAndroid Build Coastguard Worker
5096*da0073e9SAndroid Build Coastguard Worker    def test_nn_parametrize(self):
5097*da0073e9SAndroid Build Coastguard Worker        class Module(nn.Module):
5098*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
5099*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5100*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.randn(10, 10))
5101*da0073e9SAndroid Build Coastguard Worker
5102*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5103*da0073e9SAndroid Build Coastguard Worker                return self.param @ x
5104*da0073e9SAndroid Build Coastguard Worker
5105*da0073e9SAndroid Build Coastguard Worker        class Parametrization(torch.nn.Module):
5106*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5107*da0073e9SAndroid Build Coastguard Worker                return torch.sin(x)
5108*da0073e9SAndroid Build Coastguard Worker
5109*da0073e9SAndroid Build Coastguard Worker        m = Module()
5110*da0073e9SAndroid Build Coastguard Worker        torch.nn.utils.parametrize.register_parametrization(
5111*da0073e9SAndroid Build Coastguard Worker            m, "param", Parametrization()
5112*da0073e9SAndroid Build Coastguard Worker        )
5113*da0073e9SAndroid Build Coastguard Worker
5114*da0073e9SAndroid Build Coastguard Worker        sin_found = False
5115*da0073e9SAndroid Build Coastguard Worker
5116*da0073e9SAndroid Build Coastguard Worker        def backend(gm, _):
5117*da0073e9SAndroid Build Coastguard Worker            nonlocal sin_found
5118*da0073e9SAndroid Build Coastguard Worker            for node in gm.graph.nodes:
5119*da0073e9SAndroid Build Coastguard Worker                if node.target is torch.sin:
5120*da0073e9SAndroid Build Coastguard Worker                    sin_found = True
5121*da0073e9SAndroid Build Coastguard Worker            return gm
5122*da0073e9SAndroid Build Coastguard Worker
5123*da0073e9SAndroid Build Coastguard Worker        opt_m = torch.compile(m, backend=backend, fullgraph=True)
5124*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(10, 10)
5125*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(inp), opt_m(inp))
5126*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(sin_found)
5127*da0073e9SAndroid Build Coastguard Worker
5128*da0073e9SAndroid Build Coastguard Worker        torch.nn.utils.parametrize.remove_parametrizations(m, "param")
5129*da0073e9SAndroid Build Coastguard Worker        sin_found = False
5130*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(inp), opt_m(inp))
5131*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(sin_found)
5132*da0073e9SAndroid Build Coastguard Worker
5133*da0073e9SAndroid Build Coastguard Worker    def test_nn_module_property_closure(self):
5134*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
5135*da0073e9SAndroid Build Coastguard Worker
5136*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
5137*da0073e9SAndroid Build Coastguard Worker            @property
5138*da0073e9SAndroid Build Coastguard Worker            def y(self):
5139*da0073e9SAndroid Build Coastguard Worker                return torch.ones(10, 10) + x
5140*da0073e9SAndroid Build Coastguard Worker
5141*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5142*da0073e9SAndroid Build Coastguard Worker                return x @ self.y
5143*da0073e9SAndroid Build Coastguard Worker
5144*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
5145*da0073e9SAndroid Build Coastguard Worker
5146*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5147*da0073e9SAndroid Build Coastguard Worker            return mod(x)
5148*da0073e9SAndroid Build Coastguard Worker
5149*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
5150*da0073e9SAndroid Build Coastguard Worker
5151*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(10, 10)
5152*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(inp), opt_fn(inp))
5153*da0073e9SAndroid Build Coastguard Worker
5154*da0073e9SAndroid Build Coastguard Worker    def test_global_fn_mutation(self):
5155*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
5156*da0073e9SAndroid Build Coastguard Worker            return global_fn(x) + y
5157*da0073e9SAndroid Build Coastguard Worker
5158*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
5159*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(1)
5160*da0073e9SAndroid Build Coastguard Worker
5161*da0073e9SAndroid Build Coastguard Worker        opt = torch.compile(foo, fullgraph=True, backend="eager")
5162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt(x, y), foo(x, y))
5163*da0073e9SAndroid Build Coastguard Worker
5164*da0073e9SAndroid Build Coastguard Worker        # Change global_fn
5165*da0073e9SAndroid Build Coastguard Worker        global global_fn
5166*da0073e9SAndroid Build Coastguard Worker
5167*da0073e9SAndroid Build Coastguard Worker        def new_fn(x):
5168*da0073e9SAndroid Build Coastguard Worker            return torch.cos(x)
5169*da0073e9SAndroid Build Coastguard Worker
5170*da0073e9SAndroid Build Coastguard Worker        global_fn = new_fn
5171*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt(x, y), foo(x, y))
5172*da0073e9SAndroid Build Coastguard Worker
5173*da0073e9SAndroid Build Coastguard Worker    # ref https://github.com/pytorch/pytorch/issues/123974
5174*da0073e9SAndroid Build Coastguard Worker    def test_list_reverse(self):
5175*da0073e9SAndroid Build Coastguard Worker        def ladder(x):
5176*da0073e9SAndroid Build Coastguard Worker            trail = x.size(-1)
5177*da0073e9SAndroid Build Coastguard Worker            assert trail > 2
5178*da0073e9SAndroid Build Coastguard Worker            weights = []
5179*da0073e9SAndroid Build Coastguard Worker            for s in [trail, trail - 1, trail - 2]:
5180*da0073e9SAndroid Build Coastguard Worker                weights.append(torch.ones(s, s - 1))
5181*da0073e9SAndroid Build Coastguard Worker
5182*da0073e9SAndroid Build Coastguard Worker            for w in weights:
5183*da0073e9SAndroid Build Coastguard Worker                x = x @ w
5184*da0073e9SAndroid Build Coastguard Worker
5185*da0073e9SAndroid Build Coastguard Worker            weights.reverse()
5186*da0073e9SAndroid Build Coastguard Worker
5187*da0073e9SAndroid Build Coastguard Worker            for w in weights:
5188*da0073e9SAndroid Build Coastguard Worker                x = x @ w.t()
5189*da0073e9SAndroid Build Coastguard Worker
5190*da0073e9SAndroid Build Coastguard Worker            return x
5191*da0073e9SAndroid Build Coastguard Worker
5192*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(3, 4)
5193*da0073e9SAndroid Build Coastguard Worker        opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager")
5194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_ladder(data), ladder(data))
5195*da0073e9SAndroid Build Coastguard Worker
5196*da0073e9SAndroid Build Coastguard Worker    def test_trace_functional_tensor_with(self):
5197*da0073e9SAndroid Build Coastguard Worker        from torch._subclasses.fake_tensor import FakeTensorMode
5198*da0073e9SAndroid Build Coastguard Worker        from torch._subclasses.functional_tensor import (
5199*da0073e9SAndroid Build Coastguard Worker            FunctionalTensor,
5200*da0073e9SAndroid Build Coastguard Worker            FunctionalTensorMode,
5201*da0073e9SAndroid Build Coastguard Worker        )
5202*da0073e9SAndroid Build Coastguard Worker
5203*da0073e9SAndroid Build Coastguard Worker        def f(a, tmp):
5204*da0073e9SAndroid Build Coastguard Worker            a_view = a.view(-1)
5205*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
5206*da0073e9SAndroid Build Coastguard Worker                a.set_(tmp)
5207*da0073e9SAndroid Build Coastguard Worker                a_view.mul_(2)
5208*da0073e9SAndroid Build Coastguard Worker            return a + tmp
5209*da0073e9SAndroid Build Coastguard Worker
5210*da0073e9SAndroid Build Coastguard Worker        fake_mode = FakeTensorMode()
5211*da0073e9SAndroid Build Coastguard Worker        with FunctionalTensorMode():
5212*da0073e9SAndroid Build Coastguard Worker            inp = torch.ones(3, 3, requires_grad=True)
5213*da0073e9SAndroid Build Coastguard Worker            inp = fake_mode.from_tensor(inp, static_shapes=True)
5214*da0073e9SAndroid Build Coastguard Worker            inp = FunctionalTensor.to_functional(inp)
5215*da0073e9SAndroid Build Coastguard Worker
5216*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(3, 3, requires_grad=True)
5217*da0073e9SAndroid Build Coastguard Worker            tmp = fake_mode.from_tensor(tmp, static_shapes=True)
5218*da0073e9SAndroid Build Coastguard Worker            tmp = FunctionalTensor.to_functional(tmp)
5219*da0073e9SAndroid Build Coastguard Worker
5220*da0073e9SAndroid Build Coastguard Worker            opt_f = torch.compile(f, backend="eager")
5221*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5222*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "cannot mutate tensors with frozen storage"
5223*da0073e9SAndroid Build Coastguard Worker            ):
5224*da0073e9SAndroid Build Coastguard Worker                opt_f(inp, tmp)
5225*da0073e9SAndroid Build Coastguard Worker
5226*da0073e9SAndroid Build Coastguard Worker    def test_const_dict_keyerror(self):
5227*da0073e9SAndroid Build Coastguard Worker        d = {}
5228*da0073e9SAndroid Build Coastguard Worker
5229*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5230*da0073e9SAndroid Build Coastguard Worker            try:
5231*da0073e9SAndroid Build Coastguard Worker                y = d[0]
5232*da0073e9SAndroid Build Coastguard Worker            except KeyError:
5233*da0073e9SAndroid Build Coastguard Worker                y = 1
5234*da0073e9SAndroid Build Coastguard Worker            return x + y
5235*da0073e9SAndroid Build Coastguard Worker
5236*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
5237*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 3)
5238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(inp), opt_fn(inp))
5239*da0073e9SAndroid Build Coastguard Worker
5240*da0073e9SAndroid Build Coastguard Worker    def test_dict_tag_guard(self):
5241*da0073e9SAndroid Build Coastguard Worker        class Foo:
5242*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
5243*da0073e9SAndroid Build Coastguard Worker                self.scalar = 10
5244*da0073e9SAndroid Build Coastguard Worker
5245*da0073e9SAndroid Build Coastguard Worker        def fn(d, x):
5246*da0073e9SAndroid Build Coastguard Worker            return d["a"] * d["b"] * d["c"].scalar * x
5247*da0073e9SAndroid Build Coastguard Worker
5248*da0073e9SAndroid Build Coastguard Worker        foo = Foo()
5249*da0073e9SAndroid Build Coastguard Worker
5250*da0073e9SAndroid Build Coastguard Worker        d = {"a": 2, "b": 3, "c": foo}
5251*da0073e9SAndroid Build Coastguard Worker
5252*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
5253*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 3)
5254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(d, inp), opt_fn(d, inp))
5255*da0073e9SAndroid Build Coastguard Worker
5256*da0073e9SAndroid Build Coastguard Worker        d["a"] = 4
5257*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(d, inp), opt_fn(d, inp))
5258*da0073e9SAndroid Build Coastguard Worker
5259*da0073e9SAndroid Build Coastguard Worker        # Check that recompilation happens
5260*da0073e9SAndroid Build Coastguard Worker        foo.scalar = 12
5261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(d, inp), opt_fn(d, inp))
5262*da0073e9SAndroid Build Coastguard Worker
5263*da0073e9SAndroid Build Coastguard Worker    def test_nonconst_issubclass(self):
5264*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5265*da0073e9SAndroid Build Coastguard Worker            if issubclass(x.__class__, np.ndarray):
5266*da0073e9SAndroid Build Coastguard Worker                return 1
5267*da0073e9SAndroid Build Coastguard Worker            return 0
5268*da0073e9SAndroid Build Coastguard Worker
5269*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
5270*da0073e9SAndroid Build Coastguard Worker        opt_fn(np.ones([3, 3]))
5271*da0073e9SAndroid Build Coastguard Worker
5272*da0073e9SAndroid Build Coastguard Worker    def test_issue126128(self):
5273*da0073e9SAndroid Build Coastguard Worker        def fn():
5274*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, 10)
5275*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(10, 1)
5276*da0073e9SAndroid Build Coastguard Worker            return torch.mm(x, y).sum()
5277*da0073e9SAndroid Build Coastguard Worker
5278*da0073e9SAndroid Build Coastguard Worker        def fn2():
5279*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(10, 100)
5280*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(100, 10)
5281*da0073e9SAndroid Build Coastguard Worker            return torch.mm(x, y).sum()
5282*da0073e9SAndroid Build Coastguard Worker
5283*da0073e9SAndroid Build Coastguard Worker        with fresh_inductor_cache():
5284*da0073e9SAndroid Build Coastguard Worker            torch.compile(fn)()
5285*da0073e9SAndroid Build Coastguard Worker
5286*da0073e9SAndroid Build Coastguard Worker        torch.compile(fn2)()
5287*da0073e9SAndroid Build Coastguard Worker
5288*da0073e9SAndroid Build Coastguard Worker    def test_jit_script_defaults(self):
5289*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
5290*da0073e9SAndroid Build Coastguard Worker        def fast_cos(x, c: float = 2.0):
5291*da0073e9SAndroid Build Coastguard Worker            return torch.cos(x) * c
5292*da0073e9SAndroid Build Coastguard Worker
5293*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
5294*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
5295*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5296*da0073e9SAndroid Build Coastguard Worker                self.fast_cos = fast_cos
5297*da0073e9SAndroid Build Coastguard Worker
5298*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5299*da0073e9SAndroid Build Coastguard Worker                return self.fast_cos(x)
5300*da0073e9SAndroid Build Coastguard Worker
5301*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
5302*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
5303*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
5304*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mod(x), opt_mod(x))
5305*da0073e9SAndroid Build Coastguard Worker
5306*da0073e9SAndroid Build Coastguard Worker    def test_enum(self):
5307*da0073e9SAndroid Build Coastguard Worker        class ExplicitEnum(str, Enum):
5308*da0073e9SAndroid Build Coastguard Worker            @classmethod
5309*da0073e9SAndroid Build Coastguard Worker            def _missing_(cls, value):
5310*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
5311*da0073e9SAndroid Build Coastguard Worker                    f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
5312*da0073e9SAndroid Build Coastguard Worker                )
5313*da0073e9SAndroid Build Coastguard Worker
5314*da0073e9SAndroid Build Coastguard Worker        class PaddingStrategy(ExplicitEnum):
5315*da0073e9SAndroid Build Coastguard Worker            LONGEST = "longest"
5316*da0073e9SAndroid Build Coastguard Worker            MAX_LENGTH = "max_length"
5317*da0073e9SAndroid Build Coastguard Worker            DO_NOT_PAD = "do_not_pad"
5318*da0073e9SAndroid Build Coastguard Worker
5319*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5320*da0073e9SAndroid Build Coastguard Worker            a = PaddingStrategy("longest")
5321*da0073e9SAndroid Build Coastguard Worker            if a == PaddingStrategy.LONGEST:
5322*da0073e9SAndroid Build Coastguard Worker                return torch.sin(x)
5323*da0073e9SAndroid Build Coastguard Worker            return torch.cos(x)
5324*da0073e9SAndroid Build Coastguard Worker
5325*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
5326*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
5327*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
5328*da0073e9SAndroid Build Coastguard Worker
5329*da0073e9SAndroid Build Coastguard Worker    def test_hasattr_builtin(self):
5330*da0073e9SAndroid Build Coastguard Worker        class MyClass:
5331*da0073e9SAndroid Build Coastguard Worker            foo: int = 1
5332*da0073e9SAndroid Build Coastguard Worker
5333*da0073e9SAndroid Build Coastguard Worker        def func(x, m):
5334*da0073e9SAndroid Build Coastguard Worker            if getattr(type(m), "foo", 0):
5335*da0073e9SAndroid Build Coastguard Worker                return x + MyClass.foo
5336*da0073e9SAndroid Build Coastguard Worker            return x
5337*da0073e9SAndroid Build Coastguard Worker
5338*da0073e9SAndroid Build Coastguard Worker        opt_func = torch.compile(func, backend="eager", fullgraph=True)
5339*da0073e9SAndroid Build Coastguard Worker        m = MyClass()
5340*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(())
5341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(func(x, m), opt_func(x, m))
5342*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(func(x, 0), opt_func(x, 0))
5343*da0073e9SAndroid Build Coastguard Worker
5344*da0073e9SAndroid Build Coastguard Worker    def test_grad(self):
5345*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
5346*da0073e9SAndroid Build Coastguard Worker            x._grad = y
5347*da0073e9SAndroid Build Coastguard Worker            return x.grad.data
5348*da0073e9SAndroid Build Coastguard Worker
5349*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, requires_grad=True)
5350*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4)
5351*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
5352*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x, y), opt_fn(x, y))
5353*da0073e9SAndroid Build Coastguard Worker
5354*da0073e9SAndroid Build Coastguard Worker    def test_nn_module_stack_bc(self):
5355*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.mutation_guard import GenerationTracker
5356*da0073e9SAndroid Build Coastguard Worker
5357*da0073e9SAndroid Build Coastguard Worker        def compiler(gm, *args):
5358*da0073e9SAndroid Build Coastguard Worker            module_stacks = [
5359*da0073e9SAndroid Build Coastguard Worker                node.meta.get("nn_module_stack", None) for node in gm.graph.nodes
5360*da0073e9SAndroid Build Coastguard Worker            ]
5361*da0073e9SAndroid Build Coastguard Worker            module_stacks, _ = pytree.tree_flatten(module_stacks)
5362*da0073e9SAndroid Build Coastguard Worker            module_stacks = [x for x in module_stacks if isinstance(x, str)]
5363*da0073e9SAndroid Build Coastguard Worker            for stack in module_stacks:
5364*da0073e9SAndroid Build Coastguard Worker                self.assertTrue("_module" not in stack)
5365*da0073e9SAndroid Build Coastguard Worker            return gm.forward
5366*da0073e9SAndroid Build Coastguard Worker
5367*da0073e9SAndroid Build Coastguard Worker        class SubMod(torch.nn.Module):
5368*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
5369*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5370*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(2, 2)
5371*da0073e9SAndroid Build Coastguard Worker
5372*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5373*da0073e9SAndroid Build Coastguard Worker                return self.linear(x)
5374*da0073e9SAndroid Build Coastguard Worker
5375*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
5376*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
5377*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5378*da0073e9SAndroid Build Coastguard Worker                self.submod1 = SubMod()
5379*da0073e9SAndroid Build Coastguard Worker                self.submod2 = SubMod()
5380*da0073e9SAndroid Build Coastguard Worker
5381*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5382*da0073e9SAndroid Build Coastguard Worker                return self.submod1(x) + self.submod2(x)
5383*da0073e9SAndroid Build Coastguard Worker
5384*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
5385*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch.compile(mod, backend=compiler)
5386*da0073e9SAndroid Build Coastguard Worker        opt_mod(torch.randn(2, 2))
5387*da0073e9SAndroid Build Coastguard Worker
5388*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True):
5389*da0073e9SAndroid Build Coastguard Worker            mod = Mod()
5390*da0073e9SAndroid Build Coastguard Worker            opt_mod = torch.compile(mod, backend=compiler)
5391*da0073e9SAndroid Build Coastguard Worker            opt_mod(torch.randn(2, 2))
5392*da0073e9SAndroid Build Coastguard Worker
5393*da0073e9SAndroid Build Coastguard Worker        # an example similar to Pippy usecase
5394*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
5395*da0073e9SAndroid Build Coastguard Worker        GenerationTracker.tag(mod.submod1)
5396*da0073e9SAndroid Build Coastguard Worker        GenerationTracker.mark_class_dynamic(type(mod.submod1))
5397*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
5398*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch.compile(mod, backend=compiler)
5399*da0073e9SAndroid Build Coastguard Worker        opt_mod(torch.randn(2, 2))
5400*da0073e9SAndroid Build Coastguard Worker
5401*da0073e9SAndroid Build Coastguard Worker    def test_is_make_fx_tracing(self):
5402*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
5403*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5404*da0073e9SAndroid Build Coastguard Worker            torch.nn.modules.activation._is_make_fx_tracing()
5405*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
5406*da0073e9SAndroid Build Coastguard Worker
5407*da0073e9SAndroid Build Coastguard Worker        fn(torch.rand(4))
5408*da0073e9SAndroid Build Coastguard Worker
5409*da0073e9SAndroid Build Coastguard Worker    def test_negative_floor_div_solve(self):
5410*da0073e9SAndroid Build Coastguard Worker        class CompiledClass(nn.Module):
5411*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
5412*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5413*da0073e9SAndroid Build Coastguard Worker                self.nums = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
5414*da0073e9SAndroid Build Coastguard Worker                self.t = 5
5415*da0073e9SAndroid Build Coastguard Worker
5416*da0073e9SAndroid Build Coastguard Worker            def forward(self):
5417*da0073e9SAndroid Build Coastguard Worker                self.num = self.nums[self.t // 12]
5418*da0073e9SAndroid Build Coastguard Worker                self.t += 1
5419*da0073e9SAndroid Build Coastguard Worker                return self.num
5420*da0073e9SAndroid Build Coastguard Worker
5421*da0073e9SAndroid Build Coastguard Worker        m = CompiledClass()
5422*da0073e9SAndroid Build Coastguard Worker        m = torch.compile(m, backend="eager")
5423*da0073e9SAndroid Build Coastguard Worker
5424*da0073e9SAndroid Build Coastguard Worker        # the first call works
5425*da0073e9SAndroid Build Coastguard Worker        m()
5426*da0073e9SAndroid Build Coastguard Worker        # the second call causes a failure
5427*da0073e9SAndroid Build Coastguard Worker        m()
5428*da0073e9SAndroid Build Coastguard Worker
5429*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/121621
5430*da0073e9SAndroid Build Coastguard Worker    def test_tensor_random(self):
5431*da0073e9SAndroid Build Coastguard Worker        def random_op(tensor, params):
5432*da0073e9SAndroid Build Coastguard Worker            res = tensor.random_(**params)
5433*da0073e9SAndroid Build Coastguard Worker            return res
5434*da0073e9SAndroid Build Coastguard Worker
5435*da0073e9SAndroid Build Coastguard Worker        random_op = torch.compile(random_op)
5436*da0073e9SAndroid Build Coastguard Worker        params = {"from": -10, "to": 10}
5437*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn([2, 3])
5438*da0073e9SAndroid Build Coastguard Worker        res = random_op(tensor, params)
5439*da0073e9SAndroid Build Coastguard Worker
5440*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/131019
5441*da0073e9SAndroid Build Coastguard Worker    def test_tensor_uniform(self):
5442*da0073e9SAndroid Build Coastguard Worker        def uniform_op(tensor, params):
5443*da0073e9SAndroid Build Coastguard Worker            res = tensor.uniform_(**params)
5444*da0073e9SAndroid Build Coastguard Worker            return res
5445*da0073e9SAndroid Build Coastguard Worker
5446*da0073e9SAndroid Build Coastguard Worker        uniform_op = torch.compile(uniform_op)
5447*da0073e9SAndroid Build Coastguard Worker        params = {"from": -10, "to": 10}
5448*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn([2, 3])
5449*da0073e9SAndroid Build Coastguard Worker        res = uniform_op(tensor, params)
5450*da0073e9SAndroid Build Coastguard Worker
5451*da0073e9SAndroid Build Coastguard Worker    def test_data_attr_mutation_after_saved_for_bw(self):
5452*da0073e9SAndroid Build Coastguard Worker        def f(x):
5453*da0073e9SAndroid Build Coastguard Worker            out = x.sin()
5454*da0073e9SAndroid Build Coastguard Worker            x.data.mul_(2)
5455*da0073e9SAndroid Build Coastguard Worker            return out
5456*da0073e9SAndroid Build Coastguard Worker
5457*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, requires_grad=True)
5458*da0073e9SAndroid Build Coastguard Worker        x_test = x.clone().detach().requires_grad_(True)
5459*da0073e9SAndroid Build Coastguard Worker
5460*da0073e9SAndroid Build Coastguard Worker        out = f(x)
5461*da0073e9SAndroid Build Coastguard Worker        out_test = torch.compile(f, backend="aot_eager")(x_test)
5462*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_test)
5463*da0073e9SAndroid Build Coastguard Worker
5464*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
5465*da0073e9SAndroid Build Coastguard Worker        out_test.sum().backward()
5466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, x_test.grad)
5467*da0073e9SAndroid Build Coastguard Worker
5468*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/128072
5469*da0073e9SAndroid Build Coastguard Worker    def test_map_with_multiple_args(self):
5470*da0073e9SAndroid Build Coastguard Worker        def f(a, b):
5471*da0073e9SAndroid Build Coastguard Worker            return a[0] * b[0] + a[1] * b[1]
5472*da0073e9SAndroid Build Coastguard Worker
5473*da0073e9SAndroid Build Coastguard Worker        def gen_inps(len_x, len_y):
5474*da0073e9SAndroid Build Coastguard Worker            x = [torch.randn(5) for _ in range(len_x)]
5475*da0073e9SAndroid Build Coastguard Worker            y = [torch.randn(5) for _ in range(len_y)]
5476*da0073e9SAndroid Build Coastguard Worker            return x, y
5477*da0073e9SAndroid Build Coastguard Worker
5478*da0073e9SAndroid Build Coastguard Worker        def g(x, y):
5479*da0073e9SAndroid Build Coastguard Worker            return map(f, x, y)
5480*da0073e9SAndroid Build Coastguard Worker
5481*da0073e9SAndroid Build Coastguard Worker        opt_g = torch.compile(g, fullgraph=True, backend="eager")
5482*da0073e9SAndroid Build Coastguard Worker
5483*da0073e9SAndroid Build Coastguard Worker        inps = gen_inps(3, 3)
5484*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
5485*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
5486*da0073e9SAndroid Build Coastguard Worker
5487*da0073e9SAndroid Build Coastguard Worker        inps = gen_inps(3, 5)
5488*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
5489*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
5490*da0073e9SAndroid Build Coastguard Worker
5491*da0073e9SAndroid Build Coastguard Worker    def test_staticmethod_allow_in_graph(self):
5492*da0073e9SAndroid Build Coastguard Worker        class MyClass:
5493*da0073e9SAndroid Build Coastguard Worker            i = 3
5494*da0073e9SAndroid Build Coastguard Worker
5495*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5496*da0073e9SAndroid Build Coastguard Worker            def foo_inner(x):
5497*da0073e9SAndroid Build Coastguard Worker                return torch.mul(x, MyClass.i)
5498*da0073e9SAndroid Build Coastguard Worker
5499*da0073e9SAndroid Build Coastguard Worker            # if dynamo inlines with fullgraph, will error
5500*da0073e9SAndroid Build Coastguard Worker            # verify that dynamo doesn't inline
5501*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5502*da0073e9SAndroid Build Coastguard Worker            @torch._dynamo.allow_in_graph
5503*da0073e9SAndroid Build Coastguard Worker            def foo1(x):
5504*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
5505*da0073e9SAndroid Build Coastguard Worker                return MyClass.foo_inner(x)
5506*da0073e9SAndroid Build Coastguard Worker
5507*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
5508*da0073e9SAndroid Build Coastguard Worker        def f_bad(x):
5509*da0073e9SAndroid Build Coastguard Worker            return MyClass.foo1(x)
5510*da0073e9SAndroid Build Coastguard Worker
5511*da0073e9SAndroid Build Coastguard Worker        f_bad(torch.ones(2, 2))
5512*da0073e9SAndroid Build Coastguard Worker
5513*da0073e9SAndroid Build Coastguard Worker    def test_guard_with_tuple_mutation(self):
5514*da0073e9SAndroid Build Coastguard Worker        class Foo:
5515*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
5516*da0073e9SAndroid Build Coastguard Worker                self.x = 10
5517*da0073e9SAndroid Build Coastguard Worker
5518*da0073e9SAndroid Build Coastguard Worker        foo = Foo()
5519*da0073e9SAndroid Build Coastguard Worker        d = {
5520*da0073e9SAndroid Build Coastguard Worker            "a": 2,
5521*da0073e9SAndroid Build Coastguard Worker            "b": (foo,),
5522*da0073e9SAndroid Build Coastguard Worker        }
5523*da0073e9SAndroid Build Coastguard Worker
5524*da0073e9SAndroid Build Coastguard Worker        def fn(x, d):
5525*da0073e9SAndroid Build Coastguard Worker            return x * d["a"] * d["b"][0].x
5526*da0073e9SAndroid Build Coastguard Worker
5527*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
5528*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(3, 3)
5529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(inp, d), opt_fn(inp, d))
5530*da0073e9SAndroid Build Coastguard Worker        d["b"][0].x = 12
5531*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(inp, d), opt_fn(inp, d))
5532*da0073e9SAndroid Build Coastguard Worker
5533*da0073e9SAndroid Build Coastguard Worker    def test_compile_complex_conj(self):
5534*da0073e9SAndroid Build Coastguard Worker        def f(x):
5535*da0073e9SAndroid Build Coastguard Worker            return torch.mul(x, 2j)
5536*da0073e9SAndroid Build Coastguard Worker
5537*da0073e9SAndroid Build Coastguard Worker        x_ref = torch.randn(4, 2, requires_grad=True)
5538*da0073e9SAndroid Build Coastguard Worker        x_test = x_ref.clone().detach().requires_grad_(True)
5539*da0073e9SAndroid Build Coastguard Worker
5540*da0073e9SAndroid Build Coastguard Worker        out_ref = f(torch.view_as_complex(x_ref))
5541*da0073e9SAndroid Build Coastguard Worker        out_test = torch.compile(f, backend="aot_eager")(torch.view_as_complex(x_test))
5542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_ref, out_test)
5543*da0073e9SAndroid Build Coastguard Worker
5544*da0073e9SAndroid Build Coastguard Worker        torch.view_as_real(out_ref).sum().backward()
5545*da0073e9SAndroid Build Coastguard Worker        torch.view_as_real(out_test).sum().backward()
5546*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_ref.grad, x_test.grad)
5547*da0073e9SAndroid Build Coastguard Worker
5548*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/132200
5549*da0073e9SAndroid Build Coastguard Worker    def test_partitioner_cse_respects_mutation_boundaries(self):
5550*da0073e9SAndroid Build Coastguard Worker        set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")
5551*da0073e9SAndroid Build Coastguard Worker        if not set_available:
5552*da0073e9SAndroid Build Coastguard Worker            return
5553*da0073e9SAndroid Build Coastguard Worker
5554*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="aot_eager_decomp_partition")
5555*da0073e9SAndroid Build Coastguard Worker        def f(x, l):
5556*da0073e9SAndroid Build Coastguard Worker            # z0 and z1 can be CSEd
5557*da0073e9SAndroid Build Coastguard Worker            z0 = x.sin()
5558*da0073e9SAndroid Build Coastguard Worker            z1 = x.sin()
5559*da0073e9SAndroid Build Coastguard Worker            y = x + 1
5560*da0073e9SAndroid Build Coastguard Worker            torch.ops.fsdp.set_.default(x, y)
5561*da0073e9SAndroid Build Coastguard Worker            # z3 and z3 can be CSEd with each other,
5562*da0073e9SAndroid Build Coastguard Worker            # but *not* with z0/z1 (they cross a mutation boundary)
5563*da0073e9SAndroid Build Coastguard Worker            z2 = x.sin()
5564*da0073e9SAndroid Build Coastguard Worker            z3 = x.sin()
5565*da0073e9SAndroid Build Coastguard Worker            return z0, z1, z2, z3, l**2
5566*da0073e9SAndroid Build Coastguard Worker
5567*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
5568*da0073e9SAndroid Build Coastguard Worker        x_clone = x.clone()
5569*da0073e9SAndroid Build Coastguard Worker        l = torch.randn(3, requires_grad=True)
5570*da0073e9SAndroid Build Coastguard Worker        z0, z1, z2, z3, _ = f(x, l)
5571*da0073e9SAndroid Build Coastguard Worker
5572*da0073e9SAndroid Build Coastguard Worker        # the partitioner runs CSE. We expect that of the 4 sin() ops above:
5573*da0073e9SAndroid Build Coastguard Worker        # - the first 2 are CSE'd
5574*da0073e9SAndroid Build Coastguard Worker        # - the last 2 are CSE'd
5575*da0073e9SAndroid Build Coastguard Worker        # - the set_() op in the middle is a mutation barrier, preventing CSE
5576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z0, (x_clone).sin())
5577*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z1, (x_clone).sin())
5578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z2, (x_clone + 1).sin())
5579*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z3, (x_clone + 1).sin())
5580*da0073e9SAndroid Build Coastguard Worker
5581*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/132197
5582*da0073e9SAndroid Build Coastguard Worker    def test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients(self):
5583*da0073e9SAndroid Build Coastguard Worker        set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")
5584*da0073e9SAndroid Build Coastguard Worker        if not set_available:
5585*da0073e9SAndroid Build Coastguard Worker            return
5586*da0073e9SAndroid Build Coastguard Worker
5587*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="aot_eager_decomp_partition")
5588*da0073e9SAndroid Build Coastguard Worker        def f(x, l):
5589*da0073e9SAndroid Build Coastguard Worker            z = x.sin()
5590*da0073e9SAndroid Build Coastguard Worker            y = x + 1
5591*da0073e9SAndroid Build Coastguard Worker            # graph input has its storage mutated
5592*da0073e9SAndroid Build Coastguard Worker            torch.ops.fsdp.set_.default(x, y)
5593*da0073e9SAndroid Build Coastguard Worker            z2 = x.sin()
5594*da0073e9SAndroid Build Coastguard Worker            return z2, l**2
5595*da0073e9SAndroid Build Coastguard Worker
5596*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
5597*da0073e9SAndroid Build Coastguard Worker        x_test = x.clone()
5598*da0073e9SAndroid Build Coastguard Worker        l = torch.randn(3, requires_grad=True)
5599*da0073e9SAndroid Build Coastguard Worker        result, _ = f(x, l)
5600*da0073e9SAndroid Build Coastguard Worker        result_test, _ = torch.compile(f, backend="aot_eager_decomp_partition")(
5601*da0073e9SAndroid Build Coastguard Worker            x_test, l
5602*da0073e9SAndroid Build Coastguard Worker        )
5603*da0073e9SAndroid Build Coastguard Worker
5604*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, result_test)
5605*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x_test)
5606*da0073e9SAndroid Build Coastguard Worker
5607*da0073e9SAndroid Build Coastguard Worker    def test_changing_stride(self):
5608*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
5609*da0073e9SAndroid Build Coastguard Worker
5610*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt)
5611*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
5612*da0073e9SAndroid Build Coastguard Worker            return x * y
5613*da0073e9SAndroid Build Coastguard Worker
5614*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 4):
5615*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(4, i)
5616*da0073e9SAndroid Build Coastguard Worker
5617*da0073e9SAndroid Build Coastguard Worker            # create a view for i > 1
5618*da0073e9SAndroid Build Coastguard Worker            if i == 1:
5619*da0073e9SAndroid Build Coastguard Worker                x1 = x
5620*da0073e9SAndroid Build Coastguard Worker            else:
5621*da0073e9SAndroid Build Coastguard Worker                x1 = x[:, 0:1]
5622*da0073e9SAndroid Build Coastguard Worker
5623*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(4, 1)
5624*da0073e9SAndroid Build Coastguard Worker            print(x1.shape, y.shape)
5625*da0073e9SAndroid Build Coastguard Worker            fn(x1, y)
5626*da0073e9SAndroid Build Coastguard Worker
5627*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cnt.frame_count <= 2)
5628*da0073e9SAndroid Build Coastguard Worker
5629*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(guard_nn_modules=False)
5630*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
5631*da0073e9SAndroid Build Coastguard Worker    def test_inlining_cornercase(self):
5632*da0073e9SAndroid Build Coastguard Worker        """
5633*da0073e9SAndroid Build Coastguard Worker        nn.Modules can be mapped to either NNModuleVariable or UnspecializedNNModuleVariable. For NNModuleVariable, the
5634*da0073e9SAndroid Build Coastguard Worker        tensor attributes become part of the Dynamo graph. For unspecialized, they are lifted as inputs.
5635*da0073e9SAndroid Build Coastguard Worker
5636*da0073e9SAndroid Build Coastguard Worker        But there is a cornercase. Suppose you have NNModuleVariable with a submodule that is
5637*da0073e9SAndroid Build Coastguard Worker        UnspecializedNNModuleVariable. Today, Dynamo will still consider the submodule as specialized (courtesy of
5638*da0073e9SAndroid Build Coastguard Worker        guard.source().is_nn_module()). In retrospect, this is a mistake but there are dependencies of export and also
5639*da0073e9SAndroid Build Coastguard Worker        cudagraphs which make it harder to fix the corner case right away. The long term solution is
5640*da0073e9SAndroid Build Coastguard Worker        inline_inbuilt_nn_modules anyways, so we might have to live with this cornercase in the short term.
5641*da0073e9SAndroid Build Coastguard Worker
5642*da0073e9SAndroid Build Coastguard Worker        We are starting to annotate the source of each nn module more precisely - NNModuleVariable attribute is marked
5643*da0073e9SAndroid Build Coastguard Worker        as NNModuleSource, UnspecilaizedNNModuleVariable attribute is marked as UnspecializedNNModuleSource. But this
5644*da0073e9SAndroid Build Coastguard Worker        changes the behavior for the cornercase. And fails some tests which have unfortunately relied on this behavior.
5645*da0073e9SAndroid Build Coastguard Worker
5646*da0073e9SAndroid Build Coastguard Worker
5647*da0073e9SAndroid Build Coastguard Worker        To solve this, we tag the source only when inline_inbuilt_nn_module flag is turned on.
5648*da0073e9SAndroid Build Coastguard Worker
5649*da0073e9SAndroid Build Coastguard Worker        In this test, we purposely turn the flag off, testing that the tagging is disabled.
5650*da0073e9SAndroid Build Coastguard Worker        """
5651*da0073e9SAndroid Build Coastguard Worker
5652*da0073e9SAndroid Build Coastguard Worker        class SubMod(torch.nn.Module):
5653*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
5654*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5655*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(1, 1)
5656*da0073e9SAndroid Build Coastguard Worker                self.a = torch.randn(1, 1)
5657*da0073e9SAndroid Build Coastguard Worker                self.counter = 0
5658*da0073e9SAndroid Build Coastguard Worker                self.multipliers = [2.2, 3.3]
5659*da0073e9SAndroid Build Coastguard Worker
5660*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5661*da0073e9SAndroid Build Coastguard Worker                self.counter += 1
5662*da0073e9SAndroid Build Coastguard Worker                return (
5663*da0073e9SAndroid Build Coastguard Worker                    self.linear(x) * self.a * self.multipliers[0] * self.multipliers[1]
5664*da0073e9SAndroid Build Coastguard Worker                )
5665*da0073e9SAndroid Build Coastguard Worker
5666*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
5667*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
5668*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5669*da0073e9SAndroid Build Coastguard Worker                self.submod = SubMod()
5670*da0073e9SAndroid Build Coastguard Worker
5671*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
5672*da0073e9SAndroid Build Coastguard Worker                return self.submod(x)
5673*da0073e9SAndroid Build Coastguard Worker
5674*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
5675*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch.compile(mod, backend="eager")
5676*da0073e9SAndroid Build Coastguard Worker
5677*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
5678*da0073e9SAndroid Build Coastguard Worker        ref = mod(x)
5679*da0073e9SAndroid Build Coastguard Worker        res = opt_mod(x)
5680*da0073e9SAndroid Build Coastguard Worker
5681*da0073e9SAndroid Build Coastguard Worker        mod.submod.multipliers = [3.3, 4.4]
5682*da0073e9SAndroid Build Coastguard Worker        # Since guard_nn_modules is False, this will not recompile
5683*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch(error_on_recompile=True):
5684*da0073e9SAndroid Build Coastguard Worker            ref = mod(x)
5685*da0073e9SAndroid Build Coastguard Worker            res = opt_mod(x)
5686*da0073e9SAndroid Build Coastguard Worker
5687*da0073e9SAndroid Build Coastguard Worker    def test_optimized_module_training(self):
5688*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Linear(3, 3)
5689*da0073e9SAndroid Build Coastguard Worker        mod.eval()
5690*da0073e9SAndroid Build Coastguard Worker
5691*da0073e9SAndroid Build Coastguard Worker        opt_mod = torch.compile(mod, backend="eager")
5692*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(opt_mod.training)
5693*da0073e9SAndroid Build Coastguard Worker
5694*da0073e9SAndroid Build Coastguard Worker        opt_mod.train()
5695*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(opt_mod.training)
5696*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mod.training)
5697*da0073e9SAndroid Build Coastguard Worker
5698*da0073e9SAndroid Build Coastguard Worker        mod.eval()
5699*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(opt_mod.training)
5700*da0073e9SAndroid Build Coastguard Worker
5701*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
5702*da0073e9SAndroid Build Coastguard Worker    def test_memleak_when_graph_input_has_tensor_attr(self):
5703*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
5704*da0073e9SAndroid Build Coastguard Worker        def f(x):
5705*da0073e9SAndroid Build Coastguard Worker            x.add_(1)
5706*da0073e9SAndroid Build Coastguard Worker
5707*da0073e9SAndroid Build Coastguard Worker        mem_before = torch.cuda.memory_allocated()
5708*da0073e9SAndroid Build Coastguard Worker
5709*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, device="cuda")
5710*da0073e9SAndroid Build Coastguard Worker        x.foo = torch.zeros(2, device="cuda")
5711*da0073e9SAndroid Build Coastguard Worker        f(x)
5712*da0073e9SAndroid Build Coastguard Worker        del x.foo
5713*da0073e9SAndroid Build Coastguard Worker        del x
5714*da0073e9SAndroid Build Coastguard Worker        mem_after = torch.cuda.memory_allocated()
5715*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mem_before, mem_after)
5716*da0073e9SAndroid Build Coastguard Worker
5717*da0073e9SAndroid Build Coastguard Worker        # check when non-tensor data structure attribute contains a tensor
5718*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
5719*da0073e9SAndroid Build Coastguard Worker        def f(x):
5720*da0073e9SAndroid Build Coastguard Worker            x.add_(1)
5721*da0073e9SAndroid Build Coastguard Worker
5722*da0073e9SAndroid Build Coastguard Worker        mem_before = torch.cuda.memory_allocated()
5723*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, device="cuda")
5724*da0073e9SAndroid Build Coastguard Worker        x.foo = [torch.zeros(2, device="cuda") for _ in range(5)]
5725*da0073e9SAndroid Build Coastguard Worker        f(x)
5726*da0073e9SAndroid Build Coastguard Worker        del x.foo
5727*da0073e9SAndroid Build Coastguard Worker        del x
5728*da0073e9SAndroid Build Coastguard Worker        mem_after = torch.cuda.memory_allocated()
5729*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mem_before, mem_after)
5730*da0073e9SAndroid Build Coastguard Worker
5731*da0073e9SAndroid Build Coastguard Worker        # check with tensor refcycle
5732*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
5733*da0073e9SAndroid Build Coastguard Worker        def g(x, y):
5734*da0073e9SAndroid Build Coastguard Worker            return x + y
5735*da0073e9SAndroid Build Coastguard Worker
5736*da0073e9SAndroid Build Coastguard Worker        mem_before = torch.cuda.memory_allocated()
5737*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, device="cuda")
5738*da0073e9SAndroid Build Coastguard Worker        y = torch.zeros(2, device="cuda")
5739*da0073e9SAndroid Build Coastguard Worker        x.foo = [y]
5740*da0073e9SAndroid Build Coastguard Worker        y.foo = [x]
5741*da0073e9SAndroid Build Coastguard Worker        g(x, y)
5742*da0073e9SAndroid Build Coastguard Worker        del x.foo
5743*da0073e9SAndroid Build Coastguard Worker        del y.foo
5744*da0073e9SAndroid Build Coastguard Worker        del x
5745*da0073e9SAndroid Build Coastguard Worker        del y
5746*da0073e9SAndroid Build Coastguard Worker        mem_after = torch.cuda.memory_allocated()
5747*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mem_before, mem_after)
5748*da0073e9SAndroid Build Coastguard Worker
5749*da0073e9SAndroid Build Coastguard Worker    def test_os_fspath(self):
5750*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
5751*da0073e9SAndroid Build Coastguard Worker        def fn(x):
5752*da0073e9SAndroid Build Coastguard Worker            os.fspath(".")
5753*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
5754*da0073e9SAndroid Build Coastguard Worker
5755*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(4))
5756*da0073e9SAndroid Build Coastguard Worker
5757*da0073e9SAndroid Build Coastguard Worker    @requires_cuda
5758*da0073e9SAndroid Build Coastguard Worker    # This test will fail as flip in combination with particular input lenghts
5759*da0073e9SAndroid Build Coastguard Worker    # produces weird results.
5760*da0073e9SAndroid Build Coastguard Worker    # This is under investigations in
5761*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/131805
5762*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Skip this flip test for the moment. It is under investigation")
5763*da0073e9SAndroid Build Coastguard Worker    def test_flip_bad_accuracy(self):
5764*da0073e9SAndroid Build Coastguard Worker        import torch
5765*da0073e9SAndroid Build Coastguard Worker        import torch._dynamo.config
5766*da0073e9SAndroid Build Coastguard Worker        import torch._functorch.config
5767*da0073e9SAndroid Build Coastguard Worker        import torch._inductor.config
5768*da0073e9SAndroid Build Coastguard Worker        import torch._inductor.inductor_prims
5769*da0073e9SAndroid Build Coastguard Worker        import torch.fx.experimental._config
5770*da0073e9SAndroid Build Coastguard Worker
5771*da0073e9SAndroid Build Coastguard Worker        class Repro(torch.nn.Module):
5772*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
5773*da0073e9SAndroid Build Coastguard Worker                super().__init__()
5774*da0073e9SAndroid Build Coastguard Worker
5775*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg0_1):
5776*da0073e9SAndroid Build Coastguard Worker                rev = torch.ops.prims.rev.default(arg0_1, [0])
5777*da0073e9SAndroid Build Coastguard Worker                arg0_1 = None
5778*da0073e9SAndroid Build Coastguard Worker                slice_1 = torch.ops.aten.slice.Tensor(rev, 0, 0, -1, 2)
5779*da0073e9SAndroid Build Coastguard Worker                slice_2 = torch.ops.aten.slice.Tensor(rev, 0, 1, 9223372036854775807, 2)
5780*da0073e9SAndroid Build Coastguard Worker                add_1 = torch.ops.aten.add.Tensor(slice_1, slice_2)
5781*da0073e9SAndroid Build Coastguard Worker                slice_1 = slice_2 = None
5782*da0073e9SAndroid Build Coastguard Worker                slice_3 = torch.ops.aten.slice.Tensor(add_1, 0, 0, -1, 2)
5783*da0073e9SAndroid Build Coastguard Worker                slice_4 = torch.ops.aten.slice.Tensor(
5784*da0073e9SAndroid Build Coastguard Worker                    add_1, 0, 1, 9223372036854775807, 2
5785*da0073e9SAndroid Build Coastguard Worker                )
5786*da0073e9SAndroid Build Coastguard Worker                add_2 = torch.ops.aten.add.Tensor(slice_3, slice_4)
5787*da0073e9SAndroid Build Coastguard Worker                slice_3 = slice_4 = None
5788*da0073e9SAndroid Build Coastguard Worker                slice_5 = torch.ops.aten.slice.Tensor(add_2, 0, 0, -1, 2)
5789*da0073e9SAndroid Build Coastguard Worker                slice_6 = torch.ops.aten.slice.Tensor(
5790*da0073e9SAndroid Build Coastguard Worker                    add_2, 0, 1, 9223372036854775807, 2
5791*da0073e9SAndroid Build Coastguard Worker                )
5792*da0073e9SAndroid Build Coastguard Worker                add_3 = torch.ops.aten.add.Tensor(slice_5, slice_6)
5793*da0073e9SAndroid Build Coastguard Worker                slice_5 = slice_6 = None
5794*da0073e9SAndroid Build Coastguard Worker                slice_9 = torch.ops.aten.slice.Tensor(add_2, 0, 0, 1)
5795*da0073e9SAndroid Build Coastguard Worker                add_2 = None
5796*da0073e9SAndroid Build Coastguard Worker                unsqueeze = torch.ops.aten.unsqueeze.default(slice_9, 1)
5797*da0073e9SAndroid Build Coastguard Worker                slice_9 = None
5798*da0073e9SAndroid Build Coastguard Worker                unsqueeze_1 = torch.ops.aten.unsqueeze.default(add_3, 1)
5799*da0073e9SAndroid Build Coastguard Worker                add_3 = None
5800*da0073e9SAndroid Build Coastguard Worker                cat = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1)
5801*da0073e9SAndroid Build Coastguard Worker                unsqueeze = unsqueeze_1 = None
5802*da0073e9SAndroid Build Coastguard Worker                view = torch.ops.aten.view.default(cat, [2])
5803*da0073e9SAndroid Build Coastguard Worker                cat = None
5804*da0073e9SAndroid Build Coastguard Worker                slice_10 = torch.ops.aten.slice.Tensor(view, 0, 0, -1)
5805*da0073e9SAndroid Build Coastguard Worker                slice_11 = torch.ops.aten.slice.Tensor(
5806*da0073e9SAndroid Build Coastguard Worker                    add_1, 0, 2, 9223372036854775807, 2
5807*da0073e9SAndroid Build Coastguard Worker                )
5808*da0073e9SAndroid Build Coastguard Worker                add_5 = torch.ops.aten.add.Tensor(slice_10, slice_11)
5809*da0073e9SAndroid Build Coastguard Worker                slice_10 = slice_11 = None
5810*da0073e9SAndroid Build Coastguard Worker                slice_12 = torch.ops.aten.slice.Tensor(add_1, 0, 0, 1)
5811*da0073e9SAndroid Build Coastguard Worker                add_1 = None
5812*da0073e9SAndroid Build Coastguard Worker                cat_1 = torch.ops.aten.cat.default([slice_12, add_5])
5813*da0073e9SAndroid Build Coastguard Worker                slice_12 = add_5 = None
5814*da0073e9SAndroid Build Coastguard Worker                unsqueeze_2 = torch.ops.aten.unsqueeze.default(cat_1, 1)
5815*da0073e9SAndroid Build Coastguard Worker                cat_1 = None
5816*da0073e9SAndroid Build Coastguard Worker                unsqueeze_3 = torch.ops.aten.unsqueeze.default(view, 1)
5817*da0073e9SAndroid Build Coastguard Worker                view = None
5818*da0073e9SAndroid Build Coastguard Worker                cat_2 = torch.ops.aten.cat.default([unsqueeze_2, unsqueeze_3], 1)
5819*da0073e9SAndroid Build Coastguard Worker                unsqueeze_2 = unsqueeze_3 = None
5820*da0073e9SAndroid Build Coastguard Worker                view_1 = torch.ops.aten.view.default(cat_2, [4])
5821*da0073e9SAndroid Build Coastguard Worker                cat_2 = None
5822*da0073e9SAndroid Build Coastguard Worker                slice_13 = torch.ops.aten.slice.Tensor(
5823*da0073e9SAndroid Build Coastguard Worker                    rev, 0, 2, 9223372036854775807, 2
5824*da0073e9SAndroid Build Coastguard Worker                )
5825*da0073e9SAndroid Build Coastguard Worker                add_6 = torch.ops.aten.add.Tensor(view_1, slice_13)
5826*da0073e9SAndroid Build Coastguard Worker                slice_13 = None
5827*da0073e9SAndroid Build Coastguard Worker                slice_14 = torch.ops.aten.slice.Tensor(rev, 0, 0, 1)
5828*da0073e9SAndroid Build Coastguard Worker                rev = None
5829*da0073e9SAndroid Build Coastguard Worker                cat_3 = torch.ops.aten.cat.default([slice_14, add_6])
5830*da0073e9SAndroid Build Coastguard Worker                slice_14 = add_6 = None
5831*da0073e9SAndroid Build Coastguard Worker                constant_pad_nd = torch.ops.aten.constant_pad_nd.default(
5832*da0073e9SAndroid Build Coastguard Worker                    view_1, [0, 1], 0.0
5833*da0073e9SAndroid Build Coastguard Worker                )
5834*da0073e9SAndroid Build Coastguard Worker                view_1 = None
5835*da0073e9SAndroid Build Coastguard Worker                unsqueeze_4 = torch.ops.aten.unsqueeze.default(cat_3, 1)
5836*da0073e9SAndroid Build Coastguard Worker                cat_3 = None
5837*da0073e9SAndroid Build Coastguard Worker                unsqueeze_5 = torch.ops.aten.unsqueeze.default(constant_pad_nd, 1)
5838*da0073e9SAndroid Build Coastguard Worker                constant_pad_nd = None
5839*da0073e9SAndroid Build Coastguard Worker                cat_4 = torch.ops.aten.cat.default([unsqueeze_4, unsqueeze_5], 1)
5840*da0073e9SAndroid Build Coastguard Worker                unsqueeze_4 = unsqueeze_5 = None
5841*da0073e9SAndroid Build Coastguard Worker                view_2 = torch.ops.aten.view.default(cat_4, [10])
5842*da0073e9SAndroid Build Coastguard Worker                cat_4 = None
5843*da0073e9SAndroid Build Coastguard Worker                slice_15 = torch.ops.aten.slice.Tensor(view_2, 0, 0, 9)
5844*da0073e9SAndroid Build Coastguard Worker                view_2 = None
5845*da0073e9SAndroid Build Coastguard Worker                rev_1 = torch.ops.prims.rev.default(slice_15, [0])
5846*da0073e9SAndroid Build Coastguard Worker                slice_15 = None
5847*da0073e9SAndroid Build Coastguard Worker                return (rev_1,)
5848*da0073e9SAndroid Build Coastguard Worker
5849*da0073e9SAndroid Build Coastguard Worker        mod = Repro()
5850*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(9, device=torch.device("cuda"))
5851*da0073e9SAndroid Build Coastguard Worker
5852*da0073e9SAndroid Build Coastguard Worker        @torch.compile
5853*da0073e9SAndroid Build Coastguard Worker        def f(x):
5854*da0073e9SAndroid Build Coastguard Worker            return mod(x)
5855*da0073e9SAndroid Build Coastguard Worker
5856*da0073e9SAndroid Build Coastguard Worker        out = f(x)
5857*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]), out[0])
5858*da0073e9SAndroid Build Coastguard Worker
5859*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/88813
5860*da0073e9SAndroid Build Coastguard Worker    def test_return_value_duplication_tensor(self) -> None:
5861*da0073e9SAndroid Build Coastguard Worker        def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5862*da0073e9SAndroid Build Coastguard Worker            return val * 2, val * 2
5863*da0073e9SAndroid Build Coastguard Worker
5864*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, requires_grad=True)
5865*da0073e9SAndroid Build Coastguard Worker
5866*da0073e9SAndroid Build Coastguard Worker        expect = fn(x)
5867*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(
5868*da0073e9SAndroid Build Coastguard Worker            expect[0].untyped_storage().data_ptr(),
5869*da0073e9SAndroid Build Coastguard Worker            expect[1].untyped_storage().data_ptr(),
5870*da0073e9SAndroid Build Coastguard Worker        )
5871*da0073e9SAndroid Build Coastguard Worker
5872*da0073e9SAndroid Build Coastguard Worker        actual = torch.compile(fn, backend="aot_eager")(x)
5873*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(
5874*da0073e9SAndroid Build Coastguard Worker            actual[0].untyped_storage().data_ptr(),
5875*da0073e9SAndroid Build Coastguard Worker            actual[1].untyped_storage().data_ptr(),
5876*da0073e9SAndroid Build Coastguard Worker        )
5877*da0073e9SAndroid Build Coastguard Worker
5878*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/114344
5879*da0073e9SAndroid Build Coastguard Worker    def test_return_value_duplication_mixed_grad(self) -> None:
5880*da0073e9SAndroid Build Coastguard Worker        def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5881*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
5882*da0073e9SAndroid Build Coastguard Worker                out0 = val + 1
5883*da0073e9SAndroid Build Coastguard Worker            out1 = val + 1
5884*da0073e9SAndroid Build Coastguard Worker            return out0, out1
5885*da0073e9SAndroid Build Coastguard Worker
5886*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, requires_grad=True)
5887*da0073e9SAndroid Build Coastguard Worker
5888*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
5889*da0073e9SAndroid Build Coastguard Worker            expect = fn(x)
5890*da0073e9SAndroid Build Coastguard Worker            actual = torch.compile(fn, backend="aot_eager")(x)
5891*da0073e9SAndroid Build Coastguard Worker
5892*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect[0].requires_grad, actual[0].requires_grad)
5893*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect[1].requires_grad, actual[1].requires_grad)
5894*da0073e9SAndroid Build Coastguard Worker
5895*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/pull/134726#discussion_r1738774371
5896*da0073e9SAndroid Build Coastguard Worker    def test_return_value_duplication_scalar(self) -> None:
5897*da0073e9SAndroid Build Coastguard Worker        def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5898*da0073e9SAndroid Build Coastguard Worker            x, y = val * 2, val * 2
5899*da0073e9SAndroid Build Coastguard Worker            return x[0], y[0]
5900*da0073e9SAndroid Build Coastguard Worker
5901*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, requires_grad=True)
5902*da0073e9SAndroid Build Coastguard Worker
5903*da0073e9SAndroid Build Coastguard Worker        expect = fn(x)
5904*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(
5905*da0073e9SAndroid Build Coastguard Worker            expect[0].untyped_storage().data_ptr(),
5906*da0073e9SAndroid Build Coastguard Worker            expect[1].untyped_storage().data_ptr(),
5907*da0073e9SAndroid Build Coastguard Worker        )
5908*da0073e9SAndroid Build Coastguard Worker
5909*da0073e9SAndroid Build Coastguard Worker        actual = torch.compile(fn, backend="aot_eager")(x)
5910*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(
5911*da0073e9SAndroid Build Coastguard Worker            actual[0].untyped_storage().data_ptr(),
5912*da0073e9SAndroid Build Coastguard Worker            actual[1].untyped_storage().data_ptr(),
5913*da0073e9SAndroid Build Coastguard Worker        )
5914*da0073e9SAndroid Build Coastguard Worker
5915*da0073e9SAndroid Build Coastguard Worker
5916*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(ReproTests)
5917*da0073e9SAndroid Build Coastguard Worker
5918*da0073e9SAndroid Build Coastguard Worker
5919*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
5920*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
5921*da0073e9SAndroid Build Coastguard Worker
5922*da0073e9SAndroid Build Coastguard Worker    run_tests()
5923