xref: /aosp_15_r20/external/pytorch/test/test_decomp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: decompositions"]
2
3import functools
4import itertools
5import re
6import unittest
7from collections import defaultdict
8from functools import partial
9
10import torch._inductor.decomposition
11import torch.autograd
12from torch import Tensor
13from torch._decomp import core_aten_decompositions, decomposition_table
14from torch._dispatch.python import enable_python_dispatcher
15from torch._ops import DispatchKey
16from torch.testing import make_tensor
17from torch.testing._internal.common_cuda import tf32_off
18from torch.testing._internal.common_device_type import (
19    instantiate_device_type_tests,
20    onlyCPU,
21    onlyCUDA,
22    onlyNativeDeviceTypes,
23    ops,
24)
25from torch.testing._internal.common_methods_invocations import (
26    op_db,
27    skip,
28    skipOps,
29    xfail,
30)
31from torch.testing._internal.common_modules import module_db, modules
32from torch.testing._internal.common_utils import (
33    is_iterable_of_tensors,
34    run_tests,
35    skipIfCrossRef,
36    skipIfTorchDynamo,
37    suppress_warnings,
38    TEST_WITH_ASAN,
39    TEST_WITH_SLOW,
40    TestCase,
41    unMarkDynamoStrictTest,
42)
43from torch.utils import _pytree as pytree
44from torch.utils._python_dispatch import TorchDispatchMode
45from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
46
47
48aten = torch.ops.aten
49
50
51# TODO: this isn't going to work with non-aten namespaces
52def overload_to_aten_name(op):
53    return op._schema.name.split("::")[1]
54
55
56# All operators that can have decomp tests
57decomposition_names = {
58    overload_to_aten_name(k)
59    for k in decomposition_table
60    if isinstance(k, torch._ops.OpOverload)
61}
62core_decomposition_names = {
63    overload_to_aten_name(k)
64    for k in core_aten_decompositions()
65    if isinstance(k, torch._ops.OpOverload)
66}
67_decomp_test_ops = [
68    op
69    for op in op_db
70    if op.aten_name in decomposition_names
71    or op.aten_backward_name in decomposition_names
72]
73_decomp_test_ops_core_autograd = [
74    op
75    for op in op_db
76    if op.aten_name in core_decomposition_names and op.supports_autograd
77]
78_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name]
79
80
81def diff_arg(arg, requires_grad=True):
82    def is_differentiable_arg(arg):
83        if requires_grad:
84            return arg.requires_grad
85        else:
86            return arg.is_floating_point() or arg.is_complex()
87
88    if is_iterable_of_tensors(arg):
89        if all(is_differentiable_arg(a) for a in arg):
90            return True
91        if all(not is_differentiable_arg(a) for a in arg):
92            return False
93        raise RuntimeError("NYI: The test runner can't handle this")
94    return isinstance(arg, Tensor) and is_differentiable_arg(arg)
95
96
97# Version of autograd.grad with some differences:
98#   - pytree inputs is allowed (but leaves of the pytree have to all
99#     be tensors)
100#   - if an input is not used as part of derivatives, we will return a
101#     zero-filled tensor for the result
102def _autograd_grad(
103    outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
104):
105    inputs, inputs_spec = tree_flatten(inputs)
106    diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
107    if grad_outputs is None:
108        diff_outputs = tuple(out for out in outputs if out.requires_grad)
109    else:
110        diff_grad_outputs = [
111            (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
112        ]
113        if len(diff_grad_outputs) == 0:
114            diff_outputs, grad_outputs = (), ()
115        else:
116            diff_outputs, grad_outputs = zip(*diff_grad_outputs)
117    grad_inputs = torch.autograd.grad(
118        diff_outputs,
119        diff_inputs,
120        grad_outputs,
121        retain_graph=retain_graph,
122        create_graph=create_graph,
123        allow_unused=True,
124    )
125    result = []
126    grad_inputs_iter = iter(grad_inputs)
127    for inp in inputs:
128        if inp.requires_grad:
129            grad_input = next(grad_inputs_iter)
130            if grad_input is None:
131                result.append(torch.zeros_like(inp))
132            else:
133                result.append(grad_input)
134        else:
135            result.append(torch.zeros_like(inp))
136    return tree_unflatten(result, inputs_spec)
137
138
139def _as_tuple(val):
140    if isinstance(val, tuple):
141        return val
142    return (val,)
143
144
145def ref_vjp_no_create(f, *primals):
146    result = f(*primals)
147
148    def wrapped(cotangents):
149        return _autograd_grad(
150            _as_tuple(result),
151            primals,
152            _as_tuple(cotangents),
153            create_graph=False,
154            retain_graph=True,
155        )
156
157    return result, wrapped
158
159
160dtype_precisions = {
161    torch.float16: (0.001, 1e-5),
162    torch.bfloat16: (0.016, 1e-4),
163    torch.float32: (1.3e-6, 1e-5),
164    torch.float64: (1e-7, 1e-7),
165    torch.complex32: (0.001, 1e-5),
166    torch.complex64: (1.3e-6, 1e-5),
167    torch.complex128: (1e-7, 1e-7),
168}
169# Returns the "default" rtol and atol for comparing scalars or
170# tensors of the given dtypes.
171
172
173def _getDefaultRtolAndAtol(dtype0, dtype1):
174    rtol = max(
175        dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0]
176    )
177    atol = max(
178        dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1]
179    )
180    return rtol, atol
181
182
183def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
184    assert orig.dtype == decomp.dtype, f"{i} Operation:  {op}"
185    if orig.numel() == 0 or decomp.numel() == 0:
186        assert orig.numel() == decomp.numel()
187        return
188    assert orig.shape == decomp.shape, f"{i} Operation:  {op}"
189    tol_table = {
190        (torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
191        (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
192        (torch.float16, torch.ops.aten.native_layer_norm_backward.default): 1e-3,
193        (torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2,
194        (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
195        (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
196        (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
197        (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
198        (torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
199        (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
200        (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
201        (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
202        (torch.bfloat16, torch.ops.aten.var_mean.correction): 5e-7,
203        (torch.float16, torch.ops.aten.var_mean.correction): 5e-7,
204        (torch.bfloat16, torch.ops.aten.var_mean.dim): 5e-7,
205        (torch.float16, torch.ops.aten.var_mean.dim): 5e-7,
206        (torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2,
207        (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1,
208        (torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2,
209        (torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1,
210        (torch.float16, torch.ops.aten.hardswish.default): 2e-7,
211        (torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7,
212        (torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2,
213        (torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 5e-2,
214        (torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
215        (torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
216        (torch.float16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3,
217        (torch.bfloat16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3,
218        (torch.float16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
219        (torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
220        (torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3,
221        (torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2,
222        # see https://github.com/pytorch/pytorch/pull/96264
223        (torch.float16, torch.ops.aten.mv.default): 1e-5,
224        (torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
225        (torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5,
226        (torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7,
227    }
228    if ref.is_floating_point():
229        orig_diff = (orig - ref).abs().max()
230        decomp_diff = (decomp - ref).abs().max()
231        atol = tol_table.get((test_dtype, op), 1e-7)
232        if decomp_diff > orig_diff + atol:
233            raise RuntimeError(
234                f"Difference from float64 is larger with decomposition {op.__name__}"
235                f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
236                f"atol = {atol}\n"
237                f"args = {args}\n"
238                f"kwargs = {kwargs}"
239            )
240    else:
241        test_case.assertEqual(
242            orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}"
243        )
244
245
246def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
247    test_case.assertEqual(
248        orig.dtype,
249        decomp.dtype,
250        f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}",
251    )
252    # Before adding an entry to this table, make sure your decomposition is right :)
253    tol_table = {
254        # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
255        (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3),
256        (torch.float32, torch.ops.aten.native_layer_norm_backward.default): (
257            1e-3,
258            1e-3,
259        ),
260        (torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
261        # This exceeds default tolerances only on CPU, on CUDA it's fine
262        (torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5),
263        # Exceeds tolerances on CUDA, likely due to fma
264        (torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5),
265        (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5),
266        (torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4),
267        (torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4),
268        # The decomposition is TOO correct. It computes everything in int64, so sometimes
269        # there's an off-by-one error. See
270        # https://github.com/pytorch/pytorch/issues/81996
271        # https://github.com/pytorch/pytorch/issues/82230
272        (torch.int8, torch.ops.aten.linspace.default): (0, 1),
273        (torch.uint8, torch.ops.aten.linspace.default): (0, 1),
274        (torch.int16, torch.ops.aten.linspace.default): (0, 1),
275        (torch.int32, torch.ops.aten.linspace.default): (0, 1),
276        (torch.int64, torch.ops.aten.linspace.default): (0, 1),
277        (torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
278        (torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
279        (torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
280        (torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
281        (torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
282        (torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
283        (torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
284        (torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
285        (torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
286        (torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
287        (torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
288        (torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
289        (torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
290        (torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
291        (torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
292    }
293    if (decomp.dtype, op) in tol_table:
294        rtol, atol = tol_table[(decomp.dtype, op)]
295    else:
296        rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
297    test_case.assertEqual(
298        orig,
299        decomp,
300        rtol=rtol,
301        atol=atol,
302        msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}",
303    )
304
305
306# Given f, returns an f' such that:
307# - f' takes only positional arguments
308# - All arguments to f' are floating-point Tensors
309# - All outputs of f' are floating-point Tensors
310def normalize_op_input_output2(
311    f, args, kwargs, output_process_fn_grad=None, requires_grad=True
312):
313    flat_args, args_spec = tree_flatten(args)
314    diff_argnums = tuple(
315        i
316        for i, arg in enumerate(flat_args)
317        if diff_arg(arg, requires_grad=requires_grad)
318    )
319    assert len(diff_argnums) > 0
320    primals = tuple(flat_args[i] for i in diff_argnums)
321
322    @functools.wraps(f)
323    def wrapped(*primals):
324        _args = list(flat_args)
325        for num, arg in zip(diff_argnums, primals):
326            _args[num] = arg
327        _args = tree_unflatten(_args, args_spec)
328        result = f(*_args, **kwargs)
329        if output_process_fn_grad is not None:
330            result = output_process_fn_grad(result)
331        if isinstance(result, tuple):
332            # TODO We should check that the integer outputs also agree
333            result = tuple(
334                r
335                for r in result
336                if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex())
337            )
338            assert len(result) > 0
339        return result
340
341    return wrapped, primals
342
343
344# NB: This also upcasts dtype arguments
345# TODO: handle complex correctly
346def upcast_tensor(x, dtype=torch.float32):
347    if isinstance(x, Tensor) and x.dtype.is_floating_point:
348        return x.to(dtype=dtype)
349    elif isinstance(x, torch.dtype) and x in [
350        torch.float16,
351        torch.bfloat16,
352        torch.float,
353    ]:
354        return dtype
355    else:
356        return x
357
358
359def normalize_op_input_output(f, sample, requires_grad=True):
360    args = tuple([sample.input] + list(sample.args))
361    return normalize_op_input_output2(
362        f,
363        args,
364        sample.kwargs,
365        sample.output_process_fn_grad,
366        requires_grad=requires_grad,
367    )
368
369
370CROSS_REF_EXCLUDE_SET = {
371    # CUBLAS_STATUS_NOT_SUPPORTED when calling
372    # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k,
373    # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF,
374    # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
375    # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
376    ("cuda", torch.bfloat16, "nn.functional.bilinear"),
377    # randomness
378    (None, None, "special.ndtr"),  # aten.special_ndtr was not decomposed
379    (None, None, "new_empty"),
380    (None, None, "empty_like"),
381    (None, None, "empty"),
382    # AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default.
383    (None, None, "item"),
384    # It's the only in-place op without an out-of-place equivalent in the Python API
385    # Its OpInfo wrongly registers it as `torch.zero_(x.clone())`.
386    (None, None, "zero_"),
387    # No idea what's going on here
388    # In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), [])
389    # in the test, but it seems to pass when tested locally and in the logsumexp test
390    (None, torch.float32, "masked.logsumexp"),
391    (None, torch.float64, "masked.logsumexp"),
392    # exp_vml_cpu not implemented for Half
393    (torch.cpu, torch.float16, "signal.windows.exponential"),
394    (torch.cpu, torch.float16, "signal.windows.gaussian"),
395    # sin_vml_cpu not implemented for Half
396    (torch.cpu, torch.float16, "signal.windows.cosine"),
397    # CompositeAutogradImplicit
398    # See https://github.com/pytorch/pytorch/issues/81669
399    (None, None, "nn.functional.relu6"),
400    # This decomp runs before autograd.
401    (None, None, "nn.functional.rrelu"),
402    (None, None, "meshgrid"),
403    # Decomposition registered as Autograd
404    (None, None, "nn.functional.hardshrink"),
405    (None, None, "nn.functional.softshrink"),
406    # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit)
407    (None, None, "diag"),
408    # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32
409    ("cpu", torch.bfloat16, "_softmax_backward_data"),
410    (None, None, "norm"),
411    # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise)
412    (None, None, "native_batch_norm"),
413    (None, None, "_upsample_bilinear2d_aa"),
414    (None, None, "empty_strided"),  # aten.empty_strided was not decomposed
415}
416
417CROSS_REF_BACKWARD_EXCLUDE_SET = {
418    # Decomposed backward formula is not as precise
419    ("cpu", torch.bfloat16, "nn.functional.hardswish"),
420    ("cuda", torch.float16, "nn.functional.cross_entropy"),
421}
422
423all_decomposed = set()
424all_called = defaultdict(int)
425
426# Helpful snippet for testing coverage
427"""
428import atexit
429def check_coverage():
430    print("missing coverage:")
431    print("\n".join(map(str, decomposition_table.keys() - all_decomposed)))
432atexit.register(check_coverage)
433"""
434
435# Helpful snippet for Horace to create his google sheet :)
436"""
437import atexit
438def dump_ops():
439    with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g:
440        for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__):
441            f.write(f'{op.__name__}\n')
442            g.write(f'{count}\n')
443    with open('run_decompositions.txt', 'w') as f:
444        for op in sorted([i.__name__ for i in all_decomposed]):
445            f.write(f'{op}\n')
446
447atexit.register(dump_ops)
448"""
449
450
451def any_unsupported(args, kwargs):
452    def test_unsupported(t):
453        if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
454            # These are all things that we haven't coded decompositions
455            # to handle correctly.  Maybe they should.
456            return any(
457                [
458                    t.is_sparse_csr,
459                    t.is_sparse,
460                    t.is_mkldnn,
461                    t.is_quantized,
462                    t.is_nested,
463                    torch._is_functional_tensor(t),
464                ]
465            )
466        elif torch.overrides.is_tensor_like(t):
467            # Decompositions will generally change the behavior of Tensor-like
468            # subclasses, so bypass tests in this case too
469            return True
470        else:
471            return False
472
473    flat_args = pytree.arg_tree_leaves(*args, **kwargs)
474    return any(test_unsupported(x) for x in flat_args)
475
476
477core_backward_failures = {
478    skip("_softmax_backward_data"),  # slow: fails with --timeout=360 secs
479    xfail("addcdiv"),
480    skip("addcmul"),  # slow: fails with --timeout=360 secs
481    skip("deg2rad"),  # slow: fails with --timeout=360 secs
482    skip("diag_embed"),  # slow: fails with --timeout=360 secs
483    skip("frac"),  # slow: fails with --timeout=360 secs
484    skip("grid_sampler_2d"),  # slow: fails with --timeout=360 secs
485    xfail("lerp"),
486    skip("logaddexp"),  # slow: fails with --timeout=360 secs
487    skip("native_dropout_backward"),  # slow: fails with --timeout=360 secs
488    xfail("nn.functional.binary_cross_entropy_with_logits"),
489    skip("nn.functional.glu"),  # slow: fails with --timeout=360 secs
490    xfail("nn.functional.hardshrink"),
491    xfail("nn.functional.softshrink"),
492    skip("nn.functional.unfold"),  # slow: fails with --timeout=360 secs
493    xfail("norm"),
494    xfail("norm", "fro"),
495    xfail("norm", "inf"),
496    xfail("norm", "nuc"),
497    skip("rad2deg"),  # slow: fails with --timeout=360 secs
498    skip("renorm"),  # slow: fails with --timeout=360 secs
499    skip("rot90"),  # slow: fails with --timeout=360 secs
500    skip("rsub"),  # slow: fails with --timeout=360 secs
501    skip("sgn"),  # slow: fails with --timeout=360 secs
502    skip("special.xlog1py"),  # slow: fails with --timeout=360 secs
503    xfail("stack"),
504    skip("tril"),  # slow: fails with --timeout=360 secs
505    skip("triu"),  # slow: fails with --timeout=360 secs
506    skip("unfold_copy"),  # slow: fails with --timeout=360 secs
507    skip("xlogy"),  # slow: fails with --timeout=360 secs
508    xfail("zero_"),
509}
510if not TEST_WITH_SLOW:
511    core_backward_failures.update(
512        {
513            skip("addr"),  # slow: takes 46 sec on A100
514            skip("baddbmm"),  # slow: takes 800+ sec on A100
515            skip("clamp_min"),  # slow: takes 800 sec on A100
516            skip("clamp_max"),  # slow: takes 800 sec on A100
517            skip("logit"),  # slow: takes 44 sec on A100
518            skip("nn.functional.hardswish"),  # slow: takes 60 sec on A100
519            skip("std_mean"),  # slow: takes 170 sec on A100
520            skip("split", variant_name="list_args"),  # slow: takes 118 sec on A100
521            skip("transpose"),  # slow: takes 50 sec on A100
522            skip("unbind"),  # slow: takes 70 sec on A100
523            skip("unsafe_split"),  # slow: takes 49 sec on A100
524        }
525    )
526
527comprehensive_failures = {
528    xfail(
529        "nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)
530    ),  # off by one error
531    xfail(
532        "nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)
533    ),  # off by one error
534    xfail(
535        "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
536    ),  # off by one error
537}
538
539
540@unMarkDynamoStrictTest
541class TestDecomp(TestCase):
542    longMessage = True
543
544    # NB: This actually overlaps with test_comprehensive, but it only
545    # runs on things that are definitely decomposed so it's a lot faster
546    # to run
547    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
548    @onlyNativeDeviceTypes
549    @skipIfCrossRef
550    @suppress_warnings
551    @ops(_decomp_test_ops)
552    def test_quick(self, device, dtype, op):
553        self.do_cross_ref(device, dtype, op, run_all=False)
554
555    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
556    @skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures)
557    @onlyNativeDeviceTypes
558    @skipIfCrossRef
559    @suppress_warnings
560    @ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,))
561    def test_quick_core_backward(self, device, dtype, op):
562        for sample_input in op.sample_inputs(device, dtype, requires_grad=True):
563            aten_name = op.decomp_aten_name or op.aten_name
564            args = [sample_input.input] + list(sample_input.args)
565            kwargs = sample_input.kwargs
566            func = partial(op.get_op(), **kwargs)
567            with self.DecompCrossRefMode(
568                self, self.precision, self.rel_tol, dtype, run_all=False
569            ) as mode, enable_python_dispatcher():
570                torch.autograd.gradcheck(func, args)
571            self.check_decomposed(aten_name, mode)
572
573    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
574    @onlyNativeDeviceTypes
575    @skipIfCrossRef
576    @skipOps("TestDecomp", "test_comprehensive", comprehensive_failures)
577    @suppress_warnings
578    @ops(op_db)
579    def test_comprehensive(self, device, dtype, op):
580        self.do_cross_ref(device, dtype, op, run_all=True)
581
582    def test_uniform(self, device):
583        size = (2, 3, 4, 5)
584        dtype = torch.float32
585        x = make_tensor(size, dtype=dtype, device=device)
586        low = 0.3
587        high = 0.9
588
589        torch.manual_seed(123)
590        ref = torch.ops.aten.uniform(x, low, high)
591        torch.manual_seed(123)
592        res = torch._decomp.decompositions.uniform(x, low=low, high=high)
593        self.assertEqual(ref, res)
594
595    def test_broadcasting_index_copy(self, device):
596        x = torch.zeros([1, 10], device=device)
597        xs = torch.ones([2, 10], device=device)
598
599        def index_copy(xs, x):
600            torch._decomp.decompositions.index_copy_(
601                xs, 0, torch.tensor(0).to(device), x
602            )
603
604        index_copy(xs, x)
605
606        xs_two = torch.ones([2, 10], device=device)
607        xs_two[0] = x
608
609        self.assertEqual(xs, xs_two)
610
611    def test_cat_single_input(self, device):
612        decomp_table = torch._inductor.decomposition.select_decomp_table()
613        cat_inductor = decomp_table[torch.ops.aten.cat.default]
614
615        inp = torch.rand([2048, 2048], device=device)
616        inps = [inp for _ in range(10)]
617
618        for dim in (-1, 0, 1):
619            self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))
620
621    def test_rrelu_with_noise(self, device):
622        # rrelu_with_noise behavior depends on a) whether elements in the input
623        # are <= 0, and b) whether we're in training mode. Cover all cases:
624        dtype = torch.float64
625        x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device)
626        lower = 1.0
627        upper = 4.0
628        training = False
629
630        torch.manual_seed(123)
631        noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
632        ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
633
634        torch.manual_seed(123)
635        noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
636        res = torch._decomp.decompositions.rrelu_with_noise(
637            x,
638            noise_res,
639            lower,
640            upper,
641            training,
642        )
643        self.assertEqual(ref, res)
644        self.assertEqual(noise_ref, noise_res)
645
646        # Now with training=True:
647        training = True
648
649        torch.manual_seed(123)
650        noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
651        ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)
652
653        torch.manual_seed(123)
654        noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
655        res = torch._decomp.decompositions.rrelu_with_noise(
656            x,
657            noise_res,
658            lower,
659            upper,
660            training,
661        )
662        self.assertEqual(ref, res)
663        self.assertEqual(noise_ref, noise_res)
664
665    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
666    @suppress_warnings
667    @tf32_off()
668    # only tests RNNs since we have py dispsatcher decomps for them
669    @modules(
670        filter(
671            lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
672            module_db,
673        )
674    )
675    def test_rnn_decomp_module(self, device, dtype, module_info, training):
676        module_cls = module_info.module_cls
677        module_inputs = module_info.module_inputs_func(
678            module_info,
679            device=device,
680            dtype=dtype,
681            requires_grad=True,
682            training=training,
683        )
684        for module_input in module_inputs:
685            if module_input.forward_input is None:
686                continue
687            args, kwargs = (
688                module_input.constructor_input.args,
689                module_input.constructor_input.kwargs,
690            )
691            m = module_cls(*args, **kwargs)
692            m.to(device).to(dtype)
693
694            args, kwargs = (
695                module_input.forward_input.args,
696                module_input.forward_input.kwargs,
697            )
698            with self.DecompCrossRefMode(
699                self, self.precision, self.rel_tol, dtype, run_all=True
700            ), enable_python_dispatcher():
701                decomp_out = m(*args, **kwargs)
702
703            non_decomp_out = m(*args, **kwargs)
704            # without this check, incorrect decomps at the python dispatcher level can still pass because
705            # they're checking aten decomps at the torch_dispatch level
706            self.assertEqual(decomp_out, non_decomp_out)
707
708    def test_batch_norm_unflatten_weight_bias(self, device):
709        # https://github.com/pytorch/pytorch/issues/100970
710        shape = (1, 3, 2, 2)
711        input = torch.randn(shape, device=device)
712        weight = torch.randn((3, 1, 1, 1), device=device)
713        bias = torch.randn(3, device=device)
714        mean = torch.randn(3, device=device)
715        var = torch.randn(3, device=device)
716        res = torch._decomp.decompositions.native_batch_norm(
717            input, weight, bias, mean, var, False, 1, 1e-05
718        )
719        self.assertEqual(shape, res[0].shape)
720
721    def test_arange_graph(self, device):
722        from torch.fx.experimental.proxy_tensor import make_fx
723
724        def func(x, start):
725            le = x.shape[-1]
726            if start is None:
727                a = torch.arange(le, dtype=torch.float32, device=x.device)
728            else:
729                a = torch.arange(start, le, dtype=torch.float32, device=x.device)
730            return a
731
732        pattern = r", device = device\(.+\), requires_grad = False"
733
734        cfunc = make_fx(func, decomposition_table=decomposition_table)
735        fx_g = cfunc(torch.rand(10, device=device), None)
736        fx_g_code = fx_g.code.strip()
737        # Remove device and requires_grad
738        fx_g_code = re.sub(pattern, "", fx_g_code)
739        self.assertExpectedInline(
740            fx_g_code,
741            """\
742def forward(self, x_1, start_1):
743    iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64)
744    mul = torch.ops.prims.mul.default(iota, 1);  iota = None
745    add = torch.ops.prims.add.default(mul, 0);  mul = None
746    convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None
747    return convert_element_type""",
748        )
749
750        fx_g = cfunc(torch.rand(10, device=device), 1)
751        fx_g_code = fx_g.code.strip()
752        # Remove device and requires_grad
753        fx_g_code = re.sub(pattern, "", fx_g_code)
754        self.assertExpectedInline(
755            fx_g_code,
756            """\
757def forward(self, x_1, start_1):
758    iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64)
759    mul = torch.ops.prims.mul.default(iota, 1);  iota = None
760    add = torch.ops.prims.add.default(mul, 1);  mul = None
761    convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None
762    return convert_element_type""",
763        )
764
765    def test_masked_fill(self, device):
766        from torch.fx.experimental.proxy_tensor import make_fx
767
768        if torch.device(device).type not in [
769            "xpu",
770            "cuda",
771            torch._C._get_privateuse1_backend_name(),
772        ]:
773            self.skipTest("only runs on XPU and CUDA and PrivateUse1.")
774
775        def func(scores, mask, value):
776            return scores.masked_fill(mask, value)
777
778        scores_t = torch.tensor([1, 2, 3, 4], device=device)
779        mask_t = torch.tensor([True, True, True, True], device=device)
780        value_t = torch.tensor(0, dtype=scores_t.dtype)
781        cfunc = make_fx(func, decomposition_table=decomposition_table)
782        fx_g = cfunc(scores_t, mask_t, value_t)
783        self.assertExpectedInline(
784            fx_g.code.strip(),
785            """\
786def forward(self, scores_1, mask_1, value_1):
787    where = torch.ops.prims.where.default(mask_1, value_1, scores_1);  mask_1 = value_1 = scores_1 = None
788    return where""",
789        )
790
791    class DecompCrossRefMode(TorchDispatchMode):
792        def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):
793            self.test_case = test_case
794            self.saved_precision = saved_precision
795            self.saved_rel_tol = saved_rel_tol
796            self.test_dtype = dtype
797            self.run_all = run_all
798
799            # We check the correctness of each decomposition right after running it.
800            # So, when we encounter a decomposition, we run the function normally, and
801            # then run the decomposition, and ensure they're identical.
802            self.called = set()
803            self.decomposed = set()
804
805        def __torch_dispatch__(self, func, types, args=(), kwargs=None):
806            self.test_case.precision = self.saved_precision
807            self.test_case.rel_tol = self.saved_rel_tol
808
809            self.called.add(func)
810            all_called[func] += 1
811
812            # Stuff we shouldn't bother testing
813            # (TODO: remove detach from the decomp table?)
814            # N.b. Testing in-place ops would need dedicated logic
815            in_place = func.name()[-1] == "_"
816            ignored_ops = [
817                torch.ops.aten.detach.default,
818                # non-deterministic ops
819                torch.ops.aten.empty.memory_format,
820                torch.ops.aten.empty_like.default,
821                torch.ops.aten.new_empty.default,
822                torch.ops.aten.empty_strided.default,
823                torch.ops.aten.new_empty_strided.default,
824                torch.ops.aten.randn.default,
825                torch.ops.aten.native_dropout.default,
826            ]
827            if (
828                func not in decomposition_table
829                or func in ignored_ops
830                or torch.Tag.nondeterministic_seeded in func.tags
831                or any_unsupported(args, kwargs)
832                or in_place
833            ):
834                return func(*args, **kwargs)
835
836            self.decomposed.add(func)
837            all_decomposed.add(func)
838
839            # We take 2 main strategies for verifying correctness/numerical stability of decompositions
840            # The first one is simply tolerance checking between decomp_out and pytorch_out
841            # However, for fp16/bf16 and reductions, this becomes very
842            # finicky, as there are not many guarantees we can make.
843            # So, for fp16/bf16, we instead compare the difference of
844            # {decomp_out, pytorch_out_64} and {pytorch_out,
845            # pytorch_out_64}. In other words, we compare how far the
846            # decomposition and pytorch are from the "ground truth" (i.e.
847            # fp64). If the decomposition results in more error, we error
848
849            # We also decompose the decomposition recursively for
850            # further coverage, as some paths not be exercised directly by
851            # OpInfos (sadly) but just by other ops
852
853            decomposition = decomposition_table[func]
854
855            do_relative_check = self.test_dtype in [torch.float16, torch.bfloat16]
856            if self.run_all:
857                # Execute recursively via DFS, to find the root of a possible error first
858                with self:
859                    decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs))
860            else:
861                decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs))
862
863            # At this stage we should not be decomposing an in-place op
864            # We'd like to have decompositions that decompose out-of-place ops into out-of-place ops
865            #  because decompositions are run after functionalisation and we would not like them to
866            #  de-functionalise the graph, as that would break AoTAutograd
867            # We run the real function *after* the decomposition to make sure that the
868            # decomposition does not modify any of the inputs in-place. If it does
869            # real_out should be differen than decom_out so we should catch this
870            real_out_unflat = func(*args, **kwargs)
871            real_out = pytree.tree_leaves(real_out_unflat)
872
873            assert len(real_out) == len(decomp_out)
874
875            if do_relative_check:
876                upcast = partial(upcast_tensor, dtype=torch.float64)
877                real_out_double, _ = tree_flatten(
878                    func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
879                )
880                for i, (orig, decomp, ref) in enumerate(
881                    zip(real_out, decomp_out, real_out_double)
882                ):
883                    if not isinstance(orig, torch.Tensor):
884                        assert type(orig) == type(decomp)
885                        assert orig == decomp
886                        continue
887                    op_assert_ref(
888                        self.test_case,
889                        func,
890                        self.test_dtype,
891                        i,
892                        orig,
893                        decomp,
894                        ref,
895                        args,
896                        kwargs,
897                    )
898            else:
899                for orig, decomp in zip(real_out, decomp_out):
900                    if not isinstance(orig, torch.Tensor):
901                        assert type(orig) == type(decomp)
902                        assert orig == decomp
903                        continue
904                    op_assert_equal(
905                        self.test_case,
906                        func,
907                        self.test_dtype,
908                        orig,
909                        decomp,
910                        args,
911                        kwargs,
912                    )
913
914            return real_out_unflat
915
916    def check_decomposed(self, aten_name, mode):
917        self.assertTrue(
918            any(overload_to_aten_name(c) == aten_name for c in mode.decomposed),
919            msg=(
920                f"aten.{aten_name} was not decomposed, saw calls for: "
921                f"{', '.join(map(str, list(mode.called)))}. If your op is  "
922                f"CompositeImplicitAutograd you should skip this test "
923                f"by updating CROSS_REF_EXCLUDE_SET."
924            ),
925        )
926
927    @skipIfTorchDynamo("Test does not work with TorchDynamo")
928    def do_cross_ref(self, device, dtype, op, *, run_all):
929        test_keys = [
930            (torch.device(device).type, dtype, op.name),
931            (None, dtype, op.name),
932            (None, None, op.name),
933        ]
934        if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys):
935            self.skipTest(f"{op.name} in {dtype} not supported")
936
937        skip_decomp_vjp = any(
938            key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys
939        )
940
941        requires_grad = (
942            op.supports_autograd
943            and dtype in op.supported_backward_dtypes(torch.device(device).type)
944            # TODO: OpInfo really ought to error out for this case, but it's
945            # not exercised in test_ops_gradients atm.  The problem is not
946            # complex32 per-se (which is supported by data movement only ops)
947            # but that when we do backwards we expect other ops like add to work
948            and not dtype == torch.complex32
949        )
950        samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
951
952        aten_name = op.decomp_aten_name or op.aten_name
953
954        func = op.get_op()
955
956        def run_without_python_dispatcher(mode):
957            return any(
958                isinstance(op, torch._ops.OpOverload)
959                and op.has_kernel_for_dispatch_key(
960                    DispatchKey.CompositeImplicitAutograd
961                )
962                for op in mode.decomposed.union([func])
963            )
964
965        for sample_input in samples:
966            if requires_grad:
967                fn, primals = normalize_op_input_output(func, sample_input)
968                primals = tree_map(
969                    lambda x: x if isinstance(x, torch.Tensor) else x, primals
970                )
971
972                # Once https://github.com/pytorch/pytorch/pull/75965/ I can
973                # store the called list on the mode object instance and no
974                # explicit clearing is necessary as I will create a fresh mode
975                # for each region
976                with self.DecompCrossRefMode(
977                    self, self.precision, self.rel_tol, dtype, run_all
978                ) as mode, enable_python_dispatcher():
979                    decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
980                if run_without_python_dispatcher(mode):
981                    # without this check, incorrect decomps at the python dispatcher level can still pass because
982                    # they're checking aten decomps at the torch_dispatch level.
983                    with self.DecompCrossRefMode(
984                        self, self.precision, self.rel_tol, dtype, run_all
985                    ) as mode:
986                        decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
987                if aten_name in decomposition_names:
988                    self.check_decomposed(aten_name, mode)
989
990                if not skip_decomp_vjp and (
991                    op.aten_backward_name in decomposition_names or run_all
992                ):
993                    cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
994
995                    with self.DecompCrossRefMode(
996                        self, self.precision, self.rel_tol, dtype, run_all
997                    ) as mode, enable_python_dispatcher():
998                        decomp_vjp_fn(cotangents)
999                    if run_without_python_dispatcher(mode):
1000                        # without this check, incorrect decomps at the python dispatcher level can still pass because
1001                        # they're checking aten decomps at the torch_dispatch level.
1002                        with self.DecompCrossRefMode(
1003                            self, self.precision, self.rel_tol, dtype, run_all
1004                        ) as mode:
1005                            decomp_vjp_fn(cotangents)
1006                    if not run_all:
1007                        self.check_decomposed(op.aten_backward_name, mode)
1008
1009            elif aten_name in decomposition_names or run_all:
1010                args = [sample_input.input] + list(sample_input.args)
1011                kwargs = sample_input.kwargs
1012                # A failure here might be because the decomposition for the op is wrong or because a
1013                # decomposition used by the particular op is wrong.
1014                with self.DecompCrossRefMode(
1015                    self, self.precision, self.rel_tol, dtype, run_all
1016                ) as mode, enable_python_dispatcher():
1017                    func(*args, **kwargs)
1018
1019                if run_without_python_dispatcher(mode):
1020                    # without this check, incorrect decomps at the python dispatcher level can still pass because
1021                    # they're checking aten decomps at the torch_dispatch level.
1022                    with self.DecompCrossRefMode(
1023                        self, self.precision, self.rel_tol, dtype, run_all
1024                    ) as mode:
1025                        func(*args, **kwargs)
1026
1027                if not run_all:
1028                    self.check_decomposed(aten_name, mode)
1029            else:
1030                assert op.supports_autograd
1031                self.skipTest(
1032                    "only backwards is decomposed, but dtype doesn't support AD"
1033                )
1034
1035
1036instantiate_device_type_tests(TestDecomp, globals())
1037
1038
1039class DecompOneOffTests(TestCase):
1040    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1041    @onlyNativeDeviceTypes
1042    @skipIfCrossRef
1043    def test_contiguous_softmax(self, device):
1044        size = (2, 4, 3, 3)
1045        stride = (9, 18, 3, 1)
1046        dtype = torch.float32
1047
1048        x = torch.randn(size, dtype=dtype, device=device)
1049        x = torch.as_strided(x, size, stride)
1050
1051        ref = torch.ops.aten._softmax(x, -1, False)
1052        res = torch._decomp.decompositions._softmax(x, -1, False)
1053        self.assertEqual(ref.stride(), res.stride())
1054
1055    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1056    @onlyNativeDeviceTypes
1057    @skipIfCrossRef
1058    def test_contiguous_log_softmax(self, device):
1059        size = (2, 4, 3, 3)
1060        stride = (9, 18, 3, 1)
1061
1062        dtype = torch.float32
1063        x = torch.randn(size, dtype=dtype, device=device)
1064        x = torch.as_strided(x, size, stride)
1065
1066        ref = torch.ops.aten._log_softmax(x, -1, False)
1067        res = torch._decomp.decompositions._log_softmax(x, -1, False)
1068        self.assertEqual(ref.stride(), res.stride())
1069
1070    @onlyCUDA
1071    def test_exponential_non_inf(self, device):
1072        inp = torch.empty((4, 400, 256), device=device)
1073
1074        with torch._dynamo.utils.preserve_rng_state():
1075            exp_ref = inp.exponential_()
1076        exp = torch._refs.exponential(inp)
1077
1078        self.assertEqual(exp, exp_ref)
1079        self.assertFalse(exp.isinf().any())
1080
1081    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1082    @skipIfCrossRef
1083    @onlyCUDA
1084    def test_amp_batch_norm_backward(self):
1085        device = "cuda"
1086        grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
1087        x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
1088        weight = torch.randn((2,), dtype=torch.float32, device=device)
1089        rmean = torch.randn((2,), dtype=torch.float32, device=device)
1090        rvar = torch.randn((2,), dtype=torch.float32, device=device)
1091        mean = torch.randn((0,), dtype=torch.float32, device=device)
1092
1093        ref = torch.ops.aten.native_batch_norm_backward(
1094            grad_out,
1095            x,
1096            weight,
1097            rmean,
1098            rvar,
1099            mean,
1100            mean,
1101            False,
1102            1e-05,
1103            [True, True, True],
1104        )
1105        res = torch._decomp.decompositions.native_batch_norm_backward(
1106            grad_out,
1107            x,
1108            weight,
1109            rmean,
1110            rvar,
1111            mean,
1112            mean,
1113            False,
1114            1e-05,
1115            [True, True, True],
1116        )
1117        for a, b in zip(ref, res):
1118            self.assertEqual(a.stride(), b.stride())
1119            self.assertEqual(a.dtype, b.dtype)
1120
1121    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1122    @onlyNativeDeviceTypes
1123    @skipIfCrossRef
1124    def test_elu_backward(self, device):
1125        size = (2, 4, 3, 3)
1126        dtype = torch.float32
1127        grad_out = torch.randn(size, dtype=dtype, device=device)
1128        out = torch.randn(size, dtype=dtype, device=device)
1129
1130        ref = torch.ops.aten.elu_backward(grad_out, 1.0, 1, 1, True, out)
1131        res = torch._decomp.decompositions.elu_backward(grad_out, 1.0, 1, 1, True, out)
1132        self.assertEqual(ref, res)
1133
1134    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1135    @onlyNativeDeviceTypes
1136    @skipIfCrossRef
1137    def test_threshold_backward_dtype(self, device):
1138        grad = torch.randint(10, (4,), device=device)
1139        input_tensor = torch.randint(10, (4,), device=device)
1140
1141        ref = torch.ops.aten.threshold_backward(grad, input_tensor, 1)
1142        res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1)
1143        self.assertEqual(ref.dtype, res.dtype)
1144
1145    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1146    @onlyNativeDeviceTypes
1147    @skipIfCrossRef
1148    def test_weight_norm_interface(self, device):
1149        g = torch.randn((3, 10, 10), device=device)
1150        v = torch.randn((1, 1, 10), device=device)
1151
1152        ref = torch.ops.aten._weight_norm_interface(g, v, 2)
1153        res = torch._decomp.decompositions._weight_norm_interface(g, v, 2)
1154        self.assertTrue(torch.allclose(ref[0], res[0]))
1155        self.assertTrue(torch.allclose(ref[1], res[1]))
1156
1157        inp = torch.rand([30, 10], device=device)
1158        inp2 = torch.rand([30, 1], device=device)
1159
1160        self.assertEqual(
1161            torch.ops.aten._weight_norm_interface(inp, inp2),
1162            torch._decomp.decompositions._weight_norm_interface(inp, inp2),
1163        )
1164
1165    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
1166    @onlyCPU
1167    @skipIfCrossRef
1168    @skipOps(
1169        "DecompOneOffTests",
1170        "test_sdpa",
1171        [
1172            xfail(
1173                "nn.functional.scaled_dot_product_attention",
1174                dtypes=[torch.half],
1175            ),
1176        ],
1177    )
1178    @ops(_sdpa_op_info)
1179    def test_sdpa(self, device, dtype, op):
1180        # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we
1181        # add support for float16 over there we should update this test as well.
1182
1183        class ScaledDotProductAttention(torch.nn.Module):
1184            def __init__(self) -> None:
1185                super().__init__()
1186
1187            def forward(
1188                self, query_layer, key_layer, value_layer, mask=None, is_causal=True
1189            ):
1190                attn_output = op(
1191                    query_layer,
1192                    key_layer,
1193                    value_layer,
1194                    attn_mask=mask,
1195                    dropout_p=0.0,
1196                    is_causal=is_causal,
1197                )
1198                return attn_output
1199
1200        query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1201        key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1202        value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
1203        masks = [None, torch.ones((1, 1, 100, 100), device=device, dtype=torch.bool)]
1204
1205        atol, rtol = dtype_precisions[dtype]
1206
1207        for mask in masks:
1208            is_causal = mask is None
1209            attention = ScaledDotProductAttention()
1210            decomposed_res = (
1211                torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu(
1212                    query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask
1213                )
1214            )
1215            eager_res = op(
1216                query_layer,
1217                key_layer,
1218                value_layer,
1219                attn_mask=mask,
1220                dropout_p=0.0,
1221                is_causal=is_causal,
1222            )
1223
1224            self.assertTrue(
1225                torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol)
1226            )
1227
1228
1229instantiate_device_type_tests(DecompOneOffTests, globals())
1230
1231
1232class HasDecompTest(TestCase):
1233    def setUp(self):
1234        super().setUp()
1235        self.maxDiff = None
1236
1237    @staticmethod
1238    def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool:
1239        has_tensor_arg = any(
1240            "Tensor" in str(a.type)
1241            for a in itertools.chain(op._schema.arguments, op._schema.returns)
1242        )
1243        if not has_tensor_arg:
1244            return False
1245
1246        try:
1247            # CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions
1248            return not op.has_kernel_for_dispatch_key(
1249                DispatchKey.CompositeImplicitAutograd
1250            )
1251        except RuntimeError as e:
1252            # has_key fails for some jit-registered ops, which shouldn't be
1253            # relevant here anyway
1254            if "does not exist" in str(e):
1255                return False
1256            raise
1257
1258    def test_has_decomposition(self):
1259        def all_aten_overloads():
1260            for name in torch._C._dispatch_get_all_op_names():
1261                if not name.startswith("aten::"):
1262                    continue
1263
1264                name = name[6:]
1265                if "." in name:
1266                    packet_name, overload_name = name.split(".")
1267                else:
1268                    packet_name, overload_name = name, "default"
1269
1270                packet = getattr(aten, packet_name)
1271                assert isinstance(packet, torch._ops.OpOverloadPacket)
1272                op = getattr(packet, overload_name)
1273                yield op
1274
1275        # This is for operators that are only registered in some CI
1276        # configurations, so would cause the test to fail
1277        allow_list = {aten.get_gradients.default}
1278
1279        overloads_wanting_decomp = {
1280            op for op in all_aten_overloads() if self._can_appear_in_trace(op)
1281        }
1282        ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys()
1283        ops_missing_decomp -= allow_list
1284        self.assertExpected(
1285            "".join(sorted(op.name() + "\n" for op in ops_missing_decomp))
1286        )
1287
1288    def test_aten_core_operators(self):
1289        # If a decomposition isn't included in the core decompositions,
1290        # then it must decompose a core ATen operator.
1291        #
1292        # See NOTE [Core ATen Ops]
1293        #
1294        # If this test fails then either:
1295        # - Add the decomposition to torch._decomp.core_aten_decompositions,
1296        #   if decomposition should be used by inductor (not a core operator).
1297        # - Run this test again with EXPECTTEST_ACCEPT=1 to update the list of
1298        #   core ATen operators (and inductor will not use the decomposition).
1299
1300        # Some decompositions are registered for CompositeImplicitAutograd
1301        # operators, which never appear in AOTAutograd's graph so are never used.
1302        useful_decomps = {
1303            op
1304            for op in decomposition_table.keys()
1305            if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op)
1306        }
1307        core_decomps = torch._decomp.core_aten_decompositions().keys()
1308        core_aten_ops = useful_decomps - core_decomps
1309        self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops)))
1310
1311
1312if __name__ == "__main__":
1313    run_tests()
1314