xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/opinfo/definitions/_masked.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import unittest
4from collections.abc import Sequence
5from functools import partial
6from typing import List
7
8import numpy as np
9
10import torch
11from torch.testing import make_tensor
12from torch.testing._internal.common_device_type import tol, toleranceOverride
13from torch.testing._internal.common_dtype import (
14    all_types_and,
15    all_types_and_complex_and,
16    complex_types,
17    floating_and_complex_types_and,
18    floating_types_and,
19    integral_types,
20)
21from torch.testing._internal.opinfo.core import (
22    DecorateInfo,
23    gradcheck_wrapper_masked_operation,
24    gradcheck_wrapper_masked_pointwise_operation,
25    M,
26    OpInfo,
27    ReductionOpInfo,
28    S,
29    sample_inputs_reduction,
30    SampleInput,
31)
32from torch.testing._internal.opinfo.utils import prod_numpy, reference_reduction_numpy
33
34
35# Used for log_softmax, softmax, softmin
36def sample_inputs_softmax_variant(
37    op_info,
38    device,
39    dtype,
40    requires_grad,
41    with_dtype=False,
42    use_zero_dimensions=True,
43    **kwargs,
44):
45    make_arg = partial(
46        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
47    )
48    cases = [
49        ((S,), (0,)),
50        ((S, S), (0,)),
51        ((S, S), (1,)),
52        ((S, S), (-1,)),
53        ((S, M, S), (2,)),
54        *([((S, 0, 0), (-1,))] if use_zero_dimensions else []),
55    ]
56    kwargs = dict(dtype=torch.float64) if with_dtype else None
57
58    # PyTorch on XLA throws an error when passed with dim argument for 0d tensor.
59    # See https://github.com/pytorch/xla/issues/3061 for more details.
60    if torch.device(device).type != "xla":
61        cases.append(((), (0,)))
62
63    return (
64        SampleInput(make_arg(shape), args=dim, kwargs=kwargs) for shape, dim in cases
65    )
66
67
68def _generate_masked_op_mask(input_shape, device, **kwargs):
69    make_arg = partial(
70        make_tensor, dtype=torch.bool, device=device, requires_grad=False
71    )
72    yield None
73    yield make_arg(input_shape)
74    if len(input_shape) > 2:
75        # broadcast last mask dimension:
76        yield make_arg(input_shape[:-1] + (1,))
77        # broadcast middle mask dimension:
78        yield make_arg(input_shape[:1] + (1,) + input_shape[2:])
79        # broadcast first mask dimension:
80        yield make_arg((1,) + input_shape[1:])
81        # mask.ndim < input.ndim
82        yield make_arg(input_shape[1:])
83        # mask.ndim == 1
84        yield make_arg(input_shape[-1:])
85        # masks that require broadcasting of inputs (mask.ndim >
86        # input.ndim) will not be supported, however, we may
87        # reconsider this if there will be demand on this kind of
88        # degenerate cases.
89
90
91def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwargs):
92    """Sample inputs for masked reduction operators.
93
94    Masked reduction operator is a reduction operator with trailing
95    mask optional argument. A mask is a bool tensor with the same
96    shape as input or a shape that is broadcastable to input shape.
97    """
98    kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
99
100    for sample_input in sample_inputs_reduction(
101        op_info, device, dtype, requires_grad, **kwargs
102    ):
103        for mask in _generate_masked_op_mask(
104            sample_input.input.shape, device, **kwargs
105        ):
106            sample_input_args, sample_input_kwargs = sample_input.args, dict(
107                mask=mask, **sample_input.kwargs
108            )
109            yield SampleInput(
110                sample_input.input.detach().requires_grad_(requires_grad),
111                args=sample_input_args,
112                kwargs=sample_input_kwargs,
113            )
114            if (
115                not requires_grad
116                and dtype.is_floating_point
117                and sample_input.input.ndim == 2
118                and mask is not None
119                and mask.shape == sample_input.input.shape
120            ):
121                for v in [torch.inf, -torch.inf, torch.nan]:
122                    t = sample_input.input.detach()
123                    t.diagonal(0, -2, -1).fill_(v)
124                    yield SampleInput(
125                        t.requires_grad_(requires_grad),
126                        args=sample_input_args,
127                        kwargs=sample_input_kwargs,
128                    )
129
130
131def sample_inputs_sparse_coo_masked_reduction(
132    op_info, device, dtype, requires_grad, **kwargs
133):
134    """Sample inputs for masked reduction operators that support inputs
135    with sparse coo layouts.
136    """
137    if op_info.supports_sparse:
138        op_name = op_info.name.replace("masked.", "")
139        for sample_input in sample_inputs_masked_reduction(
140            op_info, device, dtype, requires_grad, **kwargs
141        ):
142            mask = sample_input.kwargs.get("mask")
143            if mask is not None:
144                sample_input_kwargs = sample_input.kwargs.copy()
145                sample_input_kwargs.update(mask=mask.to_sparse())
146                yield SampleInput(
147                    sample_input.input.to_sparse(),
148                    args=sample_input.args,
149                    kwargs=sample_input_kwargs,
150                )
151            else:
152                if op_name in {"prod", "amax", "amin"}:
153                    # FIXME: for now reductions with non-zero reduction identity and
154                    # unspecified mask are not supported for sparse COO
155                    # tensors, see torch.masked.prod implementation
156                    # for details.
157                    continue
158                yield SampleInput(
159                    sample_input.input.to_sparse(),
160                    args=sample_input.args,
161                    kwargs=sample_input.kwargs,
162                )
163
164
165def sample_inputs_sparse_csr_masked_reduction(
166    op_info, device, dtype, requires_grad, **kwargs
167):
168    """Sample inputs for masked reduction operators that support inputs
169    with sparse csr layouts.
170    """
171    if op_info.supports_sparse_csr:
172        op_name = op_info.name.replace("masked.", "")
173        for sample_input in sample_inputs_masked_reduction(
174            op_info, device, dtype, requires_grad, **kwargs
175        ):
176            if not (
177                sample_input.input.ndim == 2 and sample_input.kwargs.get("keepdim")
178            ):
179                # - sparse CSR tensors are always 2-D tensors
180                # - masked reduction on CSR tensors are defined only if keepdim is True.
181                continue
182            mask = sample_input.kwargs.get("mask")
183            if mask is not None:
184                sample_input_kwargs = sample_input.kwargs.copy()
185                sample_input_kwargs.update(mask=mask.to_sparse_csr())
186                new_sample = SampleInput(
187                    sample_input.input.to_sparse_csr(),
188                    args=sample_input.args,
189                    kwargs=sample_input_kwargs,
190                )
191            else:
192                if op_name in ["prod", "amax", "amin", "mean"]:
193                    # reductions with non-zero reduction identity and
194                    # unspecified mask is not supported for sparse CSR
195                    # tensors, see torch.masked.prod implementation
196                    # for details.
197                    continue
198                new_sample = SampleInput(
199                    sample_input.input.to_sparse_csr(),
200                    args=sample_input.args,
201                    kwargs=sample_input.kwargs,
202                )
203            yield new_sample
204            if sample_input.kwargs["dim"] == 0:
205                # Reductions of CSR tensors use different implementations for
206                # inner and/or outer dimensions. So, as a minimum of testing CSR
207                # implementations the following kwargs must be generated:
208                #   dict(dim=0, keepdim=True)
209                #   dict(dim=1, keepdim=True)
210                #   dict(dim=(0, 1), keepdim=True)
211                # Here we generate the dim=1 case from the dim=0 case.
212                sample_input_kwargs = new_sample.kwargs.copy()
213                sample_input_kwargs.update(dim=1)
214                yield SampleInput(
215                    new_sample.input.clone(),
216                    args=sample_input.args,
217                    kwargs=sample_input_kwargs,
218                )
219
220
221def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
222    """Sample inputs for masked norm."""
223    for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
224        for sample_input in sample_inputs_masked_reduction(
225            op_info, device, dtype, requires_grad, **kwargs
226        ):
227            sample_input_args, sample_input_kwargs = (
228                ord,
229            ) + sample_input.args, sample_input.kwargs.copy()
230            yield SampleInput(
231                sample_input.input.clone().requires_grad_(requires_grad),
232                args=sample_input_args,
233                kwargs=sample_input_kwargs,
234            )
235
236
237def reference_masked_std_var(
238    numpy_fn,
239):
240    ref = reference_reduction_numpy(numpy_fn)
241
242    # Translate unbiased or correction arguments into ddof
243    def func(
244        input,
245        dim=None,
246        unbiased=None,
247        *,
248        correction=None,
249        **kwargs,
250    ):
251        ddof = 1
252        if unbiased is not None:
253            ddof = 1 if unbiased else 0
254        if correction is not None:
255            ddof = correction
256
257        if isinstance(dim, Sequence):
258            dim = tuple(dim)
259
260        return ref(input, dim, ddof=ddof, **kwargs)
261
262    return func
263
264
265def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs):
266    """Sample inputs for masked std/var."""
267    kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims
268    from torch.testing._internal.common_methods_invocations import sample_inputs_std_var
269
270    def masked_samples():
271        for sample_input in sample_inputs_std_var(
272            op_info, device, dtype, requires_grad, **kwargs
273        ):
274            if len(sample_input.args) and isinstance(sample_input.args[0], bool):
275                continue  # masked.{std, var} doesn't support `.var(unbiased)`
276
277            for mask in _generate_masked_op_mask(
278                sample_input.input.shape, device, **kwargs
279            ):
280                sample_input_args, sample_input_kwargs = sample_input.args, dict(
281                    mask=mask, **sample_input.kwargs
282                )
283                yield SampleInput(
284                    sample_input.input.detach().requires_grad_(requires_grad),
285                    args=sample_input_args,
286                    kwargs=sample_input_kwargs,
287                )
288                if (
289                    not requires_grad
290                    and dtype.is_floating_point
291                    and sample_input.input.ndim == 2
292                    and mask is not None
293                    and mask.shape == sample_input.input.shape
294                ):
295                    for v in [torch.inf, -torch.inf, torch.nan]:
296                        t = sample_input.input.detach()
297                        t.diagonal(0, -2, -1).fill_(v)
298                        yield SampleInput(
299                            t.requires_grad_(requires_grad),
300                            args=sample_input_args,
301                            kwargs=sample_input_kwargs,
302                        )
303
304    for sample_input in masked_samples():
305        correction = sample_input.kwargs.get("correction")
306        if correction is None:
307            correction = int(sample_input.kwargs.get("unbiased", True))
308
309        dim = sample_input.kwargs.get("dim", None)
310
311        if sample_input.kwargs.get("mask") is None:
312            orig_count = torch.masked.sum(
313                torch.ones(sample_input.input.shape, dtype=torch.int64),
314                dim,
315                keepdim=True,
316            )
317        else:
318            inmask = torch.masked._input_mask(
319                sample_input.input, *sample_input.args, **sample_input.kwargs
320            )
321            orig_count = torch.masked.sum(
322                inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
323                dim,
324                keepdim=True,
325                mask=inmask,
326            )
327        if orig_count.min() <= correction + 1:
328            # Skip samples that lead to nans in var computation
329            continue
330
331        yield sample_input
332
333
334def sample_inputs_masked_softmax(
335    op_info, device, dtype, requires_grad, with_dtype=False, **kwargs
336):
337    """Sample inputs for masked softmax, log_softmax, and softmin.
338
339    Masked normalization operator is a reduction operator with
340    trailing mask optional argument. A mask is a bool tensor with the
341    same shape as input or a shape that is broadcastable to input
342    shape.
343    """
344    for sample_input in sample_inputs_softmax_variant(
345        op_info, device, dtype, requires_grad, with_dtype=with_dtype, **kwargs
346    ):
347        for mask in _generate_masked_op_mask(
348            sample_input.input.shape, device, **kwargs
349        ):
350            yield SampleInput(
351                sample_input.input.clone().requires_grad_(requires_grad),
352                *sample_input.args,
353                mask=mask,
354                **sample_input.kwargs,
355            )
356
357
358def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs):
359    """Sample inputs for masked cumsum and cumprod."""
360    inputs: List[SampleInput] = []
361    for sample_input in sample_inputs_softmax_variant(
362        op_info, device, dtype, requires_grad, **kwargs
363    ):
364        for mask in _generate_masked_op_mask(
365            sample_input.input.shape, device, **kwargs
366        ):
367            if type(mask) != torch.Tensor:
368                continue
369            sample_input_args, sample_input_kwargs = sample_input.args, dict(
370                mask=mask, **sample_input.kwargs
371            )
372            if "keepdim" in sample_input_kwargs:
373                sample_input_kwargs.pop("keepdim")
374            # dimension is required
375            if sample_input_args:
376                dim = sample_input.args[0]
377            else:
378                if "dim" not in sample_input_kwargs:
379                    continue
380                dim = sample_input_kwargs.pop("dim")
381                sample_input_args = (dim,)
382            yield SampleInput(
383                sample_input.input.clone().requires_grad_(requires_grad),
384                *sample_input_args,
385                **sample_input_kwargs,
386            )
387
388
389def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
390    """Sample inputs for masked logaddexp."""
391    shapes = [(S,), (S, S), (S, M, S)]
392    input_mask_lists = [
393        list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
394    ]
395    other_mask_lists = [
396        list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
397    ]
398
399    make_arg = partial(
400        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad
401    )
402    for shape, input_masks, other_masks in zip(
403        shapes, input_mask_lists, other_mask_lists
404    ):
405        for input_mask, other_mask in zip(input_masks, other_masks):
406            yield SampleInput(
407                make_arg(shape),
408                make_arg(shape),
409                input_mask=input_mask,
410                other_mask=other_mask,
411            )
412
413
414def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs):
415    """Sample inputs for masked normalize."""
416    for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
417        for sample_input in sample_inputs_softmax_variant(
418            op_info, device, dtype, requires_grad, use_zero_dimensions=False, **kwargs
419        ):
420            yield SampleInput(
421                sample_input.input.clone().requires_grad_(requires_grad),
422                ord,
423                *sample_input.args,
424                **sample_input.kwargs,
425            )
426
427
428op_db: List[OpInfo] = [
429    ReductionOpInfo(
430        "masked.sum",
431        ref=reference_reduction_numpy(np.sum),
432        method_variant=None,
433        identity=0,
434        nan_policy="propagate",
435        supports_out=False,
436        supports_forward_ad=True,
437        supports_fwgrad_bwgrad=True,
438        supports_sparse=True,
439        supports_sparse_csr=True,
440        promotes_int_to_int64=True,
441        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
442        skips=(
443            DecorateInfo(
444                unittest.skip("Failing on some jobs"),
445                "TestReductions",
446                "test_reference_masked",
447                dtypes=(torch.bool, torch.int8, torch.int16, torch.int32),
448            ),
449            DecorateInfo(
450                unittest.expectedFailure,
451                "TestNormalizeOperators",
452                "test_normalize_operator_exhaustive",
453            ),
454            # FIXME: sum reduces all dimensions when dim=[]
455            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
456            DecorateInfo(
457                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
458            ),
459            # RuntimeError: undefined value tensor
460            DecorateInfo(
461                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
462            ),
463        ),
464        decorators=[
465            DecorateInfo(
466                toleranceOverride(
467                    {
468                        torch.bfloat16: tol(atol=1e-03, rtol=5e-2),
469                        torch.float16: tol(atol=1e-03, rtol=5e-3),
470                    }
471                ),
472                "TestReductions",
473                "test_reference_masked",
474            ),
475            DecorateInfo(
476                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}),
477                "TestReductions",
478                "test_ref_small_input",
479            ),
480            DecorateInfo(
481                toleranceOverride(
482                    {
483                        torch.bfloat16: tol(atol=0.1, rtol=0.1),
484                        torch.float16: tol(atol=5e-3, rtol=5e-3),
485                    }
486                ),
487                "TestMasked",
488                "test_mask_layout",
489            ),
490        ],
491        sample_inputs_func=sample_inputs_masked_reduction,
492        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
493        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
494    ),
495    ReductionOpInfo(
496        "masked.prod",
497        ref=prod_numpy,
498        method_variant=None,
499        identity=1,
500        nan_policy="propagate",
501        # https://github.com/pytorch/pytorch/issues/80411
502        gradcheck_fast_mode=True,
503        supports_out=False,
504        supports_forward_ad=True,
505        supports_fwgrad_bwgrad=True,
506        supports_sparse=True,
507        supports_sparse_csr=True,
508        promotes_int_to_int64=True,
509        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
510        skips=(
511            DecorateInfo(
512                unittest.expectedFailure,
513                "TestNormalizeOperators",
514                "test_normalize_operator_exhaustive",
515            ),
516            DecorateInfo(
517                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
518            ),
519            DecorateInfo(
520                unittest.skip("Failing on some jobs"),
521                "TestReductions",
522                "test_reference_masked",
523                dtypes=(torch.bool, torch.int8, torch.int16, torch.int32),
524            ),
525            DecorateInfo(
526                "TestReductions",
527                "test_ref_small_input",
528                dtypes=(torch.int8, torch.int16, torch.int32),
529            ),
530            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
531            DecorateInfo(
532                unittest.skip("Skipped!"),
533                "TestMasked",
534                "test_mask_layout",
535                device_type="cuda",
536                dtypes=(torch.bool, *integral_types(), *complex_types()),
537            ),
538        ),
539        decorators=[
540            DecorateInfo(
541                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-02)}),
542                "TestReductions",
543                "test_reference_masked",
544            ),
545            DecorateInfo(
546                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
547                "TestReductions",
548                "test_ref_duplicate_values",
549            ),
550            DecorateInfo(
551                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
552                "TestReductions",
553                "test_ref_small_input",
554            ),
555            DecorateInfo(
556                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1.5e-03)}),
557                "TestMasked",
558                "test_mask_layout",
559                device_type="cpu",
560            ),
561            DecorateInfo(
562                toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)}),
563                "TestOperators",
564                "test_jvp",
565                device_type="cuda",
566            ),
567        ],
568        sample_inputs_func=sample_inputs_masked_reduction,
569        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
570        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
571    ),
572    OpInfo(
573        "masked.cumsum",
574        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
575        method_variant=None,
576        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
577        gradcheck_fast_mode=True,
578        supports_out=False,
579        supports_forward_ad=True,
580        supports_fwgrad_bwgrad=True,
581        skips=(
582            DecorateInfo(
583                unittest.expectedFailure,
584                "TestNormalizeOperators",
585                "test_normalize_operator_exhaustive",
586            ),
587            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
588            DecorateInfo(
589                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
590            ),
591        ),
592        # Can reuse the same inputs; dim is required in both
593        sample_inputs_func=sample_inputs_masked_cumops,
594        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
595    ),
596    OpInfo(
597        "masked.cumprod",
598        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
599        method_variant=None,
600        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
601        gradcheck_fast_mode=True,
602        supports_out=False,
603        supports_forward_ad=True,
604        supports_fwgrad_bwgrad=True,
605        skips=(
606            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
607            DecorateInfo(
608                unittest.expectedFailure,
609                "TestNormalizeOperators",
610                "test_normalize_operator_exhaustive",
611            ),
612            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
613            DecorateInfo(
614                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
615            ),
616            DecorateInfo(
617                toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
618                "TestCompositeCompliance",
619                "test_backward",
620                device_type="cuda",
621            ),
622            DecorateInfo(
623                toleranceOverride({torch.float16: tol(atol=1e-2, rtol=2.6e-3)}),
624                "TestInductorOpInfo",
625                "test_comprehensive",
626                device_type="cuda",
627            ),
628        ),
629        # Can reuse the same inputs; dim is required in both
630        sample_inputs_func=sample_inputs_masked_cumops,
631        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
632    ),
633    ReductionOpInfo(
634        "masked.amax",
635        nan_policy="propagate",
636        supports_out=False,
637        dtypes=all_types_and(torch.float16, torch.bfloat16),
638        supports_sparse=True,
639        supports_forward_ad=True,
640        supports_fwgrad_bwgrad=True,
641        supports_sparse_csr=True,
642        ref=reference_reduction_numpy(np.amax),
643        skips=(
644            DecorateInfo(
645                unittest.expectedFailure,
646                "TestNormalizeOperators",
647                "test_normalize_operator_exhaustive",
648            ),
649            # FIXME: amax reduces all dimensions when dim=[]
650            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
651            DecorateInfo(
652                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
653            ),
654            # RuntimeError: Unknown builtin op: aten::iinfo
655            DecorateInfo(
656                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
657            ),
658            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
659            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
660            DecorateInfo(
661                unittest.skip("Skipped!"),
662                "TestMasked",
663                "test_mask_layout",
664                dtypes=(torch.bool, *integral_types(), *complex_types()),
665            ),
666        ),
667        sample_inputs_func=sample_inputs_masked_reduction,
668        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
669        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
670        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
671    ),
672    ReductionOpInfo(
673        "masked.amin",
674        nan_policy="propagate",
675        supports_out=False,
676        supports_forward_ad=True,
677        supports_fwgrad_bwgrad=True,
678        dtypes=all_types_and(torch.float16, torch.bfloat16),
679        supports_sparse=True,
680        supports_sparse_csr=True,
681        ref=reference_reduction_numpy(np.amin),
682        skips=(
683            DecorateInfo(
684                unittest.expectedFailure,
685                "TestNormalizeOperators",
686                "test_normalize_operator_exhaustive",
687            ),
688            # FIXME: amax reduces all dimensions when dim=[]
689            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
690            DecorateInfo(
691                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
692            ),
693            # RuntimeError: Unknown builtin op: aten::iinfo
694            DecorateInfo(
695                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
696            ),
697            # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs)
698            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
699            DecorateInfo(
700                unittest.skip("Skipped!"),
701                "TestMasked",
702                "test_mask_layout",
703                dtypes=(torch.bool, *integral_types(), *complex_types()),
704            ),
705        ),
706        sample_inputs_func=sample_inputs_masked_reduction,
707        sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
708        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
709        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
710    ),
711    ReductionOpInfo(
712        "masked.argmax",
713        supports_out=False,
714        supports_multiple_dims=False,
715        supports_autograd=False,
716        dtypes=all_types_and(torch.float16, torch.bfloat16),
717        ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
718        skips=(
719            DecorateInfo(
720                unittest.expectedFailure,
721                "TestNormalizeOperators",
722                "test_normalize_operator_exhaustive",
723            ),
724            # initial is not a keyword for argmax
725            DecorateInfo(
726                unittest.expectedFailure, "TestReductions", "test_reference_masked"
727            ),
728            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
729            DecorateInfo(
730                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
731            ),
732        ),
733        sample_inputs_func=sample_inputs_masked_reduction,
734        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
735    ),
736    ReductionOpInfo(
737        "masked.argmin",
738        supports_out=False,
739        supports_multiple_dims=False,
740        supports_autograd=False,
741        dtypes=all_types_and(torch.float16, torch.bfloat16),
742        ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
743        skips=(
744            DecorateInfo(
745                unittest.expectedFailure,
746                "TestNormalizeOperators",
747                "test_normalize_operator_exhaustive",
748            ),
749            # initial is not a keyword for argmin
750            DecorateInfo(
751                unittest.expectedFailure, "TestReductions", "test_reference_masked"
752            ),
753            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
754            DecorateInfo(
755                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
756            ),
757        ),
758        sample_inputs_func=sample_inputs_masked_reduction,
759        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
760    ),
761    ReductionOpInfo(
762        "masked.mean",
763        ref=reference_reduction_numpy(np.mean)
764        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
765        else None,
766        method_variant=None,
767        nan_policy="propagate",
768        supports_out=False,
769        supports_sparse_csr=True,
770        supports_forward_ad=True,
771        supports_fwgrad_bwgrad=True,
772        promotes_int_to_float=True,
773        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
774        skips=(
775            DecorateInfo(
776                unittest.expectedFailure,
777                "TestReductions",
778                "test_ref_duplicate_values",
779                dtypes=(torch.bool,),
780            ),
781            DecorateInfo(
782                unittest.expectedFailure,
783                "TestReductions",
784                "test_reference_masked",
785                dtypes=(torch.bool,),
786            ),
787            DecorateInfo(
788                unittest.expectedFailure,
789                "TestReductions",
790                "test_ref_small_input",
791                dtypes=(torch.bool,),
792            ),
793            DecorateInfo(
794                unittest.expectedFailure,
795                "TestNormalizeOperators",
796                "test_normalize_operator_exhaustive",
797            ),
798            # FIXME: sum reduces all dimensions when dim=[]
799            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
800            DecorateInfo(
801                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
802            ),
803            # RuntimeError: undefined value tensor
804            DecorateInfo(
805                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
806            ),
807            # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs)
808            DecorateInfo(
809                unittest.skip("Skipped!"),
810                "TestMasked",
811                "test_mask_layout",
812                dtypes=(torch.bool, *integral_types(), *complex_types()),
813            ),
814        ),
815        decorators=[
816            DecorateInfo(
817                toleranceOverride(
818                    {
819                        torch.bfloat16: tol(atol=1e-03, rtol=0.05),
820                        torch.float16: tol(atol=1e-03, rtol=1e-03),
821                    }
822                ),
823                "TestReductions",
824                "test_reference_masked",
825            ),
826            DecorateInfo(
827                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}),
828                "TestReductions",
829                "test_ref_small_input",
830            ),
831            DecorateInfo(
832                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=2e-03)}),
833                "TestSparseCompressed",
834                "test_consistency",
835                device_type="cuda",
836            ),
837        ],
838        sample_inputs_func=sample_inputs_masked_reduction,
839        sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction,
840        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
841    ),
842    OpInfo(
843        "masked.median",
844        dtypes=floating_types_and(torch.bfloat16, torch.float16),
845        method_variant=None,
846        supports_out=False,
847        supports_forward_ad=True,
848        supports_fwgrad_bwgrad=True,
849        skips=(
850            DecorateInfo(
851                unittest.expectedFailure,
852                "TestNormalizeOperators",
853                "test_normalize_operator_exhaustive",
854            ),
855            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
856            DecorateInfo(
857                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
858            ),
859        ),
860        sample_inputs_func=partial(
861            sample_inputs_masked_softmax, use_zero_dimensions=False
862        ),
863        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
864    ),
865    ReductionOpInfo(
866        "masked.norm",
867        identity=0,
868        method_variant=None,
869        nan_policy="propagate",
870        supports_out=False,
871        promotes_int_to_float=True,
872        dtypes=floating_types_and(torch.float16, torch.bfloat16),
873        skips=(
874            DecorateInfo(
875                unittest.expectedFailure,
876                "TestNormalizeOperators",
877                "test_normalize_operator_exhaustive",
878            ),
879            # FIXME: sum reduces all dimensions when dim=[]
880            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
881            DecorateInfo(
882                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
883            ),
884            # torch.jit.frontend.NotSupportedError: Compiled functions
885            # can't take variable number of arguments or use
886            # keyword-only arguments with defaults
887            DecorateInfo(
888                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
889            ),
890        ),
891        supports_forward_ad=True,
892        supports_fwgrad_bwgrad=True,
893        sample_inputs_func=sample_inputs_masked_norm,
894        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
895    ),
896    ReductionOpInfo(
897        "masked.var",
898        ref=reference_masked_std_var(np.var)
899        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
900        else None,
901        method_variant=None,
902        nan_policy="propagate",
903        supports_out=False,
904        supports_forward_ad=True,
905        supports_fwgrad_bwgrad=True,
906        # See https://github.com/pytorch/pytorch/pull/78358
907        check_batched_forward_grad=False,
908        promotes_int_to_float=True,
909        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
910        skips=(
911            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
912            DecorateInfo(
913                unittest.skip("Skipped!"),
914                "TestSchemaCheckModeOpInfo",
915                "test_schema_correctness",
916                dtypes=(torch.complex64, torch.complex128),
917            ),
918            DecorateInfo(
919                unittest.expectedFailure,
920                "TestNormalizeOperators",
921                "test_normalize_operator_exhaustive",
922            ),
923            # FIXME: sum reduces all dimensions when dim=[]
924            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
925            DecorateInfo(
926                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
927            ),
928            # RuntimeError: undefined value tensor
929            DecorateInfo(
930                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
931            ),
932        ),
933        decorators=[
934            DecorateInfo(
935                toleranceOverride(
936                    {
937                        torch.float16: tol(atol=1e-02, rtol=1e-02),
938                        torch.bfloat16: tol(atol=1e-03, rtol=1e-03),
939                    }
940                ),
941                "TestReductions",
942                "test_reference_masked",
943            ),
944            DecorateInfo(
945                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
946                "TestReductions",
947                "test_ref_small_input",
948            ),
949            DecorateInfo(
950                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
951                "TestMasked",
952                "test_reference_masked",
953            ),
954            DecorateInfo(
955                toleranceOverride(
956                    {
957                        torch.float16: tol(atol=1e-02, rtol=1e-02),
958                        torch.bfloat16: tol(atol=1e-03, rtol=1e-03),
959                    }
960                ),
961                "TestMasked",
962                "test_reference_masked",
963            ),
964            DecorateInfo(
965                toleranceOverride(
966                    {
967                        torch.float16: tol(atol=4e-5, rtol=2e-2),
968                    }
969                ),
970                "TestInductorOpInfo",
971                "test_comprehensive",
972                device_type="cuda",
973            ),
974        ],
975        sample_inputs_func=sample_inputs_masked_std_var,
976        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
977        check_batched_grad=True,
978    ),
979    ReductionOpInfo(
980        "masked.std",
981        ref=reference_masked_std_var(np.std)
982        if np.lib.NumpyVersion(np.__version__) >= "1.20.2"
983        else None,
984        method_variant=None,
985        nan_policy="propagate",
986        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
987        gradcheck_fast_mode=True,
988        supports_out=False,
989        supports_forward_ad=True,
990        supports_fwgrad_bwgrad=True,
991        # See https://github.com/pytorch/pytorch/pull/78358
992        check_batched_forward_grad=False,
993        promotes_int_to_float=True,
994        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
995        skips=(
996            # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
997            DecorateInfo(
998                unittest.skip("Skipped!"),
999                "TestSchemaCheckModeOpInfo",
1000                "test_schema_correctness",
1001                dtypes=(torch.complex64, torch.complex128),
1002            ),
1003            DecorateInfo(
1004                unittest.expectedFailure,
1005                "TestNormalizeOperators",
1006                "test_normalize_operator_exhaustive",
1007            ),
1008            # FIXME: sum reduces all dimensions when dim=[]
1009            DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"),
1010            DecorateInfo(
1011                unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim"
1012            ),
1013            # RuntimeError: undefined value tensor
1014            DecorateInfo(
1015                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
1016            ),
1017        ),
1018        decorators=[
1019            DecorateInfo(
1020                toleranceOverride(
1021                    {
1022                        torch.bfloat16: tol(atol=1e-02, rtol=1e-02),
1023                        torch.float16: tol(atol=1e-02, rtol=1e-02),
1024                    }
1025                ),
1026                "TestReductions",
1027                "test_reference_masked",
1028            ),
1029            DecorateInfo(
1030                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
1031                "TestReductions",
1032                "test_ref_small_input",
1033            ),
1034            DecorateInfo(
1035                toleranceOverride(
1036                    {
1037                        torch.float16: tol(atol=1e-02, rtol=1e-02),
1038                        torch.bfloat16: tol(atol=5e-03, rtol=5e-04),
1039                    }
1040                ),
1041                "TestMasked",
1042                "test_reference_masked",
1043            ),
1044        ],
1045        sample_inputs_func=sample_inputs_masked_std_var,
1046        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
1047        check_batched_grad=True,
1048    ),
1049    OpInfo(
1050        "masked.softmax",
1051        method_variant=None,
1052        dtypes=floating_types_and(torch.half, torch.bfloat16),
1053        sample_inputs_func=sample_inputs_masked_softmax,
1054        skips=(
1055            DecorateInfo(
1056                unittest.expectedFailure,
1057                "TestNormalizeOperators",
1058                "test_normalize_operator_exhaustive",
1059            ),
1060            DecorateInfo(
1061                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
1062            ),
1063        ),
1064        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
1065        supports_forward_ad=True,
1066        supports_fwgrad_bwgrad=True,
1067        supports_out=False,
1068    ),
1069    OpInfo(
1070        "masked.log_softmax",
1071        method_variant=None,
1072        dtypes=floating_types_and(torch.half, torch.bfloat16),
1073        sample_inputs_func=sample_inputs_masked_softmax,
1074        skips=(
1075            DecorateInfo(
1076                unittest.expectedFailure,
1077                "TestNormalizeOperators",
1078                "test_normalize_operator_exhaustive",
1079            ),
1080            DecorateInfo(
1081                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
1082            ),
1083        ),
1084        decorators=[
1085            DecorateInfo(
1086                toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1e-02)}),
1087                "TestMasked",
1088                "test_reference_masked",
1089            ),
1090        ],
1091        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
1092        supports_forward_ad=True,
1093        supports_fwgrad_bwgrad=True,
1094        supports_out=False,
1095    ),
1096    OpInfo(
1097        "masked.softmin",
1098        method_variant=None,
1099        dtypes=floating_types_and(torch.half, torch.bfloat16),
1100        sample_inputs_func=sample_inputs_masked_softmax,
1101        skips=(
1102            DecorateInfo(
1103                unittest.expectedFailure,
1104                "TestNormalizeOperators",
1105                "test_normalize_operator_exhaustive",
1106            ),
1107            DecorateInfo(
1108                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
1109            ),
1110            # FIXME:
1111            # Mismatched elements: 2 / 2 (100.0%)
1112            # Greatest absolute difference: nan at index (0,) (up to 0.0001 allowed)
1113            # Greatest relative difference: nan at index (0,) (up to 0.0001 allowed
1114            DecorateInfo(
1115                unittest.skip("Skipped!"),
1116                "TestOperators",
1117                "test_vmapvjpvjp",
1118                device_type="cpu",
1119            ),
1120        ),
1121        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
1122        supports_forward_ad=True,
1123        supports_fwgrad_bwgrad=True,
1124        supports_out=False,
1125    ),
1126    OpInfo(
1127        "masked.normalize",
1128        method_variant=None,
1129        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
1130        sample_inputs_func=sample_inputs_masked_normalize,
1131        decorators=[
1132            DecorateInfo(
1133                toleranceOverride({torch.float16: tol(atol=2e-5, rtol=6e-3)}),
1134                "TestInductorOpInfo",
1135                "test_comprehensive",
1136                device_type="cuda",
1137            ),
1138        ],
1139        skips=(
1140            DecorateInfo(
1141                unittest.expectedFailure,
1142                "TestNormalizeOperators",
1143                "test_normalize_operator_exhaustive",
1144            ),
1145            DecorateInfo(
1146                unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
1147            ),
1148        ),
1149        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
1150        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
1151        gradcheck_fast_mode=True,
1152        supports_forward_ad=True,
1153        supports_fwgrad_bwgrad=True,
1154        supports_out=False,
1155    ),
1156    OpInfo(
1157        "masked.logaddexp",
1158        dtypes=floating_types_and(torch.float16, torch.bfloat16),
1159        supports_out=False,
1160        supports_forward_ad=True,
1161        supports_fwgrad_bwgrad=True,
1162        check_batched_forward_grad=False,
1163        skips=(
1164            DecorateInfo(
1165                unittest.expectedFailure,
1166                "TestNormalizeOperators",
1167                "test_normalize_operator_exhaustive",
1168            ),
1169            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
1170            DecorateInfo(
1171                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
1172            ),
1173            DecorateInfo(
1174                unittest.skip("Skipped!"), "TestFwdGradients", "test_fn_gradgrad"
1175            ),
1176            DecorateInfo(
1177                unittest.skip("Skipped!"), "TestBwdGradients", "test_fn_gradgrad"
1178            ),
1179        ),
1180        sample_inputs_func=sample_inputs_masked_logaddexp,
1181        gradcheck_wrapper=gradcheck_wrapper_masked_pointwise_operation,
1182    ),
1183    ReductionOpInfo(
1184        "masked.logsumexp",
1185        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
1186        method_variant=None,
1187        nan_policy="propagate",
1188        supports_out=False,
1189        supports_forward_ad=True,
1190        supports_fwgrad_bwgrad=True,
1191        skips=(
1192            DecorateInfo(
1193                unittest.skip("Skipped!"),
1194                "TestNormalizeOperators",
1195                "test_normalize_operator_exhaustive",
1196            ),
1197            # FIXME: reduces all dimensions when dim=[]
1198            DecorateInfo(unittest.skip("Skipped!"), "TestReductions", "test_dim_empty"),
1199            DecorateInfo(
1200                unittest.skip("Skipped!"), "TestReductions", "test_dim_empty_keepdim"
1201            ),
1202            # Identity can't be -torch.inf without overflow
1203            DecorateInfo(
1204                unittest.skip("Skipped!"),
1205                "TestReductions",
1206                "test_empty_tensor_empty_slice",
1207            ),
1208            # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults
1209            DecorateInfo(
1210                unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit"
1211            ),
1212            # all the values are the same except for -inf vs nan
1213            DecorateInfo(unittest.skip("Skipped!"), "TestDecomp", "test_comprehensive"),
1214            # FIXME:
1215            # Mismatched elements: 2 / 12 (16.7%)
1216            # Greatest absolute difference: 9223372034707292160 at index (0, 0, 0, 0)
1217            # Greatest relative difference: 0.0 at index (0, 0, 0, 1)
1218            DecorateInfo(
1219                unittest.skip("Skipped!"),
1220                "TestInductorOpInfo",
1221                "test_comprehensive",
1222                device_type="cpu",
1223            ),
1224        ),
1225        sample_inputs_func=sample_inputs_masked_reduction,
1226        gradcheck_wrapper=gradcheck_wrapper_masked_operation,
1227    ),
1228]
1229