xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/opinfo/core.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import collections
4import collections.abc
5import math
6import operator
7import unittest
8from dataclasses import asdict, dataclass
9from enum import Enum
10from functools import partial
11from itertools import product
12from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
13
14import torch
15from torch.testing import make_tensor
16from torch.testing._internal.common_device_type import (
17    skipCPUIfNoFFT,
18    tol,
19    toleranceOverride,
20)
21from torch.testing._internal.common_dtype import (
22    _dispatch_dtypes,
23    floating_and_complex_types,
24    floating_and_complex_types_and,
25    floating_types,
26    get_all_dtypes,
27)
28from torch.testing._internal.common_utils import (
29    is_iterable_of_tensors,
30    noncontiguous_like,
31    OPINFO_SAMPLE_INPUT_INDEX,
32    TEST_WITH_ROCM,
33    torch_to_numpy_dtype_dict,
34    TrackedInputIter,
35)
36from torch.testing._internal.opinfo import utils
37from torchgen.utils import dataclass_repr
38
39
40# Reasonable testing sizes for dimensions
41L = 20
42M = 10
43S = 5
44XS = 3
45
46# Unique value to distinguish default from anything else
47_NOTHING = object()
48
49
50# Extension of getattr to support qualified names
51# e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm
52def _getattr_qual(obj, name, default=_NOTHING):
53    try:
54        for path in name.split("."):
55            obj = getattr(obj, path)
56        return obj
57    except AttributeError:
58        if default is not _NOTHING:
59            return default
60        else:
61            raise
62
63
64class DecorateInfo:
65    """Describes which test, or type of tests, should be wrapped in the given
66    decorators when testing an operator. Any test that matches all provided
67    arguments will be decorated. The decorators will only be applied if the
68    active_if argument is True."""
69
70    __slots__ = [
71        "decorators",
72        "cls_name",
73        "test_name",
74        "device_type",
75        "dtypes",
76        "active_if",
77    ]
78
79    def __init__(
80        self,
81        decorators,
82        cls_name=None,
83        test_name=None,
84        *,
85        device_type=None,
86        dtypes=None,
87        active_if=True,
88    ):
89        self.decorators = (
90            list(decorators)
91            if isinstance(decorators, collections.abc.Sequence)
92            else [decorators]
93        )
94        self.cls_name = cls_name
95        self.test_name = test_name
96        self.device_type = device_type
97        self.dtypes = dtypes
98        self.active_if = active_if
99
100        # Validate dtypes
101        if self.dtypes is not None:
102            for dtype in self.dtypes:
103                assert isinstance(dtype, torch.dtype)
104
105    def is_active(self, cls_name, test_name, device_type, dtype, param_kwargs):
106        return (
107            self.active_if
108            and (self.cls_name is None or self.cls_name == cls_name)
109            and (self.test_name is None or self.test_name == test_name)
110            and (self.device_type is None or self.device_type == device_type)
111            and (self.dtypes is None or dtype in self.dtypes)
112            # Support callables over kwargs to determine if the decorator is active.
113            and (
114                self.active_if(param_kwargs)
115                if isinstance(self.active_if, Callable)
116                else self.active_if
117            )
118        )
119
120
121# FIXME
122# Note: historically the 'input' kwarg had to be a Tensor or TensorList, but we are trying
123#   to support scalar inputs, too. Some tests still depend on 'input' being a Tensor
124#   or TensorList, however.
125class SampleInput:
126    """Represents sample inputs to a function."""
127
128    __slots__ = [
129        "input",
130        "args",
131        "kwargs",
132        "output_process_fn_grad",
133        "broadcasts_input",
134        "name",
135    ]
136
137    def __init__(
138        self,
139        input,
140        *var_args,
141        args=None,
142        kwargs=None,
143        output_process_fn_grad=None,
144        broadcasts_input=None,
145        name=None,
146        **var_kwargs,
147    ):
148        # input is the first input to the op and is typically either a Tensor or TensorList (Sequence[Tensor]).
149        # This follows the typical pattern where for Tensor inputs op(t, ...) = t.op(...).
150        self.input = input
151
152        # Allow calling either as SampleInput(input, args=args, kwargs=kwargs), or as
153        # SampleInput(input, *args, **kwargs) but not to mix the two forms
154        if args is not None or kwargs is not None:
155            assert (
156                not var_args and not var_kwargs
157            ), """
158A SampleInput can be constructed "naturally" with *args and **kwargs or by
159explicitly setting the "args" and "kwargs" parameters, but the two
160methods of construction cannot be mixed!"""
161        elif len(var_args) or len(var_kwargs):
162            assert (
163                output_process_fn_grad is None
164                and broadcasts_input is None
165                and name is None
166            ), """
167A SampleInput constructed "naturally" with *args and **kwargs
168cannot specify additional metadata in keyword arguments"""
169
170        self.args = args if args is not None else var_args
171        assert isinstance(self.args, tuple)
172        self.kwargs = kwargs if kwargs is not None else var_kwargs
173        assert isinstance(self.kwargs, dict)
174
175        self.output_process_fn_grad = (
176            output_process_fn_grad
177            if output_process_fn_grad is not None
178            else lambda x: x
179        )
180        self.name = name if name is not None else ""
181
182        # Specifies if `self.input` is broadcasted or not,
183        # given that the operator supports broadcasting.
184        # This field is used to verify the behavior for inplace variant.
185        #
186        # If a SampleInput is marked with `broadcasts_input=True`,
187        # it is verified that we get a `RuntimeError` with this sample,
188        # and inplace variant. Also inplace grad{grad} tests are skipped,
189        # for such inputs (as they will error out otherwise).
190        self.broadcasts_input = (
191            broadcasts_input if broadcasts_input is not None else False
192        )
193
194    def with_metadata(
195        self, *, output_process_fn_grad=None, broadcasts_input=None, name=None
196    ):
197        if output_process_fn_grad is not None:
198            self.output_process_fn_grad = output_process_fn_grad
199        if broadcasts_input is not None:
200            self.broadcasts_input = broadcasts_input
201        if name is not None:
202            self.name = name
203        return self
204
205    def _repr_helper(self, formatter):
206        # Helper function to return the details of the SampleInput as `str`
207        # It consolidates all the fields of SampleInput and allows,
208        # formatting the fields like `input`, `args`, etc with `formatter`
209        # callable to customize the representation.
210        # Look at `summary` method for example.
211        arguments = [
212            f"input={formatter(self.input)}",
213            f"args={formatter(self.args)}",
214            f"kwargs={formatter(self.kwargs)}",
215            f"broadcasts_input={self.broadcasts_input}",
216            f"name={repr(self.name)}",
217        ]
218
219        return f'SampleInput({", ".join(a for a in arguments if a is not None)})'
220
221    def __repr__(self):
222        return self._repr_helper(lambda x: x)
223
224    def summary(self):
225        # Returns the SampleInput details in a more
226        # friendly format.
227        # It formats `Tensor` and `TensorList`
228        # in a more condensed representation.
229        def formatter(arg):
230            # Format any instance of `Tensor` (standalone, in list, or in dict)
231            # by Tensor[TensorShape]
232            # Eg. Tensor with shape (3, 4) is formatted as Tensor[3, 4]
233            if isinstance(arg, torch.Tensor):
234                shape = str(tuple(arg.shape))
235                dtype = str(arg.dtype)
236                device = str(arg.device)
237                contiguity_suffix = ""
238                # NB: sparse CSR tensors annoyingly return is_sparse=False
239                is_sparse = arg.is_sparse or arg.layout == torch.sparse_csr
240                if not is_sparse and not arg.is_contiguous():
241                    contiguity_suffix = ", contiguous=False"
242                return f'Tensor[size={shape}, device="{device}", dtype={dtype}{contiguity_suffix}]'
243            elif isinstance(arg, dict):
244                return {k: formatter(v) for k, v in arg.items()}
245            elif is_iterable_of_tensors(arg):
246                return "TensorList[" + ", ".join(map(formatter, arg)) + "]"
247            elif isinstance(arg, (list, tuple)):  # Handle list, tuple
248                return "(" + ",".join(map(formatter, arg)) + ")"
249
250            return repr(arg)
251
252        return self._repr_helper(formatter)
253
254    # Applies the transform f(t) -> t to each tensor and dtype in the SampleInput
255    def transform(self, f):
256        def tt(t):
257            def _tt(t):
258                with torch.no_grad():
259                    return f(t)
260
261            if isinstance(t, torch.Tensor):
262                return _tt(t)
263            elif isinstance(t, torch.dtype):
264                return _tt(t)
265            elif isinstance(t, list):
266                return list(map(tt, t))
267            elif isinstance(t, tuple):
268                return tuple(map(tt, t))
269            elif isinstance(t, dict):
270                return {k: tt(v) for k, v in t.items()}
271            else:
272                return t
273
274        sample_tt_input, tt_args, tt_kwargs = (
275            tt(self.input),
276            tt(self.args),
277            tt(self.kwargs),
278        )
279
280        # Note the transformed SampleInput assumes metadata like output_process_fn_grad is still valid!
281        return SampleInput(
282            sample_tt_input,
283            args=tt_args,
284            kwargs=tt_kwargs,
285            output_process_fn_grad=self.output_process_fn_grad,
286            broadcasts_input=self.broadcasts_input,
287            name=self.name + "_transformed",
288        )
289
290    # Returns the NumPy version of the sample input object in the form of a tuple: (input, args, kwargs)
291    # Converts tensors to ndarrays by calling .detach().cpu().numpy() on them
292    # Converts dtypes by remapping them using torch_to_numpy_dtype_dict
293    def numpy(self):
294        def to_numpy(t):
295            if isinstance(t, torch.Tensor):
296                if t.dtype is torch.bfloat16:
297                    return t.detach().cpu().to(torch.float32).numpy()
298                if t.dtype is torch.chalf:
299                    return t.detach().cpu().to(torch.cfloat).numpy()
300                return t.detach().cpu().numpy()
301            elif isinstance(t, torch.dtype):
302                return torch_to_numpy_dtype_dict[t]
303
304            return t
305
306        return self.transform(to_numpy)
307
308    def noncontiguous(self):
309        def to_noncontiguous(t):
310            if isinstance(t, torch.Tensor):
311                return noncontiguous_like(t)
312            elif isinstance(t, torch.dtype):
313                return t
314
315            return t
316
317        return self.transform(to_noncontiguous)
318
319
320NumericsFilter = collections.namedtuple("NumericsFilter", ["condition", "safe_val"])
321
322
323class ErrorInput:
324    """
325    A SampleInput that will cause the operation to throw an error plus information
326    about the resulting error.
327    """
328
329    __slots__ = ["sample_input", "error_type", "error_regex"]
330
331    def __init__(self, sample_input, *, error_type=RuntimeError, error_regex):
332        self.sample_input = sample_input
333        self.error_type = error_type
334        self.error_regex = error_regex
335
336
337class AliasInfo:
338    """Class holds alias information. For example, torch.abs ->
339    torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_
340    """
341
342    def __init__(self, alias_name):
343        self.name = alias_name
344        self.op = _getattr_qual(torch, alias_name)
345        self.method_variant = getattr(torch.Tensor, alias_name, None)
346        self.inplace_variant = getattr(torch.Tensor, alias_name + "_", None)
347
348    def __call__(self, *args, **kwargs):
349        return self.op(*args, **kwargs)
350
351
352# Note [OpInfos]
353# ~~~~~~~~~~~~~~
354#
355# The majority of this note was written shortly after the PyTorch 1.9 release.
356# If you notice it's out-of-date or think it could be improved then please
357# file an issue.
358#
359# See also: the OpInfo tracker (https://github.com/pytorch/pytorch/issues/54261)
360# See also: "Writing Test Templates" in common_device_type.py to learn how to
361#   parametrize a test template using OpInfos.
362# See also: PyTorch's GitHub wiki on running and writing tests
363#   https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
364# See also: ModuleInfos, OpInfo's sister class, defined in common_modules.py
365#
366# An OpInfo is a collection of metadata related to a PyTorch operator. This
367#   metadata is used to generate tests that validate properties of the operator,
368#   like if it implements the correct gradient formula.
369#
370# WHY OPINFOS?
371# ~~~~~~~~~~~~
372#
373# OpInfos are principally intended to do three things:
374#
375#   1) to allow systematic testing over all PyTorch's operators
376#   2) to simplify operating testing by autogenerating many tests
377#   3) to allow systems (like autograd, torchscript, fx, nnc...) to test
378#        against every PyTorch operator
379#
380# All these goals are still a work in progress. Not every operator has an
381#   OpInfo, and some operator tests that could be automatically generated
382#   still have to be written manually.
383#
384# It's helpful to understand that OpInfos are both about test simplification and
385#   modularity. PyTorch is a complicated framework with many interrelated systems,
386#   too many for any one person to keep track of. An OpInfo can be thought of as the
387#   interface between an operator implementer and those other systems. Instead of
388#   requiring the implementer of torch.foo understand how to test its forward
389#   mode AD or NNC support that's typically handled automatically just by
390#   defining an OpInfo.
391#
392# It's often surprising to OpInfo writers that just implementing an OpInfo
393#   typically can't verify an operator is actually implemented correctly:
394#
395# "If an OpInfo doesn't validate my op works as expected, what's the point
396#     of it?"
397#
398# But the point of is the above. OpInfos are intended to let you focus on testing
399#   the operator logic you're familiar with instead of having to write tests for
400#   how the operator interacts with each of PyTorch's many systems.
401#
402# And, OK, it turns out that SOMETIMES just writing an OpInfo DOES
403#   validate your op works as expected, but that's only in special
404#   cases. See below for details.
405#
406# WHAT'S AN OPINFO?
407# ~~~~~~~~~~~~~~~~~
408#
409# So what is an OpInfo? It's a Python class that describes an operator's properties,
410#   like which dtypes it supports on the CPU and whether it has any aliases.
411#   These properties can be divided into three categories:
412#
413#   1) Metadata describing the operator, like the operator's name and if it
414#     "supports" the out kwarg.
415#   2) Test directives, like "skips" that tell the test suite to skip some
416#     tests.
417#   3) A "sample inputs" function that generates valid inputs for the operator.
418#
419# OpInfo attributes are described in more detail below.
420#
421# THE SAMPLE INPUTS FUNCTION
422# ~~~~~~~~~~~~~~~~~~~~~~~~~~
423#
424# The "sample inputs" function merits special elaboration. This function is
425#   crucial to testing with OpInfos. A typical OpInfo test has to treat the operator
426#   as a black box. There's no structure for the test to understand or exploit.
427#   Without "sample inputs" it wouldn't even know how to call the OpInfo's
428#   operator. The sample input function saves the day by providing different
429#   "SampleInputs" that can be used to call the operator. A sample input
430#   function should have the following signature:
431#
432#   def sample_inputs_foo(op_info, device, dtype, requires_grad, **kwargs):
433#
434#   And should return an iterable of SampleInputs (see the class description
435#   above). Each SampleInput defines an "input", "args", "kwargs", an
436#   "output_process_fn_grad" function, the "broadcasts_input" bool and a
437#   "name".
438#
439#   All the "sample_inputs" functions are invoked within a `torch.no_grad()`
440#   environment for efficiency and correctness. As such remember to set the
441#   "requires_grad" flag on the inputs **after** performing any transformations
442#   on them.
443#
444# The "input" is the first argument to the operator, or the tensor that
445#   the method or inplace variants of the operator should be called on, and
446#   should be on the requested device, of the requested dtype, and its
447#   requires_grad attribute should be set to the requires_grad argument.
448#
449# "args" should contain positional arguments, and "kwargs" keyword arguments.
450#
451# "output_process_fn_grad" has an interesting name. It's a function that maps
452#   the operator's output (when given the input, args, and kwargs) to the
453#   portion of the output to gradcheck. For example, consider an operator
454#   like torch.linalg.slogdet
455#   (https://pytorch.org/docs/main/generated/torch.linalg.slogdet.html).
456#   This operator returns a tuple of two tensors, but the first tensor
457#   cannot be backwarded through. Its "output_process_fn_grad" filters
458#   this output tuple to just the second argument, which we can call backward
459#   on. Functions that produce a single tensor can ignore this argument.
460#
461# "broadcasts_input" is a bool indicated if the SampleInput causes the operator
462#   to broadcast the "input" argument. This is important for tests to understand
463#   because inplace variants of operations throw a runtime error if they
464#   would broadcast their input arguments, so tests that work with inplace
465#   variants filter SampleInputs that broadcast their input.
466#
467# "name" is a string that's just used for debugging. It appears when printing
468#   the SampleInput.
469#
470# Sample inputs are designed to be used with many tests, some
471#   that are very time consuming, so they should be a small
472#   set with small tensors. An elaborated set of sample inputs
473#   can be specified using the "reference_inputs_func" attribute.
474#   The "reference inputs" for an operation are an extended
475#   set of sample inputs that can more exhausively test an
476#   operator. They are used by only a few tests that are careful
477#   not to take too long to run. Adding reference inputs
478#   is highly encouraged!
479#
480# THE (OPTIONAL) ERROR INPUTS FUNCTION
481# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
482#
483# OpInfos may optionally specify "error inputs" through an error function. If
484#   specified test_errors in test_ops.py will call the op with these inputs
485#   and validate that the desired error is thrown.
486#
487# Error inputs automate a common testing pattern where multiple inputs are
488#   passed to an operation and the errors they thrown are reviewed. Tests
489#   written in this style should be ported to the new OpInfo pattern.
490#
491# Error inputs are specified using the ErrorInputs class, which contains
492#   a SampleInput (see above) and data about the expected error.
493#
494# OPINFO FILE ORGANIZATION
495# ~~~~~~~~~~~~~~~~~~~~~~~~
496#
497# All OpInfos are currently defined in this file. Most OpInfo tests are defined
498#   in test_ops.py, but some system-specific tests are defined in those
499#   systems' test files, and subclass-specific tests are defined in the test
500#   file that corresponds to that subclass (see the below).
501#   Expect a reorganization in the future.
502#
503# WHAT'S TESTED?
504# ~~~~~~~~~~~~~~
505#
506# Every OpInfo in the op_db sequence has the following properties validated in
507# test_ops.py:
508#
509#   - that its supported dtypes are specified correctly
510#   - that the operation produces the same results when called with noncontiguous inputs
511#   - that it supports the out= argument properly (if it allows out=),
512#       see https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
513#   - that it works with the conjugate view bit properly
514#   - that its function, method, and inplace variants perform the same operation
515#       (that is, that torch.add, torch.Tensor.add, and torch.Tensor.add_ all
516#       do the same thing).
517#   - that its inplace variant preserves the input's storage
518#   - that its gradient formula is implemented correctly, and that it supports
519#       gradgrad and complex grad and gradgrad and forward mode AD properly for
520#       the op's function and inplace variants (method variants are skipped
521#       to reduce test time).
522#   - that the operation performs the same operation when traced or scripted
523#       using the jit
524#   - that the operation is autodifferentiated by the jit as expected
525#   - that the operator's aliases, if any, perform the same operation and that
526#       the jit understands the alias
527#   - that the operator throws the correct errors (if error_inputs is defined)
528#   - that the operator produces the same results as a NumPy reference (if ref is defined)
529#   - that the operator produces the same results as a NumPy reference on an extended
530#       set of "reference inputs" (if both ref and reference_inputs_func are defined)
531#       (NOTE: elementwise unary and elementwise binary OpInfos do this even if only
532#         ref is defined, because they effectively autogenerate reference inputs)
533#   - that the operator works on different CUDA devices
534#
535# Additional OpInfo tests are in test_jit_fuser_te.py, test_fx_experimental.py,
536#   and test_fx.py. These tests validate that operators work with NNC and FX
537#   as expected.
538#
539# For performance, some of the above tests may only run on the first
540#   SampleInput returned by an OpInfo's sample input function.
541#
542# In addition to these tests, some subclasses (discussed in the next section)
543#   define additional tests.
544#
545# Critically, as mentioned above, what's not necessarily tested is that the operator
546#   works as expected. When implementing an OpInfo an engineer must still
547#   typically write one or more tests validating the operator's behavior.
548#   The exception to this is if reference testing is sufficient, or if
549#   the operation belongs to an OpInfo subclass that has more exhaustive
550#   operator testing. Elementwise unary and elementwise binary operators,
551#   in particular, usually don't require additional testing beyond
552#   writing an Opinfo.
553#
554#
555# OPINFO (SUB)CLASSES
556# ~~~~~~~~~~~~~~~~~~~
557#
558# In addition to the OpInfo base class there are several specialized OpInfo
559#   subclasses. For example, the UnaryUfuncInfo subclass is used for
560#   unary elementwise operations. These operations have a common structure
561#   that test_unary_ufuncs.py exploits with additional automated testing.
562#   The automated testing in test_unary_ufuncs.py is so thorough, comparing
563#   the operator to a NumPy reference function on a plethora of values, that
564#   just implementing an OpInfo for a unary elementwise operation is often
565#   sufficient testing.
566#
567# The ForeachFuncInfo is another OpInfo subclass that is hyper-specialized to a
568#   very unique class of operations. These OpInfos aren't included in the
569#   op_db sequence and have their own tests.
570#
571# Other OpInfo subclasses, like SpectralFuncInfo, are just for convenience
572# when writing OpInfos.
573#
574# TESTING A NEW OPERATOR
575# ~~~~~~~~~~~~~~~~~~~~~~
576#
577# If you're adding a new operator to any of the following namespaces:
578#   - torch
579#   - torch.fft
580#   - torch.linalg,
581#   - torch.special
582#   - torch.nn.functional
583# then you should typically add an OpInfo for it.
584#
585# As mentioned a couple times above, implementing an OpInfo is not
586#   usually sufficient testing (unless the operator is a unary or binary elementwise
587#   operator). The OpInfo will only test the properties described in the
588#   "WHAT'S TESTED" section. It DOES NOT necessarily verify that the operator is
589#   implemented correctly.
590#
591# TIPS FOR WRITING AN OPINFO AND OPINFO TESTS
592# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
593#
594# Writing an OpInfo can be a little daunting. Since the point of an OpInfo is to
595#   be consumed by a variety of systems it can be hard to understand how to
596#   deal with test failures or how to set the OpInfo metadata properly.
597#
598# Before adding an OpInfo it helps to look at other OpInfos. A sample inputs
599#   function must be defined, and the operator's dtypes must be specified.
600#   Once that's done you should run the operator's tests in test_ops.py
601#   (these can be filtered using the "-k" argument in pytest). Tests that
602#   fail should provide an error message that describes what to change about
603#   your OpInfo. You don't need to worry about changing an OpInfo's default
604#   values unless a test yells at you.
605#
606# Similarly, if you're writing a test that consumes OpInfos then it's critical
607#   your test provides a clear error message describing what to do when it
608#   fails. You should not assume the OpInfo implementer is familiar with your
609#   system.
610#
611# If you see a confusing error message while developing an OpInfo then please
612#   file an issue describing what happened.
613#
614# This trial-and-error approach to writing an OpInfo can be frustrating,
615#   but it's probably necessary as long as OpInfos don't require
616#   learning about all the systems that consume them. One thing that can help
617#   is the get_supported_dtypes() function defined in utils.py. This
618#   function can be used to programmatically specify the dtypes an operator
619#   supports, and is especially useful if writing an OpInfo on a machine
620#   without a CUDA device. See its documentation for more details.
621#
622# THE FUTURE OF OPINFOS AND OPINFO TESTING
623# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
624#
625# In the future we expect OpInfo coverage to improve and cover
626#   the great majority of PyTorch's (public) operators.
627#
628
629
630# Classes and methods for the operator database
631@dataclass
632class OpInfo:
633    """Operator information and helper functions for acquiring it."""
634
635    # the string name of the function
636    name: str
637
638    # An optional reference function that accepts ndarrays (AKA "NumPy arrays").
639    # If given, the op will be compared with its reference on each of its sample inputs.
640    ref: Optional[Callable] = None
641
642    # the following metadata describes the operator, its variants, and its aliases, if any
643
644    # iterable of aliases, e.g. ("absolute",) for torch.abs
645    aliases: Iterable = None
646
647    # additional string to include in the test name
648    # this is useful when an op needs multiple OpInfos,
649    # like divide does, often because it's really several
650    # different ops behind the scenes
651    variant_test_name: str = ""
652
653    # the function variant of the operation, populated as torch.<name> if None
654    op: Callable = None
655
656    # allows the method variant of this operation to be specified as follows:
657    # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name
658    # - if None, then the OpInfo explicitly specifies is has no associated method
659    # - if a Callable, then that callable should be the method associated with this operation
660    method_variant: Callable = _NOTHING
661
662    # allows the inplace variant of this operation to be specified as follows:
663    # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name
664    # - if None, then the OpInfo explicitly specifies is has no associated inplace variant
665    # - if a Callable, then that callable should be the inplace variant associated with this operation
666    inplace_variant: Callable = _NOTHING
667
668    # allows the operator variant of this operation to be specified as follows:
669    # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name
670    # - if None, then the OpInfo explicitly specifies is has no associated operator
671    # - if a Callable, then that callable should be the operator associated with this operation
672    operator_variant: Callable = _NOTHING
673
674    # allows the inplace operator variant of this operation to be specified as follows:
675    # - if _NOTHING (default), then the OpInfo attempts to discover the variant using its name
676    # - if None, then the OpInfo explicitly specifies is has no associated inplace operator
677    # - if a Callable, then that callable should be the inplace operator associated with this operation
678    inplace_operator_variant: Callable = _NOTHING
679
680    # the following metadata are test directives for skipping or modifying tests
681
682    # information about which tests to skip
683    skips: Tuple = ()
684
685    # decorators to apply to generated tests
686    decorators: Tuple = ()
687
688    # the following are pointers to functions to generate certain classes of inputs
689
690    # function to generate sample inputs with strided layouts
691    sample_inputs_func: Callable = None
692
693    # function to generate a more thorough set of samples inputs with strided layouts
694    reference_inputs_func: Callable = None
695
696    # function to generate inputs that will throw errors
697    error_inputs_func: Callable = None
698
699    # function to generate sparse (coo, csr, csc, bsr, bsc) inputs that will throw errors
700    error_inputs_sparse_func: Callable = None
701
702    # function to generate sample inputs with sparse coo layouts
703    sample_inputs_sparse_coo_func: Callable = None
704
705    # function to generate sample inputs with sparse csr layouts
706    sample_inputs_sparse_csr_func: Callable = None
707
708    # function to generate sample inputs with sparse csc layouts
709    sample_inputs_sparse_csc_func: Callable = None
710
711    # function to generate sample inputs with sparse bsr layouts
712    sample_inputs_sparse_bsr_func: Callable = None
713
714    # function to generate sample inputs with sparse bsc layouts
715    sample_inputs_sparse_bsc_func: Callable = None
716
717    # the following metadata relates to dtype support and is tested for correctness in test_ops.py
718
719    # dtypes this function works with on the CPU,
720    # inherited by other device types that don't specify their own dtypes
721    dtypes: _dispatch_dtypes = None
722
723    # the following dtypesIf... options override the dtypes value on their respective device types
724
725    # dtypes this function is expected to work with on CUDA
726    dtypesIfCUDA: _dispatch_dtypes = None
727
728    # dtypes this function is expected to work with on ROCM
729    dtypesIfROCM: _dispatch_dtypes = None
730
731    dtypesIfHpu: _dispatch_dtypes = None
732
733    # dtypes this function is expected to work with on XPU
734    dtypesIfXPU: _dispatch_dtypes = None
735
736    # backward dtypes this function is expected to work with
737    backward_dtypes: _dispatch_dtypes = None
738
739    # backward dtypes this function is expected to work with on CUDA
740    backward_dtypesIfCUDA: _dispatch_dtypes = None
741
742    # backward dtypes this function is expected to work with on ROCM
743    backward_dtypesIfROCM: _dispatch_dtypes = None
744
745    backward_dtypesIfHpu: _dispatch_dtypes = None
746
747    # the following metadata describes the operators out= support
748
749    # whether the op supports the out kwarg
750    # defaults to True, if the op does not allow the out kwarg or
751    # supports it incorrectly then test_out in test_ops.py should fail
752    supports_out: bool = True
753
754    # the following metadata relates to autograd support
755    # whether the operation supports backward mode AD
756    # if true, gradient correctness is tested in test_ops.py
757    # using the op's sample inputs
758    supports_autograd: bool = True
759
760    # whether the op supports second order gradients
761    # if true, gradgrad correctness is tested in test_ops.py
762    # defaults to support_autograd's value
763    # TODO: rename this to supports_bwgrad_bwgrad to be consistent with below
764    supports_gradgrad: bool = None
765
766    # whether the ops supports second order gradients via
767    # forward-over-reverse. If True, forward-over-reverse gradgrad correctness
768    # is tested. If False, test that forward grad is not implemented.
769    # Defaults to False.
770    supports_fwgrad_bwgrad: bool = False
771
772    # whether the operation supports inplace autograd
773    # if true, tested in test_ops.py
774    # defaults to supports_autograd's value
775    supports_inplace_autograd: bool = None
776
777    # Whether the operation support forward mode AD
778    # If the value is True, we check that the gradients are correct
779    # If the value is False, we test that forward grad is not implemented
780    supports_forward_ad: bool = False
781
782    # Whether the operation has a varargs variant
783    # (e.g. functions like ones, zeros, methods like view, permute)
784    supports_varargs: bool = False
785
786    # Whether the forward operation avoids materializing COW tensor inputs
787    supports_cow_input_no_materialize_forward: bool = True
788
789    # Whether the backward operation avoids materializing COW tensor inputs
790    supports_cow_input_no_materialize_backward: bool = True
791
792    # Whether to skip the backward part of the COW tensor input test
793    skip_cow_input_backward: bool = False
794
795    # If `supports_cow_input_no_materialize_forward == True`, this list contains
796    # the arg indices or kwarg names of inputs that are expected to materialize
797    allow_cow_input_materialize_forward: List[Union[int, str]] = None
798
799    # If `supports_cow_input_no_materialize_backward == True`, this list contains
800    # the arg indices or kwarg names of inputs that are expected to materialize
801    allow_cow_input_materialize_backward: List[Union[int, str]] = None
802
803    # wrapper function for gradcheck
804    gradcheck_wrapper: Callable = lambda op, *args, **kwargs: op(*args, **kwargs)
805
806    # whether to check batched grad when doing gradcheck
807    # defaults to support_autograd's value
808    check_batched_grad: bool = None
809
810    # whether to check batched grad grad when doing gradgradcheck
811    # default's to support_gradgrad's value
812    check_batched_gradgrad: bool = None
813
814    # whether to check batched forward grad when doing gradcheck
815    # defaults to the value of `supports_forward_ad`
816    check_batched_forward_grad: bool = None
817
818    # whether to check batched forward grad when doing gradcheck
819    # defaults to the value of `check_batched_forward_grad`
820    check_inplace_batched_forward_grad: bool = None
821
822    # tolerance for nondeterminism while performing gradcheck
823    gradcheck_nondet_tol: float = 0.0
824
825    # Whether to use the fast implmentation for gradcheck/gradgradcheck.
826    # When set to None, defers to the default value provided by the wrapper
827    # function around gradcheck (testing._internal.common_utils.gradcheck)
828    gradcheck_fast_mode: bool = None
829
830    # the following metadata relates to JIT support and is tested for correctness in test_ops.py
831
832    # name of the corresponding aten:: operator
833    aten_name: str = None
834
835    # if this is a composite implicit autograd op, the decomposed op
836    decomp_aten_name: Optional[str] = None
837
838    # name of the corresponding aten:: operator for backwards
839    aten_backward_name: Optional[str] = None
840
841    # if a op's aten::node is expected to be symbolically autodiffed
842    assert_autodiffed: bool = False
843
844    # a list of strings with node names that are expected to be in a
845    # DifferentiableGraph when autodiffed. Ex: ['aten::add', 'aten::mm'],
846    # default is populated to be ['aten::(name of Python operator)']
847    autodiff_nonfusible_nodes: List[str] = None
848
849    # a list of strings with node names that are expected to be in FusionGroups
850    # inside of DifferentiableGraphs when this operation is autodiffed.
851    # Ex: ['aten::add', 'aten::mm'], defaults to an empty list
852    # Note: currently no ops use fusible nodes
853    autodiff_fusible_nodes: List[str] = None
854
855    # the following metadata relates to sparse support and is used in test_sparse.py
856
857    # whether the op supports sparse coo inputs, defaults to False
858    # TODO: rename supports_sparse to supports_sparse_coo
859    supports_sparse: bool = None
860
861    # only run tracing tests
862    supports_scripting: bool = True
863
864    # if the operator can be traced
865    supports_tracing: bool = True
866
867    # the following metadata relates to sparse compressed support and
868    # is used in test_sparse_csr.py and test_sparse.py
869
870    # whether the op supports sparse csr inputs, defaults to False
871    supports_sparse_csr: bool = None
872    # whether the op supports sparse csc inputs, defaults to False
873    supports_sparse_csc: bool = None
874    # whether the op supports sparse bsr inputs, defaults to False
875    supports_sparse_bsr: bool = None
876    # whether the op supports sparse bsc inputs, defaults to False
877    supports_sparse_bsc: bool = None
878    # whether the op supports nested jagged inputs, defaults to False
879    supports_njt: bool = None
880
881    # whether the op promotes integer inputs to float
882    promotes_int_to_float: bool = False
883
884    # the following metadata relates to complex support and is checked in test_ops.py
885
886    test_conjugated_samples: bool = True
887
888    test_neg_view: bool = True
889
890    # assert that jit shape analysis fully propagates shape
891    assert_jit_shape_analysis: bool = False
892
893    # the following metadata relates to ExpandedWeights support and is checked in test_expanded_weights.py
894
895    supports_expanded_weight: bool = False
896
897    is_factory_function: bool = False
898
899    def __post_init__(self):
900        self._original_opinfo_args = asdict(self).copy()
901
902        assert self.dtypes is not None, f"OpInfo for {self.name} has no dtypes!"
903
904        dtypes_args = (
905            self.dtypes,
906            self.dtypesIfCUDA,
907            self.dtypesIfROCM,
908            self.dtypesIfXPU,
909        )
910
911        # Validates the dtypes are generated from the dispatch-related functions
912        for dtype_list in dtypes_args:
913            assert isinstance(dtype_list, (_dispatch_dtypes, type(None)))
914
915        if self.aten_name is None:
916            self.aten_name = self.name
917
918        # Attribute to verify dynamic_dtypes are used.
919        self.dynamic_dtypes = any(
920            isinstance(dtypes, utils._dynamic_dispatch_dtypes) for dtypes in dtypes_args
921        )
922
923        if self.dynamic_dtypes:
924            # Make sure `dtyesIfCUDA` is dynamic, if dynamic dispatch is used for CPU
925            # This is because, below we set dtypesIfCUDA to dtypes if they are None.
926            assert isinstance(self.dtypesIfCUDA, utils._dynamic_dispatch_dtypes), (
927                f"To use dynamic dypes for operator {self.name}, "
928                "acquire the dtypes dynamically for argument `dtypesIfCUDA`."
929                "This is to ensure that CUDA dtypes are acquired correctly as they"
930                "differ from CPU dtypes occasionally"
931            )
932
933        self.dtypes = set(self.dtypes)
934
935        # NOTE: backward dtypes must be acquired before forward dtypes
936        #   since they fallback to explicit (not implicit!) specifications of
937        #   forward dtypes
938        self.backward_dtypesIfROCM = (
939            set(self.backward_dtypesIfROCM)
940            if self.backward_dtypesIfROCM is not None
941            else (
942                self.backward_dtypesIfCUDA
943                if self.backward_dtypesIfCUDA is not None
944                else self.backward_dtypes
945                if self.backward_dtypes is not None
946                else self.dtypesIfROCM
947                if self.dtypesIfROCM is not None
948                else self.dtypesIfCUDA
949                if self.dtypesIfCUDA is not None
950                else self.dtypes
951            )
952        )
953        self.backward_dtypesIfCUDA = (
954            set(self.backward_dtypesIfCUDA)
955            if self.backward_dtypesIfCUDA is not None
956            else (
957                self.backward_dtypes
958                if self.backward_dtypes is not None
959                else self.dtypesIfCUDA
960                if self.dtypesIfCUDA is not None
961                else self.dtypes
962            )
963        )
964        self.backward_dtypesIfHpu = (
965            set(self.backward_dtypesIfHpu)
966            if self.backward_dtypesIfHpu is not None
967            else (
968                self.backward_dtypes
969                if self.backward_dtypes is not None
970                else self.dtypes
971            )
972        )
973
974        self.backward_dtypes = (
975            set(self.backward_dtypes)
976            if self.backward_dtypes is not None
977            else self.dtypes
978        )
979
980        self.dtypesIfCUDA = (
981            set(self.dtypesIfCUDA) if self.dtypesIfCUDA is not None else self.dtypes
982        )
983        self.dtypesIfROCM = (
984            set(self.dtypesIfROCM)
985            if self.dtypesIfROCM is not None
986            else self.dtypesIfCUDA
987        )
988        self.dtypesIfXPU = (
989            set(self.dtypesIfXPU) if self.dtypesIfXPU is not None else self.dtypesIfCUDA
990        )
991
992        self.dtypesIfHpu = (
993            set(self.dtypesIfHpu) if self.dtypesIfHpu is not None else self.dtypes
994        )
995
996        # NOTE: if the op is unspecified it is assumed to be under the torch namespace
997        if not self.op:
998            self.op = _getattr_qual(torch, self.name)
999
1000        if self.method_variant is _NOTHING:
1001            self.method_variant = getattr(torch.Tensor, self.name, None)
1002
1003        # attributes like real, imag are not callable
1004        if not callable(self.method_variant):
1005            self.method_variant = None
1006
1007        if self.inplace_variant is _NOTHING:
1008            inplace_name = self.name + "_"
1009            self.inplace_variant = getattr(torch.Tensor, inplace_name, None)
1010
1011        if self.operator_variant is _NOTHING:
1012            self.operator_variant = getattr(operator, self.name, None)
1013
1014        if self.inplace_operator_variant is _NOTHING:
1015            # Note: operator.i<op> will use operator.<op> and assign the result to the lhs when no
1016            # __i<op>__ method is found. This results in the appearance of an inplace operator variant which
1017            # does not have the correct inplace behavior. To avoid this, we guard automatic detection of the inplace
1018            # operator with a check that an inplace variant exists.
1019            if self.inplace_variant is not None:
1020                inplace_operator_name = "i" + self.name
1021                self.inplace_operator_variant = getattr(
1022                    operator, inplace_operator_name, None
1023                )
1024            else:
1025                self.inplace_operator_variant = None
1026
1027        self.decorators = (*self.decorators, *self.skips)
1028
1029        # Specifying sample inputs function without specifying the
1030        # corresponding layout support implies the layout support:
1031        if self.supports_sparse is None:
1032            self.supports_sparse = self.sample_inputs_sparse_coo_func is not None
1033        if self.sample_inputs_sparse_coo_func is None:
1034            self.sample_inputs_sparse_coo_func = self._sample_inputs_unspecified
1035
1036        if self.supports_sparse_csr is None:
1037            self.supports_sparse_csr = self.sample_inputs_sparse_csr_func is not None
1038        if self.sample_inputs_sparse_csr_func is None:
1039            self.sample_inputs_sparse_csr_func = self._sample_inputs_unspecified
1040
1041        if self.supports_sparse_csc is None:
1042            self.supports_sparse_csc = self.sample_inputs_sparse_csc_func is not None
1043        if self.sample_inputs_sparse_csc_func is None:
1044            self.sample_inputs_sparse_csc_func = self._sample_inputs_unspecified
1045
1046        if self.supports_sparse_bsr is None:
1047            self.supports_sparse_bsr = self.sample_inputs_sparse_bsr_func is not None
1048        if self.sample_inputs_sparse_bsr_func is None:
1049            self.sample_inputs_sparse_bsr_func = self._sample_inputs_unspecified
1050
1051        if self.supports_sparse_bsc is None:
1052            self.supports_sparse_bsc = self.sample_inputs_sparse_bsc_func is not None
1053        if self.sample_inputs_sparse_bsc_func is None:
1054            self.sample_inputs_sparse_bsc_func = self._sample_inputs_unspecified
1055
1056        if self.supports_njt is None:
1057            self.supports_njt = False
1058
1059        # We run the sampling functions without tracking the gradiends of the creation of inputs
1060        self.sample_inputs_func = torch.no_grad()(self.sample_inputs_func)
1061        self.sample_inputs_sparse_coo_func = torch.no_grad()(
1062            self.sample_inputs_sparse_coo_func
1063        )
1064        self.sample_inputs_sparse_csr_func = torch.no_grad()(
1065            self.sample_inputs_sparse_csr_func
1066        )
1067        self.sample_inputs_sparse_csc_func = torch.no_grad()(
1068            self.sample_inputs_sparse_csc_func
1069        )
1070        self.sample_inputs_sparse_bsr_func = torch.no_grad()(
1071            self.sample_inputs_sparse_bsr_func
1072        )
1073        self.sample_inputs_sparse_bsc_func = torch.no_grad()(
1074            self.sample_inputs_sparse_bsc_func
1075        )
1076        if self.reference_inputs_func is not None:
1077            self.reference_inputs_func = torch.no_grad()(self.reference_inputs_func)
1078
1079        if not self.autodiff_fusible_nodes:
1080            self.autodiff_fusible_nodes = []
1081
1082        if self.autodiff_nonfusible_nodes is None:
1083            self.autodiff_nonfusible_nodes = ["aten::" + self.name]
1084
1085        # Autograd support
1086
1087        # Autograd flags that depend on backward AD only
1088        # - If setting has been explicitly set, raise error if inconsistent
1089        if self.supports_gradgrad is None:
1090            self.supports_gradgrad = self.supports_autograd
1091        else:
1092            assert not (self.supports_gradgrad and not self.supports_autograd), (
1093                "supports_gradgrad refines the part of autograd is supported, so it should "
1094                "not be set if supports_autograd is False"
1095            )
1096        if self.check_batched_grad is None:
1097            self.check_batched_grad = self.supports_autograd or self.supports_forward_ad
1098        else:
1099            assert not (
1100                self.check_batched_grad
1101                and not (self.supports_autograd or self.supports_forward_ad)
1102            ), (
1103                "check_batched_grad refines the part of autograd that will be checked (by gradcheck), so "
1104                "it should not be set if supports_autograd is False"
1105            )
1106        if self.check_batched_gradgrad is None:
1107            self.check_batched_gradgrad = self.supports_gradgrad
1108        else:
1109            assert not (self.check_batched_gradgrad and not self.supports_gradgrad), (
1110                "check_batched_gradgrad refines the part of autograd that will be checked (by "
1111                "gradgradcheck), so it should not be set if either supports_gradgrad or supports_autograd "
1112                "is False."
1113            )
1114        if self.check_batched_forward_grad is None:
1115            self.check_batched_forward_grad = self.supports_forward_ad
1116        else:
1117            assert not (
1118                self.check_batched_forward_grad and not self.supports_forward_ad
1119            ), (
1120                "check_batched_forward_grad should only be used when supports_forward_ad "
1121                "is True. It is used to disable the test in the specific cases "
1122                "where the op supports forward ad but fails to compute "
1123                "batched forward grad."
1124            )
1125
1126        if self.check_inplace_batched_forward_grad is None:
1127            self.check_inplace_batched_forward_grad = self.check_batched_forward_grad
1128        else:
1129            assert not (
1130                self.check_inplace_batched_forward_grad
1131                and not self.check_batched_forward_grad
1132            ), (
1133                "check_batched_forward_grad should only be used when check_batched_forward_grad "
1134                "is True. It is used to disable the test in the specific cases "
1135                "where the op supports batched forward grad but fails to compute batched forward "
1136                "grad for the inplace variant of the op."
1137            )
1138
1139        assert not (self.supports_fwgrad_bwgrad and not self.supports_autograd), (
1140            "supports_fwgrad_bwgrad enables forward-over-backward gradgrad checks and should only be "
1141            "True if backward ad is also checked, i.e., supports_forward_ad should be True.",
1142            self.name,
1143        )
1144
1145        # Autograd flags that depend on both forward AD and backward AD
1146        if self.supports_inplace_autograd is None:
1147            self.supports_inplace_autograd = (
1148                self.supports_autograd or self.supports_forward_ad
1149            )
1150        else:
1151            assert not (
1152                self.supports_inplace_autograd
1153                and not self.supports_autograd
1154                and not self.supports_forward_ad
1155            ), (
1156                "supports_inplace_autograd refines the part of autograd that is supported, so "
1157                "it should not be set if both supports_autograd and supports_forward_ad are False"
1158            )
1159
1160        if self.aliases is not None:
1161            self.aliases = tuple(AliasInfo(a) for a in self.aliases)  # type: ignore[assignment]
1162        else:
1163            self.aliases = ()
1164
1165    def __call__(self, *args, **kwargs):
1166        """Calls the function variant of the operator."""
1167        return self.op(*args, **kwargs)
1168
1169    def __str__(self):
1170        return dataclass_repr(self)
1171
1172    def get_op(self):
1173        """Returns the function variant of the operator, torch.<op_name>."""
1174        return self.op
1175
1176    def get_method(self):
1177        """Returns the method variant of the operator, torch.Tensor.<op_name>.
1178        Returns None if the operator has no method variant.
1179        """
1180        return self.method_variant
1181
1182    def get_inplace(self):
1183        """Returns the inplace variant of the operator, torch.Tensor.<op_name>_.
1184        Returns None if the operator has no inplace variant.
1185        """
1186        return self.inplace_variant
1187
1188    def get_operator(self):
1189        """Returns operator variant of the operator, e.g. operator.neg
1190        Returns None if the operator has no operator variant.
1191        """
1192        return self.operator_variant
1193
1194    def get_inplace_operator(self):
1195        """Returns the inplace operator variant of the operator, e.g operator.iadd
1196        Returns None if the operator has no inplace operator variant"""
1197        return self.inplace_operator_variant
1198
1199    def conjugate_sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
1200        """Returns an iterable of SampleInputs but with the tensor input or first
1201        tensor in a sequence input conjugated.
1202        """
1203
1204        set_seed = kwargs.pop("set_seed", True)
1205        samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
1206        conj_samples = list(samples)
1207
1208        def conjugate(tensor):
1209            _requires_grad = tensor.requires_grad
1210            tensor = tensor.conj()
1211            return tensor.requires_grad_(_requires_grad)
1212
1213        for i, sample in enumerate(samples):
1214            sample = conj_samples[i]
1215            # Note: it is assumed that the input here is either a tensor or tensorlist
1216            if isinstance(sample.input, torch.Tensor):
1217                sample.input = conjugate(sample.input)
1218            else:
1219                sample.input[0] = conjugate(sample.input[0])
1220
1221        return TrackedInputIter(
1222            iter(conj_samples),
1223            "conjugate sample input",
1224            set_seed=set_seed,
1225            restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
1226        )
1227
1228    def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
1229        """
1230        Returns an iterable of SampleInputs.
1231
1232        These samples should be sufficient to test the function works correctly
1233        with autograd, TorchScript, etc.
1234        """
1235        set_seed = kwargs.pop("set_seed", True)
1236        samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
1237
1238        if kwargs.get("include_conjugated_inputs", False):
1239            conj_samples = self.conjugate_sample_inputs(
1240                device, dtype, requires_grad, **kwargs
1241            )
1242            samples_list = list(samples)
1243            samples_list.extend(conj_samples)
1244            samples = tuple(samples_list)
1245
1246        return TrackedInputIter(
1247            iter(samples),
1248            "sample input",
1249            set_seed=set_seed,
1250            restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
1251        )
1252
1253    def reference_inputs(self, device, dtype, requires_grad=False, **kwargs):
1254        """
1255        Returns an iterable of SampleInputs.
1256
1257        Distinct from sample_inputs() above because this returns an expanded set
1258        of inputs when reference_inputs_func is defined. If undefined this returns
1259        the sample inputs.
1260        """
1261        set_seed = kwargs.pop("set_seed", True)
1262        if self.reference_inputs_func is None:
1263            samples = self.sample_inputs_func(
1264                self, device, dtype, requires_grad, **kwargs
1265            )
1266            return TrackedInputIter(
1267                iter(samples),
1268                "reference input",
1269                set_seed=set_seed,
1270                restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
1271            )
1272
1273        if kwargs.get("include_conjugated_inputs", False):
1274            raise NotImplementedError
1275
1276        references = self.reference_inputs_func(
1277            self, device, dtype, requires_grad, **kwargs
1278        )
1279        return TrackedInputIter(
1280            iter(references),
1281            "reference input",
1282            set_seed=set_seed,
1283            restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
1284        )
1285
1286    def error_inputs(self, device, **kwargs):
1287        """
1288        Returns an iterable of ErrorInputs.
1289        """
1290        set_seed = kwargs.pop("set_seed", True)
1291        errs = self.error_inputs_func(self, device, **kwargs)
1292        return TrackedInputIter(
1293            iter(errs),
1294            "error input",
1295            callback=lambda e: e.sample_input,
1296            set_seed=set_seed,
1297            restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
1298        )
1299
1300    def error_inputs_sparse(self, device, layout, **kwargs):
1301        """
1302        Returns an iterable of ErrorInputs that contain sparse sample
1303        inputs with a specified layout.
1304        """
1305        if not self.supports_sparse_layout(layout):
1306            raise unittest.SkipTest("unsupported sparse layout")
1307        return self.error_inputs_sparse_func(self, device, layout, **kwargs)
1308
1309    def supports_sparse_layout(self, layout):
1310        """Return True if OpInfo supports the specified sparse layout."""
1311        layout_name = str(layout).split(".")[-1]
1312        # map torch.sparse_coo to OpInfo.supports_sparse:
1313        layout_name = layout_name.replace("_coo", "")
1314        return getattr(self, f"supports_{layout_name}")
1315
1316    def sample_inputs_sparse(
1317        self, layout, device, dtype, requires_grad=False, **kwargs
1318    ):
1319        """Returns an iterable of SampleInputs that contain inputs with a
1320        specified sparse layout.
1321        """
1322        layout_name = str(layout).split(".")[-1]
1323        sample_inputs_mth = getattr(self, "sample_inputs_" + layout_name)
1324
1325        def non_empty_sampler(op, generator):
1326            found_sample = False
1327            for sample in generator:
1328                found_sample = True
1329                yield sample
1330            if not found_sample:
1331                raise unittest.SkipTest("NO SAMPLES!")
1332
1333        return non_empty_sampler(
1334            self,
1335            sample_inputs_mth(device, dtype, requires_grad=requires_grad, **kwargs),
1336        )
1337
1338    def _sample_inputs_unspecified(self, *args, **kwargs):
1339        """Raises an NotImplemented exception in a OpInfo instance creation
1340        that specifies supports_sparse(|_csr|_csc|_bsr|_bsc)=True
1341        without specifying the corresponding sample function as
1342        sample_inputs_sparse_(coo|csr|csc|bsr|bsc)_func.
1343
1344        To avoid this, either define the corresponding sample function,
1345        or re-map unsupported samples to error inputs in an appropiate
1346
1347          opinfo/definitions/sparse.py:_validate_sample_input_sparse_<op>
1348
1349        function.
1350        """
1351        raise NotImplementedError("no sample function specified")
1352
1353    def sample_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs):
1354        """Returns an iterable of SampleInputs that contain inputs with sparse
1355        coo layout.
1356        """
1357        return self.sample_inputs_sparse_coo_func(
1358            self, device, dtype, requires_grad, **kwargs
1359        )
1360
1361    def sample_inputs_sparse_csr(self, device, dtype, requires_grad=False, **kwargs):
1362        """Returns an iterable of SampleInputs that contain inputs with sparse
1363        csr layout.
1364        """
1365        return self.sample_inputs_sparse_csr_func(
1366            self, device, dtype, requires_grad, **kwargs
1367        )
1368
1369    def sample_inputs_sparse_csc(self, device, dtype, requires_grad=False, **kwargs):
1370        """Returns an iterable of SampleInputs that contain inputs with sparse
1371        csc layout.
1372        """
1373        return self.sample_inputs_sparse_csc_func(
1374            self, device, dtype, requires_grad, **kwargs
1375        )
1376
1377    def sample_inputs_sparse_bsr(self, device, dtype, requires_grad=False, **kwargs):
1378        """Returns an iterable of SampleInputs that contain inputs with sparse
1379        bsr layout.
1380        """
1381        return self.sample_inputs_sparse_bsr_func(
1382            self, device, dtype, requires_grad, **kwargs
1383        )
1384
1385    def sample_inputs_sparse_bsc(self, device, dtype, requires_grad=False, **kwargs):
1386        """Returns an iterable of SampleInputs that contain inputs with sparse
1387        bsc layout.
1388        """
1389        return self.sample_inputs_sparse_bsc_func(
1390            self, device, dtype, requires_grad, **kwargs
1391        )
1392
1393    def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
1394        """Returns the decorators targeting the given test."""
1395        result = []
1396        for decorator in self.decorators:
1397            if isinstance(decorator, DecorateInfo):
1398                if decorator.is_active(
1399                    test_class, test_name, device, dtype, param_kwargs
1400                ):
1401                    result.extend(decorator.decorators)
1402            else:
1403                result.append(decorator)
1404        return result
1405
1406    def supported_dtypes(self, device_type):
1407        if device_type == "privateuse1":
1408            device_type = torch._C._get_privateuse1_backend_name()
1409        device_type = torch.device(device_type).type
1410        if device_type == "cuda":
1411            return self.dtypesIfROCM if TEST_WITH_ROCM else self.dtypesIfCUDA
1412        if device_type == "xpu":
1413            return self.dtypesIfXPU
1414        if device_type == "hpu":
1415            return self.dtypesIfHpu
1416        return self.dtypes
1417
1418    def supported_backward_dtypes(self, device_type):
1419        if not self.supports_autograd:
1420            return set()
1421
1422        if device_type == "privateuse1":
1423            device_type = torch._C._get_privateuse1_backend_name()
1424        device_type = torch.device(device_type).type
1425        backward_dtypes = None
1426        if device_type == "cuda":
1427            backward_dtypes = (
1428                self.backward_dtypesIfROCM
1429                if TEST_WITH_ROCM
1430                else self.backward_dtypesIfCUDA
1431            )
1432        elif device_type == "hpu":
1433            backward_dtype = self.backward_dtypesIfHpu
1434        else:
1435            backward_dtypes = self.backward_dtypes
1436
1437        allowed_backward_dtypes = floating_and_complex_types_and(
1438            torch.bfloat16, torch.float16, torch.complex32
1439        )
1440        return set(allowed_backward_dtypes).intersection(backward_dtypes)
1441
1442    def supports_dtype(self, dtype, device_type) -> bool:
1443        return dtype in self.supported_dtypes(device_type)
1444
1445    @property
1446    def full_name(self):
1447        """Returns a full name that helps to uniquely identify this OpInfo."""
1448        variant = "." + self.variant_test_name if self.variant_test_name else ""
1449        # example: "normal.in_place" where "normal" is the name and "in_place" is the variant
1450        return f"{self.name}{variant}"
1451
1452    @property
1453    def formatted_name(self):
1454        """Returns a formatted full name for this OpInfo that can be used in test names."""
1455        return self.full_name.replace(".", "_")
1456
1457
1458def _generate_reduction_inputs(device, dtype, requires_grad, **kwargs):
1459    """Generates input tensors for testing reduction operators"""
1460    yield make_tensor([], dtype=dtype, device=device, requires_grad=requires_grad)
1461    yield make_tensor([2], dtype=dtype, device=device, requires_grad=requires_grad)
1462    yield make_tensor([3, 5], dtype=dtype, device=device, requires_grad=requires_grad)
1463    yield make_tensor(
1464        [3, 2, 1, 2], dtype=dtype, device=device, requires_grad=requires_grad
1465    )
1466
1467
1468def _generate_reduction_kwargs(ndim, supports_multiple_dims=True):
1469    """Generates a subset of all valid dim and keepdim kwargs given ndim that
1470    is appropriate for testing reduction operators.
1471    """
1472
1473    # Test default dim and keepdim
1474    yield {}
1475
1476    # Test reducing inner and outer most dimensions
1477    yield {"dim": 0, "keepdim": True}
1478    yield {"dim": -1, "keepdim": False}
1479
1480    # Test reducing middle dimension
1481    if ndim > 2:
1482        yield {"dim": ndim // 2, "keepdim": True}
1483
1484    if supports_multiple_dims:
1485        # Test reducing all dimensions
1486        yield {"dim": tuple(range(ndim)), "keepdim": False}
1487
1488        # Test reducing both first and last dimensions
1489        if ndim > 1:
1490            yield {"dim": (0, -1), "keepdim": True}
1491
1492        # Test reducing every other dimension starting with the second
1493        if ndim > 3:
1494            yield {"dim": tuple(range(1, ndim, 2)), "keepdim": False}
1495
1496
1497def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs):
1498    """Sample inputs for reduction operators."""
1499
1500    # TODO(@heitorschueroff) Once all reduction operators are using
1501    # ReductionOpInfo use op_info.supports_multiple_dims directly.
1502    supports_multiple_dims: bool = kwargs.get("supports_multiple_dims", True)
1503
1504    # TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo
1505    # use op_info.generate_args_kwargs directly.
1506    generate_args_kwargs = kwargs.get(
1507        "generate_args_kwargs", lambda *args, **kwargs: (yield (), {})
1508    )
1509
1510    for t in _generate_reduction_inputs(device, dtype, requires_grad):
1511        for reduction_kwargs in _generate_reduction_kwargs(
1512            t.ndim, supports_multiple_dims
1513        ):
1514            for args, kwargs in generate_args_kwargs(t, **reduction_kwargs):
1515                kwargs.update(reduction_kwargs)
1516                yield SampleInput(
1517                    t.detach().requires_grad_(requires_grad), args=args, kwargs=kwargs
1518                )
1519
1520
1521# NOTE [Reductions]:
1522#
1523# For testing purposes, we relax the definition of a reduction operator
1524# as defined in the docstring below. We do this to capture operators with
1525# a similar API so they can be tested automatically. However...
1526#
1527# Strictly speaking a reduction operator is an operator that can reduce an
1528# array to a single scalar value and that can be computed from the partial
1529# result of reducing subarrays. This usually means that the reduction operation
1530# should be commutative and associative. This definition is important when it
1531# comes to implementation as it determines how a reduction can be parallelized.
1532#
1533# For example, many summary statistics such as median, mode and quantile cannot
1534# be computed from partial results because these are sorting and counting based
1535# algorithms that need information that would be lost in the reduced value.
1536class ReductionOpInfo(OpInfo):
1537    """Reduction operator information.
1538
1539    An operator is a reduction operator if it reduces one or more dimensions of
1540    the input tensor to a single value. Reduction operators must implement the
1541    following signature:
1542
1543    - `op(input, *args, *, dim=None, keepdim=False, **kwargs) -> Tensor`
1544
1545    ReductionOpInfo tests that reduction operators implement a consistent API.
1546    Optional features such as reducing over multiple dimensions are captured in
1547    the optional keyword parameters of the ReductionOpInfo constructor.
1548
1549    If a reduction operator does not yet implement the full required API of
1550    reduction operators, this should be documented by xfailing the failing
1551    tests rather than adding optional parameters to ReductionOpInfo.
1552
1553    NOTE
1554    The API for reduction operators has not yet been finalized and some
1555    requirements may change.
1556
1557    See tests in test/test_reductions.py
1558    """
1559
1560    def __init__(
1561        self,
1562        name,
1563        *,
1564        # The identity value for the operator if it has one.
1565        identity: Optional[Any] = None,
1566        # The nan policy for the operator if it implements one.
1567        # - propagate: NaN values are propagated to the output
1568        # - omit: NaN values are discarded during the reduction
1569        nan_policy: Optional[str] = None,
1570        # Whether the operator supports reducing multiple dimensions.
1571        supports_multiple_dims: bool = True,
1572        # Whether the operator promotes integral to floating point dtypes.
1573        promotes_int_to_float: bool = False,
1574        # Whether the operator promotes all integral dtypes to int64.
1575        promotes_int_to_int64: bool = False,
1576        # If a specific dtype is given, then the operator always returns that
1577        # dtype irrespective of the input dtype. If None, the operator returns
1578        # the dtype according to the type promotion rules above.
1579        result_dtype: Optional[torch.dtype] = None,
1580        # Casts complex results to real (e.g. linalg.norm or torch.var)
1581        complex_to_real: bool = False,
1582        # ReductionOpInfo tests generate their own input, dim and keepdim
1583        # arguments and call this function to generate tuples of extra args and
1584        # kwargs to use when calling the op. This is required for operators that
1585        # have other required parameters besides the input tensor.
1586        generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: (
1587            yield (),
1588            {},
1589        ),
1590        # Options from the OpInfo base class
1591        **kwargs,
1592    ):
1593        self._original_reduction_args = locals().copy()
1594        assert nan_policy in (None, "propagate", "omit")
1595
1596        # These are mutually exclusive options
1597        assert not (result_dtype and promotes_int_to_float)
1598        assert not (result_dtype and promotes_int_to_int64)
1599        assert not (result_dtype and complex_to_real)
1600        assert not (promotes_int_to_float and promotes_int_to_int64)
1601
1602        # Default sample_inputs_func for ReductionOpInfo which augments sample
1603        # inputs from sample_inputs_reduction with the args and kwargs from
1604        # generate_args_kwargs. This is only used if sample_inputs_func is None.
1605        def sample_inputs_func(*args, **kwargs):
1606            kwargs["supports_multiple_dims"] = supports_multiple_dims
1607            kwargs["generate_args_kwargs"] = generate_args_kwargs
1608            yield from sample_inputs_reduction(*args, **kwargs)
1609
1610        # Override OpInfo defaults and call base class __init__
1611        kwargs.setdefault("inplace_variant", None)
1612        kwargs.setdefault("sample_inputs_func", sample_inputs_func)
1613        super().__init__(name, promotes_int_to_float=promotes_int_to_float, **kwargs)
1614
1615        self.identity = identity
1616        self.nan_policy = nan_policy
1617        self.supports_multiple_dims = supports_multiple_dims
1618        self.promotes_int_to_int64 = promotes_int_to_int64
1619        self.complex_to_real = complex_to_real
1620        self.result_dtype = result_dtype
1621        self.generate_args_kwargs = generate_args_kwargs
1622
1623
1624# The base reference input generation for elementwise binary operations
1625def _reference_inputs_elementwise_binary(
1626    op, device, dtype, requires_grad, exclude_zero, **kwargs
1627):
1628    yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
1629    yield from generate_elementwise_binary_tensors(
1630        op,
1631        device=device,
1632        dtype=dtype,
1633        requires_grad=requires_grad,
1634        exclude_zero=exclude_zero,
1635    )
1636    if dtype is not torch.bool:
1637        yield from generate_elementwise_binary_small_value_tensors(
1638            op, device=device, dtype=dtype, requires_grad=requires_grad
1639        )
1640    if dtype not in (torch.bool, torch.uint8, torch.int8):
1641        yield from generate_elementwise_binary_large_value_tensors(
1642            op, device=device, dtype=dtype, requires_grad=requires_grad
1643        )
1644    yield from generate_elementwise_binary_broadcasting_tensors(
1645        op,
1646        device=device,
1647        dtype=dtype,
1648        requires_grad=requires_grad,
1649        exclude_zero=exclude_zero,
1650    )
1651    yield from generate_elementwise_binary_with_scalar_samples(
1652        op, device=device, dtype=dtype, requires_grad=requires_grad
1653    )
1654
1655    yield from generate_elementwise_binary_with_scalar_and_type_promotion_samples(
1656        op, device=device, dtype=dtype, requires_grad=requires_grad
1657    )
1658
1659    if dtype.is_floating_point or dtype.is_complex:
1660        yield from generate_elementwise_binary_extremal_value_tensors(
1661            op, device=device, dtype=dtype, requires_grad=requires_grad
1662        )
1663
1664
1665# Note that these references inputs use scalars for the SampleInput.input value,
1666#   and many tests require SampleInput.input be a tensor or a list of tensors
1667def reference_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs):
1668    if hasattr(op, "rhs_make_tensor_kwargs"):
1669        exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False)
1670
1671    gen = partial(
1672        _reference_inputs_elementwise_binary,
1673        op,
1674        device,
1675        dtype,
1676        requires_grad,
1677        exclude_zero,
1678        **kwargs,
1679    )
1680
1681    # yields "normal" samples
1682    yield from gen()
1683
1684    # yields noncontiguous samples
1685    for sample in gen():
1686        yield sample.noncontiguous()
1687
1688    yield from generate_elementwise_binary_noncontiguous_tensors(
1689        op,
1690        device=device,
1691        dtype=dtype,
1692        requires_grad=requires_grad,
1693        exclude_zero=exclude_zero,
1694    )
1695
1696    yield from generate_elementwise_binary_arbitrarily_strided_tensors(
1697        op,
1698        device=device,
1699        dtype=dtype,
1700        requires_grad=requires_grad,
1701        exclude_zero=exclude_zero,
1702    )
1703
1704
1705# A functional that extends an elementwise binary operator's bespoke error inputs
1706#   with generic error inputs for the class of elementwise binary operations
1707def make_error_inputs_elementwise_binary(error_inputs_func):
1708    def error_inputs_func_wrapper(op, device, **kwargs):
1709        if error_inputs_func is not None:
1710            yield from error_inputs_func(op, device, **kwargs)
1711
1712        if not op.supports_rhs_python_scalar:
1713            si = SampleInput(torch.tensor((1, 2, 3), device=device), args=(2,))
1714            yield ErrorInput(si, error_type=Exception, error_regex="")
1715
1716        if not op.supports_one_python_scalar:
1717            si = SampleInput(2, args=(torch.tensor((1, 2, 3), device=device),))
1718            yield ErrorInput(si, error_type=Exception, error_regex="")
1719
1720        if (
1721            not kwargs.get("skip_two_python_scalars", False)
1722            and not op.supports_two_python_scalars
1723        ):
1724            si = SampleInput(2, args=(3,))
1725            yield ErrorInput(si, error_type=Exception, error_regex="")
1726
1727    return error_inputs_func_wrapper
1728
1729
1730# The following functions and classes are for testing elementwise binary operators.
1731
1732
1733# Returns a generator of pairs of contiguous tensors on the requested device
1734#   and with the requested dtype.
1735#
1736# This function is intended to test the non-vectorized and vectorized code
1737#   paths of elementwise binary functions, as well as their handling of odd tensor
1738#   sizes (like zero-dim tensors and tensors with zero elements).
1739#
1740# Each iterable will include an a tensor with no elements,
1741#   zero dim (scalar) tensors, small 1D tensors, a medium 1D tensor, and
1742#   a large 2D tensor.
1743def generate_elementwise_binary_tensors(
1744    op, *, device, dtype, requires_grad=False, exclude_zero=False
1745):
1746    shapes = (
1747        # tensors with no elements
1748        (0,),
1749        (1, 0, 3),
1750        # zero dim (scalar) tensor
1751        (),
1752        # small 1D tensor
1753        (20,),
1754        # medium 1D tensor
1755        (812,),
1756        # large 2D tensor
1757        (1029, 917),
1758    )
1759
1760    make_arg = partial(
1761        make_tensor,
1762        device=device,
1763        dtype=dtype,
1764        requires_grad=requires_grad,
1765        exclude_zero=exclude_zero,
1766    )
1767    for shape in shapes:
1768        lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
1769        rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
1770        yield SampleInput(lhs, args=(rhs,))
1771
1772
1773def generate_elementwise_binary_arbitrarily_strided_tensors(
1774    op, *, device, dtype, requires_grad=False, exclude_zero=False
1775):
1776    # shape, strides, offset
1777    strided_cases = (
1778        ((5, 6, 2), (1, 1, 7), 2),
1779        ((5, 5, 4), (1, 1, 7), 2),
1780        ((5, 5, 2), (4, 5, 7), 3),
1781        ((5, 5, 2), (5, 5, 7), 3),
1782        ((5, 5, 2), (5, 5, 5), 3),
1783        ((9, 5, 2), (0, 1, 7), 3),
1784    )
1785
1786    make_arg = partial(
1787        make_tensor,
1788        device=device,
1789        dtype=dtype,
1790        requires_grad=requires_grad,
1791        exclude_zero=exclude_zero,
1792    )
1793    for shape, strides, offset in strided_cases:
1794        a = make_arg(
1795            500,
1796        ).as_strided(shape, strides, offset)
1797        b = make_arg(shape)
1798        yield SampleInput(a, args=(b,))
1799
1800
1801# Returns a generator of pairs of contiguous tensors on the requested device and with
1802#   the requested dtype.
1803#
1804# Unlike the previous function, the values in these tensors are specified manually.
1805def generate_elementwise_binary_small_value_tensors(
1806    op, *, device, dtype, requires_grad=False, exclude_zero=None
1807):
1808    if exclude_zero is None:
1809        if hasattr(op, "rhs_make_tensor_kwargs"):
1810            exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False)
1811
1812    # defines interesting values
1813    _unsigned_int_vals = (0, 1, 55, 127, 128, 190, 210, 220, 254)
1814    _int_vals = (0, -1, 1, -55, 55, -127, 127, -128)
1815    _float_vals = (
1816        0.0,
1817        -0.0,
1818        -0.001,
1819        0.001,
1820        -0.25,
1821        0.25,
1822        -1.0,
1823        1.0,
1824        -math.pi / 2,
1825        math.pi / 2,
1826        -math.pi + 0.00001,
1827        math.pi - 0.00001,
1828        -math.pi,
1829        math.pi,
1830        -math.pi - 0.00001,
1831        math.pi + 0.00001,
1832    )
1833
1834    l_vals = []
1835    r_vals = []
1836
1837    if dtype.is_floating_point:
1838        prod = product(_float_vals, _float_vals)
1839    elif dtype.is_complex:
1840        complex_vals = product(_float_vals, _float_vals)
1841        # Note the use of list is required here or the map generator will be
1842        #  emptied by the following product and it won't produce the desired cross-product
1843        complex_vals = [complex(*x) for x in complex_vals]
1844        prod = product(complex_vals, complex_vals)
1845    elif dtype in (torch.int8, torch.int16, torch.int32, torch.int64):
1846        prod = product(_int_vals, _int_vals)
1847    elif dtype is torch.uint8:
1848        prod = product(_unsigned_int_vals, _unsigned_int_vals)
1849    else:
1850        raise ValueError("Unsupported dtype!")
1851
1852    for l, r in prod:
1853        l_vals.append(l)
1854        if r == 0 and exclude_zero:
1855            r_vals.append(1)
1856        else:
1857            r_vals.append(r)
1858
1859    lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
1860    rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
1861
1862    yield SampleInput(lhs, args=(rhs,))
1863
1864
1865def generate_elementwise_binary_large_value_tensors(
1866    op, *, device, dtype, requires_grad=False
1867):
1868    _large_int_vals = (-1113, 1113, -10701, 10701)
1869    _large_float16_vals = (-501, 501, -1001.2, 1001.2, -13437.7, 13437.7)
1870    _large_float_vals = _large_float16_vals + (-4988429.2, 4988429.2, -1e20, 1e20)
1871
1872    l_vals = []
1873    r_vals = []
1874
1875    if dtype == torch.float16:
1876        prod = product(_large_float16_vals, _large_float16_vals)
1877    elif dtype.is_floating_point:
1878        prod = product(_large_float_vals, _large_float_vals)
1879    elif dtype.is_complex:
1880        complex_vals = product(_large_float_vals, _large_float_vals)
1881        # Note the use of list is required here or the map generator will be
1882        #  emptied by the following product and it won't produce the desired cross-product
1883        complex_vals = [complex(*x) for x in complex_vals]
1884        prod = product(complex_vals, complex_vals)
1885    elif dtype in (torch.int16, torch.int32, torch.int64):
1886        prod = product(_large_int_vals, _large_int_vals)
1887    else:
1888        raise ValueError("Unsupported dtype!")
1889
1890    for l, r in prod:
1891        l_vals.append(l)
1892        r_vals.append(r)
1893
1894    lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
1895    rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
1896
1897    yield SampleInput(lhs, args=(rhs,))
1898
1899
1900def generate_elementwise_binary_extremal_value_tensors(
1901    op, *, device, dtype, requires_grad=False
1902):
1903    _float_extremals = (float("inf"), float("-inf"), float("nan"))
1904
1905    l_vals = []
1906    r_vals = []
1907
1908    if dtype.is_floating_point:
1909        prod = product(_float_extremals, _float_extremals)
1910    elif dtype.is_complex:
1911        complex_vals = product(_float_extremals, _float_extremals)
1912        # Note the use of list is required here or the map generator will be
1913        #  emptied by the following product and it won't produce the desired cross-product
1914        complex_vals = [complex(*x) for x in complex_vals]
1915        prod = product(complex_vals, complex_vals)
1916    else:
1917        raise ValueError("Unsupported dtype!")
1918
1919    for l, r in prod:
1920        l_vals.append(l)
1921        r_vals.append(r)
1922
1923    lhs = torch.tensor(l_vals, device=device, dtype=dtype, requires_grad=requires_grad)
1924    rhs = torch.tensor(r_vals, device=device, dtype=dtype, requires_grad=requires_grad)
1925
1926    yield SampleInput(lhs, args=(rhs,))
1927
1928    # Test case for NaN propagation
1929    nan = (
1930        float("nan") if dtype.is_floating_point else complex(float("nan"), float("nan"))
1931    )
1932    lhs = make_tensor(
1933        (128, 128), device=device, dtype=dtype, requires_grad=requires_grad
1934    )
1935    lhs.view(-1)[::3] = nan
1936    rhs = make_tensor(
1937        (128, 128), device=device, dtype=dtype, requires_grad=requires_grad
1938    )
1939    rhs.view(-1)[::3] = nan
1940
1941    yield SampleInput(lhs, args=(rhs,))
1942
1943
1944# Returns a generator of pairs of contiguous and noncontiguous tensors that
1945#   require broadcasting
1946def generate_elementwise_binary_broadcasting_tensors(
1947    op, *, device, dtype, requires_grad=False, exclude_zero=False
1948):
1949    shapes = (
1950        ((1,), ()),
1951        ((2,), ()),
1952        ((1,), (2,)),
1953        ((2, 1), (2,)),
1954        ((1, 2), (2,)),
1955        ((3, 2), (2,)),
1956        ((1, 3, 2), (2,)),
1957        ((1, 3, 2), (3, 2)),
1958        ((3, 1, 2), (3, 2)),
1959        ((2, 3, 2), ()),
1960        ((3, 1, 2), (1, 3, 2)),
1961    )
1962
1963    make_arg = partial(
1964        make_tensor,
1965        device=device,
1966        dtype=dtype,
1967        requires_grad=requires_grad,
1968        exclude_zero=exclude_zero,
1969    )
1970    for shape, noncontiguous in product(shapes, [True, False]):
1971        shape_lhs, shape_rhs = shape
1972        lhs = make_arg(
1973            shape_lhs, noncontiguous=noncontiguous, **op.lhs_make_tensor_kwargs
1974        )
1975        rhs = make_arg(
1976            shape_rhs, noncontiguous=noncontiguous, **op.rhs_make_tensor_kwargs
1977        )
1978
1979        yield SampleInput(lhs, args=(rhs,), broadcasts_input=True)
1980
1981
1982# Returns a generator of pairs of contiguous tensors and scalars
1983def generate_elementwise_binary_with_scalar_samples(
1984    op, *, device, dtype, requires_grad=False
1985):
1986    make_arg = partial(
1987        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
1988    )
1989
1990    shapes = ((), (3,), (5, 3), (0, 1, 3), (1, 5))
1991    if op.supports_rhs_python_scalar:
1992        for shape in shapes:
1993            lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
1994            rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
1995            lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item()
1996            rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item()
1997
1998            yield SampleInput(lhs, args=(rhs_scalar,))
1999
2000        # Extends with scalar lhs
2001        if op.supports_one_python_scalar:
2002            yield SampleInput(lhs_scalar, args=(rhs,))
2003
2004    if op.supports_two_python_scalars:
2005        lhs_scalar = make_arg((), **op.lhs_make_tensor_kwargs).item()
2006        rhs_scalar = make_arg((), **op.rhs_make_tensor_kwargs).item()
2007
2008        yield SampleInput(lhs_scalar, args=(rhs_scalar,))
2009
2010
2011# Returns a generator of pairs of contiguous tensors and 0d tensors and scalars and type promotion
2012def generate_elementwise_binary_with_scalar_and_type_promotion_samples(
2013    op, *, device, dtype, requires_grad=False
2014):
2015    # add these samples only for logical and comparison ops, arithmetic ops are not happy about extremal scalars
2016    if op.name in (
2017        "eq",
2018        "ne",
2019        "gt",
2020        "ge",
2021        "lt",
2022        "le",
2023        "logical_and",
2024        "logical_or",
2025        "logical_xor",
2026    ):
2027        make_arg = partial(
2028            make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
2029        )
2030        shape = (
2031            23,
2032        )  # this shape is big enough to trigger vectorization, and has non-vectorized tail
2033        values = (float("nan"), float("inf"), -float("inf"))
2034        scalar_tensors = tuple(torch.tensor(val) for val in values)
2035        if op.supports_rhs_python_scalar:
2036            lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
2037            rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
2038            for scalar in values + scalar_tensors:
2039                yield SampleInput(lhs, args=(scalar,))
2040                # Extends with scalar lhs
2041                if op.supports_one_python_scalar:
2042                    yield SampleInput(scalar, args=(rhs,))
2043
2044
2045# Returns a generator of pairs of noncontiguous tensors
2046def generate_elementwise_binary_noncontiguous_tensors(
2047    op, *, device, dtype, requires_grad=False, exclude_zero=False
2048):
2049    make_arg = partial(
2050        make_tensor,
2051        device=device,
2052        dtype=dtype,
2053        requires_grad=requires_grad,
2054        exclude_zero=exclude_zero,
2055    )
2056
2057    # Generic noncontiguity
2058    lhs = make_arg((1026,), noncontiguous=True, **op.lhs_make_tensor_kwargs)
2059    rhs = make_arg((1026,), noncontiguous=True, **op.rhs_make_tensor_kwargs)
2060
2061    yield SampleInput(lhs.clone(), args=(rhs.clone(),))
2062    yield SampleInput(lhs.contiguous(), args=(rhs,))
2063
2064    # Transposed
2065    lhs = make_arg((789, 357), **op.lhs_make_tensor_kwargs)
2066    rhs = make_arg((789, 357), **op.rhs_make_tensor_kwargs)
2067
2068    yield SampleInput(lhs.T, args=(rhs.T,))
2069
2070    # More noncontiguity
2071    shapes = ((5, 7), (1024,))
2072
2073    for shape in shapes:
2074        lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
2075        rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
2076
2077        lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0]
2078        lhs_non_contig.copy_(lhs)
2079
2080        rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[..., 0]
2081        rhs_non_contig.copy_(rhs)
2082
2083        yield SampleInput(lhs_non_contig.clone(), args=(rhs_non_contig.clone(),))
2084        yield SampleInput(lhs_non_contig.contiguous(), args=(rhs_non_contig,))
2085
2086    # Noncontiguous indices
2087    shape = (2, 2, 1, 2)
2088    lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
2089    rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
2090
2091    lhs_non_contig = lhs[:, 1, ...]
2092    rhs_non_contig = rhs[:, 1, ...]
2093
2094    yield SampleInput(lhs_non_contig.clone(), args=(rhs_non_contig.clone(),))
2095    yield SampleInput(lhs_non_contig.contiguous(), args=(rhs_non_contig,))
2096
2097    # Expanded tensors
2098    shapes = ((1, 3), (1, 7), (5, 7))
2099
2100    for shape in shapes:
2101        lhs = make_arg(shape, **op.lhs_make_tensor_kwargs)
2102        rhs = make_arg(shape, **op.rhs_make_tensor_kwargs)
2103
2104        lhs_non_contig = lhs.expand(3, -1, -1)
2105        rhs_non_contig = rhs.expand(3, -1, -1)
2106
2107        yield SampleInput(lhs_non_contig, args=(rhs_non_contig,))
2108
2109
2110# Sample inputs for elementwise binary operators, like add
2111def sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs):
2112    _M = S if kwargs.get("small_inputs_only", False) else M
2113    _S = XS if kwargs.get("small_inputs_only", False) else S
2114
2115    if hasattr(op, "rhs_make_tensor_kwargs"):
2116        exclude_zero = op.rhs_make_tensor_kwargs.get("exclude_zero", False)
2117
2118    make_arg = partial(
2119        make_tensor,
2120        device=device,
2121        dtype=dtype,
2122        requires_grad=requires_grad,
2123        exclude_zero=exclude_zero,
2124    )
2125
2126    shapes = (
2127        ((), ()),
2128        ((_S,), ()),
2129        ((_S, 1), (_S,)),
2130        ((_M, _S), ()),
2131        ((_S, _M, _S), (_M, _S)),
2132        ((_S, _M, _S), (_S, _M, _S)),
2133        ((_M, 1, _S), (_M, _S)),
2134        ((_M, 1, _S), (1, _M, _S)),
2135        ((0, 1, XS), (0, _M, XS)),
2136    )
2137
2138    sample_kwargs = kwargs.get("sample_kwargs", {})
2139
2140    for shape_lhs, shape_rhs in shapes:
2141        lhs = make_arg(shape_lhs, **op.lhs_make_tensor_kwargs)
2142        rhs = make_arg(shape_rhs, **op.rhs_make_tensor_kwargs)
2143        broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)
2144
2145        yield SampleInput(
2146            lhs, args=(rhs,), kwargs=sample_kwargs, broadcasts_input=broadcasts_input
2147        )
2148
2149
2150# Metadata class for binary "universal functions (ufuncs)" that accept two
2151# tensor and have common properties
2152class BinaryUfuncInfo(OpInfo):
2153    """Operator information for 'universal binary functions (binary ufuncs).'
2154    These are functions of two tensors with common properties like:
2155      - they are elementwise functions
2156      - the output shape is determined by the input shape
2157      - they typically have method and inplace variants
2158      - they typically support the out kwarg
2159      - they typically have NumPy or SciPy references
2160    See NumPy's universal function documentation
2161    (https://numpy.org/doc/stable/reference/ufuncs.html) for more details
2162    about the concept of ufuncs.
2163    """
2164
2165    def __init__(
2166        self,
2167        name,
2168        *,
2169        sample_inputs_func=sample_inputs_elementwise_binary,
2170        reference_inputs_func=reference_inputs_elementwise_binary,
2171        error_inputs_func=None,
2172        lhs_make_tensor_kwargs=None,
2173        rhs_make_tensor_kwargs=None,
2174        always_returns_bool=False,  # Set to true if the op always returns bool tensors
2175        supports_rhs_python_scalar=True,  # Whether the operator allows Tensor x scalar inputs
2176        supports_one_python_scalar=False,  # Whether the operator allows scalar x tensor and tensor x scalar inputs
2177        supports_two_python_scalars=False,  # Whether the operator allows scalar x scalar inputs
2178        **kwargs,
2179    ):
2180        self._original_binary_ufunc_args = locals().copy()
2181
2182        # Elementwise binary operations perform the equivalent of test_numpy_refs
2183        #   in test_binary_ufuncs, but with additional test granularity. So the
2184        #   generic test_ops.py test is skipped because it's redundant.
2185        common_skips = (
2186            DecorateInfo(
2187                unittest.skip("Skipping redundant test."),
2188                "TestCommon",
2189                "test_numpy_refs",
2190            ),
2191        )
2192        kwargs["skips"] = kwargs.get("skips", ()) + common_skips
2193        super().__init__(
2194            name,
2195            sample_inputs_func=sample_inputs_func,
2196            reference_inputs_func=reference_inputs_func,
2197            error_inputs_func=make_error_inputs_elementwise_binary(error_inputs_func),
2198            **kwargs,
2199        )
2200
2201        # [lr]hs_make_tensor_kwargs are part of the OpInfo to be able to dynamically generate valid samples later on.
2202        if lhs_make_tensor_kwargs is None:
2203            lhs_make_tensor_kwargs = {}
2204        self.lhs_make_tensor_kwargs = lhs_make_tensor_kwargs
2205
2206        if rhs_make_tensor_kwargs is None:
2207            rhs_make_tensor_kwargs = {}
2208        self.rhs_make_tensor_kwargs = rhs_make_tensor_kwargs
2209
2210        self.always_returns_bool = always_returns_bool
2211        self.supports_rhs_python_scalar = supports_rhs_python_scalar
2212        self.supports_one_python_scalar = supports_one_python_scalar
2213        self.supports_two_python_scalars = supports_two_python_scalars
2214
2215        if self.supports_two_python_scalars:
2216            self.supports_one_python_scalar = True
2217
2218        if self.supports_one_python_scalar:
2219            assert (
2220                supports_rhs_python_scalar
2221            ), "Can't support lhs and rhs Python scalars but not rhs scalars!"
2222
2223
2224# The following functions and classes are for testing elementwise unary operators.
2225def sample_inputs_elementwise_unary(
2226    op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
2227):
2228    if not op_kwargs:
2229        op_kwargs = {}
2230
2231    _L = S if kwargs.get("small_inputs_only", False) else L
2232
2233    low, high = op_info.domain
2234    is_floating = dtype.is_floating_point or dtype.is_complex
2235    low = low if low is None or not is_floating else low + op_info._domain_eps
2236    high = high if high is None or not is_floating else high - op_info._domain_eps
2237    if (
2238        op_info.supports_sparse_csr
2239        or op_info.supports_sparse_csc
2240        or op_info.supports_sparse_bsr
2241        or op_info.supports_sparse_bsc
2242    ):
2243        # Tensors with dim=2 for sparse compressed testing
2244        yield SampleInput(
2245            make_tensor(
2246                (_L, _L),
2247                device=device,
2248                dtype=dtype,
2249                low=low,
2250                high=high,
2251                requires_grad=requires_grad,
2252            ),
2253            kwargs=op_kwargs,
2254        )
2255    else:
2256        # Creates a 1D, empty, and scalar tensor
2257        for shape in ((_L,), (1, 0, 3), ()):
2258            yield SampleInput(
2259                make_tensor(
2260                    shape,
2261                    device=device,
2262                    dtype=dtype,
2263                    low=low,
2264                    high=high,
2265                    requires_grad=requires_grad,
2266                ),
2267                kwargs=op_kwargs,
2268            )
2269
2270
2271# Replace values satisfying condition with a safe value. This is used to block
2272# out values the could cause singularity like tan(pi/2)
2273def _replace_values_in_tensor(tensor, condition, safe_value):
2274    mask = condition(tensor)
2275    tensor.masked_fill_(mask, safe_value)
2276
2277
2278# Helper to create a unary elementwise tensor with valid inputs
2279def _make_unary_elementwise_tensor(shape, *, op, dtype, **kwargs):
2280    low, high = op.domain
2281    is_floating = dtype.is_floating_point or dtype.is_complex
2282    low = low if low is None or not is_floating else low + op._domain_eps
2283    high = high if high is None or not is_floating else high - op._domain_eps
2284
2285    a = make_tensor(shape, low=low, high=high, dtype=dtype, **kwargs)
2286
2287    if op.reference_numerics_filter is not None and dtype is not torch.bool:
2288        condition, safe_value = op.reference_numerics_filter
2289        _replace_values_in_tensor(a, condition, safe_value)
2290
2291    return a
2292
2293
2294# Restricts the values in the tensor to the domain of the
2295# given elementwise unary operator
2296def _filter_unary_elementwise_tensor(a, *, op):
2297    # short-circuits for boolean tensors
2298    if a.dtype is torch.bool:
2299        return a
2300
2301    low, high = op.domain
2302    is_floating = a.dtype.is_floating_point or a.dtype.is_complex
2303    low = low if low is None or not is_floating else low + op._domain_eps
2304    high = high if high is None or not is_floating else high - op._domain_eps
2305
2306    if a.dtype is torch.uint8 and low is not None:
2307        low = max(low, 0)
2308
2309    if not a.dtype.is_floating_point and not a.dtype.is_complex:
2310        low = math.ceil(low) if low is not None else None
2311        high = math.floor(high) if high is not None else None
2312
2313    if op.reference_numerics_filter is not None:
2314        condition, safe_value = op.reference_numerics_filter
2315        _replace_values_in_tensor(a, condition, safe_value)
2316
2317    if low is not None or high is not None:
2318        if a.dtype.is_complex:
2319            a.real.clamp_(low, high)
2320            a.imag.clamp_(low, high)
2321        else:
2322            a.clamp_(min=low, max=high)
2323
2324    return a
2325
2326
2327def generate_elementwise_unary_tensors(op, *, device, dtype, requires_grad, **kwargs):
2328    # Special-cases bool
2329    if dtype is torch.bool:
2330        tensors = (
2331            torch.empty(0, device=device, dtype=torch.bool),
2332            torch.tensor(True, device=device),
2333            torch.tensor(False, device=device),
2334            torch.tensor((True, False), device=device),
2335            make_tensor((812,), device=device, dtype=dtype),
2336            make_tensor((1029, 917), device=device, dtype=dtype),
2337        )
2338        for a in tensors:
2339            yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0])
2340
2341    shapes = (
2342        (1029, 917),
2343        (812,),
2344        # Empty sizes
2345        (0,),
2346        (0, 3, 3),
2347        (1, 0, 5),
2348        (6, 0, 0, 0),
2349        (3, 0, 1, 0),
2350    )
2351
2352    make_arg = partial(
2353        _make_unary_elementwise_tensor,
2354        op=op,
2355        device=device,
2356        dtype=dtype,
2357        requires_grad=requires_grad,
2358    )
2359    for shape in shapes:
2360        a = make_arg(shape)
2361        yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0])
2362
2363
2364def generate_elementwise_unary_small_value_tensors(
2365    op, *, device, dtype, requires_grad=False
2366):
2367    for sample in generate_elementwise_binary_small_value_tensors(
2368        op, device=device, dtype=dtype, requires_grad=requires_grad
2369    ):
2370        a = _filter_unary_elementwise_tensor(sample.input, op=op)
2371        yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0])
2372
2373
2374def generate_elementwise_unary_large_value_tensors(
2375    op, *, device, dtype, requires_grad=False
2376):
2377    for sample in generate_elementwise_binary_large_value_tensors(
2378        op, device=device, dtype=dtype, requires_grad=requires_grad
2379    ):
2380        a = _filter_unary_elementwise_tensor(sample.input, op=op)
2381        yield SampleInput(sample.input, kwargs=op.sample_kwargs(device, dtype, a)[0])
2382
2383
2384def generate_elementwise_unary_extremal_value_tensors(
2385    op, *, device, dtype, requires_grad=False
2386):
2387    for sample in generate_elementwise_binary_extremal_value_tensors(
2388        op, device=device, dtype=dtype, requires_grad=requires_grad
2389    ):
2390        yield SampleInput(
2391            sample.input, kwargs=op.sample_kwargs(device, dtype, sample.input)[0]
2392        )
2393
2394
2395def generate_elementwise_unary_noncontiguous_tensors(
2396    op, *, device, dtype, requires_grad=False
2397):
2398    make_arg = partial(
2399        _make_unary_elementwise_tensor,
2400        op=op,
2401        device=device,
2402        dtype=dtype,
2403        requires_grad=requires_grad,
2404    )
2405
2406    # Generic noncontiguity
2407    t = make_arg((1026,), noncontiguous=True)
2408    yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0])
2409
2410    # Transposed
2411    t = make_arg((1024, 1024)).T
2412    yield SampleInput(t, kwargs=op.sample_kwargs(device, dtype, t)[0])
2413
2414    # Expanded tensors
2415    shapes = ((1, 3), (1, 7), (5, 7))
2416
2417    for shape in shapes:
2418        t = make_arg(shape)
2419        t_non_contig = t.expand(3, -1, -1)
2420        yield SampleInput(
2421            t_non_contig, kwargs=op.sample_kwargs(device, dtype, t_non_contig)[0]
2422        )
2423
2424
2425def generate_elementwise_unary_arbitrarily_strided_tensors(
2426    op, *, device, dtype, requires_grad=False
2427):
2428    # shape, strides, offset
2429    strided_cases = (
2430        ((5, 6, 2), (1, 1, 7), 2),
2431        ((5, 5, 4), (1, 1, 7), 2),
2432        ((5, 5, 2), (4, 5, 7), 3),
2433        ((5, 5, 2), (5, 5, 7), 3),
2434        ((5, 5, 2), (5, 5, 5), 3),
2435        ((9, 5, 2), (0, 1, 7), 3),
2436    )
2437
2438    make_arg = partial(
2439        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
2440    )
2441    for shape, strides, offset in strided_cases:
2442        a = make_arg(
2443            500,
2444        ).as_strided(shape, strides, offset)
2445        yield SampleInput(a, kwargs=op.sample_kwargs(device, dtype, a)[0])
2446
2447
2448# Reuses the elementwise binary generators for consistency
2449# TODO: in the future generalize the reference generators to handle n-ary elementwise operations
2450def _reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs):
2451    yield from op.sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
2452
2453    yield from generate_elementwise_unary_tensors(
2454        op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
2455    )
2456
2457    if dtype is not torch.bool:
2458        yield from generate_elementwise_unary_small_value_tensors(
2459            op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
2460        )
2461    if dtype not in (torch.bool, torch.uint8, torch.int8) and (
2462        op.handles_large_floats
2463        or (not dtype.is_floating_point and not dtype.is_complex)
2464    ):
2465        yield from generate_elementwise_unary_large_value_tensors(
2466            op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
2467        )
2468
2469    if dtype.is_floating_point or (
2470        op.handles_complex_extremal_values and dtype.is_complex
2471    ):
2472        yield from generate_elementwise_unary_extremal_value_tensors(
2473            op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
2474        )
2475
2476
2477def reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs):
2478    gen = partial(
2479        _reference_inputs_elementwise_unary, op, device, dtype, requires_grad, **kwargs
2480    )
2481
2482    # yields "normal" samples
2483    yield from gen()
2484
2485    # yields noncontiguous samples
2486    for sample in gen():
2487        yield sample.noncontiguous()
2488
2489    yield from generate_elementwise_unary_noncontiguous_tensors(
2490        op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
2491    )
2492
2493    yield from generate_elementwise_unary_arbitrarily_strided_tensors(
2494        op, device=device, dtype=dtype, requires_grad=requires_grad, **kwargs
2495    )
2496
2497
2498# Metadata class for unary "universal functions (ufuncs)" that accept a single
2499# tensor and have common properties like:
2500class UnaryUfuncInfo(OpInfo):
2501    """Operator information for 'universal unary functions (unary ufuncs).'
2502    These are functions of a single tensor with common properties like:
2503      - they are elementwise functions
2504      - the input shape is the output shape
2505      - they typically have method and inplace variants
2506      - they typically support the out kwarg
2507      - they typically have NumPy or SciPy references
2508    See NumPy's universal function documentation
2509    (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
2510    about the concept of ufuncs.
2511    """
2512
2513    def __init__(
2514        self,
2515        name,  # the string name of the function
2516        *,
2517        dtypes=floating_types(),
2518        domain=(None, None),  # the [low, high) domain of the function
2519        handles_complex_extremal_values=True,  # whether the op correctly handles extremal values (like nan/inf)
2520        handles_large_floats=True,  # whether the op correctly handles large float values (like 1e20)
2521        supports_complex_to_float=False,  # op supports casting from complex input to real output safely eg. angle
2522        sample_inputs_func=sample_inputs_elementwise_unary,
2523        reference_inputs_func=reference_inputs_elementwise_unary,
2524        sample_kwargs=lambda device, dtype, input: ({}, {}),
2525        reference_numerics_filter=None,  # Filters values in the range of the domain specified above but that should not be tested
2526        **kwargs,
2527    ):
2528        self._original_unary_ufunc_args = locals().copy()
2529
2530        super().__init__(
2531            name,
2532            dtypes=dtypes,
2533            sample_inputs_func=sample_inputs_func,
2534            reference_inputs_func=reference_inputs_func,
2535            **kwargs,
2536        )
2537        self.domain = domain
2538        self.handles_complex_extremal_values = handles_complex_extremal_values
2539        self.handles_large_floats = handles_large_floats
2540        self.supports_complex_to_float = supports_complex_to_float
2541        self.reference_numerics_filter = reference_numerics_filter
2542
2543        # test_unary_ufuncs.py generates its own inputs to test the consistency
2544        # of the operator on sliced tensors, non-contig tensors, etc.
2545        # `sample_kwargs` is a utility function to provide kwargs
2546        # along with those inputs if required (eg. clamp).
2547        # It should return two dictionaries, first holding kwarg for
2548        # torch operator and second one for reference NumPy operator.
2549        self.sample_kwargs = sample_kwargs
2550
2551        # Epsilon to ensure grad and gradgrad checks don't test values
2552        #   outside a function's domain.
2553        self._domain_eps = 1e-5
2554
2555
2556def sample_inputs_spectral_ops(self, device, dtype, requires_grad=False, **kwargs):
2557    is_fp16_or_chalf = dtype == torch.complex32 or dtype == torch.half
2558    if not is_fp16_or_chalf:
2559        nd_tensor = partial(
2560            make_tensor,
2561            (S, S + 1, S + 2),
2562            device=device,
2563            dtype=dtype,
2564            requires_grad=requires_grad,
2565        )
2566        oned_tensor = partial(
2567            make_tensor, (31,), device=device, dtype=dtype, requires_grad=requires_grad
2568        )
2569    else:
2570        # cuFFT supports powers of 2 for half and complex half precision
2571        # NOTE: For hfft, hfft2, hfftn, irfft, irfft2, irfftn with default args
2572        # where output_size n=2*(input_size - 1), we make sure that logical fft size is a power of two
2573        low = None
2574        high = None
2575        if self.name in ["fft.hfft", "fft.irfft", "_refs.fft.hfft", "_refs.fft.irfft"]:
2576            shapes = ((2, 9, 9), (33,))
2577        elif self.name in [
2578            "fft.hfft2",
2579            "fft.irfft2",
2580            "_refs.fft.hfft2",
2581            "_refs.fft.irfft2",
2582        ]:
2583            shapes = ((2, 8, 9), (33,))
2584        elif self.name in [
2585            "fft.hfftn",
2586            "fft.irfftn",
2587            "_refs.fft.hfftn",
2588            "_refs.fft.irfftn",
2589        ]:
2590            shapes = ((2, 2, 33), (33,))
2591            # Adjusting the limits because the test would be flaky due to over-saturation of float16
2592            # See: https://github.com/pytorch/pytorch/pull/81416
2593            low = -1.0
2594            high = 1.0
2595        else:
2596            shapes = ((2, 8, 16), (32,))
2597        nd_tensor = partial(
2598            make_tensor,
2599            shapes[0],
2600            device=device,
2601            low=low,
2602            high=high,
2603            dtype=dtype,
2604            requires_grad=requires_grad,
2605        )
2606        oned_tensor = partial(
2607            make_tensor,
2608            shapes[1],
2609            device=device,
2610            low=low,
2611            high=high,
2612            dtype=dtype,
2613            requires_grad=requires_grad,
2614        )
2615
2616    if self.ndimensional == SpectralFuncType.ND:
2617        yield SampleInput(
2618            nd_tensor(),
2619            s=(3, 10) if not is_fp16_or_chalf else (4, 8),
2620            dim=(1, 2),
2621            norm="ortho",
2622        )
2623        yield SampleInput(nd_tensor(), norm="ortho")
2624        yield SampleInput(nd_tensor(), s=(8,))
2625        yield SampleInput(oned_tensor())
2626        yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3, (0, -1)])
2627    elif self.ndimensional == SpectralFuncType.TwoD:
2628        yield SampleInput(
2629            nd_tensor(),
2630            s=(3, 10) if not is_fp16_or_chalf else (4, 8),
2631            dim=(1, 2),
2632            norm="ortho",
2633        )
2634        yield SampleInput(nd_tensor(), norm="ortho")
2635        yield SampleInput(nd_tensor(), s=(6, 8) if not is_fp16_or_chalf else (4, 8))
2636        yield SampleInput(nd_tensor(), dim=0)
2637        yield SampleInput(nd_tensor(), dim=(0, -1))
2638        yield SampleInput(nd_tensor(), dim=(-3, -2, -1))
2639    else:
2640        yield SampleInput(
2641            nd_tensor(),
2642            n=10 if not is_fp16_or_chalf else 8,
2643            dim=1,
2644            norm="ortho",
2645        )
2646        yield SampleInput(nd_tensor(), norm="ortho")
2647        yield SampleInput(nd_tensor(), n=7 if not is_fp16_or_chalf else 8)
2648        yield SampleInput(oned_tensor())
2649        yield from (SampleInput(nd_tensor(), dim=dim) for dim in [-1, -2, -3])
2650
2651
2652SpectralFuncType = Enum("SpectralFuncType", ("OneD", "TwoD", "ND"))
2653
2654
2655# Metadata class for Fast Fourier Transforms in torch.fft.
2656class SpectralFuncInfo(OpInfo):
2657    """Operator information for torch.fft transforms."""
2658
2659    def __init__(
2660        self,
2661        name,  # the string name of the function
2662        *,
2663        ref=None,  # Reference implementation (probably in np.fft namespace)
2664        dtypes=floating_and_complex_types(),
2665        ndimensional: SpectralFuncType,
2666        sample_inputs_func=sample_inputs_spectral_ops,
2667        decorators=None,
2668        **kwargs,
2669    ):
2670        self._original_spectral_func_args = dict(locals()).copy()
2671        self._original_spectral_func_args.update(kwargs)
2672
2673        decorators = list(decorators) if decorators is not None else []
2674        decorators += [
2675            skipCPUIfNoFFT,
2676            DecorateInfo(
2677                toleranceOverride({torch.chalf: tol(4e-2, 4e-2)}),
2678                "TestCommon",
2679                "test_complex_half_reference_testing",
2680            ),
2681        ]
2682
2683        super().__init__(
2684            name=name,
2685            dtypes=dtypes,
2686            decorators=decorators,
2687            sample_inputs_func=sample_inputs_func,
2688            **kwargs,
2689        )
2690        self.ref = ref
2691        self.ndimensional = ndimensional
2692
2693
2694class ShapeFuncInfo(OpInfo):
2695    """Early version of a specialized OpInfo for Shape manipulating operations like tile and roll"""
2696
2697    def __init__(
2698        self,
2699        name,  # the string name of the function
2700        *,
2701        ref,  # a reference function
2702        dtypes=floating_types(),
2703        dtypesIfCUDA=None,
2704        dtypesIfROCM=None,
2705        dtypesIfXPU=None,
2706        sample_inputs_func=None,
2707        **kwargs,
2708    ):
2709        super().__init__(
2710            name,
2711            dtypes=dtypes,
2712            dtypesIfCUDA=dtypesIfCUDA,
2713            dtypesIfROCM=dtypesIfROCM,
2714            dtypesIfXPU=dtypesIfXPU,
2715            sample_inputs_func=sample_inputs_func,
2716            **kwargs,
2717        )
2718        self.ref = ref
2719
2720
2721def sample_inputs_foreach(
2722    self,
2723    device,
2724    dtype,
2725    N,
2726    *,
2727    noncontiguous=False,
2728    same_size=False,
2729    low=None,
2730    high=None,
2731    zero_size: bool,
2732    requires_grad: bool,
2733    # mutually exclusive from same_size and zero_size, which are all or nothing
2734    intersperse_empty_tensors: bool = False,
2735):
2736    if zero_size:
2737        return [torch.empty(0, dtype=dtype, device=device) for _ in range(N)]
2738    if same_size:
2739        return [
2740            make_tensor(
2741                (N, N),
2742                dtype=dtype,
2743                device=device,
2744                noncontiguous=noncontiguous,
2745                low=low,
2746                high=high,
2747                requires_grad=requires_grad,
2748            )
2749            for _ in range(N)
2750        ]
2751    else:
2752        # interweave some empty tensors + have the last 2 tensors be empty (see #100701)
2753        return [
2754            torch.empty(0, dtype=dtype, device=device, requires_grad=requires_grad)
2755            if (i % 3 == 0 or i >= N - 2) and intersperse_empty_tensors
2756            else make_tensor(
2757                (N - i, N - i),
2758                dtype=dtype,
2759                device=device,
2760                noncontiguous=noncontiguous,
2761                low=low,
2762                high=high,
2763                requires_grad=requires_grad,
2764            )
2765            for i in range(N)
2766        ]
2767
2768
2769def get_foreach_method_names(name):
2770    # get torch inplace reference function
2771    op_name = "_foreach_" + name
2772    inplace_op_name = op_name + "_"
2773
2774    op = getattr(torch, op_name, None)
2775    inplace_op = getattr(torch, inplace_op_name, None)
2776
2777    ref = getattr(torch, name, None)
2778    ref_inplace = getattr(torch.Tensor, name + "_", None)
2779    return op, inplace_op, ref, ref_inplace
2780
2781
2782@dataclass
2783class ForeachFuncInfo(OpInfo):
2784    """Early version of a specialized OpInfo for foreach functions
2785
2786    The main differences from the parent class are (a) `dtypes`, `dtypesIfCUDA`, and `dtypesIfROCM`
2787    are set to `get_all_dtypes(include_qint=False)`, and (b) the following arguments.
2788
2789    ``supports_alpha_param=True`` means that the function supports a python scalar (``numbers.Number``)
2790    as the last keyword argument such as `_foreach_add`.
2791    ``supports_scalar_self_arg=True`` means that the function can take a python scalar as its first argument.
2792    Currently only `_foreach_pow` supports this.
2793    ``backward_requires_result=True``, which could sound self-explanatory, means that the function uses
2794    the forward result for its backward computation.
2795    """
2796
2797    supports_alpha_param: bool = False
2798    supports_scalar_self_arg: bool = False
2799    backward_requires_result: bool = False
2800
2801    def __post_init__(self):
2802        (
2803            foreach_method,
2804            foreach_method_inplace,
2805            torch_ref_method,
2806            torch_ref_inplace,
2807        ) = get_foreach_method_names(self.name)
2808        if not self.supports_out:
2809            # note(crcrpar): `foreach_method` for `"zero"` is `None` but `None` would call
2810            # `_getattr_qual` in `OpInfo.__post_init__` which should fail since `_foreach_zero`
2811            # is not defined at the moment. Thus to skip the qualification, set a similar torch
2812            # function.
2813            assert foreach_method is None
2814            assert torch_ref_method is None
2815            foreach_method = foreach_method_inplace
2816            torch_ref_method = torch_ref_inplace
2817
2818        self.dtypes = _dispatch_dtypes(get_all_dtypes(include_qint=False))
2819
2820        self.op = foreach_method
2821        self.method_variant = foreach_method
2822        self.ref = torch_ref_method
2823        self.inplace_variant = foreach_method_inplace
2824        self.ref_inplace = torch_ref_inplace
2825        self.has_no_in_place = self.inplace_variant is None
2826
2827        name = self.name
2828        self.name = f"_foreach_{name}"
2829        if name == "norm":
2830            self.ref = torch.linalg.vector_norm
2831        elif name == "minimum":
2832            # because minimum ref does not support inplace or scalar
2833            self.ref = torch.clamp_max
2834            self.ref_inplace = torch.Tensor.clamp_max_
2835        elif name == "maximum":
2836            # because maximum ref does not support inplace or scalar
2837            self.ref = torch.clamp_min
2838            self.ref_inplace = torch.Tensor.clamp_min_
2839
2840        # The following sets `dtypesIfCUDA` and `dtypesIfROCM` accordingly.
2841        super().__post_init__()
2842
2843    def sample_zero_size_inputs(self, device, dtype, requires_grad=False, **kwargs):
2844        if not hasattr(self.sample_inputs_func, "sample_zero_size_tensor_inputs"):
2845            return []
2846        return self.sample_inputs_func.sample_zero_size_tensor_inputs(
2847            self, device, dtype, requires_grad, **kwargs
2848        )
2849
2850
2851def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs):
2852    """Gradcheck wrapper for functions that take Hermitian matrices as input.
2853
2854    They require a modified function because the finite-difference algorithm
2855    for calculating derivatives does not preserve the Hermitian property of the input.
2856    """
2857    return op(input + input.mH, *args, **kwargs)
2858
2859
2860def gradcheck_wrapper_triangular_input(op, *args, upper=False, idx=0, **kwargs):
2861    """Gradcheck wrapper for functions that take lower or upper triangular matrices as input.
2862
2863    They require a modified function because the finite-difference algorithm
2864    for calculating derivatives does not preserve the triangular property of the input.
2865    `idx` is used to specific which `args[idx]` is to be triangularized.
2866    """
2867    triangular_arg = args[idx].triu() if upper else args[idx].tril()
2868    return op(*args[:idx], triangular_arg, *args[idx + 1 :], upper, **kwargs)
2869
2870
2871def gradcheck_wrapper_triangular_input_real_positive_diagonal(
2872    op, *args, upper=False, idx=0, **kwargs
2873):
2874    """Gradcheck wrapper for functions that take lower/upper triangular matrices
2875    with real and positive diagonals, for example, cholesky-like operations.
2876    """
2877    arg = args[idx]
2878    arg_diag = arg.diagonal(0, -2, -1)
2879    arg_diag_embed = torch.diag_embed(arg_diag)
2880    id_diag_tensor = torch.ones_like(arg_diag)
2881    id_tensor = torch.diag_embed(id_diag_tensor)
2882    # new_arg = arg - diag(arg) + I
2883    new_arg = arg - arg_diag_embed + id_tensor
2884    return gradcheck_wrapper_triangular_input(
2885        op, *args[:idx], new_arg, *args[idx + 1 :], upper=upper, idx=idx, **kwargs
2886    )
2887
2888
2889def gradcheck_wrapper_masked_operation(op, input, *args, **kwargs):
2890    """Gradcheck wrapper for masked operations.
2891
2892    When mask is specified, replaces masked-out elements with zeros.
2893
2894    Use for operations that produce non-finite masked-out elements,
2895    for instance, for minimum and maximum reductions.
2896    """
2897    output = op(input, *args, **kwargs)
2898    mask = kwargs.get("mask")
2899    if mask is not None:
2900        output_mask = torch.masked._output_mask(op, input, *args, **kwargs)
2901        output = torch.where(output_mask, output, output.new_zeros([]))
2902    return output
2903
2904
2905def gradcheck_wrapper_masked_pointwise_operation(op, input, *args, **kwargs):
2906    """Gradcheck wrapper for masked pointwise operations. Assumes that the result
2907    will be masked iff both tensors are masked at a specific index
2908
2909    When mask is specified, replaces masked-out elements with zeros.
2910
2911    Use for operations that produce non-finite masked-out elements,
2912    for instance, for minimum and maximum reductions.
2913    """
2914    output = op(input, *args, **kwargs)
2915    input_mask = kwargs.get("input_mask")
2916    other_mask = kwargs.get("other_mask")
2917    if input_mask is not None and other_mask is not None:
2918        combined_mask = torch.logical_and(input_mask, other_mask)
2919        new_kwargs = dict(mask=combined_mask, **kwargs)
2920        output_mask = torch.masked._input_mask(input, *args, **new_kwargs)
2921        output = torch.where(output_mask, output, output.new_zeros([]))
2922    return output
2923
2924
2925def clone_sample(sample, **kwargs):
2926    """
2927    Given a SampleInput, this function analyzes its input, args and kwargs,
2928    and produces a copy with each non-Tensor entry being copied by reference,
2929    and with each Tensor entry cloned with `t.clone().requires_grad_(t.requires_grad)`
2930    """
2931
2932    def clone_tensor(t):
2933        if isinstance(t, torch.Tensor):
2934            return t.detach().clone().requires_grad_(t.requires_grad)
2935        else:
2936            return t
2937
2938    sample_kwargs = kwargs if kwargs else sample.kwargs
2939
2940    return SampleInput(
2941        clone_tensor(sample.input),
2942        args=tuple(map(clone_tensor, sample.args)),
2943        kwargs={k: clone_tensor(v) for k, v in sample_kwargs.items()},
2944    )
2945