xref: /aosp_15_r20/external/pytorch/test/test_decomp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: decompositions"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerimport itertools
5*da0073e9SAndroid Build Coastguard Workerimport re
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
8*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport torch._inductor.decomposition
11*da0073e9SAndroid Build Coastguard Workerimport torch.autograd
12*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
13*da0073e9SAndroid Build Coastguard Workerfrom torch._decomp import core_aten_decompositions, decomposition_table
14*da0073e9SAndroid Build Coastguard Workerfrom torch._dispatch.python import enable_python_dispatcher
15*da0073e9SAndroid Build Coastguard Workerfrom torch._ops import DispatchKey
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import tf32_off
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
19*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
20*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
21*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
22*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes,
23*da0073e9SAndroid Build Coastguard Worker    ops,
24*da0073e9SAndroid Build Coastguard Worker)
25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
26*da0073e9SAndroid Build Coastguard Worker    op_db,
27*da0073e9SAndroid Build Coastguard Worker    skip,
28*da0073e9SAndroid Build Coastguard Worker    skipOps,
29*da0073e9SAndroid Build Coastguard Worker    xfail,
30*da0073e9SAndroid Build Coastguard Worker)
31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_modules import module_db, modules
32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
33*da0073e9SAndroid Build Coastguard Worker    is_iterable_of_tensors,
34*da0073e9SAndroid Build Coastguard Worker    run_tests,
35*da0073e9SAndroid Build Coastguard Worker    skipIfCrossRef,
36*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
37*da0073e9SAndroid Build Coastguard Worker    suppress_warnings,
38*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ASAN,
39*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_SLOW,
40*da0073e9SAndroid Build Coastguard Worker    TestCase,
41*da0073e9SAndroid Build Coastguard Worker    unMarkDynamoStrictTest,
42*da0073e9SAndroid Build Coastguard Worker)
43*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode
45*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker# TODO: this isn't going to work with non-aten namespaces
52*da0073e9SAndroid Build Coastguard Workerdef overload_to_aten_name(op):
53*da0073e9SAndroid Build Coastguard Worker    return op._schema.name.split("::")[1]
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker# All operators that can have decomp tests
57*da0073e9SAndroid Build Coastguard Workerdecomposition_names = {
58*da0073e9SAndroid Build Coastguard Worker    overload_to_aten_name(k)
59*da0073e9SAndroid Build Coastguard Worker    for k in decomposition_table
60*da0073e9SAndroid Build Coastguard Worker    if isinstance(k, torch._ops.OpOverload)
61*da0073e9SAndroid Build Coastguard Worker}
62*da0073e9SAndroid Build Coastguard Workercore_decomposition_names = {
63*da0073e9SAndroid Build Coastguard Worker    overload_to_aten_name(k)
64*da0073e9SAndroid Build Coastguard Worker    for k in core_aten_decompositions()
65*da0073e9SAndroid Build Coastguard Worker    if isinstance(k, torch._ops.OpOverload)
66*da0073e9SAndroid Build Coastguard Worker}
67*da0073e9SAndroid Build Coastguard Worker_decomp_test_ops = [
68*da0073e9SAndroid Build Coastguard Worker    op
69*da0073e9SAndroid Build Coastguard Worker    for op in op_db
70*da0073e9SAndroid Build Coastguard Worker    if op.aten_name in decomposition_names
71*da0073e9SAndroid Build Coastguard Worker    or op.aten_backward_name in decomposition_names
72*da0073e9SAndroid Build Coastguard Worker]
73*da0073e9SAndroid Build Coastguard Worker_decomp_test_ops_core_autograd = [
74*da0073e9SAndroid Build Coastguard Worker    op
75*da0073e9SAndroid Build Coastguard Worker    for op in op_db
76*da0073e9SAndroid Build Coastguard Worker    if op.aten_name in core_decomposition_names and op.supports_autograd
77*da0073e9SAndroid Build Coastguard Worker]
78*da0073e9SAndroid Build Coastguard Worker_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name]
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Workerdef diff_arg(arg, requires_grad=True):
82*da0073e9SAndroid Build Coastguard Worker    def is_differentiable_arg(arg):
83*da0073e9SAndroid Build Coastguard Worker        if requires_grad:
84*da0073e9SAndroid Build Coastguard Worker            return arg.requires_grad
85*da0073e9SAndroid Build Coastguard Worker        else:
86*da0073e9SAndroid Build Coastguard Worker            return arg.is_floating_point() or arg.is_complex()
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    if is_iterable_of_tensors(arg):
89*da0073e9SAndroid Build Coastguard Worker        if all(is_differentiable_arg(a) for a in arg):
90*da0073e9SAndroid Build Coastguard Worker            return True
91*da0073e9SAndroid Build Coastguard Worker        if all(not is_differentiable_arg(a) for a in arg):
92*da0073e9SAndroid Build Coastguard Worker            return False
93*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("NYI: The test runner can't handle this")
94*da0073e9SAndroid Build Coastguard Worker    return isinstance(arg, Tensor) and is_differentiable_arg(arg)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker# Version of autograd.grad with some differences:
98*da0073e9SAndroid Build Coastguard Worker#   - pytree inputs is allowed (but leaves of the pytree have to all
99*da0073e9SAndroid Build Coastguard Worker#     be tensors)
100*da0073e9SAndroid Build Coastguard Worker#   - if an input is not used as part of derivatives, we will return a
101*da0073e9SAndroid Build Coastguard Worker#     zero-filled tensor for the result
102*da0073e9SAndroid Build Coastguard Workerdef _autograd_grad(
103*da0073e9SAndroid Build Coastguard Worker    outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
104*da0073e9SAndroid Build Coastguard Worker):
105*da0073e9SAndroid Build Coastguard Worker    inputs, inputs_spec = tree_flatten(inputs)
106*da0073e9SAndroid Build Coastguard Worker    diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
107*da0073e9SAndroid Build Coastguard Worker    if grad_outputs is None:
108*da0073e9SAndroid Build Coastguard Worker        diff_outputs = tuple(out for out in outputs if out.requires_grad)
109*da0073e9SAndroid Build Coastguard Worker    else:
110*da0073e9SAndroid Build Coastguard Worker        diff_grad_outputs = [
111*da0073e9SAndroid Build Coastguard Worker            (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
112*da0073e9SAndroid Build Coastguard Worker        ]
113*da0073e9SAndroid Build Coastguard Worker        if len(diff_grad_outputs) == 0:
114*da0073e9SAndroid Build Coastguard Worker            diff_outputs, grad_outputs = (), ()
115*da0073e9SAndroid Build Coastguard Worker        else:
116*da0073e9SAndroid Build Coastguard Worker            diff_outputs, grad_outputs = zip(*diff_grad_outputs)
117*da0073e9SAndroid Build Coastguard Worker    grad_inputs = torch.autograd.grad(
118*da0073e9SAndroid Build Coastguard Worker        diff_outputs,
119*da0073e9SAndroid Build Coastguard Worker        diff_inputs,
120*da0073e9SAndroid Build Coastguard Worker        grad_outputs,
121*da0073e9SAndroid Build Coastguard Worker        retain_graph=retain_graph,
122*da0073e9SAndroid Build Coastguard Worker        create_graph=create_graph,
123*da0073e9SAndroid Build Coastguard Worker        allow_unused=True,
124*da0073e9SAndroid Build Coastguard Worker    )
125*da0073e9SAndroid Build Coastguard Worker    result = []
126*da0073e9SAndroid Build Coastguard Worker    grad_inputs_iter = iter(grad_inputs)
127*da0073e9SAndroid Build Coastguard Worker    for inp in inputs:
128*da0073e9SAndroid Build Coastguard Worker        if inp.requires_grad:
129*da0073e9SAndroid Build Coastguard Worker            grad_input = next(grad_inputs_iter)
130*da0073e9SAndroid Build Coastguard Worker            if grad_input is None:
131*da0073e9SAndroid Build Coastguard Worker                result.append(torch.zeros_like(inp))
132*da0073e9SAndroid Build Coastguard Worker            else:
133*da0073e9SAndroid Build Coastguard Worker                result.append(grad_input)
134*da0073e9SAndroid Build Coastguard Worker        else:
135*da0073e9SAndroid Build Coastguard Worker            result.append(torch.zeros_like(inp))
136*da0073e9SAndroid Build Coastguard Worker    return tree_unflatten(result, inputs_spec)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Workerdef _as_tuple(val):
140*da0073e9SAndroid Build Coastguard Worker    if isinstance(val, tuple):
141*da0073e9SAndroid Build Coastguard Worker        return val
142*da0073e9SAndroid Build Coastguard Worker    return (val,)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Workerdef ref_vjp_no_create(f, *primals):
146*da0073e9SAndroid Build Coastguard Worker    result = f(*primals)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    def wrapped(cotangents):
149*da0073e9SAndroid Build Coastguard Worker        return _autograd_grad(
150*da0073e9SAndroid Build Coastguard Worker            _as_tuple(result),
151*da0073e9SAndroid Build Coastguard Worker            primals,
152*da0073e9SAndroid Build Coastguard Worker            _as_tuple(cotangents),
153*da0073e9SAndroid Build Coastguard Worker            create_graph=False,
154*da0073e9SAndroid Build Coastguard Worker            retain_graph=True,
155*da0073e9SAndroid Build Coastguard Worker        )
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker    return result, wrapped
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Workerdtype_precisions = {
161*da0073e9SAndroid Build Coastguard Worker    torch.float16: (0.001, 1e-5),
162*da0073e9SAndroid Build Coastguard Worker    torch.bfloat16: (0.016, 1e-4),
163*da0073e9SAndroid Build Coastguard Worker    torch.float32: (1.3e-6, 1e-5),
164*da0073e9SAndroid Build Coastguard Worker    torch.float64: (1e-7, 1e-7),
165*da0073e9SAndroid Build Coastguard Worker    torch.complex32: (0.001, 1e-5),
166*da0073e9SAndroid Build Coastguard Worker    torch.complex64: (1.3e-6, 1e-5),
167*da0073e9SAndroid Build Coastguard Worker    torch.complex128: (1e-7, 1e-7),
168*da0073e9SAndroid Build Coastguard Worker}
169*da0073e9SAndroid Build Coastguard Worker# Returns the "default" rtol and atol for comparing scalars or
170*da0073e9SAndroid Build Coastguard Worker# tensors of the given dtypes.
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Workerdef _getDefaultRtolAndAtol(dtype0, dtype1):
174*da0073e9SAndroid Build Coastguard Worker    rtol = max(
175*da0073e9SAndroid Build Coastguard Worker        dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0]
176*da0073e9SAndroid Build Coastguard Worker    )
177*da0073e9SAndroid Build Coastguard Worker    atol = max(
178*da0073e9SAndroid Build Coastguard Worker        dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1]
179*da0073e9SAndroid Build Coastguard Worker    )
180*da0073e9SAndroid Build Coastguard Worker    return rtol, atol
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Workerdef op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
184*da0073e9SAndroid Build Coastguard Worker    assert orig.dtype == decomp.dtype, f"{i} Operation:  {op}"
185*da0073e9SAndroid Build Coastguard Worker    if orig.numel() == 0 or decomp.numel() == 0:
186*da0073e9SAndroid Build Coastguard Worker        assert orig.numel() == decomp.numel()
187*da0073e9SAndroid Build Coastguard Worker        return
188*da0073e9SAndroid Build Coastguard Worker    assert orig.shape == decomp.shape, f"{i} Operation:  {op}"
189*da0073e9SAndroid Build Coastguard Worker    tol_table = {
190*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
191*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
192*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.native_layer_norm_backward.default): 1e-3,
193*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2,
194*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
195*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
196*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
197*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
198*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
199*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
200*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
201*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
202*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.var_mean.correction): 5e-7,
203*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.var_mean.correction): 5e-7,
204*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.var_mean.dim): 5e-7,
205*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.var_mean.dim): 5e-7,
206*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2,
207*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1,
208*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2,
209*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1,
210*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.hardswish.default): 2e-7,
211*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7,
212*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2,
213*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 5e-2,
214*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
215*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
216*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3,
217*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3,
218*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
219*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
220*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3,
221*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2,
222*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/pull/96264
223*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.mv.default): 1e-5,
224*da0073e9SAndroid Build Coastguard Worker        (torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
225*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5,
226*da0073e9SAndroid Build Coastguard Worker        (torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7,
227*da0073e9SAndroid Build Coastguard Worker    }
228*da0073e9SAndroid Build Coastguard Worker    if ref.is_floating_point():
229*da0073e9SAndroid Build Coastguard Worker        orig_diff = (orig - ref).abs().max()
230*da0073e9SAndroid Build Coastguard Worker        decomp_diff = (decomp - ref).abs().max()
231*da0073e9SAndroid Build Coastguard Worker        atol = tol_table.get((test_dtype, op), 1e-7)
232*da0073e9SAndroid Build Coastguard Worker        if decomp_diff > orig_diff + atol:
233*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
234*da0073e9SAndroid Build Coastguard Worker                f"Difference from float64 is larger with decomposition {op.__name__}"
235*da0073e9SAndroid Build Coastguard Worker                f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
236*da0073e9SAndroid Build Coastguard Worker                f"atol = {atol}\n"
237*da0073e9SAndroid Build Coastguard Worker                f"args = {args}\n"
238*da0073e9SAndroid Build Coastguard Worker                f"kwargs = {kwargs}"
239*da0073e9SAndroid Build Coastguard Worker            )
240*da0073e9SAndroid Build Coastguard Worker    else:
241*da0073e9SAndroid Build Coastguard Worker        test_case.assertEqual(
242*da0073e9SAndroid Build Coastguard Worker            orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}"
243*da0073e9SAndroid Build Coastguard Worker        )
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Workerdef op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
247*da0073e9SAndroid Build Coastguard Worker    test_case.assertEqual(
248*da0073e9SAndroid Build Coastguard Worker        orig.dtype,
249*da0073e9SAndroid Build Coastguard Worker        decomp.dtype,
250*da0073e9SAndroid Build Coastguard Worker        f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}",
251*da0073e9SAndroid Build Coastguard Worker    )
252*da0073e9SAndroid Build Coastguard Worker    # Before adding an entry to this table, make sure your decomposition is right :)
253*da0073e9SAndroid Build Coastguard Worker    tol_table = {
254*da0073e9SAndroid Build Coastguard Worker        # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
255*da0073e9SAndroid Build Coastguard Worker        (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3),
256*da0073e9SAndroid Build Coastguard Worker        (torch.float32, torch.ops.aten.native_layer_norm_backward.default): (
257*da0073e9SAndroid Build Coastguard Worker            1e-3,
258*da0073e9SAndroid Build Coastguard Worker            1e-3,
259*da0073e9SAndroid Build Coastguard Worker        ),
260*da0073e9SAndroid Build Coastguard Worker        (torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
261*da0073e9SAndroid Build Coastguard Worker        # This exceeds default tolerances only on CPU, on CUDA it's fine
262*da0073e9SAndroid Build Coastguard Worker        (torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5),
263*da0073e9SAndroid Build Coastguard Worker        # Exceeds tolerances on CUDA, likely due to fma
264*da0073e9SAndroid Build Coastguard Worker        (torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5),
265*da0073e9SAndroid Build Coastguard Worker        (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5),
266*da0073e9SAndroid Build Coastguard Worker        (torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4),
267*da0073e9SAndroid Build Coastguard Worker        (torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4),
268*da0073e9SAndroid Build Coastguard Worker        # The decomposition is TOO correct. It computes everything in int64, so sometimes
269*da0073e9SAndroid Build Coastguard Worker        # there's an off-by-one error. See
270*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/81996
271*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/82230
272*da0073e9SAndroid Build Coastguard Worker        (torch.int8, torch.ops.aten.linspace.default): (0, 1),
273*da0073e9SAndroid Build Coastguard Worker        (torch.uint8, torch.ops.aten.linspace.default): (0, 1),
274*da0073e9SAndroid Build Coastguard Worker        (torch.int16, torch.ops.aten.linspace.default): (0, 1),
275*da0073e9SAndroid Build Coastguard Worker        (torch.int32, torch.ops.aten.linspace.default): (0, 1),
276*da0073e9SAndroid Build Coastguard Worker        (torch.int64, torch.ops.aten.linspace.default): (0, 1),
277*da0073e9SAndroid Build Coastguard Worker        (torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
278*da0073e9SAndroid Build Coastguard Worker        (torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
279*da0073e9SAndroid Build Coastguard Worker        (torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
280*da0073e9SAndroid Build Coastguard Worker        (torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
281*da0073e9SAndroid Build Coastguard Worker        (torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
282*da0073e9SAndroid Build Coastguard Worker        (torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
283*da0073e9SAndroid Build Coastguard Worker        (torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
284*da0073e9SAndroid Build Coastguard Worker        (torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
285*da0073e9SAndroid Build Coastguard Worker        (torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
286*da0073e9SAndroid Build Coastguard Worker        (torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
287*da0073e9SAndroid Build Coastguard Worker        (torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
288*da0073e9SAndroid Build Coastguard Worker        (torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
289*da0073e9SAndroid Build Coastguard Worker        (torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
290*da0073e9SAndroid Build Coastguard Worker        (torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
291*da0073e9SAndroid Build Coastguard Worker        (torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
292*da0073e9SAndroid Build Coastguard Worker    }
293*da0073e9SAndroid Build Coastguard Worker    if (decomp.dtype, op) in tol_table:
294*da0073e9SAndroid Build Coastguard Worker        rtol, atol = tol_table[(decomp.dtype, op)]
295*da0073e9SAndroid Build Coastguard Worker    else:
296*da0073e9SAndroid Build Coastguard Worker        rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
297*da0073e9SAndroid Build Coastguard Worker    test_case.assertEqual(
298*da0073e9SAndroid Build Coastguard Worker        orig,
299*da0073e9SAndroid Build Coastguard Worker        decomp,
300*da0073e9SAndroid Build Coastguard Worker        rtol=rtol,
301*da0073e9SAndroid Build Coastguard Worker        atol=atol,
302*da0073e9SAndroid Build Coastguard Worker        msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}",
303*da0073e9SAndroid Build Coastguard Worker    )
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker# Given f, returns an f' such that:
307*da0073e9SAndroid Build Coastguard Worker# - f' takes only positional arguments
308*da0073e9SAndroid Build Coastguard Worker# - All arguments to f' are floating-point Tensors
309*da0073e9SAndroid Build Coastguard Worker# - All outputs of f' are floating-point Tensors
310*da0073e9SAndroid Build Coastguard Workerdef normalize_op_input_output2(
311*da0073e9SAndroid Build Coastguard Worker    f, args, kwargs, output_process_fn_grad=None, requires_grad=True
312*da0073e9SAndroid Build Coastguard Worker):
313*da0073e9SAndroid Build Coastguard Worker    flat_args, args_spec = tree_flatten(args)
314*da0073e9SAndroid Build Coastguard Worker    diff_argnums = tuple(
315*da0073e9SAndroid Build Coastguard Worker        i
316*da0073e9SAndroid Build Coastguard Worker        for i, arg in enumerate(flat_args)
317*da0073e9SAndroid Build Coastguard Worker        if diff_arg(arg, requires_grad=requires_grad)
318*da0073e9SAndroid Build Coastguard Worker    )
319*da0073e9SAndroid Build Coastguard Worker    assert len(diff_argnums) > 0
320*da0073e9SAndroid Build Coastguard Worker    primals = tuple(flat_args[i] for i in diff_argnums)
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(f)
323*da0073e9SAndroid Build Coastguard Worker    def wrapped(*primals):
324*da0073e9SAndroid Build Coastguard Worker        _args = list(flat_args)
325*da0073e9SAndroid Build Coastguard Worker        for num, arg in zip(diff_argnums, primals):
326*da0073e9SAndroid Build Coastguard Worker            _args[num] = arg
327*da0073e9SAndroid Build Coastguard Worker        _args = tree_unflatten(_args, args_spec)
328*da0073e9SAndroid Build Coastguard Worker        result = f(*_args, **kwargs)
329*da0073e9SAndroid Build Coastguard Worker        if output_process_fn_grad is not None:
330*da0073e9SAndroid Build Coastguard Worker            result = output_process_fn_grad(result)
331*da0073e9SAndroid Build Coastguard Worker        if isinstance(result, tuple):
332*da0073e9SAndroid Build Coastguard Worker            # TODO We should check that the integer outputs also agree
333*da0073e9SAndroid Build Coastguard Worker            result = tuple(
334*da0073e9SAndroid Build Coastguard Worker                r
335*da0073e9SAndroid Build Coastguard Worker                for r in result
336*da0073e9SAndroid Build Coastguard Worker                if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex())
337*da0073e9SAndroid Build Coastguard Worker            )
338*da0073e9SAndroid Build Coastguard Worker            assert len(result) > 0
339*da0073e9SAndroid Build Coastguard Worker        return result
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker    return wrapped, primals
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker# NB: This also upcasts dtype arguments
345*da0073e9SAndroid Build Coastguard Worker# TODO: handle complex correctly
346*da0073e9SAndroid Build Coastguard Workerdef upcast_tensor(x, dtype=torch.float32):
347*da0073e9SAndroid Build Coastguard Worker    if isinstance(x, Tensor) and x.dtype.is_floating_point:
348*da0073e9SAndroid Build Coastguard Worker        return x.to(dtype=dtype)
349*da0073e9SAndroid Build Coastguard Worker    elif isinstance(x, torch.dtype) and x in [
350*da0073e9SAndroid Build Coastguard Worker        torch.float16,
351*da0073e9SAndroid Build Coastguard Worker        torch.bfloat16,
352*da0073e9SAndroid Build Coastguard Worker        torch.float,
353*da0073e9SAndroid Build Coastguard Worker    ]:
354*da0073e9SAndroid Build Coastguard Worker        return dtype
355*da0073e9SAndroid Build Coastguard Worker    else:
356*da0073e9SAndroid Build Coastguard Worker        return x
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Workerdef normalize_op_input_output(f, sample, requires_grad=True):
360*da0073e9SAndroid Build Coastguard Worker    args = tuple([sample.input] + list(sample.args))
361*da0073e9SAndroid Build Coastguard Worker    return normalize_op_input_output2(
362*da0073e9SAndroid Build Coastguard Worker        f,
363*da0073e9SAndroid Build Coastguard Worker        args,
364*da0073e9SAndroid Build Coastguard Worker        sample.kwargs,
365*da0073e9SAndroid Build Coastguard Worker        sample.output_process_fn_grad,
366*da0073e9SAndroid Build Coastguard Worker        requires_grad=requires_grad,
367*da0073e9SAndroid Build Coastguard Worker    )
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard WorkerCROSS_REF_EXCLUDE_SET = {
371*da0073e9SAndroid Build Coastguard Worker    # CUBLAS_STATUS_NOT_SUPPORTED when calling
372*da0073e9SAndroid Build Coastguard Worker    # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k,
373*da0073e9SAndroid Build Coastguard Worker    # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF,
374*da0073e9SAndroid Build Coastguard Worker    # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
375*da0073e9SAndroid Build Coastguard Worker    # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
376*da0073e9SAndroid Build Coastguard Worker    ("cuda", torch.bfloat16, "nn.functional.bilinear"),
377*da0073e9SAndroid Build Coastguard Worker    # randomness
378*da0073e9SAndroid Build Coastguard Worker    (None, None, "special.ndtr"),  # aten.special_ndtr was not decomposed
379*da0073e9SAndroid Build Coastguard Worker    (None, None, "new_empty"),
380*da0073e9SAndroid Build Coastguard Worker    (None, None, "empty_like"),
381*da0073e9SAndroid Build Coastguard Worker    (None, None, "empty"),
382*da0073e9SAndroid Build Coastguard Worker    # AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default.
383*da0073e9SAndroid Build Coastguard Worker    (None, None, "item"),
384*da0073e9SAndroid Build Coastguard Worker    # It's the only in-place op without an out-of-place equivalent in the Python API
385*da0073e9SAndroid Build Coastguard Worker    # Its OpInfo wrongly registers it as `torch.zero_(x.clone())`.
386*da0073e9SAndroid Build Coastguard Worker    (None, None, "zero_"),
387*da0073e9SAndroid Build Coastguard Worker    # No idea what's going on here
388*da0073e9SAndroid Build Coastguard Worker    # In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), [])
389*da0073e9SAndroid Build Coastguard Worker    # in the test, but it seems to pass when tested locally and in the logsumexp test
390*da0073e9SAndroid Build Coastguard Worker    (None, torch.float32, "masked.logsumexp"),
391*da0073e9SAndroid Build Coastguard Worker    (None, torch.float64, "masked.logsumexp"),
392*da0073e9SAndroid Build Coastguard Worker    # exp_vml_cpu not implemented for Half
393*da0073e9SAndroid Build Coastguard Worker    (torch.cpu, torch.float16, "signal.windows.exponential"),
394*da0073e9SAndroid Build Coastguard Worker    (torch.cpu, torch.float16, "signal.windows.gaussian"),
395*da0073e9SAndroid Build Coastguard Worker    # sin_vml_cpu not implemented for Half
396*da0073e9SAndroid Build Coastguard Worker    (torch.cpu, torch.float16, "signal.windows.cosine"),
397*da0073e9SAndroid Build Coastguard Worker    # CompositeAutogradImplicit
398*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/81669
399*da0073e9SAndroid Build Coastguard Worker    (None, None, "nn.functional.relu6"),
400*da0073e9SAndroid Build Coastguard Worker    # This decomp runs before autograd.
401*da0073e9SAndroid Build Coastguard Worker    (None, None, "nn.functional.rrelu"),
402*da0073e9SAndroid Build Coastguard Worker    (None, None, "meshgrid"),
403*da0073e9SAndroid Build Coastguard Worker    # Decomposition registered as Autograd
404*da0073e9SAndroid Build Coastguard Worker    (None, None, "nn.functional.hardshrink"),
405*da0073e9SAndroid Build Coastguard Worker    (None, None, "nn.functional.softshrink"),
406*da0073e9SAndroid Build Coastguard Worker    # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit)
407*da0073e9SAndroid Build Coastguard Worker    (None, None, "diag"),
408*da0073e9SAndroid Build Coastguard Worker    # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32
409*da0073e9SAndroid Build Coastguard Worker    ("cpu", torch.bfloat16, "_softmax_backward_data"),
410*da0073e9SAndroid Build Coastguard Worker    (None, None, "norm"),
411*da0073e9SAndroid Build Coastguard Worker    # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise)
412*da0073e9SAndroid Build Coastguard Worker    (None, None, "native_batch_norm"),
413*da0073e9SAndroid Build Coastguard Worker    (None, None, "_upsample_bilinear2d_aa"),
414*da0073e9SAndroid Build Coastguard Worker    (None, None, "empty_strided"),  # aten.empty_strided was not decomposed
415*da0073e9SAndroid Build Coastguard Worker}
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard WorkerCROSS_REF_BACKWARD_EXCLUDE_SET = {
418*da0073e9SAndroid Build Coastguard Worker    # Decomposed backward formula is not as precise
419*da0073e9SAndroid Build Coastguard Worker    ("cpu", torch.bfloat16, "nn.functional.hardswish"),
420*da0073e9SAndroid Build Coastguard Worker    ("cuda", torch.float16, "nn.functional.cross_entropy"),
421*da0073e9SAndroid Build Coastguard Worker}
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Workerall_decomposed = set()
424*da0073e9SAndroid Build Coastguard Workerall_called = defaultdict(int)
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker# Helpful snippet for testing coverage
427*da0073e9SAndroid Build Coastguard Worker"""
428*da0073e9SAndroid Build Coastguard Workerimport atexit
429*da0073e9SAndroid Build Coastguard Workerdef check_coverage():
430*da0073e9SAndroid Build Coastguard Worker    print("missing coverage:")
431*da0073e9SAndroid Build Coastguard Worker    print("\n".join(map(str, decomposition_table.keys() - all_decomposed)))
432*da0073e9SAndroid Build Coastguard Workeratexit.register(check_coverage)
433*da0073e9SAndroid Build Coastguard Worker"""
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker# Helpful snippet for Horace to create his google sheet :)
436*da0073e9SAndroid Build Coastguard Worker"""
437*da0073e9SAndroid Build Coastguard Workerimport atexit
438*da0073e9SAndroid Build Coastguard Workerdef dump_ops():
439*da0073e9SAndroid Build Coastguard Worker    with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g:
440*da0073e9SAndroid Build Coastguard Worker        for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__):
441*da0073e9SAndroid Build Coastguard Worker            f.write(f'{op.__name__}\n')
442*da0073e9SAndroid Build Coastguard Worker            g.write(f'{count}\n')
443*da0073e9SAndroid Build Coastguard Worker    with open('run_decompositions.txt', 'w') as f:
444*da0073e9SAndroid Build Coastguard Worker        for op in sorted([i.__name__ for i in all_decomposed]):
445*da0073e9SAndroid Build Coastguard Worker            f.write(f'{op}\n')
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Workeratexit.register(dump_ops)
448*da0073e9SAndroid Build Coastguard Worker"""
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Workerdef any_unsupported(args, kwargs):
452*da0073e9SAndroid Build Coastguard Worker    def test_unsupported(t):
453*da0073e9SAndroid Build Coastguard Worker        if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
454*da0073e9SAndroid Build Coastguard Worker            # These are all things that we haven't coded decompositions
455*da0073e9SAndroid Build Coastguard Worker            # to handle correctly.  Maybe they should.
456*da0073e9SAndroid Build Coastguard Worker            return any(
457*da0073e9SAndroid Build Coastguard Worker                [
458*da0073e9SAndroid Build Coastguard Worker                    t.is_sparse_csr,
459*da0073e9SAndroid Build Coastguard Worker                    t.is_sparse,
460*da0073e9SAndroid Build Coastguard Worker                    t.is_mkldnn,
461*da0073e9SAndroid Build Coastguard Worker                    t.is_quantized,
462*da0073e9SAndroid Build Coastguard Worker                    t.is_nested,
463*da0073e9SAndroid Build Coastguard Worker                    torch._is_functional_tensor(t),
464*da0073e9SAndroid Build Coastguard Worker                ]
465*da0073e9SAndroid Build Coastguard Worker            )
466*da0073e9SAndroid Build Coastguard Worker        elif torch.overrides.is_tensor_like(t):
467*da0073e9SAndroid Build Coastguard Worker            # Decompositions will generally change the behavior of Tensor-like
468*da0073e9SAndroid Build Coastguard Worker            # subclasses, so bypass tests in this case too
469*da0073e9SAndroid Build Coastguard Worker            return True
470*da0073e9SAndroid Build Coastguard Worker        else:
471*da0073e9SAndroid Build Coastguard Worker            return False
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker    flat_args = pytree.arg_tree_leaves(*args, **kwargs)
474*da0073e9SAndroid Build Coastguard Worker    return any(test_unsupported(x) for x in flat_args)
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Workercore_backward_failures = {
478*da0073e9SAndroid Build Coastguard Worker    skip("_softmax_backward_data"),  # slow: fails with --timeout=360 secs
479*da0073e9SAndroid Build Coastguard Worker    xfail("addcdiv"),
480*da0073e9SAndroid Build Coastguard Worker    skip("addcmul"),  # slow: fails with --timeout=360 secs
481*da0073e9SAndroid Build Coastguard Worker    skip("deg2rad"),  # slow: fails with --timeout=360 secs
482*da0073e9SAndroid Build Coastguard Worker    skip("diag_embed"),  # slow: fails with --timeout=360 secs
483*da0073e9SAndroid Build Coastguard Worker    skip("frac"),  # slow: fails with --timeout=360 secs
484*da0073e9SAndroid Build Coastguard Worker    skip("grid_sampler_2d"),  # slow: fails with --timeout=360 secs
485*da0073e9SAndroid Build Coastguard Worker    xfail("lerp"),
486*da0073e9SAndroid Build Coastguard Worker    skip("logaddexp"),  # slow: fails with --timeout=360 secs
487*da0073e9SAndroid Build Coastguard Worker    skip("native_dropout_backward"),  # slow: fails with --timeout=360 secs
488*da0073e9SAndroid Build Coastguard Worker    xfail("nn.functional.binary_cross_entropy_with_logits"),
489*da0073e9SAndroid Build Coastguard Worker    skip("nn.functional.glu"),  # slow: fails with --timeout=360 secs
490*da0073e9SAndroid Build Coastguard Worker    xfail("nn.functional.hardshrink"),
491*da0073e9SAndroid Build Coastguard Worker    xfail("nn.functional.softshrink"),
492*da0073e9SAndroid Build Coastguard Worker    skip("nn.functional.unfold"),  # slow: fails with --timeout=360 secs
493*da0073e9SAndroid Build Coastguard Worker    xfail("norm"),
494*da0073e9SAndroid Build Coastguard Worker    xfail("norm", "fro"),
495*da0073e9SAndroid Build Coastguard Worker    xfail("norm", "inf"),
496*da0073e9SAndroid Build Coastguard Worker    xfail("norm", "nuc"),
497*da0073e9SAndroid Build Coastguard Worker    skip("rad2deg"),  # slow: fails with --timeout=360 secs
498*da0073e9SAndroid Build Coastguard Worker    skip("renorm"),  # slow: fails with --timeout=360 secs
499*da0073e9SAndroid Build Coastguard Worker    skip("rot90"),  # slow: fails with --timeout=360 secs
500*da0073e9SAndroid Build Coastguard Worker    skip("rsub"),  # slow: fails with --timeout=360 secs
501*da0073e9SAndroid Build Coastguard Worker    skip("sgn"),  # slow: fails with --timeout=360 secs
502*da0073e9SAndroid Build Coastguard Worker    skip("special.xlog1py"),  # slow: fails with --timeout=360 secs
503*da0073e9SAndroid Build Coastguard Worker    xfail("stack"),
504*da0073e9SAndroid Build Coastguard Worker    skip("tril"),  # slow: fails with --timeout=360 secs
505*da0073e9SAndroid Build Coastguard Worker    skip("triu"),  # slow: fails with --timeout=360 secs
506*da0073e9SAndroid Build Coastguard Worker    skip("unfold_copy"),  # slow: fails with --timeout=360 secs
507*da0073e9SAndroid Build Coastguard Worker    skip("xlogy"),  # slow: fails with --timeout=360 secs
508*da0073e9SAndroid Build Coastguard Worker    xfail("zero_"),
509*da0073e9SAndroid Build Coastguard Worker}
510*da0073e9SAndroid Build Coastguard Workerif not TEST_WITH_SLOW:
511*da0073e9SAndroid Build Coastguard Worker    core_backward_failures.update(
512*da0073e9SAndroid Build Coastguard Worker        {
513*da0073e9SAndroid Build Coastguard Worker            skip("addr"),  # slow: takes 46 sec on A100
514*da0073e9SAndroid Build Coastguard Worker            skip("baddbmm"),  # slow: takes 800+ sec on A100
515*da0073e9SAndroid Build Coastguard Worker            skip("clamp_min"),  # slow: takes 800 sec on A100
516*da0073e9SAndroid Build Coastguard Worker            skip("clamp_max"),  # slow: takes 800 sec on A100
517*da0073e9SAndroid Build Coastguard Worker            skip("logit"),  # slow: takes 44 sec on A100
518*da0073e9SAndroid Build Coastguard Worker            skip("nn.functional.hardswish"),  # slow: takes 60 sec on A100
519*da0073e9SAndroid Build Coastguard Worker            skip("std_mean"),  # slow: takes 170 sec on A100
520*da0073e9SAndroid Build Coastguard Worker            skip("split", variant_name="list_args"),  # slow: takes 118 sec on A100
521*da0073e9SAndroid Build Coastguard Worker            skip("transpose"),  # slow: takes 50 sec on A100
522*da0073e9SAndroid Build Coastguard Worker            skip("unbind"),  # slow: takes 70 sec on A100
523*da0073e9SAndroid Build Coastguard Worker            skip("unsafe_split"),  # slow: takes 49 sec on A100
524*da0073e9SAndroid Build Coastguard Worker        }
525*da0073e9SAndroid Build Coastguard Worker    )
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Workercomprehensive_failures = {
528*da0073e9SAndroid Build Coastguard Worker    xfail(
529*da0073e9SAndroid Build Coastguard Worker        "nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)
530*da0073e9SAndroid Build Coastguard Worker    ),  # off by one error
531*da0073e9SAndroid Build Coastguard Worker    xfail(
532*da0073e9SAndroid Build Coastguard Worker        "nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)
533*da0073e9SAndroid Build Coastguard Worker    ),  # off by one error
534*da0073e9SAndroid Build Coastguard Worker    xfail(
535*da0073e9SAndroid Build Coastguard Worker        "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
536*da0073e9SAndroid Build Coastguard Worker    ),  # off by one error
537*da0073e9SAndroid Build Coastguard Worker}
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest
541*da0073e9SAndroid Build Coastguard Workerclass TestDecomp(TestCase):
542*da0073e9SAndroid Build Coastguard Worker    longMessage = True
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker    # NB: This actually overlaps with test_comprehensive, but it only
545*da0073e9SAndroid Build Coastguard Worker    # runs on things that are definitely decomposed so it's a lot faster
546*da0073e9SAndroid Build Coastguard Worker    # to run
547*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
548*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
549*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
550*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
551*da0073e9SAndroid Build Coastguard Worker    @ops(_decomp_test_ops)
552*da0073e9SAndroid Build Coastguard Worker    def test_quick(self, device, dtype, op):
553*da0073e9SAndroid Build Coastguard Worker        self.do_cross_ref(device, dtype, op, run_all=False)
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
556*da0073e9SAndroid Build Coastguard Worker    @skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures)
557*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
558*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
559*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
560*da0073e9SAndroid Build Coastguard Worker    @ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,))
561*da0073e9SAndroid Build Coastguard Worker    def test_quick_core_backward(self, device, dtype, op):
562*da0073e9SAndroid Build Coastguard Worker        for sample_input in op.sample_inputs(device, dtype, requires_grad=True):
563*da0073e9SAndroid Build Coastguard Worker            aten_name = op.decomp_aten_name or op.aten_name
564*da0073e9SAndroid Build Coastguard Worker            args = [sample_input.input] + list(sample_input.args)
565*da0073e9SAndroid Build Coastguard Worker            kwargs = sample_input.kwargs
566*da0073e9SAndroid Build Coastguard Worker            func = partial(op.get_op(), **kwargs)
567*da0073e9SAndroid Build Coastguard Worker            with self.DecompCrossRefMode(
568*da0073e9SAndroid Build Coastguard Worker                self, self.precision, self.rel_tol, dtype, run_all=False
569*da0073e9SAndroid Build Coastguard Worker            ) as mode, enable_python_dispatcher():
570*da0073e9SAndroid Build Coastguard Worker                torch.autograd.gradcheck(func, args)
571*da0073e9SAndroid Build Coastguard Worker            self.check_decomposed(aten_name, mode)
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
574*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
575*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
576*da0073e9SAndroid Build Coastguard Worker    @skipOps("TestDecomp", "test_comprehensive", comprehensive_failures)
577*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
578*da0073e9SAndroid Build Coastguard Worker    @ops(op_db)
579*da0073e9SAndroid Build Coastguard Worker    def test_comprehensive(self, device, dtype, op):
580*da0073e9SAndroid Build Coastguard Worker        self.do_cross_ref(device, dtype, op, run_all=True)
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker    def test_uniform(self, device):
583*da0073e9SAndroid Build Coastguard Worker        size = (2, 3, 4, 5)
584*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
585*da0073e9SAndroid Build Coastguard Worker        x = make_tensor(size, dtype=dtype, device=device)
586*da0073e9SAndroid Build Coastguard Worker        low = 0.3
587*da0073e9SAndroid Build Coastguard Worker        high = 0.9
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
590*da0073e9SAndroid Build Coastguard Worker        ref = torch.ops.aten.uniform(x, low, high)
591*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
592*da0073e9SAndroid Build Coastguard Worker        res = torch._decomp.decompositions.uniform(x, low=low, high=high)
593*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker    def test_broadcasting_index_copy(self, device):
596*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros([1, 10], device=device)
597*da0073e9SAndroid Build Coastguard Worker        xs = torch.ones([2, 10], device=device)
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker        def index_copy(xs, x):
600*da0073e9SAndroid Build Coastguard Worker            torch._decomp.decompositions.index_copy_(
601*da0073e9SAndroid Build Coastguard Worker                xs, 0, torch.tensor(0).to(device), x
602*da0073e9SAndroid Build Coastguard Worker            )
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker        index_copy(xs, x)
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker        xs_two = torch.ones([2, 10], device=device)
607*da0073e9SAndroid Build Coastguard Worker        xs_two[0] = x
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(xs, xs_two)
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker    def test_cat_single_input(self, device):
612*da0073e9SAndroid Build Coastguard Worker        decomp_table = torch._inductor.decomposition.select_decomp_table()
613*da0073e9SAndroid Build Coastguard Worker        cat_inductor = decomp_table[torch.ops.aten.cat.default]
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand([2048, 2048], device=device)
616*da0073e9SAndroid Build Coastguard Worker        inps = [inp for _ in range(10)]
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker        for dim in (-1, 0, 1):
619*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker    def test_rrelu_with_noise(self, device):
622*da0073e9SAndroid Build Coastguard Worker        # rrelu_with_noise behavior depends on a) whether elements in the input
623*da0073e9SAndroid Build Coastguard Worker        # are <= 0, and b) whether we're in training mode. Cover all cases:
624*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float64
625*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device)
626*da0073e9SAndroid Build Coastguard Worker        lower = 1.0
627*da0073e9SAndroid Build Coastguard Worker        upper = 4.0
628*da0073e9SAndroid Build Coastguard Worker        training = False
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
631*da0073e9SAndroid Build Coastguard Worker        noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
632*da0073e9SAndroid Build Coastguard Worker        ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
635*da0073e9SAndroid Build Coastguard Worker        noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
636*da0073e9SAndroid Build Coastguard Worker        res = torch._decomp.decompositions.rrelu_with_noise(
637*da0073e9SAndroid Build Coastguard Worker            x,
638*da0073e9SAndroid Build Coastguard Worker            noise_res,
639*da0073e9SAndroid Build Coastguard Worker            lower,
640*da0073e9SAndroid Build Coastguard Worker            upper,
641*da0073e9SAndroid Build Coastguard Worker            training,
642*da0073e9SAndroid Build Coastguard Worker        )
643*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(noise_ref, noise_res)
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker        # Now with training=True:
647*da0073e9SAndroid Build Coastguard Worker        training = True
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
650*da0073e9SAndroid Build Coastguard Worker        noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
651*da0073e9SAndroid Build Coastguard Worker        ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(123)
654*da0073e9SAndroid Build Coastguard Worker        noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
655*da0073e9SAndroid Build Coastguard Worker        res = torch._decomp.decompositions.rrelu_with_noise(
656*da0073e9SAndroid Build Coastguard Worker            x,
657*da0073e9SAndroid Build Coastguard Worker            noise_res,
658*da0073e9SAndroid Build Coastguard Worker            lower,
659*da0073e9SAndroid Build Coastguard Worker            upper,
660*da0073e9SAndroid Build Coastguard Worker            training,
661*da0073e9SAndroid Build Coastguard Worker        )
662*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
663*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(noise_ref, noise_res)
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
666*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
667*da0073e9SAndroid Build Coastguard Worker    @tf32_off()
668*da0073e9SAndroid Build Coastguard Worker    # only tests RNNs since we have py dispsatcher decomps for them
669*da0073e9SAndroid Build Coastguard Worker    @modules(
670*da0073e9SAndroid Build Coastguard Worker        filter(
671*da0073e9SAndroid Build Coastguard Worker            lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
672*da0073e9SAndroid Build Coastguard Worker            module_db,
673*da0073e9SAndroid Build Coastguard Worker        )
674*da0073e9SAndroid Build Coastguard Worker    )
675*da0073e9SAndroid Build Coastguard Worker    def test_rnn_decomp_module(self, device, dtype, module_info, training):
676*da0073e9SAndroid Build Coastguard Worker        module_cls = module_info.module_cls
677*da0073e9SAndroid Build Coastguard Worker        module_inputs = module_info.module_inputs_func(
678*da0073e9SAndroid Build Coastguard Worker            module_info,
679*da0073e9SAndroid Build Coastguard Worker            device=device,
680*da0073e9SAndroid Build Coastguard Worker            dtype=dtype,
681*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
682*da0073e9SAndroid Build Coastguard Worker            training=training,
683*da0073e9SAndroid Build Coastguard Worker        )
684*da0073e9SAndroid Build Coastguard Worker        for module_input in module_inputs:
685*da0073e9SAndroid Build Coastguard Worker            if module_input.forward_input is None:
686*da0073e9SAndroid Build Coastguard Worker                continue
687*da0073e9SAndroid Build Coastguard Worker            args, kwargs = (
688*da0073e9SAndroid Build Coastguard Worker                module_input.constructor_input.args,
689*da0073e9SAndroid Build Coastguard Worker                module_input.constructor_input.kwargs,
690*da0073e9SAndroid Build Coastguard Worker            )
691*da0073e9SAndroid Build Coastguard Worker            m = module_cls(*args, **kwargs)
692*da0073e9SAndroid Build Coastguard Worker            m.to(device).to(dtype)
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker            args, kwargs = (
695*da0073e9SAndroid Build Coastguard Worker                module_input.forward_input.args,
696*da0073e9SAndroid Build Coastguard Worker                module_input.forward_input.kwargs,
697*da0073e9SAndroid Build Coastguard Worker            )
698*da0073e9SAndroid Build Coastguard Worker            with self.DecompCrossRefMode(
699*da0073e9SAndroid Build Coastguard Worker                self, self.precision, self.rel_tol, dtype, run_all=True
700*da0073e9SAndroid Build Coastguard Worker            ), enable_python_dispatcher():
701*da0073e9SAndroid Build Coastguard Worker                decomp_out = m(*args, **kwargs)
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker            non_decomp_out = m(*args, **kwargs)
704*da0073e9SAndroid Build Coastguard Worker            # without this check, incorrect decomps at the python dispatcher level can still pass because
705*da0073e9SAndroid Build Coastguard Worker            # they're checking aten decomps at the torch_dispatch level
706*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(decomp_out, non_decomp_out)
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_unflatten_weight_bias(self, device):
709*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/100970
710*da0073e9SAndroid Build Coastguard Worker        shape = (1, 3, 2, 2)
711*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(shape, device=device)
712*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn((3, 1, 1, 1), device=device)
713*da0073e9SAndroid Build Coastguard Worker        bias = torch.randn(3, device=device)
714*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn(3, device=device)
715*da0073e9SAndroid Build Coastguard Worker        var = torch.randn(3, device=device)
716*da0073e9SAndroid Build Coastguard Worker        res = torch._decomp.decompositions.native_batch_norm(
717*da0073e9SAndroid Build Coastguard Worker            input, weight, bias, mean, var, False, 1, 1e-05
718*da0073e9SAndroid Build Coastguard Worker        )
719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(shape, res[0].shape)
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker    def test_arange_graph(self, device):
722*da0073e9SAndroid Build Coastguard Worker        from torch.fx.experimental.proxy_tensor import make_fx
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker        def func(x, start):
725*da0073e9SAndroid Build Coastguard Worker            le = x.shape[-1]
726*da0073e9SAndroid Build Coastguard Worker            if start is None:
727*da0073e9SAndroid Build Coastguard Worker                a = torch.arange(le, dtype=torch.float32, device=x.device)
728*da0073e9SAndroid Build Coastguard Worker            else:
729*da0073e9SAndroid Build Coastguard Worker                a = torch.arange(start, le, dtype=torch.float32, device=x.device)
730*da0073e9SAndroid Build Coastguard Worker            return a
731*da0073e9SAndroid Build Coastguard Worker
732*da0073e9SAndroid Build Coastguard Worker        pattern = r", device = device\(.+\), requires_grad = False"
733*da0073e9SAndroid Build Coastguard Worker
734*da0073e9SAndroid Build Coastguard Worker        cfunc = make_fx(func, decomposition_table=decomposition_table)
735*da0073e9SAndroid Build Coastguard Worker        fx_g = cfunc(torch.rand(10, device=device), None)
736*da0073e9SAndroid Build Coastguard Worker        fx_g_code = fx_g.code.strip()
737*da0073e9SAndroid Build Coastguard Worker        # Remove device and requires_grad
738*da0073e9SAndroid Build Coastguard Worker        fx_g_code = re.sub(pattern, "", fx_g_code)
739*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
740*da0073e9SAndroid Build Coastguard Worker            fx_g_code,
741*da0073e9SAndroid Build Coastguard Worker            """\
742*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1, start_1):
743*da0073e9SAndroid Build Coastguard Worker    iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64)
744*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.prims.mul.default(iota, 1);  iota = None
745*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.prims.add.default(mul, 0);  mul = None
746*da0073e9SAndroid Build Coastguard Worker    convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None
747*da0073e9SAndroid Build Coastguard Worker    return convert_element_type""",
748*da0073e9SAndroid Build Coastguard Worker        )
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker        fx_g = cfunc(torch.rand(10, device=device), 1)
751*da0073e9SAndroid Build Coastguard Worker        fx_g_code = fx_g.code.strip()
752*da0073e9SAndroid Build Coastguard Worker        # Remove device and requires_grad
753*da0073e9SAndroid Build Coastguard Worker        fx_g_code = re.sub(pattern, "", fx_g_code)
754*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
755*da0073e9SAndroid Build Coastguard Worker            fx_g_code,
756*da0073e9SAndroid Build Coastguard Worker            """\
757*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1, start_1):
758*da0073e9SAndroid Build Coastguard Worker    iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64)
759*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.prims.mul.default(iota, 1);  iota = None
760*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.prims.add.default(mul, 1);  mul = None
761*da0073e9SAndroid Build Coastguard Worker    convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None
762*da0073e9SAndroid Build Coastguard Worker    return convert_element_type""",
763*da0073e9SAndroid Build Coastguard Worker        )
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker    def test_masked_fill(self, device):
766*da0073e9SAndroid Build Coastguard Worker        from torch.fx.experimental.proxy_tensor import make_fx
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type not in [
769*da0073e9SAndroid Build Coastguard Worker            "xpu",
770*da0073e9SAndroid Build Coastguard Worker            "cuda",
771*da0073e9SAndroid Build Coastguard Worker            torch._C._get_privateuse1_backend_name(),
772*da0073e9SAndroid Build Coastguard Worker        ]:
773*da0073e9SAndroid Build Coastguard Worker            self.skipTest("only runs on XPU and CUDA and PrivateUse1.")
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker        def func(scores, mask, value):
776*da0073e9SAndroid Build Coastguard Worker            return scores.masked_fill(mask, value)
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker        scores_t = torch.tensor([1, 2, 3, 4], device=device)
779*da0073e9SAndroid Build Coastguard Worker        mask_t = torch.tensor([True, True, True, True], device=device)
780*da0073e9SAndroid Build Coastguard Worker        value_t = torch.tensor(0, dtype=scores_t.dtype)
781*da0073e9SAndroid Build Coastguard Worker        cfunc = make_fx(func, decomposition_table=decomposition_table)
782*da0073e9SAndroid Build Coastguard Worker        fx_g = cfunc(scores_t, mask_t, value_t)
783*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
784*da0073e9SAndroid Build Coastguard Worker            fx_g.code.strip(),
785*da0073e9SAndroid Build Coastguard Worker            """\
786*da0073e9SAndroid Build Coastguard Workerdef forward(self, scores_1, mask_1, value_1):
787*da0073e9SAndroid Build Coastguard Worker    where = torch.ops.prims.where.default(mask_1, value_1, scores_1);  mask_1 = value_1 = scores_1 = None
788*da0073e9SAndroid Build Coastguard Worker    return where""",
789*da0073e9SAndroid Build Coastguard Worker        )
790*da0073e9SAndroid Build Coastguard Worker
791*da0073e9SAndroid Build Coastguard Worker    class DecompCrossRefMode(TorchDispatchMode):
792*da0073e9SAndroid Build Coastguard Worker        def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):
793*da0073e9SAndroid Build Coastguard Worker            self.test_case = test_case
794*da0073e9SAndroid Build Coastguard Worker            self.saved_precision = saved_precision
795*da0073e9SAndroid Build Coastguard Worker            self.saved_rel_tol = saved_rel_tol
796*da0073e9SAndroid Build Coastguard Worker            self.test_dtype = dtype
797*da0073e9SAndroid Build Coastguard Worker            self.run_all = run_all
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Worker            # We check the correctness of each decomposition right after running it.
800*da0073e9SAndroid Build Coastguard Worker            # So, when we encounter a decomposition, we run the function normally, and
801*da0073e9SAndroid Build Coastguard Worker            # then run the decomposition, and ensure they're identical.
802*da0073e9SAndroid Build Coastguard Worker            self.called = set()
803*da0073e9SAndroid Build Coastguard Worker            self.decomposed = set()
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker        def __torch_dispatch__(self, func, types, args=(), kwargs=None):
806*da0073e9SAndroid Build Coastguard Worker            self.test_case.precision = self.saved_precision
807*da0073e9SAndroid Build Coastguard Worker            self.test_case.rel_tol = self.saved_rel_tol
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker            self.called.add(func)
810*da0073e9SAndroid Build Coastguard Worker            all_called[func] += 1
811*da0073e9SAndroid Build Coastguard Worker
812*da0073e9SAndroid Build Coastguard Worker            # Stuff we shouldn't bother testing
813*da0073e9SAndroid Build Coastguard Worker            # (TODO: remove detach from the decomp table?)
814*da0073e9SAndroid Build Coastguard Worker            # N.b. Testing in-place ops would need dedicated logic
815*da0073e9SAndroid Build Coastguard Worker            in_place = func.name()[-1] == "_"
816*da0073e9SAndroid Build Coastguard Worker            ignored_ops = [
817*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten.detach.default,
818*da0073e9SAndroid Build Coastguard Worker                # non-deterministic ops
819*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten.empty.memory_format,
820*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten.empty_like.default,
821*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten.new_empty.default,
822*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten.empty_strided.default,
823*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten.new_empty_strided.default,
824*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten.randn.default,
825*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten.native_dropout.default,
826*da0073e9SAndroid Build Coastguard Worker            ]
827*da0073e9SAndroid Build Coastguard Worker            if (
828*da0073e9SAndroid Build Coastguard Worker                func not in decomposition_table
829*da0073e9SAndroid Build Coastguard Worker                or func in ignored_ops
830*da0073e9SAndroid Build Coastguard Worker                or torch.Tag.nondeterministic_seeded in func.tags
831*da0073e9SAndroid Build Coastguard Worker                or any_unsupported(args, kwargs)
832*da0073e9SAndroid Build Coastguard Worker                or in_place
833*da0073e9SAndroid Build Coastguard Worker            ):
834*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker            self.decomposed.add(func)
837*da0073e9SAndroid Build Coastguard Worker            all_decomposed.add(func)
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker            # We take 2 main strategies for verifying correctness/numerical stability of decompositions
840*da0073e9SAndroid Build Coastguard Worker            # The first one is simply tolerance checking between decomp_out and pytorch_out
841*da0073e9SAndroid Build Coastguard Worker            # However, for fp16/bf16 and reductions, this becomes very
842*da0073e9SAndroid Build Coastguard Worker            # finicky, as there are not many guarantees we can make.
843*da0073e9SAndroid Build Coastguard Worker            # So, for fp16/bf16, we instead compare the difference of
844*da0073e9SAndroid Build Coastguard Worker            # {decomp_out, pytorch_out_64} and {pytorch_out,
845*da0073e9SAndroid Build Coastguard Worker            # pytorch_out_64}. In other words, we compare how far the
846*da0073e9SAndroid Build Coastguard Worker            # decomposition and pytorch are from the "ground truth" (i.e.
847*da0073e9SAndroid Build Coastguard Worker            # fp64). If the decomposition results in more error, we error
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker            # We also decompose the decomposition recursively for
850*da0073e9SAndroid Build Coastguard Worker            # further coverage, as some paths not be exercised directly by
851*da0073e9SAndroid Build Coastguard Worker            # OpInfos (sadly) but just by other ops
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker            decomposition = decomposition_table[func]
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker            do_relative_check = self.test_dtype in [torch.float16, torch.bfloat16]
856*da0073e9SAndroid Build Coastguard Worker            if self.run_all:
857*da0073e9SAndroid Build Coastguard Worker                # Execute recursively via DFS, to find the root of a possible error first
858*da0073e9SAndroid Build Coastguard Worker                with self:
859*da0073e9SAndroid Build Coastguard Worker                    decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs))
860*da0073e9SAndroid Build Coastguard Worker            else:
861*da0073e9SAndroid Build Coastguard Worker                decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs))
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker            # At this stage we should not be decomposing an in-place op
864*da0073e9SAndroid Build Coastguard Worker            # We'd like to have decompositions that decompose out-of-place ops into out-of-place ops
865*da0073e9SAndroid Build Coastguard Worker            #  because decompositions are run after functionalisation and we would not like them to
866*da0073e9SAndroid Build Coastguard Worker            #  de-functionalise the graph, as that would break AoTAutograd
867*da0073e9SAndroid Build Coastguard Worker            # We run the real function *after* the decomposition to make sure that the
868*da0073e9SAndroid Build Coastguard Worker            # decomposition does not modify any of the inputs in-place. If it does
869*da0073e9SAndroid Build Coastguard Worker            # real_out should be differen than decom_out so we should catch this
870*da0073e9SAndroid Build Coastguard Worker            real_out_unflat = func(*args, **kwargs)
871*da0073e9SAndroid Build Coastguard Worker            real_out = pytree.tree_leaves(real_out_unflat)
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker            assert len(real_out) == len(decomp_out)
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker            if do_relative_check:
876*da0073e9SAndroid Build Coastguard Worker                upcast = partial(upcast_tensor, dtype=torch.float64)
877*da0073e9SAndroid Build Coastguard Worker                real_out_double, _ = tree_flatten(
878*da0073e9SAndroid Build Coastguard Worker                    func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
879*da0073e9SAndroid Build Coastguard Worker                )
880*da0073e9SAndroid Build Coastguard Worker                for i, (orig, decomp, ref) in enumerate(
881*da0073e9SAndroid Build Coastguard Worker                    zip(real_out, decomp_out, real_out_double)
882*da0073e9SAndroid Build Coastguard Worker                ):
883*da0073e9SAndroid Build Coastguard Worker                    if not isinstance(orig, torch.Tensor):
884*da0073e9SAndroid Build Coastguard Worker                        assert type(orig) == type(decomp)
885*da0073e9SAndroid Build Coastguard Worker                        assert orig == decomp
886*da0073e9SAndroid Build Coastguard Worker                        continue
887*da0073e9SAndroid Build Coastguard Worker                    op_assert_ref(
888*da0073e9SAndroid Build Coastguard Worker                        self.test_case,
889*da0073e9SAndroid Build Coastguard Worker                        func,
890*da0073e9SAndroid Build Coastguard Worker                        self.test_dtype,
891*da0073e9SAndroid Build Coastguard Worker                        i,
892*da0073e9SAndroid Build Coastguard Worker                        orig,
893*da0073e9SAndroid Build Coastguard Worker                        decomp,
894*da0073e9SAndroid Build Coastguard Worker                        ref,
895*da0073e9SAndroid Build Coastguard Worker                        args,
896*da0073e9SAndroid Build Coastguard Worker                        kwargs,
897*da0073e9SAndroid Build Coastguard Worker                    )
898*da0073e9SAndroid Build Coastguard Worker            else:
899*da0073e9SAndroid Build Coastguard Worker                for orig, decomp in zip(real_out, decomp_out):
900*da0073e9SAndroid Build Coastguard Worker                    if not isinstance(orig, torch.Tensor):
901*da0073e9SAndroid Build Coastguard Worker                        assert type(orig) == type(decomp)
902*da0073e9SAndroid Build Coastguard Worker                        assert orig == decomp
903*da0073e9SAndroid Build Coastguard Worker                        continue
904*da0073e9SAndroid Build Coastguard Worker                    op_assert_equal(
905*da0073e9SAndroid Build Coastguard Worker                        self.test_case,
906*da0073e9SAndroid Build Coastguard Worker                        func,
907*da0073e9SAndroid Build Coastguard Worker                        self.test_dtype,
908*da0073e9SAndroid Build Coastguard Worker                        orig,
909*da0073e9SAndroid Build Coastguard Worker                        decomp,
910*da0073e9SAndroid Build Coastguard Worker                        args,
911*da0073e9SAndroid Build Coastguard Worker                        kwargs,
912*da0073e9SAndroid Build Coastguard Worker                    )
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker            return real_out_unflat
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker    def check_decomposed(self, aten_name, mode):
917*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
918*da0073e9SAndroid Build Coastguard Worker            any(overload_to_aten_name(c) == aten_name for c in mode.decomposed),
919*da0073e9SAndroid Build Coastguard Worker            msg=(
920*da0073e9SAndroid Build Coastguard Worker                f"aten.{aten_name} was not decomposed, saw calls for: "
921*da0073e9SAndroid Build Coastguard Worker                f"{', '.join(map(str, list(mode.called)))}. If your op is  "
922*da0073e9SAndroid Build Coastguard Worker                f"CompositeImplicitAutograd you should skip this test "
923*da0073e9SAndroid Build Coastguard Worker                f"by updating CROSS_REF_EXCLUDE_SET."
924*da0073e9SAndroid Build Coastguard Worker            ),
925*da0073e9SAndroid Build Coastguard Worker        )
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Test does not work with TorchDynamo")
928*da0073e9SAndroid Build Coastguard Worker    def do_cross_ref(self, device, dtype, op, *, run_all):
929*da0073e9SAndroid Build Coastguard Worker        test_keys = [
930*da0073e9SAndroid Build Coastguard Worker            (torch.device(device).type, dtype, op.name),
931*da0073e9SAndroid Build Coastguard Worker            (None, dtype, op.name),
932*da0073e9SAndroid Build Coastguard Worker            (None, None, op.name),
933*da0073e9SAndroid Build Coastguard Worker        ]
934*da0073e9SAndroid Build Coastguard Worker        if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys):
935*da0073e9SAndroid Build Coastguard Worker            self.skipTest(f"{op.name} in {dtype} not supported")
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Worker        skip_decomp_vjp = any(
938*da0073e9SAndroid Build Coastguard Worker            key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys
939*da0073e9SAndroid Build Coastguard Worker        )
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker        requires_grad = (
942*da0073e9SAndroid Build Coastguard Worker            op.supports_autograd
943*da0073e9SAndroid Build Coastguard Worker            and dtype in op.supported_backward_dtypes(torch.device(device).type)
944*da0073e9SAndroid Build Coastguard Worker            # TODO: OpInfo really ought to error out for this case, but it's
945*da0073e9SAndroid Build Coastguard Worker            # not exercised in test_ops_gradients atm.  The problem is not
946*da0073e9SAndroid Build Coastguard Worker            # complex32 per-se (which is supported by data movement only ops)
947*da0073e9SAndroid Build Coastguard Worker            # but that when we do backwards we expect other ops like add to work
948*da0073e9SAndroid Build Coastguard Worker            and not dtype == torch.complex32
949*da0073e9SAndroid Build Coastguard Worker        )
950*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker        aten_name = op.decomp_aten_name or op.aten_name
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Worker        func = op.get_op()
955*da0073e9SAndroid Build Coastguard Worker
956*da0073e9SAndroid Build Coastguard Worker        def run_without_python_dispatcher(mode):
957*da0073e9SAndroid Build Coastguard Worker            return any(
958*da0073e9SAndroid Build Coastguard Worker                isinstance(op, torch._ops.OpOverload)
959*da0073e9SAndroid Build Coastguard Worker                and op.has_kernel_for_dispatch_key(
960*da0073e9SAndroid Build Coastguard Worker                    DispatchKey.CompositeImplicitAutograd
961*da0073e9SAndroid Build Coastguard Worker                )
962*da0073e9SAndroid Build Coastguard Worker                for op in mode.decomposed.union([func])
963*da0073e9SAndroid Build Coastguard Worker            )
964*da0073e9SAndroid Build Coastguard Worker
965*da0073e9SAndroid Build Coastguard Worker        for sample_input in samples:
966*da0073e9SAndroid Build Coastguard Worker            if requires_grad:
967*da0073e9SAndroid Build Coastguard Worker                fn, primals = normalize_op_input_output(func, sample_input)
968*da0073e9SAndroid Build Coastguard Worker                primals = tree_map(
969*da0073e9SAndroid Build Coastguard Worker                    lambda x: x if isinstance(x, torch.Tensor) else x, primals
970*da0073e9SAndroid Build Coastguard Worker                )
971*da0073e9SAndroid Build Coastguard Worker
972*da0073e9SAndroid Build Coastguard Worker                # Once https://github.com/pytorch/pytorch/pull/75965/ I can
973*da0073e9SAndroid Build Coastguard Worker                # store the called list on the mode object instance and no
974*da0073e9SAndroid Build Coastguard Worker                # explicit clearing is necessary as I will create a fresh mode
975*da0073e9SAndroid Build Coastguard Worker                # for each region
976*da0073e9SAndroid Build Coastguard Worker                with self.DecompCrossRefMode(
977*da0073e9SAndroid Build Coastguard Worker                    self, self.precision, self.rel_tol, dtype, run_all
978*da0073e9SAndroid Build Coastguard Worker                ) as mode, enable_python_dispatcher():
979*da0073e9SAndroid Build Coastguard Worker                    decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
980*da0073e9SAndroid Build Coastguard Worker                if run_without_python_dispatcher(mode):
981*da0073e9SAndroid Build Coastguard Worker                    # without this check, incorrect decomps at the python dispatcher level can still pass because
982*da0073e9SAndroid Build Coastguard Worker                    # they're checking aten decomps at the torch_dispatch level.
983*da0073e9SAndroid Build Coastguard Worker                    with self.DecompCrossRefMode(
984*da0073e9SAndroid Build Coastguard Worker                        self, self.precision, self.rel_tol, dtype, run_all
985*da0073e9SAndroid Build Coastguard Worker                    ) as mode:
986*da0073e9SAndroid Build Coastguard Worker                        decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
987*da0073e9SAndroid Build Coastguard Worker                if aten_name in decomposition_names:
988*da0073e9SAndroid Build Coastguard Worker                    self.check_decomposed(aten_name, mode)
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker                if not skip_decomp_vjp and (
991*da0073e9SAndroid Build Coastguard Worker                    op.aten_backward_name in decomposition_names or run_all
992*da0073e9SAndroid Build Coastguard Worker                ):
993*da0073e9SAndroid Build Coastguard Worker                    cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker                    with self.DecompCrossRefMode(
996*da0073e9SAndroid Build Coastguard Worker                        self, self.precision, self.rel_tol, dtype, run_all
997*da0073e9SAndroid Build Coastguard Worker                    ) as mode, enable_python_dispatcher():
998*da0073e9SAndroid Build Coastguard Worker                        decomp_vjp_fn(cotangents)
999*da0073e9SAndroid Build Coastguard Worker                    if run_without_python_dispatcher(mode):
1000*da0073e9SAndroid Build Coastguard Worker                        # without this check, incorrect decomps at the python dispatcher level can still pass because
1001*da0073e9SAndroid Build Coastguard Worker                        # they're checking aten decomps at the torch_dispatch level.
1002*da0073e9SAndroid Build Coastguard Worker                        with self.DecompCrossRefMode(
1003*da0073e9SAndroid Build Coastguard Worker                            self, self.precision, self.rel_tol, dtype, run_all
1004*da0073e9SAndroid Build Coastguard Worker                        ) as mode:
1005*da0073e9SAndroid Build Coastguard Worker                            decomp_vjp_fn(cotangents)
1006*da0073e9SAndroid Build Coastguard Worker                    if not run_all:
1007*da0073e9SAndroid Build Coastguard Worker                        self.check_decomposed(op.aten_backward_name, mode)
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker            elif aten_name in decomposition_names or run_all:
1010*da0073e9SAndroid Build Coastguard Worker                args = [sample_input.input] + list(sample_input.args)
1011*da0073e9SAndroid Build Coastguard Worker                kwargs = sample_input.kwargs
1012*da0073e9SAndroid Build Coastguard Worker                # A failure here might be because the decomposition for the op is wrong or because a
1013*da0073e9SAndroid Build Coastguard Worker                # decomposition used by the particular op is wrong.
1014*da0073e9SAndroid Build Coastguard Worker                with self.DecompCrossRefMode(
1015*da0073e9SAndroid Build Coastguard Worker                    self, self.precision, self.rel_tol, dtype, run_all
1016*da0073e9SAndroid Build Coastguard Worker                ) as mode, enable_python_dispatcher():
1017*da0073e9SAndroid Build Coastguard Worker                    func(*args, **kwargs)
1018*da0073e9SAndroid Build Coastguard Worker
1019*da0073e9SAndroid Build Coastguard Worker                if run_without_python_dispatcher(mode):
1020*da0073e9SAndroid Build Coastguard Worker                    # without this check, incorrect decomps at the python dispatcher level can still pass because
1021*da0073e9SAndroid Build Coastguard Worker                    # they're checking aten decomps at the torch_dispatch level.
1022*da0073e9SAndroid Build Coastguard Worker                    with self.DecompCrossRefMode(
1023*da0073e9SAndroid Build Coastguard Worker                        self, self.precision, self.rel_tol, dtype, run_all
1024*da0073e9SAndroid Build Coastguard Worker                    ) as mode:
1025*da0073e9SAndroid Build Coastguard Worker                        func(*args, **kwargs)
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker                if not run_all:
1028*da0073e9SAndroid Build Coastguard Worker                    self.check_decomposed(aten_name, mode)
1029*da0073e9SAndroid Build Coastguard Worker            else:
1030*da0073e9SAndroid Build Coastguard Worker                assert op.supports_autograd
1031*da0073e9SAndroid Build Coastguard Worker                self.skipTest(
1032*da0073e9SAndroid Build Coastguard Worker                    "only backwards is decomposed, but dtype doesn't support AD"
1033*da0073e9SAndroid Build Coastguard Worker                )
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard Worker
1036*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDecomp, globals())
1037*da0073e9SAndroid Build Coastguard Worker
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Workerclass DecompOneOffTests(TestCase):
1040*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1041*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1042*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1043*da0073e9SAndroid Build Coastguard Worker    def test_contiguous_softmax(self, device):
1044*da0073e9SAndroid Build Coastguard Worker        size = (2, 4, 3, 3)
1045*da0073e9SAndroid Build Coastguard Worker        stride = (9, 18, 3, 1)
1046*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(size, dtype=dtype, device=device)
1049*da0073e9SAndroid Build Coastguard Worker        x = torch.as_strided(x, size, stride)
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker        ref = torch.ops.aten._softmax(x, -1, False)
1052*da0073e9SAndroid Build Coastguard Worker        res = torch._decomp.decompositions._softmax(x, -1, False)
1053*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref.stride(), res.stride())
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1056*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1057*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1058*da0073e9SAndroid Build Coastguard Worker    def test_contiguous_log_softmax(self, device):
1059*da0073e9SAndroid Build Coastguard Worker        size = (2, 4, 3, 3)
1060*da0073e9SAndroid Build Coastguard Worker        stride = (9, 18, 3, 1)
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
1063*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(size, dtype=dtype, device=device)
1064*da0073e9SAndroid Build Coastguard Worker        x = torch.as_strided(x, size, stride)
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker        ref = torch.ops.aten._log_softmax(x, -1, False)
1067*da0073e9SAndroid Build Coastguard Worker        res = torch._decomp.decompositions._log_softmax(x, -1, False)
1068*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref.stride(), res.stride())
1069*da0073e9SAndroid Build Coastguard Worker
1070*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1071*da0073e9SAndroid Build Coastguard Worker    def test_exponential_non_inf(self, device):
1072*da0073e9SAndroid Build Coastguard Worker        inp = torch.empty((4, 400, 256), device=device)
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.utils.preserve_rng_state():
1075*da0073e9SAndroid Build Coastguard Worker            exp_ref = inp.exponential_()
1076*da0073e9SAndroid Build Coastguard Worker        exp = torch._refs.exponential(inp)
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(exp, exp_ref)
1079*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(exp.isinf().any())
1080*da0073e9SAndroid Build Coastguard Worker
1081*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1082*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1083*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1084*da0073e9SAndroid Build Coastguard Worker    def test_amp_batch_norm_backward(self):
1085*da0073e9SAndroid Build Coastguard Worker        device = "cuda"
1086*da0073e9SAndroid Build Coastguard Worker        grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
1087*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
1088*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn((2,), dtype=torch.float32, device=device)
1089*da0073e9SAndroid Build Coastguard Worker        rmean = torch.randn((2,), dtype=torch.float32, device=device)
1090*da0073e9SAndroid Build Coastguard Worker        rvar = torch.randn((2,), dtype=torch.float32, device=device)
1091*da0073e9SAndroid Build Coastguard Worker        mean = torch.randn((0,), dtype=torch.float32, device=device)
1092*da0073e9SAndroid Build Coastguard Worker
1093*da0073e9SAndroid Build Coastguard Worker        ref = torch.ops.aten.native_batch_norm_backward(
1094*da0073e9SAndroid Build Coastguard Worker            grad_out,
1095*da0073e9SAndroid Build Coastguard Worker            x,
1096*da0073e9SAndroid Build Coastguard Worker            weight,
1097*da0073e9SAndroid Build Coastguard Worker            rmean,
1098*da0073e9SAndroid Build Coastguard Worker            rvar,
1099*da0073e9SAndroid Build Coastguard Worker            mean,
1100*da0073e9SAndroid Build Coastguard Worker            mean,
1101*da0073e9SAndroid Build Coastguard Worker            False,
1102*da0073e9SAndroid Build Coastguard Worker            1e-05,
1103*da0073e9SAndroid Build Coastguard Worker            [True, True, True],
1104*da0073e9SAndroid Build Coastguard Worker        )
1105*da0073e9SAndroid Build Coastguard Worker        res = torch._decomp.decompositions.native_batch_norm_backward(
1106*da0073e9SAndroid Build Coastguard Worker            grad_out,
1107*da0073e9SAndroid Build Coastguard Worker            x,
1108*da0073e9SAndroid Build Coastguard Worker            weight,
1109*da0073e9SAndroid Build Coastguard Worker            rmean,
1110*da0073e9SAndroid Build Coastguard Worker            rvar,
1111*da0073e9SAndroid Build Coastguard Worker            mean,
1112*da0073e9SAndroid Build Coastguard Worker            mean,
1113*da0073e9SAndroid Build Coastguard Worker            False,
1114*da0073e9SAndroid Build Coastguard Worker            1e-05,
1115*da0073e9SAndroid Build Coastguard Worker            [True, True, True],
1116*da0073e9SAndroid Build Coastguard Worker        )
1117*da0073e9SAndroid Build Coastguard Worker        for a, b in zip(ref, res):
1118*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.stride(), b.stride())
1119*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.dtype, b.dtype)
1120*da0073e9SAndroid Build Coastguard Worker
1121*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1122*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1123*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1124*da0073e9SAndroid Build Coastguard Worker    def test_elu_backward(self, device):
1125*da0073e9SAndroid Build Coastguard Worker        size = (2, 4, 3, 3)
1126*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
1127*da0073e9SAndroid Build Coastguard Worker        grad_out = torch.randn(size, dtype=dtype, device=device)
1128*da0073e9SAndroid Build Coastguard Worker        out = torch.randn(size, dtype=dtype, device=device)
1129*da0073e9SAndroid Build Coastguard Worker
1130*da0073e9SAndroid Build Coastguard Worker        ref = torch.ops.aten.elu_backward(grad_out, 1.0, 1, 1, True, out)
1131*da0073e9SAndroid Build Coastguard Worker        res = torch._decomp.decompositions.elu_backward(grad_out, 1.0, 1, 1, True, out)
1132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1135*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1136*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1137*da0073e9SAndroid Build Coastguard Worker    def test_threshold_backward_dtype(self, device):
1138*da0073e9SAndroid Build Coastguard Worker        grad = torch.randint(10, (4,), device=device)
1139*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.randint(10, (4,), device=device)
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker        ref = torch.ops.aten.threshold_backward(grad, input_tensor, 1)
1142*da0073e9SAndroid Build Coastguard Worker        res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1)
1143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref.dtype, res.dtype)
1144*da0073e9SAndroid Build Coastguard Worker
1145*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1146*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1147*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1148*da0073e9SAndroid Build Coastguard Worker    def test_weight_norm_interface(self, device):
1149*da0073e9SAndroid Build Coastguard Worker        g = torch.randn((3, 10, 10), device=device)
1150*da0073e9SAndroid Build Coastguard Worker        v = torch.randn((1, 1, 10), device=device)
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker        ref = torch.ops.aten._weight_norm_interface(g, v, 2)
1153*da0073e9SAndroid Build Coastguard Worker        res = torch._decomp.decompositions._weight_norm_interface(g, v, 2)
1154*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref[0], res[0]))
1155*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ref[1], res[1]))
1156*da0073e9SAndroid Build Coastguard Worker
1157*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand([30, 10], device=device)
1158*da0073e9SAndroid Build Coastguard Worker        inp2 = torch.rand([30, 1], device=device)
1159*da0073e9SAndroid Build Coastguard Worker
1160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1161*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten._weight_norm_interface(inp, inp2),
1162*da0073e9SAndroid Build Coastguard Worker            torch._decomp.decompositions._weight_norm_interface(inp, inp2),
1163*da0073e9SAndroid Build Coastguard Worker        )
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1166*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1167*da0073e9SAndroid Build Coastguard Worker    @skipIfCrossRef
1168*da0073e9SAndroid Build Coastguard Worker    @skipOps(
1169*da0073e9SAndroid Build Coastguard Worker        "DecompOneOffTests",
1170*da0073e9SAndroid Build Coastguard Worker        "test_sdpa",
1171*da0073e9SAndroid Build Coastguard Worker        [
1172*da0073e9SAndroid Build Coastguard Worker            xfail(
1173*da0073e9SAndroid Build Coastguard Worker                "nn.functional.scaled_dot_product_attention",
1174*da0073e9SAndroid Build Coastguard Worker                dtypes=[torch.half],
1175*da0073e9SAndroid Build Coastguard Worker            ),
1176*da0073e9SAndroid Build Coastguard Worker        ],
1177*da0073e9SAndroid Build Coastguard Worker    )
1178*da0073e9SAndroid Build Coastguard Worker    @ops(_sdpa_op_info)
1179*da0073e9SAndroid Build Coastguard Worker    def test_sdpa(self, device, dtype, op):
1180*da0073e9SAndroid Build Coastguard Worker        # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we
1181*da0073e9SAndroid Build Coastguard Worker        # add support for float16 over there we should update this test as well.
1182*da0073e9SAndroid Build Coastguard Worker
1183*da0073e9SAndroid Build Coastguard Worker        class ScaledDotProductAttention(torch.nn.Module):
1184*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1185*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1186*da0073e9SAndroid Build Coastguard Worker
1187*da0073e9SAndroid Build Coastguard Worker            def forward(
1188*da0073e9SAndroid Build Coastguard Worker                self, query_layer, key_layer, value_layer, mask=None, is_causal=True
1189*da0073e9SAndroid Build Coastguard Worker            ):
1190*da0073e9SAndroid Build Coastguard Worker                attn_output = op(
1191*da0073e9SAndroid Build Coastguard Worker                    query_layer,
1192*da0073e9SAndroid Build Coastguard Worker                    key_layer,
1193*da0073e9SAndroid Build Coastguard Worker                    value_layer,
1194*da0073e9SAndroid Build Coastguard Worker                    attn_mask=mask,
1195*da0073e9SAndroid Build Coastguard Worker                    dropout_p=0.0,
1196*da0073e9SAndroid Build Coastguard Worker                    is_causal=is_causal,
1197*da0073e9SAndroid Build Coastguard Worker                )
1198*da0073e9SAndroid Build Coastguard Worker                return attn_output
1199*da0073e9SAndroid Build Coastguard Worker
1200*da0073e9SAndroid Build Coastguard Worker        query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1201*da0073e9SAndroid Build Coastguard Worker        key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1202*da0073e9SAndroid Build Coastguard Worker        value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1203*da0073e9SAndroid Build Coastguard Worker        masks = [None, torch.ones((1, 1, 100, 100), device=device, dtype=torch.bool)]
1204*da0073e9SAndroid Build Coastguard Worker
1205*da0073e9SAndroid Build Coastguard Worker        atol, rtol = dtype_precisions[dtype]
1206*da0073e9SAndroid Build Coastguard Worker
1207*da0073e9SAndroid Build Coastguard Worker        for mask in masks:
1208*da0073e9SAndroid Build Coastguard Worker            is_causal = mask is None
1209*da0073e9SAndroid Build Coastguard Worker            attention = ScaledDotProductAttention()
1210*da0073e9SAndroid Build Coastguard Worker            decomposed_res = (
1211*da0073e9SAndroid Build Coastguard Worker                torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu(
1212*da0073e9SAndroid Build Coastguard Worker                    query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask
1213*da0073e9SAndroid Build Coastguard Worker                )
1214*da0073e9SAndroid Build Coastguard Worker            )
1215*da0073e9SAndroid Build Coastguard Worker            eager_res = op(
1216*da0073e9SAndroid Build Coastguard Worker                query_layer,
1217*da0073e9SAndroid Build Coastguard Worker                key_layer,
1218*da0073e9SAndroid Build Coastguard Worker                value_layer,
1219*da0073e9SAndroid Build Coastguard Worker                attn_mask=mask,
1220*da0073e9SAndroid Build Coastguard Worker                dropout_p=0.0,
1221*da0073e9SAndroid Build Coastguard Worker                is_causal=is_causal,
1222*da0073e9SAndroid Build Coastguard Worker            )
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
1225*da0073e9SAndroid Build Coastguard Worker                torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol)
1226*da0073e9SAndroid Build Coastguard Worker            )
1227*da0073e9SAndroid Build Coastguard Worker
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(DecompOneOffTests, globals())
1230*da0073e9SAndroid Build Coastguard Worker
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Workerclass HasDecompTest(TestCase):
1233*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
1234*da0073e9SAndroid Build Coastguard Worker        super().setUp()
1235*da0073e9SAndroid Build Coastguard Worker        self.maxDiff = None
1236*da0073e9SAndroid Build Coastguard Worker
1237*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1238*da0073e9SAndroid Build Coastguard Worker    def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool:
1239*da0073e9SAndroid Build Coastguard Worker        has_tensor_arg = any(
1240*da0073e9SAndroid Build Coastguard Worker            "Tensor" in str(a.type)
1241*da0073e9SAndroid Build Coastguard Worker            for a in itertools.chain(op._schema.arguments, op._schema.returns)
1242*da0073e9SAndroid Build Coastguard Worker        )
1243*da0073e9SAndroid Build Coastguard Worker        if not has_tensor_arg:
1244*da0073e9SAndroid Build Coastguard Worker            return False
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker        try:
1247*da0073e9SAndroid Build Coastguard Worker            # CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions
1248*da0073e9SAndroid Build Coastguard Worker            return not op.has_kernel_for_dispatch_key(
1249*da0073e9SAndroid Build Coastguard Worker                DispatchKey.CompositeImplicitAutograd
1250*da0073e9SAndroid Build Coastguard Worker            )
1251*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as e:
1252*da0073e9SAndroid Build Coastguard Worker            # has_key fails for some jit-registered ops, which shouldn't be
1253*da0073e9SAndroid Build Coastguard Worker            # relevant here anyway
1254*da0073e9SAndroid Build Coastguard Worker            if "does not exist" in str(e):
1255*da0073e9SAndroid Build Coastguard Worker                return False
1256*da0073e9SAndroid Build Coastguard Worker            raise
1257*da0073e9SAndroid Build Coastguard Worker
1258*da0073e9SAndroid Build Coastguard Worker    def test_has_decomposition(self):
1259*da0073e9SAndroid Build Coastguard Worker        def all_aten_overloads():
1260*da0073e9SAndroid Build Coastguard Worker            for name in torch._C._dispatch_get_all_op_names():
1261*da0073e9SAndroid Build Coastguard Worker                if not name.startswith("aten::"):
1262*da0073e9SAndroid Build Coastguard Worker                    continue
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker                name = name[6:]
1265*da0073e9SAndroid Build Coastguard Worker                if "." in name:
1266*da0073e9SAndroid Build Coastguard Worker                    packet_name, overload_name = name.split(".")
1267*da0073e9SAndroid Build Coastguard Worker                else:
1268*da0073e9SAndroid Build Coastguard Worker                    packet_name, overload_name = name, "default"
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker                packet = getattr(aten, packet_name)
1271*da0073e9SAndroid Build Coastguard Worker                assert isinstance(packet, torch._ops.OpOverloadPacket)
1272*da0073e9SAndroid Build Coastguard Worker                op = getattr(packet, overload_name)
1273*da0073e9SAndroid Build Coastguard Worker                yield op
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker        # This is for operators that are only registered in some CI
1276*da0073e9SAndroid Build Coastguard Worker        # configurations, so would cause the test to fail
1277*da0073e9SAndroid Build Coastguard Worker        allow_list = {aten.get_gradients.default}
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker        overloads_wanting_decomp = {
1280*da0073e9SAndroid Build Coastguard Worker            op for op in all_aten_overloads() if self._can_appear_in_trace(op)
1281*da0073e9SAndroid Build Coastguard Worker        }
1282*da0073e9SAndroid Build Coastguard Worker        ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys()
1283*da0073e9SAndroid Build Coastguard Worker        ops_missing_decomp -= allow_list
1284*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(
1285*da0073e9SAndroid Build Coastguard Worker            "".join(sorted(op.name() + "\n" for op in ops_missing_decomp))
1286*da0073e9SAndroid Build Coastguard Worker        )
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker    def test_aten_core_operators(self):
1289*da0073e9SAndroid Build Coastguard Worker        # If a decomposition isn't included in the core decompositions,
1290*da0073e9SAndroid Build Coastguard Worker        # then it must decompose a core ATen operator.
1291*da0073e9SAndroid Build Coastguard Worker        #
1292*da0073e9SAndroid Build Coastguard Worker        # See NOTE [Core ATen Ops]
1293*da0073e9SAndroid Build Coastguard Worker        #
1294*da0073e9SAndroid Build Coastguard Worker        # If this test fails then either:
1295*da0073e9SAndroid Build Coastguard Worker        # - Add the decomposition to torch._decomp.core_aten_decompositions,
1296*da0073e9SAndroid Build Coastguard Worker        #   if decomposition should be used by inductor (not a core operator).
1297*da0073e9SAndroid Build Coastguard Worker        # - Run this test again with EXPECTTEST_ACCEPT=1 to update the list of
1298*da0073e9SAndroid Build Coastguard Worker        #   core ATen operators (and inductor will not use the decomposition).
1299*da0073e9SAndroid Build Coastguard Worker
1300*da0073e9SAndroid Build Coastguard Worker        # Some decompositions are registered for CompositeImplicitAutograd
1301*da0073e9SAndroid Build Coastguard Worker        # operators, which never appear in AOTAutograd's graph so are never used.
1302*da0073e9SAndroid Build Coastguard Worker        useful_decomps = {
1303*da0073e9SAndroid Build Coastguard Worker            op
1304*da0073e9SAndroid Build Coastguard Worker            for op in decomposition_table.keys()
1305*da0073e9SAndroid Build Coastguard Worker            if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op)
1306*da0073e9SAndroid Build Coastguard Worker        }
1307*da0073e9SAndroid Build Coastguard Worker        core_decomps = torch._decomp.core_aten_decompositions().keys()
1308*da0073e9SAndroid Build Coastguard Worker        core_aten_ops = useful_decomps - core_decomps
1309*da0073e9SAndroid Build Coastguard Worker        self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops)))
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1313*da0073e9SAndroid Build Coastguard Worker    run_tests()
1314