xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_methods_invocations.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3from functools import wraps, partial
4from itertools import product, chain, islice
5import itertools
6import functools
7import copy
8import operator
9import random
10import unittest
11import math
12import enum
13
14import torch
15import numpy as np
16from torch import inf, nan
17
18from typing import Any, Dict, List, Tuple, Union, Sequence
19from torch.testing import make_tensor
20from torch.testing._internal.common_dtype import (
21    _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
22    floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
23    all_types, empty_types, complex_types_and, integral_types, custom_types,
24)
25from torch.testing._internal.common_device_type import \
26    (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
27     skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride,
28     skipCPUIfNoMklSparse,
29     toleranceOverride, tol)
30from torch.testing._internal.common_cuda import (
31    PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
32    SM53OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version,
33    _get_torch_rocm_version,
34)
35from torch.testing._internal.common_utils import (
36    make_fullrank_matrices_with_distinct_singular_values,
37    TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
38    torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN,
39    GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW,
40    TEST_WITH_TORCHINDUCTOR
41)
42from torch.testing._utils import wrapper_set_seed
43
44import torch._refs as refs  # noqa: F401
45import torch._refs.nn.functional
46import torch._refs.special
47import torch._refs.linalg
48import torch._prims as prims  # noqa: F401
49from torch.utils import _pytree as pytree
50
51
52from packaging import version
53
54from torch.testing._internal.opinfo.core import (  # noqa: F401
55    L,
56    M,
57    S,
58    XS,
59    _NOTHING,
60    _getattr_qual,
61    DecorateInfo,
62    SampleInput,
63    ErrorInput,
64    AliasInfo,
65    NumericsFilter,
66    OpInfo,
67    _generate_reduction_inputs,
68    _generate_reduction_kwargs,
69    sample_inputs_reduction,
70    ReductionOpInfo,
71    reference_inputs_elementwise_binary,
72    make_error_inputs_elementwise_binary,
73    generate_elementwise_binary_tensors,
74    generate_elementwise_binary_arbitrarily_strided_tensors,
75    generate_elementwise_binary_small_value_tensors,
76    generate_elementwise_binary_large_value_tensors,
77    generate_elementwise_binary_extremal_value_tensors,
78    generate_elementwise_binary_broadcasting_tensors,
79    generate_elementwise_binary_with_scalar_samples,
80    generate_elementwise_binary_with_scalar_and_type_promotion_samples,
81    generate_elementwise_binary_noncontiguous_tensors,
82    sample_inputs_elementwise_binary,
83    BinaryUfuncInfo,
84    sample_inputs_elementwise_unary,
85    generate_elementwise_unary_tensors,
86    generate_elementwise_unary_small_value_tensors,
87    generate_elementwise_unary_large_value_tensors,
88    generate_elementwise_unary_extremal_value_tensors,
89    reference_inputs_elementwise_unary,
90    UnaryUfuncInfo,
91    sample_inputs_spectral_ops,
92    SpectralFuncType,
93    SpectralFuncInfo,
94    ShapeFuncInfo,
95    sample_inputs_foreach,
96    ForeachFuncInfo,
97    gradcheck_wrapper_hermitian_input,
98    gradcheck_wrapper_triangular_input,
99    gradcheck_wrapper_triangular_input_real_positive_diagonal,
100    gradcheck_wrapper_masked_operation,
101    gradcheck_wrapper_masked_pointwise_operation,
102    clone_sample,
103)
104from torch.testing._internal.opinfo.refs import (  # NOQA: F401
105    _find_referenced_opinfo,
106    _inherit_constructor_args,
107    PythonRefInfo,
108    ReductionPythonRefInfo,
109    ElementwiseUnaryPythonRefInfo,
110    ElementwiseBinaryPythonRefInfo,
111)
112from torch.testing._internal.opinfo.utils import (
113    np_unary_ufunc_integer_promotion_wrapper,
114    reference_reduction_numpy,
115    prod_numpy
116)
117from torch.testing._internal import opinfo
118from torch.testing._internal.opinfo.definitions.linalg import (
119    sample_inputs_linalg_cholesky,
120    sample_inputs_linalg_cholesky_inverse,
121    sample_inputs_cross,
122    sample_inputs_linalg_qr_geqrf,
123    sample_inputs_linalg_invertible,
124    sample_inputs_lu_solve,
125    sample_inputs_legacy_solve,
126    sample_inputs_svd,
127    sample_inputs_linalg_det_logdet_slogdet,
128    sample_inputs_linalg_lu,
129    sample_inputs_diagonal_diag_embed,
130    error_inputs_diagonal_diag_embed,
131)
132from torch.testing._internal.opinfo.definitions.special import (
133    sample_inputs_i0_i1,
134    sample_inputs_polygamma,
135    reference_polygamma,
136)
137from torch.testing._internal.opinfo.definitions._masked import (
138    sample_inputs_softmax_variant,
139)
140from torch.testing._internal.opinfo.definitions.sparse import (
141    error_inputs_sparse_like_fns,
142    sample_inputs_sparse_like_fns,
143    error_inputs_sparse_mul,
144    sample_inputs_sparse_mul,
145    error_inputs_sparse_reduction_sum,
146    sample_inputs_sparse_reduction_sum
147)
148
149if TEST_SCIPY:
150    from scipy import stats
151    import scipy.spatial
152    import scipy.special
153
154
155# test if a tensor is close to an integer
156def close_to_int(x, eps=0.1):
157    if x.is_complex():
158        y = torch.abs(torch.view_as_complex(torch.frac(torch.view_as_real(x))))
159    else:
160        y = torch.abs(torch.frac(x))
161    return (y < eps) | (y > (1 - eps))
162
163
164def sample_inputs_slice(op_info, device, dtype, requires_grad, **kwargs):
165
166    make_input = partial(make_tensor, device=device, dtype=dtype,
167                         low=None, high=None, requires_grad=requires_grad)
168
169    yield SampleInput(make_input(3), 0)
170
171    yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2)
172
173    yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2, step=3)
174
175    yield SampleInput(make_input(20, 30, 40), dim=0, start=-10, end=-2, step=2)
176
177
178def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs):
179    make_input = partial(make_tensor, device=device, dtype=dtype,
180                         low=None, high=None, requires_grad=requires_grad)
181
182    args_cases = (
183        # Cases with tensor indices.
184        (torch.tensor([1, 2, 3]),),
185        (torch.tensor(1),),
186        (torch.tensor([1, 2, 3]), 1),
187        (torch.tensor([1, 4, 2, 5, 3, 6])[::2], 1),
188        # Cases with list of indices.
189        ((2, 4),),
190        ((2, 4), 1),
191        ((2, 4), -1),
192        # Cases with integer section.
193        (3,),
194        (3, 1),
195        (3, -1),
196    )
197
198    for args in args_cases:
199        yield SampleInput(make_input((S, S, S)), args=args)
200
201
202def sample_inputs_hsplit(op_info, device, dtype, requires_grad, **kwargs):
203    make_arg = partial(make_tensor, dtype=dtype, device=device,
204                       low=None, high=None, requires_grad=requires_grad)
205    yield SampleInput(make_arg(6), 2)
206    yield SampleInput(make_arg(S, S, S), [1, 2, 3])
207
208def sample_inputs_vsplit(op_info, device, dtype, requires_grad, **kwargs):
209    make_arg = partial(make_tensor, dtype=dtype, device=device,
210                       low=None, high=None, requires_grad=requires_grad)
211    yield SampleInput(make_arg(6, S), 2)
212    yield SampleInput(make_arg(S, S, S), [1, 2, 3])
213
214def sample_inputs_dsplit(op_info, device, dtype, requires_grad, **kwargs):
215    make_arg = partial(make_tensor, dtype=dtype, device=device,
216                       low=None, high=None, requires_grad=requires_grad)
217    yield SampleInput(make_arg(S, S, S), [1, 2, 3])
218    yield SampleInput(make_arg(S, S, 6), 2)
219
220def error_inputs_hsplit(op_info, device, **kwargs):
221    make_arg = partial(make_tensor, dtype=torch.float32, device=device)
222    err_msg1 = ("torch.hsplit requires a tensor with at least 1 dimension, "
223                "but got a tensor with 0 dimensions!")
224    yield ErrorInput(SampleInput(make_arg(()), 0), error_regex=err_msg1)
225
226    err_msg2 = (f"torch.hsplit attempted to split along dimension 1, "
227                f"but the size of the dimension {S} "
228                f"is not divisible by the split_size 0!")
229    yield ErrorInput(SampleInput(make_arg((S, S, S)), 0), error_regex=err_msg2)
230
231    # Incorrect type for indices_or_section argument
232    err_msg3 = ("received an invalid combination of arguments.")
233    yield ErrorInput(
234        SampleInput(make_arg((S, S, S)), "abc"),
235        error_type=TypeError, error_regex=err_msg3)
236
237def error_inputs_vsplit(op_info, device, **kwargs):
238    make_arg = partial(make_tensor, dtype=torch.float32, device=device)
239    err_msg1 = ("torch.vsplit requires a tensor with at least 2 dimension, "
240                "but got a tensor with 1 dimensions!")
241    yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1)
242
243    err_msg2 = (f"torch.vsplit attempted to split along dimension 0, "
244                f"but the size of the dimension {S} "
245                f"is not divisible by the split_size 0!")
246    yield ErrorInput(SampleInput(make_arg(S, S, S), 0),
247                     error_regex=err_msg2)
248
249    # Incorrect type for indices_or_section argument
250    err_msg3 = ("received an invalid combination of arguments.")
251    yield ErrorInput(SampleInput(make_arg(S, S, S), "abc"),
252                     error_type=TypeError, error_regex=err_msg3)
253
254def error_inputs_dsplit(op_info, device, **kwargs):
255    make_arg = partial(make_tensor, dtype=torch.float32, device=device)
256    err_msg1 = ("torch.dsplit requires a tensor with at least 3 dimension, "
257                "but got a tensor with 1 dimensions!")
258    yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1)
259
260    err_msg2 = (f"torch.dsplit attempted to split along dimension 2, "
261                f"but the size of the dimension {S} "
262                f"is not divisible by the split_size 0!")
263    yield ErrorInput(SampleInput(make_arg(S, S, S), 0), error_regex=err_msg2)
264
265
266def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs):
267    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
268
269    # input shape, output shape, output stride, output storage offset
270    test_cases = (
271        ((1,), (1,), (1,), 0),
272        ((3, 3), (2, 2), (1, 2), 0),
273        ((3, 3), (2, 2), (1, 2), 1),
274        ((16,), (2, 2, 2, 2), (1, 1, 1, 1), 0),
275        ((16,), (2, 1, 1, 2), (1, 7, 7, 1), 0),
276    )
277
278    for input_shape, output_shape, stride, storage_offset in test_cases:
279        input_t = make_arg(input_shape)
280        kwargs = dict(storage_offset=storage_offset)
281        yield SampleInput(input_t, args=(output_shape, stride), kwargs=kwargs)
282
283def sample_inputs_as_strided_partial_views(op_info, device, dtype, requires_grad, **kwargs):
284    def make_arg():
285        base = make_tensor((20,), device=device, dtype=dtype)
286        return base[5:15].requires_grad_(requires_grad)
287
288    # as_strided on offset, partial views
289    yield SampleInput(make_arg(), (2, 2), (1, 2))
290    yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=0)
291    yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=10)
292
293def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs):
294    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
295
296    # input shape, output shape, output stride, output storage offset
297    test_cases = [
298        ((1,), (), (), 0),
299        ((1,), (1,), (1,), 0),
300        ((3, 3), (2, 2), (1, 2), 0),
301        ((3, 3), (2, 2), (1, 2), 1),
302        ((3, 3), (2, 2), (2, 1), 0),
303        # Scatter to larger dimensions
304        ((16,), (2, 2, 2, 2), (8, 4, 2, 1), 0),
305        # Scatter to larger dimensions with strides inverted
306        ((16,), (2, 1, 1, 2), (1, 2, 4, 8), 0),
307    ]
308
309    for input_shape, output_shape, stride, storage_offset in test_cases:
310        input_t = make_arg(input_shape)
311        input_src = make_arg(output_shape)
312        yield SampleInput(input_t, input_src, output_shape, stride, storage_offset=storage_offset)
313
314
315def error_inputs_as_strided_scatter(op_info, device, **kwargs):
316    make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
317
318    # Create a small tensor and try to scatter it out of bounds
319    input_t = make_arg([4, 4])
320    input_src = make_arg([2, 2])
321    yield ErrorInput(
322        SampleInput(input_t, input_src, [2, 2], [200, 200], storage_offset=0),
323        error_regex="itemsize 4 requiring a storage size of 1604 are out of bounds for storage of size 64"
324    )
325
326
327def sample_inputs_combinations(op_info, device, dtype, requires_grad, **kwargs):
328    inputs = (
329        (0,),
330        (0, 1),
331        (0, 1, 2, 3),
332    )
333
334    rvals = [1, 2, 4]
335
336    products = product(inputs, rvals, [False, True])
337
338    for input_data, r, with_replacement in products:
339        input_t = torch.tensor(input_data, device=device, dtype=dtype, requires_grad=requires_grad)
340        yield SampleInput(input_t, r=r, with_replacement=with_replacement)
341
342def sample_inputs_cartesian_prod(op_info, device, dtype, requires_grad, **kwargs):
343    make_arg = partial(torch.tensor, device=device, dtype=dtype, requires_grad=requires_grad)
344
345    # constructs 1-D tensors with varying number of elements
346    a = make_arg((0,))
347    b = make_arg((0, 1))
348    c = make_arg((0, 1, 2, 3))
349
350    # sample with only 1 tensor
351    yield SampleInput(a)
352
353    # sample with 2 tensors
354    yield SampleInput(a, b)
355
356    # sample with 3 tensors
357    yield SampleInput(a, b, c)
358
359def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwargs):
360    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
361
362    # Ordered as input_shape, dict of dim and eps
363    cases: Tuple[tuple, dict] = (  # type: ignore[assignment]
364        ((S, S), {'dim': 1}),
365        ((S, 2), {'dim': -1}),
366        ((S,), {'dim': 0, 'eps': 0.5}),
367        ((), {'dim': 0}),
368        ((S, S, M), {'dim': 2}),
369        ((S, S), {})
370    )
371
372    for input_shape, kwargs in cases:
373        yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs)
374    # Test for Broadcasting
375    yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
376    yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
377    yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
378
379
380def sample_inputs_item(op_info, device, dtype, requires_grad, **kwargs):
381    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
382
383    cases = (
384        (),
385        (()),
386        (1),
387        ((1,)),
388    )
389
390    for shape in cases:
391        yield SampleInput(make_arg(shape))
392
393def error_inputs_item(op, device, **kwargs):
394    make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False)
395
396    cases = (
397        (M),
398        ((S,)),
399        (S, S),
400        (S, M, L),
401    )
402
403    for shape in cases:
404        yield ErrorInput(
405            SampleInput(make_arg(shape)), error_type=RuntimeError,
406            error_regex="elements cannot be converted to Scalar")
407
408
409def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
410    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
411    make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
412
413    # Ordered as: input shape, kwargs for training, momentum, eps
414    cases: Tuple[Tuple[int], dict] = (  # type: ignore[assignment]
415        ((S, S, S), {'training': True, 'momentum': 0.5, 'eps': 0.6}),
416        ((3, 2, 4), {'training': False, 'momentum': -1.2}),
417        ((3, 1), {'training': True, 'momentum': 0.0}),
418        ((0,), {'training': True}),
419        ((0,), {'training': False}),
420        ((3, 2, 3, 4), {'training': True, 'momentum': -1.0, 'eps': 0.5}),
421        ((3, 2, 3, 4), {'training': False, 'momentum': -1.0, 'eps': 0.5}),
422        ((2, 1), {}),
423    )
424
425    for input_shape, kwargs in cases:
426        # args: running mean, running var, weight and bias should necessarily be of shape: (channels,)
427        channels = input_shape[1] if len(input_shape) > 1 else 0
428        weight = make_arg(channels) if channels > 0 else None
429        bias = make_arg(channels) if channels > 0 else None
430        running_mean = make_arg_without_requires_grad(channels, low=0)
431        running_var = make_arg_without_requires_grad(channels, low=0)
432
433        yield SampleInput(
434            make_arg(input_shape),
435            args=(
436                running_mean,
437                running_var,
438                weight,
439                bias
440            ),
441            kwargs=kwargs
442        )
443
444    # Checking for permutations of weights and biases as `None`
445    weights = [channels, None, None]
446    biases = [None, channels, None]
447    is_training = [True, False, False]
448
449    for weight, bias, training in zip(weights, biases, is_training):
450        yield SampleInput(
451            make_arg(input_shape),
452            args=(
453                running_mean,
454                running_var,
455                make_arg(channels),
456                make_arg(channels)
457            ),
458            kwargs={'training': training}
459        )
460
461    # Test case for no optional kwargs
462    # running_mean and running_var are required in evaluation mode (training: False) but not in training mode
463    yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True})
464
465def sample_inputs_softmax_backward_data(op_info, device, dtype, requires_grad, **kwargs):
466    make_arg = partial(
467        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
468    )
469    cases = [
470        ((S,), 0),
471        ((S, S), 0),
472        ((S, M, S), -1),
473    ]
474    input_dtypes = [dtype]
475    if dtype == torch.float and device == 'cuda':
476        input_dtypes += [torch.float16]
477
478    for (shape, dim), input_dtype in product(cases, input_dtypes):
479        input = make_arg(shape)
480        output = torch.nn.functional.softmax(input, dim=dim, dtype=input_dtype)
481        yield SampleInput(make_arg(shape), output, dim, input_dtype)
482
483def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
484    samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
485    for sample in samples:
486        # torch.native_batch_norm does not support 0 numel tensors
487        # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
488        if sample.input.numel() == 0:
489            continue
490        args = sample.args
491        training = sample.kwargs.get('training', True)
492        momentum = sample.kwargs.get('momentum', 0.5)
493        eps = sample.kwargs.get('eps', 1e-5)
494        yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps))
495
496
497def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs):
498    samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
499    for sample in samples:
500        # torch.native_batch_norm does not support 0 numel tensors
501        # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
502        if sample.input.numel() == 0:
503            continue
504        args = sample.args
505        training = sample.kwargs.get('training', True)
506        momentum = sample.kwargs.get('momentum', 0.5)
507        eps = sample.kwargs.get('eps', 1e-5)
508        if args[0] is not None and args[1] is not None:
509            yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps))
510        else:
511            yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps))
512
513def sample_inputs__batch_norm_with_update(op_info, device, dtype, requires_grad, **kwargs):
514    samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
515    for sample in samples:
516        # torch.native_batch_norm does not support 0 numel tensors
517        # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
518        if sample.input.numel() == 0:
519            continue
520        args = sample.args
521        momentum = sample.kwargs.get('momentum', 0.5)
522        eps = sample.kwargs.get('eps', 1e-5)
523        if any(args[i] is None for i in range(4)):
524            continue
525        yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], momentum, eps))
526
527def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
528    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
529
530    cases = (
531        (()),
532        ((S, )),
533        ((S, S)),
534        ((S, M, S))
535    )
536
537    for shape in cases:
538        yield SampleInput(make_arg(shape))
539
540def sample_inputs_prelu(op_info, device, dtype, requires_grad, **kwargs):
541    op_kwargs = op_info.sample_kwargs(device, dtype, None)[0]
542    yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad,
543                                               op_kwargs=op_kwargs)
544
545    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
546
547    cases = (
548        (()),
549        ((S, )),
550        ((S, S)),
551        ((S, M, S))
552    )
553
554    for shape in cases:
555        for weight in [-1., 0., 0.8, 1.]:
556            weight_tensor = torch.tensor(weight, device=device, dtype=dtype, requires_grad=requires_grad)
557            yield SampleInput(make_arg(shape), args=(weight_tensor,))
558
559        channel_size = shape[1] if len(shape) >= 2 else 1
560        yield SampleInput(make_arg(shape), args=(make_arg((channel_size,)),))
561
562    weight_tensor = torch.tensor(1., device=device, dtype=dtype, requires_grad=requires_grad)
563
564    yield SampleInput(make_arg((S, S)), kwargs=dict(weight=weight_tensor,))
565    yield SampleInput(make_arg((S, S)), kwargs=dict(weight=make_arg((S,)),))
566
567def reference_inputs_prelu(op, device, dtype, requires_grad, **kwargs):
568    yield from sample_inputs_prelu(op, device, dtype, requires_grad, **kwargs)
569    yield from reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs)
570
571def sample_kwargs_prelu_scalar_weight(device, dtype, input):
572    weight = torch.rand((), device=device, dtype=dtype)
573    # NumPy does not support bfloat16, so we default to float32 (only for NumPy) in that case
574    if dtype == torch.bfloat16:
575        weight_cpu = weight.to(dtype=torch.float32, device="cpu")
576    else:
577        weight_cpu = weight.cpu()
578    np_weight = weight_cpu.numpy()
579    return ({'weight': weight}, {'weight': np_weight})
580
581def error_inputs_prelu(op, device):
582    # Weight has numel != 1, but self.ndim is zero-dim tensor
583    inp = make_tensor((), device=device, dtype=torch.float32)
584    weight = make_tensor((2,), device=device, dtype=torch.float32)
585    yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}),
586                     error_regex="Not allow zero-dim input tensor.")
587
588    # Weight has numel != 1, but numel does not match channel size
589    inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32)
590    weight = make_tensor((9,), device=device, dtype=torch.float32)
591    yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}),
592                     error_regex="Mismatch of parameter numbers and input channel size.")
593
594    # Weight is neither a scalar nor 1-D tensor
595    inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32)
596    weight = make_tensor((2, 4), device=device, dtype=torch.float32)
597    yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}),
598                     error_regex="prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = 2")
599
600    # src and index tensors must have the same # of dimensions
601def sample_inputs_norm(op_info, device, dtype, requires_grad, **kwargs):
602    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
603
604    # ord = inf is tested in inputs_norm_inf as it fails on some tests
605    cases = [
606        ((S, S), (2,), '2'),
607        ((S, S), (0,), '0'),
608        ((S, S), (0.5,), '0_5'),
609        ((S, S), (1,), '1'),
610        ((S, S), (3,), '3'),
611        ((S, S), (-1,), 'neg_1'),
612        ((S, S), (-2,), 'neg_2'),
613        ((S, S), (-0.5,), 'neg_0_5'),
614        ((S, S), (-1.5,), 'neg_1_5'),
615    ]
616
617    cases_nonzero_input = (
618        ((S, S, S), (1.5,), '1_5_default'),
619        ((S, S, S), (1.5, 1), '1_5_dim'),
620        ((S, S, S), (1.5, -1), '1_5_neg_dim'),
621        ((S, S, S), (1.5, 1, True), 'keepdim_1_5_dim'),
622        ((S, S, S), (1.5, -1, True), 'keepdim_1_5_neg_dim'),
623    )
624
625    cases_posdim = (
626        ((S, S), (-2, 1,), 'neg_2_dim'),
627        ((S, S), (-1, 1,), 'neg_1_dim'),
628        ((S, S), (0, 1,), '0_dim'),
629        ((S, S), (1, 1,), '1_dim'),
630        ((S, S), (2, 1,), '2_dim'),
631        ((S, S), (3, 1,), '3_dim'),
632        ((S, S, S), (2, 1), '2_dim'),
633        ((S, S, S), (3, 1), '3_dim'),
634        ((S, S, S), (2, 1, True), 'keepdim_2_dim'),
635        ((S, S, S), (3, 1, True), 'keepdim_3_dim'),
636        ((), (2, 0), '2_dim_scalar'),
637        ((), (3, 0), '3_dim_scalar'),
638        ((), (2, 0, True), 'keepdim_2_dim_scalar'),
639        ((), (3, 0, True), 'keepdim_3_dim_scalar'),
640    )
641
642    cases_negdim = ((shape, args[:1] + (-args[1],) + args[2:], name.replace("_dim", "_neg_dim"))
643                    for shape, args, name in cases_posdim)
644
645    for shape, args, name in itertools.chain(cases, cases_posdim, cases_negdim):
646        yield SampleInput(make_arg(shape), args=args, name=name)
647
648    for shape, args, name in cases_nonzero_input:
649        yield SampleInput(make_arg(shape, exclude_zero=True), args=args, name=name)
650
651
652def sample_inputs_norm_fro(op_info, device, dtype, requires_grad, **kwargs):
653    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
654
655    cases = (
656        ((S, S), (), 'default'),
657        ((S, S), ('fro',), 'fro_default'),
658        ((S, S), ('fro', [0, 1],), 'fro'),
659    )
660
661    for shape, args, name in cases:
662        yield SampleInput(make_arg(shape), args=args, name=name)
663
664
665def sample_inputs_norm_nuc(op_info, device, dtype, requires_grad, **kwargs):
666    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
667
668    cases = (
669        ((S, S), ('nuc',), 'nuc'),
670        ((S, S, S), ('nuc', [1, 2]), 'nuc_batched'),
671    )
672
673    for shape, args, name in cases:
674        yield SampleInput(make_arg(shape), args=args, name=name)
675
676
677def sample_inputs_norm_inf(op_info, device, dtype, requires_grad, **kwargs):
678    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
679
680    cases = (
681        ((S, S), (-inf,), '-inf'),
682        ((S, S), (inf,), 'inf'),
683        ((S, S), (inf, 1,), 'inf_2_dim'),
684        ((S, S), (inf, -1,), 'inf_2_neg_dim'),
685    )
686
687    for shape, args, name in cases:
688        yield SampleInput(make_arg(shape), args=args, name=name)
689
690
691def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs):
692    make_arg = partial(
693        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
694
695    shapes = (
696        ((), ()),
697        ((S,), ()),
698        ((), (S,)),
699        ((S, 1), (S,)),
700        ((M, S), ()),
701        ((S, S), (S, S))
702    )
703
704    for shape_lhs, shape_rhs in shapes:
705        lhs = make_arg(shape_lhs)
706        rhs = make_arg(shape_rhs)
707        broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)
708
709        yield SampleInput(lhs, args=(rhs,), broadcasts_input=broadcasts_input)
710        if shape_lhs == shape_rhs:
711            yield SampleInput(lhs, args=(lhs.clone().detach_(),))
712
713
714def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs):
715    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
716
717    shapes = (
718        ((), ()),
719        ((S,), ()),
720        ((S, 1), (S,)),
721        ((M, S), ()),
722        ((S, M, S), (M, S)),
723        ((S, M, S), (S, M, S)),
724        ((M, 1, S), (M, S)),
725        ((M, 1, S), (1, M, S)),
726        ((0, 1, 3), (0, 10, 3))
727    )
728
729    num_inputs = kwargs.get('num_inputs')
730    sample_kwargs = kwargs.get('sample_kwargs', {})
731
732    for shape_lhs, shape_rhs in shapes:
733        lhs = make_arg(shape_lhs)
734
735        args = []
736        for i in range(num_inputs - 1):
737            args.append(make_arg(shape_rhs))
738        broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs))
739
740        yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input)
741
742def sample_inputs_broadcast_shapes(op, device, dtype, requires_grad, **kwargs):
743    shapes = (
744        ((), ()),
745        ((S,), ()),
746        ((S, 1), (S,)),
747        ((S, 1), S),
748        ((M, S), ()),
749        ((S, M, S), (M, S)),
750        ((S, M, S), (S, M, S)),
751        ((M, 1, S), (M, S)),
752        ((M, 1, S), (1, M, S)),
753        ((0, 1, 3), (0, 10, 3))
754    )
755
756    for shape in shapes:
757        inp, *arg0 = shape
758        yield SampleInput(inp, args=tuple(arg0))
759
760def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs):
761    yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
762
763    # Adds alpha kwarg cases
764    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
765    lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs)
766    rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs)
767    if dtype is not torch.bool:
768        yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': 2})
769    else:
770        yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': True})
771    neg_alpha = -3.125 if (dtype.is_floating_point or dtype.is_complex) else -3
772    lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs)
773    rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs)
774    if dtype is not torch.bool:
775        yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': neg_alpha})
776    else:
777        yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False})
778
779def error_inputs_arange(op, device, **kwargs):
780    yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzer')
781    yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign')
782    yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign')
783    yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range')
784    yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range')
785
786def sample_inputs_arange(op, device, dtype, requires_grad, **kwargs):
787    int_samples = (
788        # positive direction
789        (-1, 2, 2),
790        # negative direction
791        (2, -3, -1),
792        # start == end
793        (1, 1, 1),
794        (1, 1, -1),
795        # divides evenly
796        (0, -8, -4),
797        (1, 5, 2),
798        # bool
799        (False, True, True),
800        # default step
801        (0, 1, None),
802        # default start
803        (None, 3, None),
804    )
805
806    def to_float(start, end, step):
807        start = start + 0.1 if start is not None else None
808        end = end + 0.1
809        step = float(step) if step is not None else None
810        return start, end, step
811
812    float_samples = (
813        # includes endpoint
814        (0., -8. - 1e-6, -4.),
815        (1., 5. + 1e-6, 2.),
816        (0., -8., -4.),
817        (1., 5., 2.),
818        *(to_float(start, end, step) for (start, end, step) in int_samples),
819    )
820
821    large_samples = (
822        (0, 10000, None),
823    )
824
825    samples = int_samples + float_samples
826    if dtype not in (torch.int8, torch.uint8):
827        samples += large_samples
828
829    for start, end, step in samples:
830        if start is None:
831            assert step is None
832            # Pass end as positional arg
833            yield SampleInput(end, kwargs={"dtype": dtype, "device": device})
834            # (Similar to) calling torch.arange(end=3)
835            yield SampleInput(0, kwargs={"end": end, "dtype": dtype, "device": device})
836        elif step is None:
837            yield SampleInput(start, args=(end,), kwargs={"dtype": dtype, "device": device})
838        else:
839            yield SampleInput(start, args=(end, step), kwargs={"dtype": dtype, "device": device})
840
841    yield SampleInput(2)
842    yield SampleInput(1, args=(3, 1))
843
844def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs):
845    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
846
847    shapes = (
848        (M,),
849        (S, S)
850    )
851
852    for shape in shapes:
853        yield SampleInput(input=shape, kwargs=dict(dtype=dtype, device=device, requires_grad=requires_grad))
854
855def sample_inputs_normal(op, device, dtype, requires_grad, **kwargs):
856
857    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
858    samples = (
859        ((S, S), 0, 5),
860        ((S, S, S), -2, 0.5),
861    )
862    for shape, mean, std in samples:
863        yield SampleInput(make_arg(shape), args=(mean, std))
864
865def error_inputs_normal(op, device, **kwargs):
866    t = torch.zeros([10], device=device)
867    invalid_std = -1
868    yield ErrorInput(
869        SampleInput(t, args=(0, invalid_std)),
870        error_type=RuntimeError,
871        error_regex=fr"normal expects std >= 0.0, but found std {invalid_std}",
872    )
873
874def sample_inputs_cauchy(op, device, dtype, requires_grad, **kwargs):
875    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
876    samples = (
877        ((M,), 0, 0.5),
878        ((S, S), 0, 1),
879        ((S, S, S), -2, 1),
880    )
881    for shape, median, gamma in samples:
882        yield SampleInput(make_arg(shape), args=(median, gamma))
883
884
885def error_inputs_cauchy(op, device, **kwargs):
886    t = torch.zeros([10], device=device)
887    invalid_scale = 0
888    yield ErrorInput(
889        SampleInput(t, args=(0, invalid_scale,)),
890        error_type=RuntimeError,
891        error_regex=fr"cauchy_ expects sigma > 0.0, but found sigma={invalid_scale}",
892    )
893
894
895def sample_inputs_exponential(op, device, dtype, requires_grad, **kwargs):
896
897    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
898    samples = (
899        ((M,), 0.5),
900        ((S, S), 1),
901        ((S, S, S), 1.5),
902    )
903    for shape, rate in samples:
904        yield SampleInput(make_arg(shape), args=(rate,))
905
906
907def error_inputs_exponential(op, device, **kwargs):
908    t = torch.zeros([10], device=device)
909    invalid_rate = 0
910    yield ErrorInput(
911        SampleInput(t, args=(invalid_rate,)),
912        error_type=RuntimeError,
913        error_regex=fr"exponential_ expects lambda > 0.0, but found lambda={invalid_rate}",
914    )
915
916
917def sample_inputs_geometric(op, device, dtype, requires_grad, **kwargs):
918
919    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
920    samples = (
921        ((M,), 0.2),
922        ((S, S), 0.5),
923        ((S, S, S), 0.8),
924    )
925    for shape, rate in samples:
926        yield SampleInput(make_arg(shape), args=(rate,))
927
928
929def error_inputs_geometric(op, device, **kwargs):
930    t = torch.zeros([10], device=device)
931    neg_prob = -1
932    yield ErrorInput(
933        SampleInput(t, args=(neg_prob,)),
934        error_type=RuntimeError,
935        error_regex=fr"geometric_ expects p to be in \(0, 1\), but got p={neg_prob}",
936    )
937
938
939def sample_inputs_log_normal(op, device, dtype, requires_grad, **kwargs):
940
941    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
942    samples = (
943        ((M,), 0, 0.25),
944        ((S, S), 0.5, 1),
945        ((S, S, S), 0, 0.5),
946    )
947    for shape, mean, std in samples:
948        yield SampleInput(make_arg(shape), args=(mean, std))
949
950
951def error_inputs_log_normal(op, device, **kwargs):
952    t = torch.zeros([10], device=device)
953    invalid_std = 0
954    yield ErrorInput(
955        SampleInput(t, args=(0, invalid_std)),
956        error_type=RuntimeError,
957        error_regex=fr"log_normal_ expects std > 0.0, but found std={invalid_std}",
958    )
959
960
961def sample_inputs_uniform(op, device, dtype, requires_grad, **kwargs):
962
963    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
964    samples = (
965        ((M,), -100, 100),
966        ((S, S), 0, 1),
967        ((S, S, S), 1, 2),
968    )
969    for shape, hi, lo in samples:
970        yield SampleInput(make_arg(shape), args=(hi, lo))
971
972def sample_inputs_ones_zeros(op, device, dtype, requires_grad, **kwargs):
973    # this is a bit messy, as we want the args to be tuples
974    # so if we pass size as a tuple, we have a tuple containing a tuple
975    sizes = (
976        (M,),
977        (S, S),
978    )
979    for size in sizes:
980        yield SampleInput(size, kwargs={'dtype': dtype, 'device': device})
981
982def sample_inputs_full(op, device, dtype, requires_grad, **kwargs):
983    def get_val(dtype):
984        return make_tensor([], dtype=dtype, device="cpu").item()
985
986    sizes = (
987        (M,),
988        (S, S),
989    )
990    fill_values = [get_val(dtype), get_val(torch.int)]
991
992    for size, fill_value in product(sizes, fill_values):
993        yield SampleInput(size, fill_value, dtype=dtype, device=device)
994
995
996def error_inputs_uniform(op, device, **kwargs):
997    t = torch.zeros([10], device=device)
998    yield ErrorInput(
999        SampleInput(t, args=(3, -1)),
1000        error_type=RuntimeError,
1001        error_regex=r"uniform_ expects to return a \[from, to\) range, but found from=3 > to=-1",
1002    )
1003
1004
1005def error_inputs_linspace(op, device, **kwargs):
1006    yield ErrorInput(SampleInput(0, args=(3, -1)), error_type=RuntimeError, error_regex='number of steps must be non-negative')
1007    yield ErrorInput(
1008        SampleInput(0, args=(3, 1.)),
1009        error_type=TypeError,
1010        error_regex="received an invalid combination of arguments - got \\(int, int, float",
1011    )
1012    yield ErrorInput(
1013        SampleInput(torch.tensor([1, 1], device=device), args=(torch.tensor([3, 3], device=device), 1)),
1014        error_type=RuntimeError,
1015        error_regex="only supports 0-dimensional start and end tensors"
1016    )
1017
1018
1019def sample_inputs_linspace(op, device, dtype, requires_grad, **kwargs):
1020    ends = (-3, 0, 1, 4, 50)
1021    starts = (-2., 0, 4.3, 50)
1022    nsteps = (0, 1, 50)
1023    # Extra case to replicate off-by-one issue on CUDA
1024    cases = list(product(starts, ends, nsteps)) + [(0, 7, 50)]
1025    for start, end, nstep in cases:
1026        if dtype == torch.uint8 and (end < 0 or start < 0):
1027            continue
1028        yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device})
1029
1030    yield SampleInput(1, args=(3, 1))
1031
1032
1033def sample_inputs_linspace_tensor_overload(op, device, dtype, requires_grad, **kwargs):
1034    ends = (-3, 0, 1, 4, 50)
1035    starts = (-2., 0, 4.3, 50)
1036    nsteps = (0, 1, 50)
1037    is_start_end_tensors = ((True, True), (True, False), (False, True))
1038    make_arg = partial(torch.tensor, device=device, requires_grad=False)
1039
1040    # Extra case to replicate off-by-one issue on CUDA
1041    cases = list(product(starts, ends, nsteps, is_start_end_tensors)) + [(0, 7, 50, (True, True))]
1042    for start, end, nstep, (is_start_tensor, is_end_tensor) in cases:
1043        if dtype == torch.uint8 and (end < 0 or start < 0):
1044            continue
1045
1046        tensor_options = {"dtype": dtype, "device": device}
1047        if is_start_tensor:
1048            start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64)
1049        if is_end_tensor:
1050            end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64)
1051
1052        yield SampleInput(start, args=(end, nstep), kwargs=tensor_options)
1053
1054    yield SampleInput(1, args=(3, 1))
1055
1056
1057def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs):
1058    ends = (-3, 0, 1.2, 2, 4)
1059    starts = (-2., 0, 1, 2, 4.3)
1060    nsteps = (0, 1, 2, 4)
1061    bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.)
1062    for start, end, nstep, base in product(starts, ends, nsteps, bases):
1063        if dtype == torch.uint8 and end < 0 or start < 0:
1064            continue
1065        if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point):
1066            # https://github.com/pytorch/pytorch/issues/82242
1067            continue
1068        if base is None:
1069            yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device})
1070        else:
1071            yield SampleInput(start, args=(end, nstep, base), kwargs={"dtype": dtype, "device": device})
1072
1073    yield SampleInput(1, args=(3, 1, 2.))
1074
1075
1076def sample_inputs_logspace_tensor_overload(op, device, dtype, requires_grad, **kwargs):
1077    ends = (-3, 0, 1.2, 2, 4)
1078    starts = (-2., 0, 1, 2, 4.3)
1079    nsteps = (0, 1, 2, 4)
1080    bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.)
1081    is_start_end_tensors = ((True, True), (True, False), (False, True))
1082    make_arg = partial(torch.tensor, device=device)
1083    for start, end, nstep, base, (is_start_tensor, is_end_tensor) in product(starts, ends, nsteps, bases, is_start_end_tensors):
1084        if dtype == torch.uint8 and end < 0 or start < 0:
1085            continue
1086        if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point):
1087            # https://github.com/pytorch/pytorch/issues/82242
1088            continue
1089
1090        tensor_options = {"dtype": dtype, "device": device}
1091
1092        if (is_start_tensor):
1093            start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64)
1094        if (is_end_tensor):
1095            end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64)
1096
1097        if base is None:
1098            yield SampleInput(start, args=(end, nstep), kwargs=tensor_options)
1099        else:
1100            yield SampleInput(start, args=(end, nstep, base), kwargs=tensor_options)
1101
1102    yield SampleInput(1, args=(3, 1, 2.))
1103
1104
1105def sample_inputs_isclose(op, device, dtype, requires_grad, **kwargs):
1106    yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
1107
1108    # Creates additional inputs to test the rtol, atol, and equal_nan params
1109    rtols = [0., 1e-7]
1110    atols = [0., 1e-7]
1111    equal_nans = [False, True]
1112
1113    products = product(rtols, atols, equal_nans)
1114
1115    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1116    for rtol, atol, equal_nan in products:
1117        lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs)
1118        rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs)
1119
1120        yield SampleInput(lhs, args=(rhs,),
1121                          kwargs=dict(rtol=rtol, atol=atol, equal_nan=equal_nan))
1122
1123
1124def error_inputs_isclose(op, device, **kwargs):
1125    make_float_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
1126
1127    yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'rtol': -0.4}),
1128                     error_type=RuntimeError,
1129                     error_regex='rtol must be greater than or equal to zero')
1130
1131    yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'atol': -0.4}),
1132                     error_type=RuntimeError,
1133                     error_regex='atol must be greater than or equal to zero')
1134
1135
1136def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
1137    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1138    yield SampleInput(make_arg((1, 2)))
1139    yield SampleInput(make_arg((2,)))
1140    yield SampleInput(make_arg(()))
1141
1142
1143def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs):
1144    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1145
1146    def make_arg_conj(size):
1147        return make_arg(size).conj().requires_grad_(requires_grad)
1148
1149    first_shape, second_shape = (S, M), (M, S)
1150
1151    yield SampleInput(make_arg(first_shape), args=(make_arg(second_shape),))
1152
1153    if dtype.is_complex:
1154        yield SampleInput(make_arg(first_shape), args=(make_arg_conj(second_shape),))
1155
1156    # Matmul of empty matrices
1157    yield SampleInput(make_arg((0, S)), args=(make_arg(S, M),))
1158    yield SampleInput(make_arg((S, 0)), args=(make_arg(0, M),))
1159
1160
1161def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs):
1162    alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6)
1163    beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2)
1164    tests_list = [
1165        ((2, 3), (2, 2), (2, 3), False),
1166        ((3, 3), (3, 3), (3, 3), False),
1167    ]
1168    tests_with_lhs_broadcasting = [
1169        ((1,), (2, 2), (2, 3), True),
1170        ((), (2, 2), (2, 3), True),
1171    ]
1172    test_cases = tests_list + tests_with_lhs_broadcasting  # type: ignore[operator]
1173
1174    kwargs = dict(alpha=alpha_val, beta=beta_val)
1175    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
1176    for shape_a, shape_b, shape_c, broadcasts_input in test_cases:
1177        yield SampleInput(
1178            make_arg(shape_a),
1179            make_arg(shape_b),
1180            make_arg(shape_c),
1181            **kwargs,
1182        ).with_metadata(broadcasts_input=broadcasts_input)
1183
1184    if dtype.is_complex:
1185        shape = (3, 3)
1186        yield SampleInput(
1187            make_arg(shape),
1188            make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad),
1189            make_arg(shape),
1190            **kwargs,
1191        )
1192        yield SampleInput(
1193            make_arg(shape),
1194            make_arg(shape),
1195            make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad),
1196            **kwargs,
1197        )
1198    # addmm of empty matrices
1199    if dtype.is_floating_point:
1200        yield SampleInput(make_arg(S, M), make_arg(S, 0), make_arg(0, M), **kwargs)
1201        # empty matmul with broadcastable input
1202        yield SampleInput(make_arg(M), make_arg(S, 0), make_arg(0, M), **kwargs).with_metadata(broadcasts_input=True)
1203
1204def sample_inputs_sparse_sampled_addmm(op_info, device, dtype, requires_grad, **kwargs):
1205    alpha = 2 + 3j if dtype.is_complex else 0.6
1206    beta = 1 + 2j if dtype.is_complex else 0.2
1207    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1208
1209    # sparse.sampled_addmm performs: alpha * (A @ B) * sparse_ones_like(C) + beta * C
1210    for m, n, k in itertools.product([0, 5], repeat=3):
1211        yield SampleInput(
1212            torch.eye(m, n, device=device, dtype=dtype)
1213            .to_sparse_csr()
1214            .requires_grad_(requires_grad),
1215            make_arg((m, k)),
1216            make_arg((k, n)),
1217            alpha=alpha,
1218            beta=beta,
1219        )
1220
1221def sample_inputs_sparse_mm_reduce(op_info, device, dtype, requires_grad, **kwargs):
1222    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1223
1224    reductions = ["sum", "mean", "amax", "amin"]
1225    for m, k, reduce in product([5, 7], [3, 11], reductions):
1226        yield SampleInput(
1227            torch.eye(m, m)
1228            .to(device=device, dtype=dtype)
1229            .to_sparse_csr()
1230            .requires_grad_(requires_grad),
1231            make_arg((m, k)),
1232            reduce,
1233        )
1234
1235
1236def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs):
1237    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
1238    yield SampleInput(make_arg(S, M), make_arg(M))
1239
1240def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs):
1241    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
1242    yield SampleInput(make_arg(M, S, M), make_arg(M, M, S))
1243
1244def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs):
1245    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1246
1247    def make_arg_conj(size):
1248        return make_arg(size).conj().requires_grad_(requires_grad)
1249
1250    yield SampleInput(make_arg((S, )), make_arg((S, )))
1251    if dtype.is_complex:
1252        # dot/vdot for (conj(input), conj(arg_tensor)) and (conj(input), arg_tensor)
1253        # is tested in test_conj_view (which tests operations with only conjugated input tensor
1254        # -- not conjugated arg tensors)
1255        yield SampleInput(make_arg((S, )), make_arg_conj((S, )))
1256
1257
1258def error_inputs_dot_vdot(op_info, device, is_ref=False, **kwargs):
1259    make_input = partial(make_tensor, device=device, dtype=torch.float32)
1260
1261    if not is_ref:
1262        yield ErrorInput(SampleInput(make_input(1), args=(make_input(3, dtype=torch.float16),)),
1263                         error_regex='dot : expected both vectors to have same dtype')
1264    yield ErrorInput(SampleInput(make_input(1, 1), args=(make_input(3),)),
1265                     error_regex='1D tensors expected')
1266    yield ErrorInput(SampleInput(make_input(9), args=(make_input(3),)),
1267                     error_regex='inconsistent tensor size')
1268    if device != "cpu" and not is_ref:
1269        yield ErrorInput(SampleInput(make_input(3), args=(make_input(3, device="cpu"),)),
1270                         error_regex='Expected all tensors to be on the same device')
1271
1272
1273def sample_inputs_addmv(op_info, device, dtype, requires_grad, **kwargs):
1274    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
1275
1276    test_cases = (((S,), (S, M), (M,), 1, 1, False),
1277                  ((S,), (S, M), (M,), 0.2, 0.6, False),
1278                  )
1279
1280    test_cases_with_broadcast = (((1,), (S, M), (M,), 1, 1, True),
1281                                 ((1,), (S, M), (M,), 0.2, 0.6, True),
1282                                 ((), (S, M), (M,), 1, 1, True),
1283                                 ((), (S, M), (M,), 0.2, 0.6, True),
1284                                 )
1285
1286    cases = test_cases + test_cases_with_broadcast
1287
1288    # addmv performs: beta * M + alpha * (mat @ vec)
1289    for size, mat, vec, beta, alpha, broadcasts_input in cases:
1290        yield SampleInput(make_arg(size), args=(make_arg(mat), make_arg(vec)),
1291                          kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=broadcasts_input)
1292
1293def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs):
1294    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1295
1296    # input_shape, batch1_shape, batch2_shape, beta_val, alpha_val, is_broadcasting
1297    test_cases = [((S, M), (S, S, S), (S, S, M), 1, 1, False),
1298                  ((1,), (S, S, S), (S, S, M), 1, 1, True),
1299                  ((S, M), (S, S, S), (S, S, M), 0.6, 0.2, False),
1300                  ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True),
1301                  ((), (S, S, S), (S, S, M), 1, 1, True),
1302                  ((), (S, S, S), (S, S, M), 0.6, 0.2, True),
1303                  ]
1304
1305    for input_shape, batch1_shape, batch2_shape, beta, alpha, is_broadcasting in test_cases:
1306        if dtype.is_complex:
1307            beta_complex, alpha_complex = beta * (1 + 2j), alpha * (2 + 3j)
1308            yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)),
1309                              kwargs=dict(beta=beta_complex, alpha=alpha_complex), broadcasts_input=is_broadcasting)
1310        yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)),
1311                          kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=is_broadcasting)
1312
1313def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs):
1314    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1315    test_cases = [(((S, S), (S, S), (S, S)), False),
1316                  (((S, S), (S, 1), (1, S)), False),
1317                  (((1,), (S, S, 1), (1, S)), True),
1318                  (((), (), ()), False),
1319                  (((S, S), (), ()), True),
1320                  (((), (S, S, 1), (1, S)), True)
1321                  ]
1322
1323    for input_args, broadcasts_input in test_cases:
1324        # addcdiv should accept inputs with zero value
1325        # Currently, it throws ZeroDivisionError when the denominator is zero
1326        # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed
1327        args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg
1328                     for arg in input_args)
1329        yield SampleInput(*args).with_metadata(broadcasts_input=broadcasts_input)
1330
1331        # addcdiv should accept inputs with zero value
1332        # Currently, it throws ZeroDivisionError when the denominator is zero
1333        # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed
1334        args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg
1335                     for arg in input_args)
1336        yield SampleInput(
1337            *args, value=3.14 if dtype.is_floating_point or dtype.is_complex else 3
1338        ).with_metadata(broadcasts_input=broadcasts_input)
1339
1340def reference_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs):
1341    yield from sample_inputs_addcmul_addcdiv(
1342        op_info, device, dtype, requires_grad, **kwargs)
1343
1344    # type promotion cases
1345    supported_dtypes = op_info.supported_dtypes(device)
1346    make_arg = partial(make_tensor, device=device, requires_grad=requires_grad)
1347
1348    types = (
1349        (torch.float64, torch.complex128),
1350        (torch.bfloat16, torch.float32),
1351    )
1352
1353    values = (
1354        None,
1355        True, False,
1356        3.14, 3,
1357        1.0, 1,
1358        0.0, 0,
1359        -3.14, -3,
1360        3.14 + 2.71j,
1361    )
1362
1363    for (type2, type3), value in product(types, values):
1364        if (type2 not in supported_dtypes or
1365                type3 not in supported_dtypes):
1366            continue
1367
1368        # RuntimeError: value cannot be converted without overflow
1369        if (type(value) is complex and
1370                type2 is not torch.complex128):
1371            continue
1372
1373        arg1 = make_arg([5, 5], dtype=dtype)
1374        arg2 = make_arg([5, 5], dtype=type2)
1375        arg3 = make_arg([1, 5], dtype=type3)
1376
1377        # TypeError: addcdiv(): argument 'value' must be Number, not NoneType
1378        if value is not None:
1379            yield SampleInput(arg1, args=(arg2, arg3), kwargs=dict(value=value))
1380        else:
1381            yield SampleInput(arg1, args=(arg2, arg3))
1382
1383def sample_inputs_baddbmm(op_info, device, dtype, requires_grad, **kwargs):
1384    test_cases = [((S, S, M), (S, S, S), (S, S, M), 1, 1, False),
1385                  ((1,), (S, S, S), (S, S, M), 1, 1, True),
1386                  ((S, S, M), (S, S, S), (S, S, M), 0.6, 0.2, False),
1387                  ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True),
1388                  ((), (S, S, S), (S, S, M), 1, 1, True),
1389                  ((), (S, S, S), (S, S, M), 0.6, 0.2, True),
1390                  ]
1391    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
1392    for (input_shape, batch1_shape, batch2_shape, alpha, beta, broadcasts_input) in test_cases:
1393        yield SampleInput(
1394            make_arg(input_shape),
1395            make_arg(batch1_shape),
1396            make_arg(batch2_shape),
1397            beta=beta,
1398            alpha=alpha
1399        ).with_metadata(broadcasts_input=broadcasts_input)
1400
1401        if dtype.is_complex:
1402            yield SampleInput(
1403                make_arg(input_shape),
1404                make_arg(batch1_shape),
1405                make_arg(batch2_shape),
1406                beta=beta * (1 + 2j),
1407                alpha=alpha * (2 + 3j),
1408            ).with_metadata(broadcasts_input=broadcasts_input)
1409
1410    if dtype.is_complex:
1411        shapes = [(S, S, S), (S, M, S), (S, S, M)]
1412        args = tuple(make_arg(s) for s in shapes)
1413        yield SampleInput(
1414            args[0].transpose_(-1, 1),
1415            args[1].transpose(-1, 1).conj().requires_grad_(requires_grad),
1416            args[2].transpose(-1, 1).conj().requires_grad_(requires_grad),
1417            beta=beta * (1 + 2j),
1418            alpha=alpha * (2 + 3j),
1419        )
1420
1421# TODO: add reduction kwargs
1422def sample_inputs_multilabel_soft_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
1423    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1424
1425    shapes = (
1426        (S,),
1427        (S, S),
1428    )
1429
1430    for shape in shapes:
1431        # Produce one with weight and one without.
1432        yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), kwargs={})
1433        yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),),
1434                          kwargs={'weight': _make_tensor(shape, requires_grad=False)})
1435
1436def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
1437    make_arg = partial(
1438        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None
1439    )
1440    yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M))
1441
1442    yield SampleInput(make_arg(), make_arg(S), make_arg(M)).with_metadata(broadcasts_input=True)
1443
1444    if dtype.is_complex:
1445        alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j
1446    elif dtype.is_floating_point:
1447        alpha, beta = 0.2, 0.6
1448    else:
1449        alpha, beta = 2, 3
1450
1451    yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M), beta=beta, alpha=alpha)
1452
1453    yield SampleInput(
1454        make_arg(),
1455        make_arg(S),
1456        make_arg(M),
1457        beta=beta,
1458        alpha=alpha,
1459    ).with_metadata(broadcasts_input=True)
1460
1461    # These samples fail gradcheck
1462    if dtype.is_floating_point and not requires_grad:
1463        tensor_options = dict(device=device, dtype=dtype, requires_grad=requires_grad)
1464        yield SampleInput(
1465            torch.tensor([[math.nan]], **tensor_options),
1466            torch.tensor([0.0], **tensor_options),
1467            torch.tensor([0.0], **tensor_options),
1468            beta=0.0,
1469            alpha=0.0,
1470        ).with_metadata(broadcasts_input=True)
1471
1472        yield SampleInput(
1473            torch.tensor([[0.0]], **tensor_options),
1474            torch.tensor([math.nan], **tensor_options),
1475            torch.tensor([math.nan], **tensor_options),
1476            beta=0.0,
1477            alpha=0.0,
1478        ).with_metadata(broadcasts_input=True)
1479
1480def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs):
1481    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1482
1483    cases = ((), (S, S, S), (S,))
1484
1485    for shape in cases:
1486        yield SampleInput(make_arg(shape))
1487
1488def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
1489    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1490    make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
1491    make_weight = partial(_make_tensor, requires_grad=False)
1492
1493    inputs = (
1494        ((), make_target([], low=0, high=1), {}),
1495        ((S,), make_target([], low=0, high=S), {"p": 1}),
1496        ((S,), make_target([1], low=0, high=S), {"p": 2}),
1497        ((S, M), make_target([S], low=0, high=M), {"margin": 1.0}),
1498        ((S, M), make_target([S], low=0, high=M), {"margin": -3.14}),
1499        ((M, S), make_target([M], low=0, high=S), {"weight": None}),
1500        ((M, S), make_target([M], low=0, high=S), {"weight": make_weight([S], low=-10., high=10.)}),
1501        ((M, S), make_target([M], low=0, high=S), {"reduction": "none"}),
1502        ((M, S), make_target([M], low=0, high=S), {"reduction": "mean"}),
1503        ((M, S), make_target([M], low=0, high=S), {"reduction": "sum"}),
1504    )
1505
1506    for input_shape, target, kwargs in inputs:
1507        yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs)
1508
1509
1510def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
1511    yield from sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs)
1512    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1513    make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
1514    make_weight = partial(_make_tensor, requires_grad=False)
1515
1516    inputs = (
1517        ((), make_target([], low=0, high=1)),
1518        ((S,), make_target([], low=0, high=S)),
1519        ((S,), make_target([1], low=0, high=S)),
1520        ((M, S), make_target([M], low=0, high=S)),
1521    )
1522    ps = (1, 2)
1523    margins = (0, 7, -3.14)
1524    weights = (False, True)
1525    reductions = (None, "none", "mean", "sum")
1526
1527    for (input_shape, target), p, margin, weight, reduction in product(inputs, ps, margins, weights, reductions):
1528        input = _make_tensor(input_shape)
1529        weight_shape = [input.size(-1)] if input.ndim > 0 else [1]
1530        weight = make_weight(weight_shape, low=-10., high=10.) if weight else None
1531        kwargs = {"p": p, "margin": margin, "weight": weight}
1532        if reduction is not None:
1533            kwargs["reduction"] = reduction
1534        yield SampleInput(input, args=(target,), kwargs=kwargs)
1535
1536
1537def error_inputs_multi_margin_loss(op, device, **kwargs):
1538    make_input = partial(make_tensor, device=device, dtype=torch.float32)
1539    # invalid reduction
1540    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'reduction': 'abc'}),
1541                     error_type=ValueError, error_regex='abc is not a valid value for reduction')
1542    # invalid input
1543    yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5,),), kwargs={}),
1544                     error_type=RuntimeError,
1545                     error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]')
1546    yield ErrorInput(SampleInput(make_input(0,), args=(make_input(5,),), kwargs={}),
1547                     error_type=RuntimeError,
1548                     error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]')
1549    # invalid target
1550    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={}),
1551                     error_type=RuntimeError, error_regex=r'inconsistent target size, expected 5 but got \[5, 4\]')
1552    # invalid target dtype
1553    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={}),
1554                     error_type=RuntimeError, error_regex='expected scalar type Long but found Float')
1555    # invalid weight
1556    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(())}),
1557                     error_type=ValueError, error_regex='weight must be one-dimensional')
1558    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5, 4)}),
1559                     error_type=ValueError, error_regex='weight must be one-dimensional')
1560    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5,)}),
1561                     error_type=RuntimeError, error_regex=r'inconsistent weight size, expected 4 but got \[5\]')
1562    # invalid p
1563    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'p': 3}),
1564                     error_type=ValueError, error_regex='only p == 1 and p == 2 supported')
1565
1566
1567def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs):
1568    inputs = (
1569        ((), (0,), True),
1570        ((S, S), (1,), True),
1571        ((S, S), (1,), False),
1572        ((S, S), (-2,), False),
1573        ((S, S), (0, 1), False),
1574    )
1575    # Test large inputs to check numerical stability
1576    lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128) else (None,)
1577    for low in lows:
1578        high = low * 2 if low is not None else None
1579        for shape, dim, keepdim in inputs:
1580            t = make_tensor(shape, dtype=dtype, device=device,
1581                            low=low, high=high,
1582                            requires_grad=requires_grad)
1583            yield SampleInput(t, dim, keepdim)
1584
1585def reference_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs):
1586    yield from sample_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs)
1587
1588    # https://github.com/pytorch/pytorch/issues/91843
1589    t = torch.tensor([20, 30, 100], dtype=dtype, device=device, requires_grad=requires_grad)
1590    yield SampleInput(t, 0, False)
1591
1592    t = torch.tensor((), dtype=dtype, device=device, requires_grad=requires_grad)
1593    yield SampleInput(t, 0, False)
1594
1595    # tests masking
1596    # https://github.com/pytorch/pytorch/pull/91860#pullrequestreview-1241344073
1597    t = torch.tensor(float("inf"))
1598    yield SampleInput(t, 0, True)
1599
1600def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
1601    inputs = [
1602        ((), {}),
1603        ((S, S), {}),
1604        ((0, S, 0), {}),
1605        ((S,), {'dtype': dtype, 'device': device}),
1606        # Hard-code some dtypes/devices. We want to test cases where the
1607        # (dtype, device) is different from the input's (dtype, device)
1608        ((S,), {'dtype': torch.double}),
1609        ((S,), {'device': 'cpu'}),
1610        ((S,), {'dtype': torch.double, 'device': 'cpu'}),
1611    ]
1612    if torch.cuda.is_available():
1613        inputs.append(((S,), {'device': 'cuda'}))
1614
1615    for shape, kwargs in inputs:
1616        t = make_tensor(shape, dtype=dtype, device=device,
1617                        low=None, high=None,
1618                        requires_grad=requires_grad)
1619        yield SampleInput(t, **kwargs)
1620
1621def reference_inputs_like_fns(op, device, dtype, requires_grad, **kwargs):
1622    yield from sample_inputs_like_fns(op, device, dtype, requires_grad, **kwargs)
1623
1624    # shape
1625    cases = (
1626        (), (0,), (1, 0), (1, 1, 4, 5), (5, 3, 0, 1), (1, 4, 3, 1, 1)
1627    )
1628
1629    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
1630    for shape in cases:
1631        yield SampleInput(make_arg(shape))
1632        yield SampleInput(make_arg(shape).transpose(0, -1))
1633        yield SampleInput(make_arg(shape, noncontiguous=True))
1634        yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1))
1635
1636def sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
1637    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1638    make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
1639
1640    inputs = (
1641        ([], make_target([], low=0, high=1), {}),
1642        ([S], make_target([S], low=0, high=S), {}),
1643        ([M, S], make_target([M, S], low=0, high=S), {}),
1644        ([M, S], make_target([M, S], low=0, high=S), {"reduction": "none"}),
1645        ([M, S], make_target([M, S], low=0, high=S), {"reduction": "mean"}),
1646        ([M, S], make_target([M, S], low=0, high=S), {"reduction": "sum"}),
1647    )
1648
1649    for shape, target, kwargs in inputs:
1650        yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs)
1651
1652
1653def reference_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
1654    yield from sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs)
1655    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1656    make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
1657    make_target_tensor = partial(torch.tensor, device=device, dtype=torch.long, requires_grad=False)
1658
1659    inputs = (
1660        # random tests including -1 target labels
1661        ([], make_target([], low=-1, high=1)),
1662        ([S], make_target([S], low=-1, high=S)),
1663        ([M, S], make_target([M, S], low=-1, high=S)),
1664        # repeated target labels and -1 (labels after the first -1 are ignored)
1665        ([], make_target_tensor(-1)),
1666        ([7], make_target_tensor([2, 0, 6, -1, 4, -1, 6])),
1667        ([4, 5], make_target_tensor([[4, -1, 0, -1, 2], [0, 0, 4, 1, 4], [-1, 3, -1, 1, 0], [4, 3, 2, 1, 0]])),
1668    )
1669    reductions = (None, "none", "mean", "sum")
1670
1671    for (shape, target), reduction in product(inputs, reductions):
1672        kwargs = {}
1673        if reduction is not None:
1674            kwargs["reduction"] = reduction
1675        yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs)
1676
1677
1678def error_inputs_multilabel_margin_loss(op, device, **kwargs):
1679    make_input = partial(make_tensor, device=device, dtype=torch.float32)
1680    # invalid reduction
1681    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}),
1682                     error_type=ValueError, error_regex='abc is not a valid value for reduction')
1683    # invalid input
1684    yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5, 4),), kwargs={}),
1685                     error_type=RuntimeError,
1686                     error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]')
1687    yield ErrorInput(SampleInput(make_input(0,), args=(make_input(0,),), kwargs={}),
1688                     error_type=RuntimeError,
1689                     error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]')
1690    # invalid target
1691    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(4,),), kwargs={}),
1692                     error_type=RuntimeError,
1693                     error_regex=r'inconsistent target size: \[4\] for input of size: \[5, 4\]')
1694    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input((),),), kwargs={}),
1695                     error_type=RuntimeError,
1696                     error_regex=r'inconsistent target size: \[\] for input of size: \[5, 4\]')
1697
1698
1699def get_independent_tensor(tensor):
1700    return tensor.clone().requires_grad_(tensor.requires_grad)
1701
1702def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs):
1703    low = 2
1704    high = 10
1705
1706    for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
1707        sample.kwargs.setdefault('device', device)
1708        # With high
1709        yield SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs)
1710        # With low and high
1711        yield SampleInput(low, high, sample.input.shape, *sample.args, **sample.kwargs)
1712
1713def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs):
1714    low = 2
1715    high = 10
1716
1717    for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
1718        # With high
1719        yield SampleInput(
1720            sample.input,
1721            high,
1722            *sample.args,
1723            **sample.kwargs)
1724        # With low and high
1725        yield SampleInput(
1726            get_independent_tensor(sample.input),
1727            low,
1728            high,
1729            *sample.args,
1730            **sample.kwargs)
1731
1732def sample_inputs_margin_ranking_loss(op_info, device, dtype, requires_grad, **kwargs):
1733    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1734
1735    shapes = (
1736        (),
1737        (S,),
1738        (S, S),
1739        (S, S, S),
1740    )
1741
1742    margins = (0., 1.)
1743    reductions = ('sum', 'mean', 'none')
1744
1745    for shape in shapes:
1746        for margin, reduction in product(margins, reductions):
1747            kwargs = {'margin': margin, 'reduction': reduction}
1748            yield SampleInput(_make_tensor(shape),
1749                              args=(_make_tensor(shape, requires_grad=False),
1750                                    _make_tensor(shape, requires_grad=False)),
1751                              kwargs=kwargs)
1752
1753def reference_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs):
1754    yield from sample_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs)
1755    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
1756
1757    for reduction in ('sum', 'mean', 'none'):
1758        if dtype.is_floating_point:  # only supports ints and floats
1759            # NaN propagation
1760            inp1 = make_input((10, ))
1761            inp1[2] = float('nan')
1762            inp2 = make_input((10, ))
1763            inp2[4] = float('nan')
1764            target = make_input((10, ))
1765            inp2[9] = float('nan')
1766            yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction})
1767
1768            # Inf handling
1769            inp1 = make_input((10, ))
1770            inp2[1] = float('inf')
1771            inp2 = make_input((10, ))
1772            inp2[4] = float('inf')
1773            target = make_input((10, ))
1774            inp2[7] = float('inf')
1775            yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction})
1776
1777        # Broadcasting
1778        inp1 = make_input((5, 2))
1779        inp2 = make_input((5, 1))
1780        target = make_input((1, 2))
1781        yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction})
1782
1783def error_inputs_margin_ranking_loss(op, device, **kwargs):
1784    make_input = partial(make_tensor, device=device, dtype=torch.float32)
1785    # invalid reduction value.
1786    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5, 4),), kwargs={'reduction': 'abc'}),
1787                     error_type=ValueError, error_regex='is not a valid value')
1788    # invalid input shapes
1789    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5,),)),
1790                     error_regex='margin_ranking_loss : All input tensors should')
1791
1792def sample_inputs_new_fns(self, device, dtype, requires_grad, *, is_strided=False, **kwargs):
1793    # input_shape, output_shape, strides, kwargs
1794    # lengths of output_shape and strides must be equal
1795    inputs = [
1796        ((), (), (), {}),
1797        ((S, S), (2, 0), (3, 4), {}),
1798        ((0, S, 0), (3, 2, 2), (1, 2, 3), {}),
1799        ((S,), (2, 3), (7, 8), {'dtype': dtype, 'device': device}),
1800        # Hard-code some dtypes/devices. We want to test cases where the
1801        # (dtype, device) is different from the input's (dtype, device)
1802        ((S,), (10,), (S,), {'dtype': torch.double}),
1803        ((S,), (1, 1, 12), (S, L, M), {'device': 'cpu'}),
1804        ((S,), (2, 2, 2), (L, M, S), {'dtype': torch.double, 'device': 'cpu'}),
1805    ]
1806    if torch.cuda.is_available():
1807        inputs.append(((S,), (7, 2), (3, 4), {'device': 'cuda'}))
1808
1809    for input_shape, output_shape, strides, kwargs in inputs:
1810        t = make_tensor(input_shape, dtype=dtype, device=device,
1811                        low=None, high=None,
1812                        requires_grad=requires_grad)
1813        if is_strided:
1814            yield SampleInput(t, output_shape, strides, **kwargs)
1815        else:
1816            yield SampleInput(t, output_shape, **kwargs)
1817
1818def sample_inputs_empty_strided(op, device, dtype, requires_grad=False, **kwargs):
1819
1820    inputs = [
1821        ((), (), {'dtype': dtype, 'device': device}),
1822        ((S,), (4,), {'dtype': dtype, 'device': device}),
1823        ((S, S), (2, 1), {'dtype': dtype, 'device': device}),
1824        ((S, S, S), (2, 0, 1), {'dtype': dtype, 'device': device}),
1825    ]
1826
1827    for shape, strides, kwargs in inputs:
1828        yield SampleInput(shape, strides, requires_grad=requires_grad, **kwargs)
1829
1830def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs):
1831    # shape
1832    cases = (
1833        (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1),
1834    )
1835
1836    for case in cases:
1837        yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad)
1838
1839def sample_inputs_empty_permuted(op, device, dtype, requires_grad, **kwargs):
1840    # shape
1841    cases = (
1842        (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1),
1843    )
1844
1845    for case in cases:
1846        for layout in itertools.permutations(range(len(case))):
1847            yield SampleInput(case, layout, device=device, dtype=dtype, requires_grad=requires_grad)
1848
1849def error_inputs_empty_permuted(op_info, device, **kwargs):
1850    yield ErrorInput(
1851        SampleInput((2,), args=((0, 1),)),
1852        error_type=RuntimeError,
1853        error_regex="Number of dimensions in size does not match the length of the physical_layout"
1854    )
1855    yield ErrorInput(
1856        SampleInput((2,), args=((3,),)),
1857        error_type=RuntimeError,
1858        error_regex="Dimension out of range"
1859    )
1860    yield ErrorInput(
1861        SampleInput((2, 3), args=((0, 0),)),
1862        error_type=RuntimeError,
1863        error_regex="Duplicate dim not allowed"
1864    )
1865
1866def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs):
1867    # Not including a scalar tensor in vals because meta tests start failing due to
1868    # lack of meta support for _local_scalar_dense
1869    # torch.tensor(2, device=device)
1870    vals = (-5, 0, 1)
1871
1872    for item in vals:
1873        yield SampleInput(item, device=device, dtype=dtype, requires_grad=requires_grad)
1874
1875def sample_inputs_eye(op, device, dtype, requires_grad, **kwargs):
1876    # only ints >= 0 are allowed for both arguments, unless m is omitted
1877    sizes = (None, 0, 1, 2, 3, 4, 7, L, M, S)
1878
1879    for n, m in product(sizes, sizes):
1880        if n is None:
1881            continue
1882
1883        # TODO: no layout
1884        _kwargs = {'device': device, 'dtype': dtype, 'requires_grad': requires_grad}
1885        if m is None:
1886            yield SampleInput(n, args=(), kwargs=_kwargs)
1887        else:
1888            yield SampleInput(n, args=(m,), kwargs=_kwargs)
1889
1890def error_inputs_eye(op_info, device, **kwargs):
1891    # TODO: no layout
1892    _kwargs = {'device': device, 'dtype': torch.float32}
1893
1894    yield ErrorInput(
1895        SampleInput(-1, args=(), kwargs=_kwargs),
1896        error_regex="n must be greater or equal to 0, got -1"
1897    )
1898
1899    yield ErrorInput(
1900        SampleInput(-7, args=(42,), kwargs=_kwargs),
1901        error_regex="n must be greater or equal to 0, got -7"
1902    )
1903
1904    yield ErrorInput(
1905        SampleInput(0, args=(-3,), kwargs=_kwargs),
1906        error_regex="m must be greater or equal to 0, got -3"
1907    )
1908
1909
1910def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs):
1911    def get_val(dtype):
1912        return make_tensor([], dtype=dtype, device="cpu").item()
1913
1914    for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
1915        # The scalar we are passing to new_full must be the same dtype
1916        # as the one of the resulting tensor
1917        use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype
1918        yield SampleInput(
1919            sample.input, *sample.args, get_val(use_dtype), **sample.kwargs)
1920
1921def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs):
1922    def get_val(dtype):
1923        return make_tensor([], dtype=dtype, device="cpu").item()
1924
1925    inputs = [
1926        ((), get_val(dtype), {}),
1927        ((S, S), get_val(dtype), {}),
1928        ((0, S, 0), get_val(dtype), {}),
1929        ((S,), get_val(dtype), {'dtype': dtype, 'device': device}),
1930        # Hard-code some dtypes/devices. We want to test cases where the
1931        # (dtype, device) is different from the input's (dtype, device)
1932        ((S,), get_val(torch.double), {'dtype': torch.double}),
1933        ((S,), get_val(dtype), {'device': 'cpu'}),
1934        ((S,), get_val(torch.double), {'dtype': torch.double, 'device': 'cpu'}),
1935    ]
1936    if torch.cuda.is_available():
1937        inputs.append(((S,), get_val(dtype), {'device': 'cuda'}))
1938
1939    for shape, fill_value, kwargs in inputs:
1940        t = make_tensor(shape, dtype=dtype, device=device,
1941                        low=None, high=None,
1942                        requires_grad=requires_grad)
1943        yield SampleInput(t, fill_value, **kwargs)
1944
1945def sample_inputs_multinomial(self, device, dtype, requires_grad, **kwargs):
1946    cases = [
1947        ([3], 3, {}),
1948        ([10], 3, {}),
1949        ([3, 10], 3, {}),
1950        ([3], 3, dict(replacement=False)),
1951        ([3], 3, dict(replacement=True)),
1952        ([3, 4], 4, dict(replacement=True)),
1953        ([3, 4], 4, dict(replacement=False)),
1954    ]
1955
1956    for shape, num_samples, kwargs in cases:
1957        t = make_tensor(shape, dtype=dtype, device=device,
1958                        low=0, high=None,
1959                        requires_grad=requires_grad)
1960        yield SampleInput(t, num_samples, **kwargs)
1961
1962def sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs):
1963    def get_value_or_make_tensor(value_or_shape):
1964        if isinstance(value_or_shape, list):
1965            return make_tensor(value_or_shape, dtype=dtype, device=device,
1966                               low=0, high=None,
1967                               requires_grad=requires_grad)
1968        return value_or_shape
1969
1970    for value_or_mean_shape, value_or_std_shape, kwargs in cases:
1971        mean = get_value_or_make_tensor(value_or_mean_shape)
1972        std = get_value_or_make_tensor(value_or_std_shape)
1973        yield SampleInput(mean, std, **kwargs)
1974
1975def sample_inputs_normal_tensor_first(self, device, dtype, requires_grad, **kwargs):
1976    # value_or_size, value_or_size, kwargs
1977    cases = [
1978        ([], [], {}),
1979        ([3], [3], {}),
1980        ([3, 4, 2], [3, 4, 2], {}),
1981        ([2, 3], 1.1, {}),
1982        ([1, 2, 3], [5, 2, 3], {}),  # broadcasting
1983    ]
1984
1985    return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs)
1986
1987def sample_inputs_normal_tensor_second(self, device, dtype, requires_grad, **kwargs):
1988    yield SampleInput(1.6, 0.3, [2, 3], dtype=dtype, device=device)
1989    yield SampleInput(1.6, 0.3, [2, 2, 2], dtype=dtype, layout=torch.strided, device=device)
1990    yield SampleInput(2.7, make_tensor([4, 3], dtype=dtype, device=device, low=0, high=None, requires_grad=requires_grad))
1991
1992def sample_inputs_bernoulli(self, device, dtype, requires_grad, **kwargs):
1993    shapes = [
1994        [3],
1995        [],
1996        [0, 3],
1997        [2, 3, 4],
1998    ]
1999
2000    for shape in shapes:
2001        t = make_tensor(shape, dtype=dtype, device=device,
2002                        low=0, high=1,
2003                        requires_grad=requires_grad)
2004        yield SampleInput(t)
2005
2006def error_inputs_bernoulli(op_info, device, **kwargs):
2007    # more than one element of the written-to tensor refers to a single memory location
2008    x = torch.rand((1,), device=device).expand((6,))
2009    err_msg = 'unsupported operation'
2010    yield ErrorInput(SampleInput(torch.rand_like(x), kwargs={'out': x}),
2011                     error_regex=err_msg)
2012
2013def sample_inputs_logcumsumexp(self, device, dtype, requires_grad, **kwargs):
2014    inputs = (
2015        ((S, S, S), 0),
2016        ((S, S, S), 1),
2017        ((), 0),
2018    )
2019
2020    for large_number in (True, False):
2021        for shape, dim in inputs:
2022            t = make_tensor(shape, dtype=dtype, device=device,
2023                            low=None, high=None,
2024                            requires_grad=requires_grad)
2025
2026            if large_number and t.dim() > 0:
2027                t[0] = 10000
2028            yield SampleInput(t, dim)
2029
2030def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs):
2031    yield SampleInput(
2032        make_tensor((S, S), dtype=dtype, device=device,
2033                    low=None, high=None,
2034                    requires_grad=requires_grad))
2035
2036
2037def error_inputs_trace(op, device):
2038    yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix")
2039
2040
2041def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs):
2042    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
2043    cases = (((S, S, S), (2, 1, 0.5)),
2044             ((S, S, S), (2, -1, 0.5)),
2045             ((S, S, S), (1, 2, 3)),
2046             ((S, S, S), (float('inf'), 2, 0.5)),
2047             )
2048
2049    for shape, args in cases:
2050        yield SampleInput(make_arg(shape), args=args)
2051
2052
2053def sample_inputs_transpose_swapdims(self, device, dtype, requires_grad, **kwargs):
2054    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
2055
2056    cases = (((1, 2, 3), (-1, -2)),
2057             ((1, 2, 3), (-1, 2)),
2058             ((1, 2, 3), (1, -2)),
2059             ((1, 2, 3), (1, 2)),
2060             ((), (0, 0)),
2061             ((1, ), (0, 0)),
2062             ((M, M), (0, 1)),
2063             ((S, S, S), (2, 0)), )
2064
2065    for shape, args in cases:
2066        yield SampleInput(make_arg(shape), args=args)
2067
2068def _numpy_ref_transpose(a, dim0, dim1):
2069    if a.ndim <= 1:
2070        return a
2071
2072    return np.swapaxes(a, dim0, dim1)
2073
2074def sample_inputs_adjoint(self, device, dtype, requires_grad, **kwargs):
2075    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
2076
2077    shapes = ((1, 2, 3), (M, M), (S, S, S), (S, M, S), (M, S, M, S))
2078    return (SampleInput(make_arg(shape)) for shape in shapes)
2079
2080def sample_inputs_T(self, device, dtype, requires_grad, **kwargs):
2081    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
2082
2083    shapes = ((M, M), (M, L))
2084    return (SampleInput(make_arg(shape)) for shape in shapes)
2085
2086def error_inputs_T(self, device, has_ndims_error=False):
2087    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
2088
2089    # Deprecated behavior in regular PyTorch, but throws an error in primTorch:
2090    # https://github.com/pytorch/pytorch/issues/86968
2091    if has_ndims_error:
2092        # ndims == 1
2093        yield ErrorInput(SampleInput(make_arg(M)),
2094                         error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 '
2095                                      r'to reverse their shape is not supported\.'))
2096
2097        # ndims > 2
2098        yield ErrorInput(SampleInput(make_arg(M, S, L)),
2099                         error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 '
2100                                      r'to reverse their shape is not supported\.'))
2101
2102
2103def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=False):
2104    """
2105    This function produces two tensors of shape (*, m, k) and (*, n, k) with k <= min(m, n).
2106    Their matrix product could be used to generate tensor of shape (*, m, n) of rank k.
2107    """
2108
2109    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2110    batches = [(), (2,)]
2111    size = [3, 4]
2112    for batch, m, n in product(batches, size, size):
2113        k = 2
2114        a = make_arg((*batch, m, k))
2115        b = make_arg((*batch, n, k))
2116        yield a, b
2117
2118
2119def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwargs):
2120    # Function that's well defined on the outputs for complex inputs
2121    def fn(usv):
2122        U, S, V = usv
2123        return U @ V.mH, S
2124
2125    for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad):
2126        *batch, m, k = a.shape
2127        n = b.shape[-2]
2128
2129        # NOTE: since svd_lowrank relies on non rank-revealing SVD,
2130        # it inherits the problem of unstable behavior with repeated
2131        # singular values including zeros.
2132        # Since we want to avoid (repeated) zeros as singular values,
2133        # we can only use k for q.
2134        # This issues could be resolved with using a rank-revealing SVD
2135        # which does not include "zero" singular values.
2136        yield SampleInput(a, b, q=k, M=None).with_metadata(output_process_fn_grad=fn)
2137
2138    for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad):
2139        *batch, m, k = a.shape
2140        n = b.shape[-2]
2141        M = make_tensor((*batch, m, n), dtype=dtype, device=device, requires_grad=requires_grad)
2142        yield SampleInput(a, b, q=k, M=M).with_metadata(output_process_fn_grad=fn)
2143
2144def chunk_iter(iterable, size):
2145    it = iter(iterable)
2146    while True:
2147        chunk = tuple(islice(it, size))
2148        if not chunk:
2149            break
2150        yield chunk
2151
2152def sample_inputs_pca_lowrank(op_info, device, dtype, requires_grad=False, **kwargs):
2153    # we reuse samples from svd_lowrank which come in group of two with
2154    # kwarg['M'] = None and with kwarg['M'] = <some tensor>
2155    samples = sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad, **kwargs)
2156    for s1, s2 in chunk_iter(samples, 2):
2157        del s1.kwargs['M']
2158        del s2.kwargs['M']
2159        s1.kwargs['center'] = False
2160        s2.kwargs['center'] = True
2161        yield s1
2162        yield s2
2163
2164def np_sinc_with_fp16_as_fp32(x):
2165    # Wraps numpy's sinc function so that fp16 values are promoted to fp32
2166    # before sinc is invoked. Context: numpy's sinc returns NaN when evaluated
2167    # at 0 for fp16.
2168    if x.dtype == np.float16:
2169        return np.sinc(x.astype(np.float32))
2170    else:
2171        return np.sinc(x)
2172
2173def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad, **kwargs):
2174    test_cases = (
2175        ((S, 1, 1), (S, S, S)),
2176        ((S, 1, S), (S, S, S)),
2177        ((S, 1), (S, S, S)),
2178        ((1,), (S, S, S)),
2179        ((1, S), (1, 1, S)),
2180        ((), ()),
2181        ((), (1, 3, 2)),
2182    )
2183
2184    return (
2185        SampleInput(
2186            make_tensor(size, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad),
2187            shape,
2188        ) for size, shape in test_cases)
2189
2190def sample_inputs_broadcast_tensors(op_info, device, dtype, requires_grad, **kwargs):
2191    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
2192    test_cases: Tuple[tuple] = (((3,), (1, 2, 1), (1, 1), (5, 1, 1),),)
2193
2194    for shape, *other_shapes in test_cases:
2195        yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes))
2196
2197def reference_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs):
2198    yield from sample_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs)
2199
2200    m = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
2201    n = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True)
2202
2203    cases = (
2204        ((), (1, 1), (1, 1, 7, 1), (3, 1, 1)),
2205        ((3, 5, 6), (1, 3, 5, 6), (1, 1, 1, 1, 6), (8, 3, 5, 6))
2206    )
2207
2208    for a, b, c, d in cases:
2209        yield SampleInput(m(a), args=(m(b), m(c), m(d)))
2210        yield SampleInput(n(a), args=(n(b), n(c), n(d)))
2211
2212def sample_inputs_block_diag(op_info, device, dtype, requires_grad, **kwargs):
2213    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
2214    test_cases: Tuple[tuple] = (
2215        ((1, S), (2, S), (3, S),),
2216        ((S, 1), (S, 2), (S, 3),),
2217        ((1,), (2,), (3,),),
2218        ((2, S), (S,))
2219    )
2220
2221    for shape, *other_shapes in test_cases:
2222        yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes))
2223        # We also want to test mixed complex-non-complex inputs to block_diag
2224        if dtype == torch.complex32 or dtype == torch.complex64:
2225            non_complex_dtype = torch.float32 if dtype == torch.complex32 else torch.float64
2226            make_arg_non_complex = partial(make_tensor, dtype=non_complex_dtype, device=device, requires_grad=requires_grad)
2227            yield SampleInput(make_arg_non_complex(shape), args=tuple(make_arg(s) for s in other_shapes))
2228
2229def sample_inputs_cdist(op_info, device, dtype, requires_grad, **kwargs):
2230    small_S = 2
2231    test_cases = (
2232        ((S, S, 2), (S, S + 1, 2)),
2233        ((S, S), (S, S)),
2234        ((S, S, S), (S, S, S)),
2235        ((3, 5), (3, 5)),
2236        ((2, 3, 5), (2, 3, 5)),
2237        ((1, 2, 3), (1, 2, 3)),
2238        ((1, 1), (S, 1)),
2239        ((0, 5), (4, 5)),
2240        ((4, 5), (0, 5)),
2241        ((0, 4, 5), (3, 5)),
2242        ((4, 5), (0, 3, 5)),
2243        ((0, 4, 5), (1, 3, 5)),
2244        ((1, 4, 5), (0, 3, 5)),
2245        # Using S here would make this one test take 9s
2246        ((small_S, small_S, small_S + 1, 2), (small_S, small_S, small_S + 2, 2)),
2247        ((small_S, 1, 1, small_S), (1, small_S, small_S)),
2248        ((1, 1, small_S), (small_S, 1, small_S, small_S)),
2249    )
2250
2251    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2252    for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
2253        # FIXME add an override for JIT and revert 0. back to 0
2254        # since it's accepted by eager
2255        for p in [0., 1., 2., 3., 0.5, 1.5, 2.5, float("inf")]:
2256            for t1_size, t2_size in test_cases:
2257                # The args should never be non-contiguous as this is not supported in the backward
2258                yield SampleInput(make_arg(t1_size), make_arg(t2_size), p, cm)
2259
2260def _fill_np(a, value):
2261    a = a.copy()
2262    a.fill(value)
2263    return a
2264
2265def _fill_sample_kwargs(device, dtype, input):
2266    if dtype is torch.bool:
2267        value = True
2268    else:
2269        value = 3
2270
2271    return ({'value': value}, {'value': value})
2272
2273def sample_inputs_comparison_ops(op, device, dtype, requires_grad, **kwargs):
2274    yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs)
2275
2276    # Adds a sample input where both tensors have the same values
2277    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2278
2279    lhs = make_arg((S, S))
2280    yield SampleInput(lhs, args=(lhs.clone(),))
2281
2282def sample_inputs_stack(op_info, device, dtype, requires_grad, **kwargs):
2283    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2284
2285    # shape x number of tensors
2286    cases = (
2287        ((3, 4), 1),
2288        ((1, 2, 1, 4), 3),
2289        ((0, 1, 0), 2),)
2290
2291    for shape, num_tensors in cases:
2292        tensors = []
2293        for _ in range(num_tensors):
2294            tensors.append(make_arg(shape))
2295        for dim in range(-1, len(shape) - 1):
2296            yield SampleInput(tensors, args=(dim,))
2297
2298
2299def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs):
2300    # 1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors.
2301    #    If all input tensors have the same ndims, we support both negative and non-negative dim.
2302    # 2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions.
2303    #        No requirements for (wrapped_dim, ...)-th dimension.
2304    # 3. Expect positive num_chunks
2305    # 4. Expect non-empty input tensor list and each input tensor should have at least 1 element
2306    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2307    same_ndim_cases = (
2308        (
2309            [
2310                torch.Size([1, 2, 3]),
2311                torch.Size([1, 2, 3]),
2312            ], -1, 5
2313        ),
2314        (
2315            [
2316                torch.Size([1, 2, 129]),
2317                torch.Size([1, 2, 297]),
2318            ], -1, 5
2319        ),
2320        (
2321            [
2322                torch.Size([1, 2, 3]),
2323                torch.Size([1, 2, 3]),
2324            ], 1, 5
2325        ),
2326        (
2327            [
2328                torch.Size([3, 3, 2, 1]),
2329                torch.Size([1, 4, 2, 2]),
2330                torch.Size([2, 1, 3, 3]),
2331            ], 0, 2
2332        ),
2333    )
2334    for sizes, dim, num_chunks in same_ndim_cases:
2335        tensors = []
2336        for size in sizes:
2337            tensors.append(make_arg(size))
2338        yield SampleInput(tensors, args=(dim, num_chunks))
2339
2340    different_ndim_case = [
2341        torch.Size([2, 3, 3]),
2342        torch.Size([2, 3, 1, 2]),
2343        torch.Size([2, 3]),
2344        torch.Size([2, 3, 2]),
2345        torch.Size([2, 3, 271]),
2346    ]
2347    max_dim, num_chunks = 2, 3
2348    for dim in range(max_dim):
2349        tensors = []
2350        for size in different_ndim_case:
2351            tensors.append(make_arg(size))
2352        yield SampleInput(tensors, args=(dim, num_chunks))
2353
2354
2355def error_inputs_chunk_cat(op_info, device, **kwargs):
2356    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
2357
2358    # input tensors have different ndims but dim is negative
2359    sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], -1, 3
2360    tensors = [make_arg(size) for size in sizes]
2361    yield ErrorInput(
2362        SampleInput(tensors, args=(dim, num_chunks)),
2363        error_regex='_chunk_cat expects non-negative dim when input tensors have different ndims',
2364    )
2365
2366    # input tensors have different ndims but dim >= ndim of some input tensors
2367    sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], 1, 3
2368    tensors = [make_arg(size) for size in sizes]
2369    yield ErrorInput(
2370        SampleInput(tensors, args=(dim, num_chunks)),
2371        error_regex='_chunk_cat expects dim < ndim for all input tensors',
2372    )
2373
2374    # some tensors have different sizes for 0, ..., dim-1 dimensions.
2375    sizes, dim, num_chunks = [torch.Size([2, 3, 4]), torch.Size([4, 3])], 1, 3
2376    tensors = [make_arg(size) for size in sizes]
2377    yield ErrorInput(
2378        SampleInput(tensors, args=(dim, num_chunks)),
2379        error_regex='_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors',
2380    )
2381
2382    # negative num_chunks
2383    sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, -1
2384    tensors = [make_arg(size) for size in sizes]
2385    yield ErrorInput(
2386        SampleInput(tensors, args=(dim, num_chunks)),
2387        error_regex='_chunk_cat expects positive num_chunks',
2388    )
2389
2390    # zero as num_chunks
2391    sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, 0
2392    tensors = [make_arg(size) for size in sizes]
2393    yield ErrorInput(
2394        SampleInput(tensors, args=(dim, num_chunks)),
2395        error_regex='_chunk_cat expects positive num_chunks',
2396    )
2397
2398    # empty input tensor list
2399    dim, num_chunks = 0, 1
2400    yield ErrorInput(
2401        SampleInput([], args=(dim, num_chunks)),
2402        error_regex='_chunk_cat expects a non-empty input tensor list',
2403    )
2404
2405    # empty input tensor with 0 elements
2406    sizes, dim, num_chunks = [torch.Size([0,]), torch.Size([3,])], 0, 1
2407    tensors = [make_arg(size) for size in sizes]
2408    yield ErrorInput(
2409        SampleInput(tensors, args=(dim, num_chunks)),
2410        error_regex='_chunk_cat expects non-empty tensor',
2411    )
2412
2413
2414def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs):
2415    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2416
2417    cases: Tuple[tuple, tuple, dict] = (  # type: ignore[assignment]
2418        ((S, S), (S, S), {'dim': -1}),
2419        ((S, S), (S, S), {'dim': 1}),
2420        ((M, S), (S, S), {'dim': 0}),  # different shapes
2421        ((1, 2, 3), (1, 2, 3), {'dim': -2}),
2422        ((0,), (0,), {'dim': 0}),  # empty tensor
2423        ((0,), (S, S), {'dim': 1}),  # empty tensor with unempty and dim=1 (special case for legacy_cat_wrap_dim)
2424        ((0, S), (S, S), {'dim': 0}),
2425        ((1,), (1,), {})  # dim not passed, fallback to default
2426    )
2427
2428    for input_shape1, input_shape2, kwargs in cases:
2429        yield SampleInput([make_arg(input_shape1), make_arg(input_shape2)], kwargs=kwargs)
2430
2431    # from coat_lite_mini
2432    yield SampleInput([make_arg((2, 2, 2, 2), memory_format=torch.channels_last)], args=(1,),)
2433
2434def error_inputs_cat(op_info, device, **kwargs):
2435
2436    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
2437
2438    # error inputs for more than one element of the written-to tensor refer to a single memory location
2439    yield ErrorInput(SampleInput([make_arg((S, S)), make_arg((S, S))],
2440                                 kwargs={'out': make_arg((1, S)).expand((2 * S, S))}),
2441                     error_regex='unsupported operation')
2442
2443    # error inputs for empty tensors
2444    yield ErrorInput(SampleInput([], kwargs={'dim': 1}),
2445                     error_regex='non-empty list of Tensors')
2446
2447    # error inputs for different sizes
2448    yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}),
2449                     error_regex='Sizes of tensors must match except in dimension')
2450    yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S, S, L, L))], kwargs={'dim': 1}),
2451                     error_regex='Sizes of tensors must match except in dimension')
2452
2453    # error inputs for different dimensions
2454    yield ErrorInput(SampleInput([make_arg((S - 1, 0)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}),
2455                     error_regex='Tensors must have same number of dimensions')
2456    yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S - 1, 0))], kwargs={'dim': 1}),
2457                     error_regex='Tensors must have same number of dimensions')
2458
2459    # error inputs for same memory locations
2460    x = torch.zeros((0), device=device)
2461    y = torch.randn((4, 6), device=device)
2462
2463    err_msg = "the written-to tensor refer to a single memory location"
2464
2465    yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': x}),
2466                     error_regex=err_msg)
2467    yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': y}),
2468                     error_regex=err_msg)
2469
2470    z = torch.zeros((4, 6), device=device)
2471    yield ErrorInput(SampleInput((y, z), kwargs={'out': z[:2, :]}),
2472                     error_regex=err_msg)
2473
2474    # error inputs for different devices
2475    if torch.device(device).type == 'cuda':
2476        x_cuda = make_tensor((3, 3), device=device, dtype=torch.float32)
2477        y_cpu = make_tensor((3, 3), device='cpu', dtype=torch.float32)
2478        yield ErrorInput(SampleInput((x_cuda, y_cpu)),
2479                         error_regex='Expected all tensors to be on the same device')
2480
2481    # error inputs for different input sizes for more than 2 tensors
2482    yield ErrorInput(SampleInput([make_arg((L, 1)), make_arg((L, 1, 1)), make_arg((L, 1, 1))]),
2483                     error_regex='Tensors must have same number of dimensions')
2484
2485    yield ErrorInput(SampleInput([make_arg((S, 1, M)), make_arg((S, 1, 1)), make_arg((S, M, 1))],
2486                                 kwargs={'dim': 1}),
2487                     error_regex='Sizes of tensors must match')
2488
2489    # error inputs for None input
2490    yield ErrorInput(SampleInput((make_arg((S, 1, 1)), None)), error_type=TypeError,
2491                     error_regex='got None')
2492
2493    # error inputs for zero-dimensional tensors
2494    yield ErrorInput(SampleInput([make_arg(()), make_arg(())]),
2495                     error_regex='zero-dimensional.*cannot be concatenated')
2496
2497    # error inputs for different dtype of out tensors
2498    d = make_tensor((2, 3), device=device, dtype=torch.double)
2499    x = make_tensor((2, 3), device=device, dtype=torch.float32)
2500    yield ErrorInput(SampleInput(x, kwargs={'out': d}), error_type=TypeError,
2501                     error_regex='invalid combination of arguments')
2502
2503def reference_inputs_cat(op, device, dtype, requires_grad, **kwargs):
2504    yield from sample_inputs_cat_concat(op, device, dtype, requires_grad, **kwargs)
2505
2506    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2507
2508    # Noncontiguous type promoting tensors
2509    a = make_arg((3, 4, 2))
2510    b = make_arg((3, 2, 2), noncontiguous=True, dtype=torch.double)
2511    c = make_arg((3, 3, 2), dtype=torch.float16).permute(1, 0, 2)
2512
2513    yield SampleInput((a, b, c), kwargs={'dim': 1})
2514
2515    # Special 1D tensor with dim length of 0 case
2516    a = make_arg((0,))
2517    b = make_arg((3, 2, 2))
2518
2519    yield SampleInput((a, b, a))
2520    yield SampleInput((a, a, a))
2521
2522def _elementwise_type_promo_np(*args, type_promotion_kind):
2523    def _maybe_torch(x):
2524        if isinstance(x, np.ndarray):
2525            return torch.from_numpy(x)
2526        return x
2527
2528    flattened = pytree.arg_tree_leaves(*args)
2529    transformed = tuple(_maybe_torch(a) for a in flattened)
2530    result_dtype, _ = prims.utils.elementwise_dtypes(
2531        *transformed,
2532        type_promotion_kind=type_promotion_kind)
2533    return torch_to_numpy_dtype_dict[result_dtype]
2534
2535def _cat_np(input_seq, dim=0):
2536    inputs = tuple(a for a in input_seq if not (a.ndim == 1 and a.size == 0))
2537
2538    if len(inputs) == 0:
2539        np_dtype = _elementwise_type_promo_np(
2540            input_seq,
2541            type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH)
2542        return np.empty(0, dtype=np_dtype)
2543
2544    return np.concatenate(inputs, axis=dim)
2545
2546def _floor_divide_np(a, b):
2547    dtype = _elementwise_type_promo_np(
2548        a,
2549        b,
2550        type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
2551    if isinstance(a, np.ndarray):
2552        a = a.astype(dtype)
2553    if isinstance(b, np.ndarray):
2554        b = b.astype(dtype)
2555    return np.floor_divide(a, b)
2556
2557def sample_inputs_hstack_dstack_vstack(op_info, device, dtype, requires_grad, **kwargs):
2558    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
2559    tensor_shapes = (
2560        # First Tensor being 1-D is special
2561        # case for hstack
2562        ((S,), (S,), (S,)),
2563        ((S, S), (S, S), (S, S)),
2564    )
2565    for s1, s2, s3 in tensor_shapes:
2566        tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3))
2567        yield SampleInput(tensors)
2568
2569def error_inputs_hstack_dstack_vstack(op, device):
2570    make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False)
2571    tensor_shapes = (
2572        ((S,), (S, S, S, S), (S,)),
2573    )
2574    for s1, s2, s3 in tensor_shapes:
2575        tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3))
2576        # Different dimension tensor
2577        yield ErrorInput(SampleInput(tensors), error_regex="Tensors must have same number of dimensions")
2578
2579    # empty tensor list
2580    yield ErrorInput(SampleInput(()), error_regex="expects a non-empty TensorList")
2581
2582def sample_inputs_unbind(op_info, device, dtype, requires_grad, **kwargs):
2583    # Note: we don't do any tests where we unbind along 0-length dims
2584    # because in that case unbind returns and empty tuple, and that breaks
2585    # some assumptions in some backward tests in test_ops.py
2586    shape_dims = (((S,), 0),
2587                  ((S, S), 0),
2588                  ((S, S), 1),
2589                  ((S, S), -1),
2590                  ((S, 0, S), 0),
2591                  ((S, S, S), 1),
2592                  )
2593    for shape, dim in shape_dims:
2594        yield SampleInput(make_tensor(shape, dtype=dtype, device=device,
2595                                      requires_grad=requires_grad),
2596                          args=(dim,))
2597
2598def error_inputs_unbind(op_info, device):
2599    make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False)
2600    yield ErrorInput(SampleInput(make_arg(()), args=(0,)), error_type=IndexError,
2601                     error_regex="Dimension specified as 0 but tensor has no dimensions")
2602    yield ErrorInput(SampleInput(make_arg((2,)), args=(2,)), error_type=IndexError,
2603                     error_regex="Dimension out of range")
2604
2605def reference_unbind(t, dim):
2606    """A numpy implementation of torch.unbind"""
2607    return tuple(s.squeeze(dim) for s in np.split(t, t.shape[dim], dim))
2608
2609def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs):
2610    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
2611    yield SampleInput(
2612        make_arg((M, S)),
2613        0,
2614        gather_variable((S, S), 1, M, True, device=device))
2615    yield SampleInput(
2616        make_arg((M, S)),
2617        1,
2618        gather_variable((M, S // 2), 0, S, True, device=device))
2619    # Empty index tensor case, see: https://github.com/pytorch/pytorch/pull/65006
2620    yield SampleInput(
2621        make_arg((S,)),
2622        0,
2623        torch.tensor([], dtype=torch.uint8, device=device))
2624    yield SampleInput(
2625        make_arg((S,)),
2626        0,
2627        torch.tensor([[], []], dtype=torch.uint8, device=device))
2628    # 0D tensor case
2629    yield SampleInput(
2630        make_arg(()),
2631        0,
2632        torch.tensor([0], dtype=torch.int64, device=device))
2633    yield SampleInput(
2634        make_arg(()),
2635        0,
2636        torch.tensor(0, dtype=torch.int64, device=device))
2637
2638def _fill_indices(idx, dim, dim_size, elems_per_row, m, n, o):
2639    for i in range(1 if dim == 0 else m):
2640        for j in range(1 if dim == 1 else n):
2641            for k in range(1 if dim == 2 else o):
2642                ii = [i, j, k]
2643                ii[dim] = slice(0, idx.size(dim) + 1)
2644                idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
2645
2646def error_inputs_gather(op_info, device, **kwargs):
2647    # src is [1, 2]
2648    #        [3, 4]
2649    src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
2650
2651    # idx is [0, 0]
2652    #        [1, 0]
2653    idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
2654
2655    # Index should be smaller than self except on dimension 1
2656    bad_src = make_tensor((1, 1), device=device, dtype=torch.float32)
2657    yield ErrorInput(SampleInput(bad_src, args=(1, idx,)),
2658                     error_regex="Size does not match at dimension 0")
2659
2660    # Index must have long dtype
2661    bad_idx = idx.to(torch.int32)
2662    yield ErrorInput(SampleInput(src, args=(1, bad_idx)),
2663                     error_regex="Expected dtype int64 for index")
2664
2665    # TODO: FIXME
2666    # out.dtype must match src.dtype
2667    # Creates new src & idx since SampleInputs can't share tensors
2668    src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
2669    idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
2670    out = torch.empty((2, 2), device=device, dtype=torch.float64)
2671    yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}),
2672                     error_regex="Expected out tensor to have dtype")
2673
2674    # src and index tensors must have the same # of dimensions
2675    # idx too few dimensions
2676    src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
2677    idx = torch.tensor((0, 0), device=device, dtype=torch.long)
2678    yield ErrorInput(SampleInput(src, args=(1, idx)),
2679                     error_regex="Index tensor must have the same number of dimensions")
2680
2681    # src too few dimensions
2682    src = torch.tensor((1, 2), device=device, dtype=torch.float32)
2683    idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
2684    yield ErrorInput(SampleInput(src, args=(0, idx)),
2685                     error_regex="Index tensor must have the same number of dimensions")
2686
2687    # index out of bounds
2688    # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices
2689    if torch.device(device).type == 'cpu':
2690        src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32)
2691        idx = torch.tensor(((0, 23), (1, 0)), device=device, dtype=torch.long)
2692        yield ErrorInput(SampleInput(src, args=(1, idx,)),
2693                         error_regex="index 23 is out of bounds for dimension")
2694
2695    x = torch.rand((1,), device=device).expand((3,))
2696    src = torch.rand((6,), device=device)
2697    ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64)
2698
2699    yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=x)),
2700                     error_type=RuntimeError,
2701                     error_regex='unsupported operation')
2702
2703    yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=src)),
2704                     error_type=RuntimeError,
2705                     error_regex='unsupported operation')
2706
2707    yield ErrorInput(SampleInput(ind.clone(), args=(0, ind[1:],), kwargs=dict(out=ind[:1])),
2708                     error_type=RuntimeError,
2709                     error_regex='unsupported operation')
2710
2711def error_inputs_take(op_info, device, **kwargs):
2712    x = torch.rand((1,), device=device).expand((3,))
2713    src = torch.rand((6,), device=device)
2714    ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64)
2715
2716    yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=x)),
2717                     error_type=RuntimeError,
2718                     error_regex='unsupported operation')
2719
2720    yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=src)),
2721                     error_type=RuntimeError,
2722                     error_regex='unsupported operation')
2723
2724    yield ErrorInput(SampleInput(ind.clone(), args=(ind[1:],), kwargs=dict(out=ind[:-1])),
2725                     error_type=RuntimeError,
2726                     error_regex='unsupported operation')
2727
2728# Error inputs for scatter
2729def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
2730    # Error when self.dtype != src.dtype (and src is not a scalar)
2731    src = make_tensor((2, 5), device=device, dtype=torch.float32)
2732    idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
2733    dst = torch.zeros((3, 5), device=device, dtype=torch.double)
2734    yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
2735                     error_regex="Expected self.dtype to be equal to src.dtype")
2736
2737    # Index dtype must be long
2738    src = make_tensor((2, 5), device=device, dtype=torch.float32)
2739    idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.int32)
2740    dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
2741    yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
2742                     error_regex="Expected dtype int64 for index")
2743
2744    # Index and destination must have the same number of dimensions
2745    src = make_tensor((2, 5), device=device, dtype=torch.float32)
2746    idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
2747    dst = torch.zeros((3, 5, 3), device=device, dtype=torch.float32)
2748    yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
2749                     error_regex="Index tensor must have the same number of dimensions as self tensor")
2750
2751    # Index and src must have the same number of dimensions when src is not a scalar
2752    src = make_tensor((2, 5, 2), device=device, dtype=torch.float32)
2753    idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
2754    dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
2755    yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
2756                     error_regex="Index tensor must have the same number of dimensions as src tensor")
2757
2758    # Index out of bounds
2759    # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices
2760    if torch.device(device).type == 'cpu':
2761        src = make_tensor((2, 5), device=device, dtype=torch.float32)
2762        idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long)
2763        dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
2764        yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
2765                         error_regex="index 34 is out of bounds for dimension 0 with size 3")
2766
2767def error_inputs_renorm(op_info, device, **kwargs):
2768    zero_d = torch.randn((), device=device)
2769    yield ErrorInput(SampleInput(zero_d, args=(0.5, 0, 1.0)), error_type=RuntimeError,
2770                     error_regex="needs at least 2 dimensions, got 0 dimensions")
2771
2772
2773def error_inputs_ormqr(op_info, device, **kwargs):
2774    zero_d = torch.randn((), device=device)
2775    yield ErrorInput(SampleInput(zero_d, args=(zero_d, zero_d)), error_type=RuntimeError,
2776                     error_regex="input must have at least 2 dimensions")
2777
2778    # https://github.com/pytorch/pytorch/issues/85218
2779    tensor_0 = torch.full((5, 0,), 1, device=device)
2780    tensor_1 = torch.full((5,), 1, device=device)
2781    tensor_2 = torch.full((5, 5,), 1, device=device)
2782    bool_3 = True
2783    bool_4 = True
2784    yield ErrorInput(SampleInput(tensor_0, args=(tensor_1, tensor_2, bool_3, bool_4)), error_type=RuntimeError,
2785                     error_regex=r"tau.shape\[-1\] must be less than or equal to input.shape\[-1\]")
2786
2787
2788def error_inputs_diag(op_info, device, **kwargs):
2789    zero_d = torch.randn((), device=device)
2790    yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError,
2791                     error_regex="1D or 2D")
2792    zero_d = torch.randn(1, 1, 1, device=device)
2793    yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError,
2794                     error_regex="1D or 2D")
2795
2796def error_inputs_embedding(op_info, device, **kwargs):
2797    indices = torch.rand(2, 2, device=device).long()
2798    weights = [
2799        torch.tensor(1.0, device=device),
2800        torch.tensor(1.0, device=device).reshape(1, 1, 1),
2801    ]
2802
2803    for weight in weights:
2804        yield ErrorInput(SampleInput(weight, args=(indices,)), error_type=RuntimeError,
2805                         error_regex="'weight' must be 2-D")
2806
2807
2808def error_inputs_t(op_info, device, **kwargs):
2809    yield ErrorInput(
2810        SampleInput(torch.randn(2, 3, 4, 5, device=device)),
2811        error_regex="expects a tensor with <= 2",
2812    )
2813
2814
2815def error_inputs_multinomial(op_info, device, **kwargs):
2816    x = torch.empty(1, 2, 3, dtype=torch.double, device=device)
2817    yield ErrorInput(SampleInput(x, args=(2,)),
2818                     error_regex="prob_dist must be 1 or 2 dim")
2819
2820    x = torch.empty(1, 2, dtype=torch.long, device=device)
2821    yield ErrorInput(SampleInput(x, args=(2,)),
2822                     error_regex="multinomial only supports floating-point dtypes for input")
2823
2824    x = torch.empty(1, 2, dtype=torch.double, device=device)
2825    y = torch.empty(1, 2, dtype=torch.double, device=device)
2826    yield ErrorInput(SampleInput(x, args=(2,), kwargs=dict(out=y)),
2827                     error_regex="multinomial expects Long tensor out")
2828
2829    x = torch.empty(2, dtype=torch.double, device=device)
2830    yield ErrorInput(SampleInput(x, args=(0,)),
2831                     error_regex="cannot sample n_sample <= 0 samples")
2832
2833    x = torch.empty(2, dtype=torch.double, device=device)
2834    yield ErrorInput(SampleInput(x, args=(-1,)),
2835                     error_regex="cannot sample n_sample <= 0 samples")
2836
2837    x = torch.empty(2, dtype=torch.double, device=device)
2838    yield ErrorInput(SampleInput(x, args=(3, False,)),
2839                     error_regex="cannot sample n_sample > prob_dist")
2840
2841    x = torch.empty(16777217, dtype=torch.double, device=device)
2842    yield ErrorInput(SampleInput(x, args=(3,)),
2843                     error_regex="number of categories cannot exceed")
2844
2845    inputs = ((1., -1., 1.), (1., inf, 1.), (1., -inf, 1.), (1., 1., nan))
2846
2847    err_msg1 = "probability tensor contains either `inf`, `nan` or element < 0"
2848    err_msg2 = "invalid multinomial distribution"
2849
2850    rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,)
2851
2852    if torch.device(device).type == 'cpu':
2853        for rep in rep_arg:
2854            kwargs = {'num_samples': 2, 'replacement': rep}
2855
2856            for shape in inputs:
2857                # error case when input tensor contains `inf`, `nan` or negative element
2858                yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs),
2859                                 error_regex=err_msg1 if rep is False else err_msg2)
2860
2861            # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input
2862            x = torch.zeros(3, device=device)
2863            yield ErrorInput(SampleInput(x, kwargs=kwargs),
2864                             error_regex=err_msg2)
2865
2866            # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input
2867            x = torch.zeros(3, 3, device=device)
2868            yield ErrorInput(SampleInput(x, kwargs=kwargs),
2869                             error_regex=err_msg2)
2870
2871            # error case for the invalid multinomial distribution
2872            x[1, :] = 1
2873            yield ErrorInput(SampleInput(x, kwargs=kwargs),
2874                             error_regex=err_msg2)
2875
2876def error_inputs_gradient(op_info, device, **kwargs):
2877    for dtype in [torch.long, torch.float32, torch.complex64]:
2878        t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device, dtype=dtype)
2879
2880        dim = (1, 0)
2881        spacing = [0.1]
2882        yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)),
2883                         error_type=RuntimeError,
2884                         error_regex='torch.gradient expected spacing to be unspecified, a scalar ')
2885
2886        yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=3)),
2887                         error_type=RuntimeError,
2888                         error_regex='torch.gradient only supports edge_order=1 and edge_order=2.')
2889
2890        dim = (1, 1)
2891        spacing = 0.1
2892        yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)),
2893                         error_type=RuntimeError,
2894                         error_regex='dim 1 appears multiple times in the list of dims')
2895
2896        dim = (0, 1)
2897        coordinates = [torch.tensor([1, 2, 4], device='cpu'), torch.tensor([1, 2, 4], device='meta')]
2898        yield ErrorInput(SampleInput(t, kwargs=dict(spacing=coordinates, dim=dim, edge_order=1)),
2899                         error_type=RuntimeError,
2900                         error_regex='torch.gradient expected each tensor to be on the same device,')
2901
2902        yield ErrorInput(SampleInput(t, kwargs=dict(dim=3)),
2903                         error_type=IndexError, error_regex='')
2904
2905        t = torch.tensor([[1], [2], [3]])
2906        yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=1)),
2907                         error_type=RuntimeError,
2908                         error_regex='torch.gradient expected each dimension size to be at least')
2909
2910        t = torch.tensor([[1, 2], [3, 4]])
2911        yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=2)),
2912                         error_type=RuntimeError,
2913                         error_regex='torch.gradient expected each dimension size to be at least')
2914
2915def sample_inputs_rrelu(op_info, device, dtype, requires_grad, **kwargs):
2916    yield from sample_inputs_elementwise_unary(
2917        op_info, device, dtype, requires_grad, op_kwargs=dict(lower=0., upper=1., training=True))
2918
2919    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2920    yield SampleInput(make_arg(S))
2921    yield SampleInput(make_arg(S), training=False)
2922
2923def error_inputs_rrelu(op_info, device, **kwargs):
2924    input = make_tensor((S, S), device=device, dtype=torch.float32)
2925    yield ErrorInput(SampleInput(input, kwargs={'lower': 0.3, 'upper': 0.1}),
2926                     error_regex='Lower bound should be less than or equal to the upper bound')
2927
2928def error_inputs_masked_select(op_info, device, **kwargs):
2929    x = torch.rand((1,), device=device).expand((3,))
2930    y = torch.rand((6,), device=device)
2931    mask = torch.tensor([True, False, True, True, False, False], device=device)
2932
2933    yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=x)),
2934                     error_type=RuntimeError,
2935                     error_regex='unsupported operation')
2936
2937    yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=y)),
2938                     error_type=RuntimeError,
2939                     error_regex='unsupported operation')
2940
2941    yield ErrorInput(SampleInput(mask.clone(), args=(mask,), kwargs=dict(out=mask)),
2942                     error_type=RuntimeError,
2943                     error_regex='unsupported operation')
2944
2945def error_inputs_median(op_info, device, **kwargs):
2946    x = torch.tensor([[[[[[[[[[[[[[[[[[[[[[[[[nan],
2947                               [nan]]]]]]]]]]]]]]]]]]]]]]]]], device=device)
2948    if device == 'cuda':
2949        yield ErrorInput(SampleInput(x, kwargs=dict(dim=(-1))),
2950                         error_type=RuntimeError,
2951                         error_regex='CUDA Tensors cannot have more than 25 dimensions')
2952    else:
2953        return
2954
2955
2956def error_inputs_index_select(op_info, device, **kwargs):
2957    x = torch.rand((1, 6), device=device).expand((2, 6))
2958    y = torch.rand((3, 6), device=device)
2959    ind = torch.tensor([0, 1], dtype=torch.int64, device=device)
2960
2961    yield ErrorInput(SampleInput(y, args=(1, ind,), kwargs=dict(out=x)),
2962                     error_type=RuntimeError,
2963                     error_regex='unsupported operation')
2964
2965def error_inputs_index_add(op_info, device, **kwargs):
2966    result = torch.tensor([[1., 2.], [4., 5.], [7., 8.]])
2967    source = torch.tensor([2., 4.])
2968
2969    yield ErrorInput(SampleInput(result, args=(0, torch.tensor([0, 2]), source)),
2970                     error_type=RuntimeError,
2971                     error_regex=r'source tensor shape must match self tensor shape, '
2972                     r'excluding the specified dimension. Got self.shape = \[3, 2\] source.shape = \[2\]')
2973
2974def error_inputs_logcumsumexp(op_info, device, **kwargs):
2975    dim = 3
2976    srcs = [torch.randn(5, 2, device=device), torch.randn(0, 2, device=device)]
2977    for src in srcs:
2978        yield ErrorInput(SampleInput(src, args=(dim,)),
2979                         error_type=IndexError,
2980                         error_regex='Dimension out of range')
2981
2982def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs):
2983    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
2984    yield SampleInput(
2985        make_arg((S, S)), gather_variable((S, S), 1, S, True, device=device), 0)
2986
2987    # `indices` broadcast
2988    yield SampleInput(
2989        make_arg((S, S)), gather_variable((1, S // 2), 0, S, True, device=device), 1)
2990
2991    # `self` broadcast
2992    yield SampleInput(
2993        make_arg((1, S)), gather_variable((S, S // 2), 0, S, True, device=device), 1)
2994
2995    # without `dim` arg
2996    yield SampleInput(
2997        make_arg((S, S)), gather_variable((S, S // 2), 0, S, True, device=device))
2998
2999
3000def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs):
3001
3002    # Error Inputs for zero-dim tensors, when 'dim' arg is not provided.
3003    shape = (S, 0, S)
3004    err_msg_amax_amin = "reduction"
3005    err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity"
3006    if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
3007        yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin)
3008    elif op_info.name in ['aminmax']:
3009        yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax)
3010
3011    # Error Inputs for tensors with more than 64 dimension
3012    sizes = [1] * 65
3013    err_msg1 = "only tensors with up to 64 dims are supported"
3014    yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': -1}),
3015                     error_regex=err_msg1)
3016    yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': 64}),
3017                     error_regex=err_msg1)
3018
3019    # Error Inputs for repeated 'dim'
3020    if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
3021        dims = [(0, 0), (0, -4)]
3022        err_msg2 = "in the list of dims"
3023        x = torch.randn(S, S, S, S, device=device)
3024        for dim in dims:
3025            yield ErrorInput(SampleInput(x, kwargs={'dim': dim}), error_regex=err_msg2)
3026
3027    # Error Input for illegal dtype
3028    input5 = torch.randn(L, L, dtype=torch.float32, device=device)
3029    max_values = torch.empty(L, dtype=torch.float32, device=device)
3030    min_values = torch.empty(L, dtype=torch.double, device=device)
3031    illegal_values = torch.empty(L, dtype=torch.int, device=device)
3032
3033    # Unlike regular PyTorch, amax and amin refs don't require input and out
3034    # dtypes to match exactly:
3035    # https://github.com/pytorch/pytorch/pull/87765#pullrequestreview-1162023824
3036    if is_ref:
3037        err_msg_amax_amin2 = ("Attempting to cast from torch.float32 to out tensor with dtype "
3038                              "torch.int32, but this can't be cast because it is not safe!")
3039    else:
3040        err_msg_amax_amin2 = ("Expected the dtype for input and out to match, but got Float "
3041                              "for input's dtype and Int for out's dtype.")
3042    err_msg_aminmax2 = "Expected out tensor to have dtype float, but got double instead"
3043
3044    if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']:
3045        yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}),
3046                         error_regex=err_msg_amax_amin2)
3047    elif op_info.name in ['aminmax']:
3048        yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}),
3049                         error_regex=err_msg_aminmax2)
3050
3051    # Error Inputs for functions to raise an error on specified zero'd dimension as reduction dim
3052    err_msg3 = "reduction"
3053    # FIXME: eager and ref impl throw different types of errors
3054    error_type = IndexError if 'refs' not in op_info.name else RuntimeError
3055    yield ErrorInput(SampleInput(torch.rand(shape, device=device), kwargs={'dim': 1}),
3056                     error_type=error_type, error_regex=err_msg3)
3057
3058def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs):
3059    test_cases: Tuple[tuple, dict] = (  # type: ignore[assignment]
3060        ((S, S, S), {}),
3061        ((S, S, S), {'dim': 1}),
3062        ((S, S, S), {'dim': 1, 'keepdim': True}),
3063        ((), {'dim': 0}),
3064        ((), {}),
3065        ((), {'dim': 0, 'keepdim': True}),
3066        ((S, 0, S), {'dim': 0}),
3067    )
3068
3069    for shape, kwargs in test_cases:
3070        yield SampleInput(
3071            make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad),
3072            **kwargs)
3073
3074def error_inputs_diff(op_info, device, **kwargs):
3075    t = torch.rand((1, 3), device=device)
3076    n = -1
3077    yield ErrorInput(SampleInput(t, args=(n, ), kwargs=kwargs),
3078                     error_type=RuntimeError,
3079                     error_regex=f'order must be non-negative but got {n}')
3080
3081def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs):
3082    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
3083
3084    test_cases = (
3085        ((1,), 0, None, None),
3086        ((S,), 0, None, None),
3087        ((S, 1), 0, None, None),
3088        ((S, 1), 1, None, None),
3089        ((S, S), 0, None, None),
3090        ((S, S), 1, None, None),
3091        ((S, S), 0, (1, S), (2, S)),
3092        ((S, S), 0, None, (2, S)),
3093        ((XS, XS, XS), 1, None, None),
3094        ((XS, XS, XS), 2, None, None),
3095        ((XS, XS, XS), 1, (XS, 1, XS), (XS, 1, XS)),
3096        ((XS, XS, XS), 2, (XS, XS, 1), (XS, XS, 1)),
3097        ((XS, XS, XS), 2, (XS, XS, XS), (XS, XS, XS)),)
3098
3099    sample_inputs = []
3100    for size, dim, size_prepend, size_append in test_cases:
3101        prepend_size = 0 if (size_prepend is None) else size_prepend[dim]
3102        append_size = 0 if (size_append is None) else size_append[dim]
3103        dim_size = size[dim] + prepend_size + append_size
3104        for n in range(dim_size):
3105            input_tensor = make_arg(size)
3106            prepend = make_arg(size_prepend) if size_prepend else None
3107            append = make_arg(size_append) if size_append else None
3108            yield SampleInput(input_tensor, n, dim, prepend, append)
3109
3110    # add some samples with n > dim_size
3111    yield SampleInput(make_arg((XS, XS, XS)), S + 1, 1)
3112    yield SampleInput(make_arg((XS, XS, XS)), S * 3 + 2, 2, make_arg((XS, XS, XS)), make_arg((XS, XS, XS)))
3113
3114def sample_inputs_histogram(op_info, device, dtype, requires_grad, **kwargs):
3115    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
3116
3117    sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
3118
3119    for size, bin_ct, weighted, density in product(sizes, range(1, 5), [False, True], [False, True]):
3120        input_tensor = make_arg(size)
3121        weight_tensor = make_arg(size) if weighted else None
3122
3123        yield SampleInput(input_tensor, bin_ct,
3124                          weight=weight_tensor, density=density)
3125
3126        bins_tensor = make_arg((bin_ct + 1,))
3127        sorted_bins, bins_indices = torch.sort(bins_tensor)
3128        yield SampleInput(input_tensor, sorted_bins,
3129                          weight=weight_tensor, density=density)
3130
3131def sample_inputs_histogramdd(op_info, device, dtype, requires_grad, **kwargs):
3132    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
3133
3134    sizes = ((S, S), (S, S, S), (S, 1, S), (S, 0, S))
3135    bin_ct_patterns = ((1, 1, 1, 1, 1), (2, 3, 2, 3, 2), (3, 2, 3, 2, 3))
3136
3137    for size, bin_ct_pattern, weighted, density in product(sizes, bin_ct_patterns, [False, True], [False, True]):
3138        input_tensor = make_arg(size)
3139        bin_ct = bin_ct_pattern[:size[-1]]
3140        weight_tensor = make_arg(size[:-1]) if weighted else None
3141
3142        yield SampleInput(input_tensor, bin_ct,
3143                          weight=weight_tensor, density=density)
3144
3145        bins_tensor = [make_arg(ct + 1) for ct in bin_ct]
3146        yield SampleInput(input_tensor, bins_tensor,
3147                          weight=weight_tensor, density=density)
3148
3149def error_inputs_histogramdd(opinfo, device, **kwargs):
3150    invalid_bins = [1, 1, 1, 1, 1]
3151    make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False)
3152    msg = "histogramdd: The size of bins must be equal to the innermost dimension of the input."
3153    yield ErrorInput(SampleInput(make_arg(5, 6), invalid_bins), error_regex=msg)
3154
3155def sample_inputs_histc(op_info, device, dtype, requires_grad, **kwargs):
3156    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
3157
3158    sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
3159
3160    for size, min, max in product(sizes, [0, -10], [0, 10]):
3161        # construct sample input omitting bins arg
3162        yield SampleInput(make_arg(size), min=min, max=max)
3163
3164        # construct sample inputs with a few different bins values
3165        for bins in [1, 3, 10]:
3166            yield SampleInput(make_arg(size), bins=bins, min=min, max=max)
3167
3168def sample_inputs_bincount(op_info, device, dtype, requires_grad, **kwargs):
3169    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
3170
3171    for size, weighted in product((S, M), [False, True]):
3172        input_tensor = torch.randint(0, size, (size,), dtype=dtype, device=device)
3173        weight_tensor = make_arg((size,)) if weighted else None
3174
3175        max_val = int(input_tensor.max().item())
3176
3177        for minlength in [0, max_val // 2, max_val, 2 * max_val]:
3178            yield SampleInput(
3179                input_tensor, weights=weight_tensor, minlength=minlength)
3180
3181def sample_inputs_bucketize(op_info, device, dtype, requires_grad, reference_inputs_mode=False, **kwargs):
3182    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
3183
3184    sizes = (((), S), ((S,), S), ((S, S), S), ((S, S, S), S), ((S, 1, S), S), ((S, 0, S), S))
3185
3186    if reference_inputs_mode:
3187        sizes += (((256,), 128), ((128,), 256), ((32, 32), 11), ((32, 4, 32), 33))
3188
3189    for (input_shape, nb), out_int32, right in product(sizes, [False, True], [False, True]):
3190        input_tensor = make_arg(input_shape)
3191        boundaries = make_arg(nb).msort()
3192
3193        yield SampleInput(input_tensor, boundaries,
3194                          out_int32=out_int32, right=right)
3195
3196reference_inputs_bucketize = partial(sample_inputs_bucketize, reference_inputs_mode=True)
3197
3198def error_inputs_bucketize(opinfo, device, **kwargs):
3199    make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False)
3200    yield ErrorInput(SampleInput(make_arg((S, S, S)), make_arg((S, S))),
3201                     error_regex="boundaries tensor must be 1 dimension")
3202
3203def sample_inputs_searchsorted(op_info, device, dtype, requires_grad, **kwargs):
3204    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
3205
3206    # (unsorted tensor size, (input sizes,), is_scalar)
3207    sizes = (
3208        ((0,), ((0,),), False),
3209        ((M,), ((), (M,), (M, M)), False),
3210        ((0, 0), ((0, 0),), False),
3211        ((M, M), ((M, M),), False),
3212        ((0, 0, 0), ((0, 0, 0),), False),
3213        ((M, M, M), ((M, M, M),), False),
3214        ((L,), ((),), True),
3215    )
3216
3217    for (size, input_sizes, is_scalar), noncontiguous, out_int32, right in product(
3218        sizes, [False, True], [False, True], [False, True]
3219    ):
3220        unsorted_tensor = make_arg(size, noncontiguous=noncontiguous)
3221        for input_size in input_sizes:
3222            input = make_arg(input_size, noncontiguous=noncontiguous)
3223            if is_scalar:
3224                input = input.item()
3225            if np.prod(size) == 0:
3226                boundary_tensor = unsorted_tensor
3227                sorter = make_tensor(size, dtype=torch.int64, device=device, noncontiguous=noncontiguous)
3228            else:
3229                boundary_tensor, sorter = torch.sort(unsorted_tensor)
3230            side = "right" if right else "left"
3231
3232            yield SampleInput(boundary_tensor, input, out_int32=out_int32, right=right)
3233            yield SampleInput(boundary_tensor, input, out_int32=out_int32, side=side)
3234
3235            yield SampleInput(unsorted_tensor, input, out_int32=out_int32, right=right, sorter=sorter)
3236            yield SampleInput(unsorted_tensor, input, out_int32=out_int32, side=side, sorter=sorter)
3237
3238def sample_inputs_gradient(op_info, device, dtype, requires_grad, **kwargs):
3239    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
3240    test_cases_float = (
3241        ((S,), None, None, 1),
3242        ((S,), 2., None, 1),
3243        ((S, S), None, None, 2),
3244        ((S, S), [2.0, 2.1], None, 1),
3245        ((S, S), [2.0, 2.1], (0, 1), 1),
3246        ((4, 4, 4), [2., 1.], (0, 1), 2),
3247    )
3248    for size, spacing, dim, edge_order in test_cases_float:
3249        t = make_arg(size)
3250        yield SampleInput(t, dim=dim, spacing=spacing, edge_order=edge_order)
3251
3252    test_cases_tensor = (
3253        ((3, 3, 3), ((1.1, 2.0, 3.5), (4.0, 2, 6.0)), (0, -1), 1),
3254        ((3, 3, 3), ((1.0, 3.0, 2.0), (8.0, 6.0, 1.0)), (0, 1), 2),
3255    )
3256    for size, coordinates, dim, edge_order in test_cases_tensor:
3257        t = make_arg(size)
3258        coordinates_tensor_list = []
3259        for coords in coordinates:
3260            # `coords` will always contain floating point values and Python 3.10 does not support this
3261            # implicit conversion to an integer using `__int__`
3262            # TODO: this can be simplified after https://github.com/pytorch/pytorch/issues/69316 is fixed
3263            a = torch.tensor(coords, device=device)
3264            coordinates_tensor_list.append(a.to(dtype))
3265        yield SampleInput(t, dim=dim, spacing=coordinates_tensor_list, edge_order=edge_order)
3266
3267def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs):
3268    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
3269    test_args = [
3270        ([1, 2],),
3271        (slice(0, 3),),
3272        ([slice(0, 3), 1],),
3273        ([[0, 2, 3], [1, 3, 3], [0, 0, 2]],),
3274        ([[0, 0, 3], [1, 1, 3], [0, 0, 2]],),
3275        ([slice(None), slice(None), [0, 3]],),
3276        ([slice(None), [0, 3], slice(None)],),
3277        ([[0, 3], slice(None), slice(None)],),
3278        ([[0, 3], [1, 2], slice(None)],),
3279        ([[0, 3], ],),
3280        ([[0, 3], slice(None)],),
3281        ([[0, 3], Ellipsis],),
3282        ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],),
3283        (index_variable(2, S, device=device),),
3284        (mask_not_all_zeros((S,)),),
3285    ]
3286
3287    for args in test_args:
3288        yield SampleInput(make_arg((S, S, S)), args=args)
3289
3290    yield SampleInput(make_arg((S, S, S, S)), args=([slice(None), [0, 1], slice(None), [0, 1]],))
3291
3292def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
3293    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
3294
3295    for accumulate in [False, True]:
3296        # Test with indices arg
3297        yield SampleInput(
3298            make_arg((S, S,)),
3299            (index_variable(2, S, device=device),),
3300            make_arg((2, S)),
3301            accumulate=accumulate)
3302
3303        # Test with mask arg
3304        mask = torch.zeros(S, dtype=torch.bool) if accumulate else mask_not_all_zeros((S,))
3305        yield SampleInput(
3306            make_arg((S, S)), (mask, ), make_arg((S,)), accumulate=accumulate)
3307
3308def sample_inputs_sort(op_info, device, dtype, requires_grad, **kwargs):
3309    def small_3d_unique():
3310        res = torch.randperm(S * S * S, dtype=torch.int64, device=device).view(S, S, S)
3311        res = res.to(dtype).requires_grad_(requires_grad)
3312        return res
3313
3314    def large_1d_unique():
3315        res = torch.randperm(L * L * L, dtype=torch.int64, device=device)
3316        res = res.to(dtype).requires_grad_(requires_grad)
3317        return res
3318
3319    # Test case for large tensor.
3320    yield SampleInput(large_1d_unique())
3321
3322    # Test cases for small 3d tensors.
3323    # Imitates legacy tests from test/test_torch.py
3324    dims = range(-3, 3)
3325    flag = [True, False]
3326    for dim, descending, stable in product(dims, flag, flag):
3327        # default schema without stable sort
3328        yield SampleInput(small_3d_unique(), dim, descending)
3329        # schema with stable sort, no CUDA support yet
3330        if torch.device(device).type == 'cpu':
3331            yield SampleInput(
3332                small_3d_unique(), dim=dim, descending=descending, stable=stable)
3333
3334    # Test cases for scalar tensor
3335    tensor_opt = dict(dtype=dtype, device=device, requires_grad=requires_grad)
3336    yield SampleInput(torch.tensor(1, **tensor_opt))
3337    yield SampleInput(torch.tensor(1, **tensor_opt), 0)
3338    yield SampleInput(torch.tensor(1, **tensor_opt), 0, True)
3339
3340    # Test cases for empty tensor
3341    yield SampleInput(torch.tensor((), **tensor_opt))
3342    yield SampleInput(torch.tensor((), **tensor_opt), 0)
3343    yield SampleInput(torch.tensor((), **tensor_opt), 0, True)
3344
3345    # Test cases for stable sort
3346    yield SampleInput(small_3d_unique(), stable=True)
3347    yield SampleInput(small_3d_unique(), dim=0, stable=True)
3348    yield SampleInput(small_3d_unique(), dim=0, descending=True, stable=True)
3349
3350def sample_inputs_threshold(op_info, device, dtype, requires_grad, **kwargs):
3351    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
3352    sizes = ((), (S,), (S, S), (S, S, S))
3353    for x_size in sizes:
3354        # threshold and values args must be numbers
3355        yield SampleInput(make_arg(x_size), make_arg(()).item(), make_arg(()).item())
3356
3357def sample_inputs_unique(op_info, device, dtype, requires_grad, **kwargs):
3358    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3359    sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
3360
3361    for shape, sorted, return_inverse, return_counts, dim in \
3362            product(sizes, [False, True], [False, True], [False, True], [None, -2, -1, 0, 1, 2]):
3363        # torch.unique cannot be called if the input tensor has a zero dimension which isn't the selected dim
3364        if 0 in shape and shape.index(0) is not dim:
3365            continue
3366
3367        # skip invalid dim args
3368        if dim is not None and (dim < -len(shape) or dim >= len(shape)):
3369            continue
3370
3371        kwargs = dict(sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
3372
3373        # construct a test case with only one distinct value
3374        input_t = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad)
3375        yield SampleInput(input_t, **kwargs)
3376
3377        # construct a test case with mixed 0s and 1s
3378        input_t = make_arg(shape, dtype=torch.bool, requires_grad=False)\
3379            .to(dtype).requires_grad_(requires_grad)
3380        yield SampleInput(input_t, **kwargs)
3381
3382        # construct a test case with many different values
3383        yield SampleInput(make_arg(shape), **kwargs)
3384
3385def sample_inputs_unique_consecutive(*args, **kwargs):
3386    for sample_input in sample_inputs_unique(*args, **kwargs):
3387        if not sample_input.kwargs["sorted"]:
3388            sample_input.kwargs.pop("sorted")
3389            yield sample_input
3390
3391def sample_inputs_adaptive_avg_pool1d(op_info, device, dtype, requires_grad, **kwargs):
3392    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3393
3394    # Ordered as (input shape, output size)
3395    cases = (
3396        ((0, 8, 8), (5,)),
3397        ((3, 8, 8), 5),
3398        ((3, 8, 8), 1)
3399    )
3400
3401    for input_shape, output_size in cases:
3402        # Batched
3403        yield SampleInput(make_arg(input_shape), args=(output_size,))
3404        # Unbatched
3405        yield SampleInput(make_arg(input_shape[1:]), args=(output_size,))
3406
3407
3408def error_inputs_adaptive_avg_pool1d(opinfo, device, **kwargs):
3409    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
3410
3411    # error inputs for empty output
3412    yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()),
3413                     error_regex="'output_size' should contain one int")
3414
3415    # error inputs for output_size lesser than 0
3416    yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)),
3417                     error_regex="elements of output_size must be greater than or equal to 0")
3418
3419
3420def sample_inputs_adaptive_avg_pool2d(op_info, device, dtype, requires_grad, **kwargs):
3421    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3422
3423    # Ordered as (input shape, output size)
3424    cases = (
3425        ((1, 8, 8, 8), (5, 7)),
3426        ((2, 8, 8, 8), (None, 7)),
3427        ((1, 8, 4, 3), (5, None)),
3428        ((1, 8, 4, 3), (None, None)),
3429        ((1, 8, 4, 3), (5)),
3430    )
3431
3432    for input_shape, output_size in cases:
3433        # Batched
3434        yield SampleInput(make_arg(input_shape), args=(output_size,))
3435        # Unbatched
3436        yield SampleInput(make_arg(input_shape[1:]), args=(output_size,))
3437
3438
3439def error_inputs_adaptive_avg_pool2d(opinfo, device, **kwargs):
3440    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
3441
3442    # error inputs for incorrect input dimension
3443    yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)),
3444                     error_type=ValueError, error_regex="Input dimension should be at least 3")
3445
3446    # error inputs for empty output
3447    yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
3448                     error_regex="output_size must be 2")
3449
3450    # error inputs for output_size lesser than 0
3451    yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)),
3452                     error_regex="elements of output_size must be greater than or equal to 0")
3453
3454
3455def sample_inputs_adaptive_avg_pool3d(op_info, device, dtype, requires_grad, **kwargs):
3456    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3457
3458    # Ordered as (input shape, output size)
3459    cases = (
3460        ((0, 8, 8, 8, 8), (5, 7, 4)),
3461        ((1, 8, 4, 3, 7), (None, None, None)),
3462        ((1, 8, 4, 3, 7), (1, 1, 1)),
3463        ((3, 3, 8, 8, 6), (5, 7, None)),
3464        ((1, 3, 8, 8, 6), (5, None, 2)),
3465        ((3, 3, 8, 8, 6), (None, 3, 2)),
3466    )
3467
3468    for input_shape, output_size in cases:
3469        # Batched
3470        yield SampleInput(make_arg(input_shape), args=(output_size,))
3471        # Unbatched
3472        yield SampleInput(make_arg(input_shape[1:]), args=(output_size,))
3473
3474
3475def error_inputs_adaptive_avg_pool3d(opinfo, device, **kwargs):
3476    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
3477
3478    # error inputs for incorrect input dimension
3479    yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)),
3480                     error_type=ValueError, error_regex="Input dimension should be at least 4")
3481
3482    # error inputs for empty output
3483    yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
3484                     error_regex="output_size must be 3")
3485
3486    # error inputs for output_size lesser than 0
3487    yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)),
3488                     error_regex="elements of output_size must be greater than or equal to 0")
3489
3490
3491def sample_inputs_adaptive_max_pool1d(op_info, device, dtype, requires_grad, **kwargs):
3492    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3493
3494    # Ordered as (input shape, output size)
3495    cases = (
3496        # ((0, 8, 8), (5,)),
3497        # 0 batch size doesn't work,  cannot reshape tensor of 0 elements into shape [0, 8, -1]
3498        ((3, 4, 4), 3),
3499        ((3, 4, 4), 1)
3500    )
3501
3502    for shapes, return_idx in product(cases, (True, False)):
3503        # Batched
3504        yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx))
3505        # Unbatched
3506        yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx))
3507
3508
3509def error_inputs_adaptive_max_pool1d(opinfo, device, **kwargs):
3510    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
3511
3512    # error inputs for empty output
3513    yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()),
3514                     error_regex="'output_size' should contain one int")
3515
3516    # error inputs for output_size lesser than 0
3517    yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)),
3518                     error_regex="Trying to create tensor with negative dimension")
3519
3520def sample_inputs_adaptive_max_pool2d(op_info, device, dtype, requires_grad, **kwargs):
3521    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3522
3523    # Ordered as (input shape, output size)
3524    cases = (
3525        # ((0, 8, 8, 8), (5, 7)),
3526        # 0 batch size doesn't work,  cannot reshape tensor of 0 elements into shape [0, 8, -1]
3527        ((1, 4, 4, 4), (2, 3)),
3528        ((2, 4, 4, 4), (None, 3)),
3529        ((2, 4, 4, 4), (1, 1)),
3530        ((1, 4, 4, 3), (3, None)),
3531        ((1, 4, 4, 3), (None, None)),
3532        ((1, 4, 4, 3), (3)),
3533    )
3534
3535    for shapes, return_idx in product(cases, (True, False)):
3536        # Batched
3537        yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx))
3538        # Unbatched
3539        yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx))
3540
3541def error_inputs_adaptive_max_pool2d(opinfo, device, **kwargs):
3542    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
3543
3544    # error inputs for incorrect input dimension
3545    yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)),
3546                     error_type=ValueError, error_regex="Input dimension should be at least 3")
3547
3548    # error inputs for empty output
3549    yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
3550                     error_regex="internal error")
3551
3552    # error inputs for output_size lesser than 0
3553    yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)),
3554                     error_regex="Trying to create tensor with negative dimension")
3555
3556
3557def sample_inputs_adaptive_max_pool3d(op_info, device, dtype, requires_grad, **kwargs):
3558    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3559
3560    # Ordered as (input shape, output size)
3561    cases = (
3562        # ((0, 8, 8, 8, 8), (5, 7, 4)),
3563        # 0 batch size doesn't work,  cannot reshape tensor of 0 elements into shape [0, 8, -1]
3564        ((1, 4, 4, 3, 5), (None, None, None)),
3565        ((1, 4, 4, 3, 5), (1, 1, 1)),
3566        ((3, 3, 4, 4, 6), (2, 3, None)),
3567        ((1, 3, 4, 4, 6), (3, None, 2)),
3568        ((3, 3, 4, 4, 6), (None, 3, 2)),
3569    )
3570
3571    for shapes, return_idx in product(cases, (True, False)):
3572        # Batched
3573        yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx))
3574        # Unbatched
3575        yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx))
3576
3577def error_inputs_adaptive_max_pool3d(opinfo, device, **kwargs):
3578    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
3579
3580    # error inputs for incorrect input dimension
3581    yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)),
3582                     error_type=ValueError, error_regex="Input dimension should be at least 4")
3583
3584    # error inputs for empty output
3585    yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()),
3586                     error_regex="internal error")
3587
3588    # error inputs for output_size lesser than 0
3589    yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)),
3590                     error_regex="Trying to create tensor with negative dimension")
3591
3592
3593class _TestParamsMaxPoolBase:
3594
3595    def __init__(self) -> None:
3596        self.kwargs = {
3597            'kernel_size': [3],
3598            'stride': [2, None],
3599            'ceil_mode': [True, False],
3600            'padding': [0, 1],
3601            'dilation': [1],
3602            'return_indices': [True, False]
3603        }
3604
3605        self.shapes = [
3606            [1, 2, None],  # batch
3607            [2],  # channels
3608            [3, 6]  # signal
3609        ]
3610
3611    def _gen_shape(self):
3612        for shape in product(*self.shapes):
3613            # shape[0] is None indicates missing batch dimension
3614            if shape[0] is None:
3615                shape = shape[1:]
3616
3617            yield shape, torch.contiguous_format
3618            # only 2d (N, C, H, W) rank 4 tensors support channels_last memory format
3619            if len(self.shapes) == 4 and len(shape) == 4:
3620                yield shape, torch.channels_last
3621
3622    def _gen_kwargs(self):
3623        keys = self.kwargs.keys()
3624        for values in product(*self.kwargs.values()):
3625            yield dict(zip(keys, values))
3626
3627    def gen_input_params(self):
3628        yield from product(self._gen_shape(), self._gen_kwargs())
3629
3630class _TestParamsMaxPool1d(_TestParamsMaxPoolBase):
3631
3632    def __init__(self) -> None:
3633        super().__init__()
3634        self.kwargs['kernel_size'] += [(3,)]
3635        self.kwargs['stride'] += [(2,)]
3636        self.kwargs['padding'] += [(1,)]
3637        self.kwargs['dilation'] += [(1,)]
3638
3639class _TestParamsMaxPool2d(_TestParamsMaxPoolBase):
3640
3641    def __init__(self) -> None:
3642        super().__init__()
3643        self.kwargs['kernel_size'] += [(3, 2)]
3644        self.kwargs['stride'] += [(2, 1)]
3645        self.kwargs['padding'] += [(1, 1)]
3646        self.kwargs['dilation'] += [(1, 2)]
3647
3648        self.shapes.append([6])
3649
3650class _TestParamsMaxPool3d(_TestParamsMaxPoolBase):
3651
3652    def __init__(self) -> None:
3653        super().__init__()
3654        self.kwargs['kernel_size'] += [(3, 2, 3)]
3655        self.kwargs['stride'] += [(2, 1, 2)]
3656        self.kwargs['dilation'] += [(1, 2, 1)]
3657
3658        self.shapes.append([6])
3659        self.shapes.append([5])
3660
3661def sample_inputs_max_pool(op_info, device, dtype, requires_grad, **kwargs):
3662    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
3663
3664    params_generator_type_dict = {
3665        'nn.functional.max_pool1d': _TestParamsMaxPool1d,
3666        'nn.functional.max_pool2d': _TestParamsMaxPool2d,
3667        'nn.functional.max_pool3d': _TestParamsMaxPool3d,
3668        'max_pool2d_with_indices_backward': _TestParamsMaxPool2d,
3669    }
3670
3671    params_generator = params_generator_type_dict[op_info.name]()
3672    for (shape, memory_format), kwargs in params_generator.gen_input_params():
3673        arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad)
3674        yield SampleInput(arg, kwargs=kwargs)
3675
3676def max_pool2d_backward(*args, kernel_size=(), stride=(), padding=(0,), dilation=(1,), ceil_mode=False, **kwargs):
3677    out, indices = torch.nn.functional.max_pool2d_with_indices(
3678        *args, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, return_indices=True)
3679    grad_out = torch.ones_like(out)
3680    if stride is None:
3681        stride = kernel_size
3682    out_b = torch.ops.aten.max_pool2d_with_indices_backward.default(
3683        grad_out, *args, kernel_size, stride, padding, dilation, ceil_mode, indices)
3684    return out_b
3685
3686def error_inputs_max_pool1d(op_info, device, **kwargs):
3687    # Toggle requires_grad because `max_pool1d` has different path
3688    # based on whether `requires_grad` is set or not.
3689    for requires_grad in (True, False):
3690        make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=requires_grad)
3691        # error inputs when pad is negative
3692        x = make_arg((0, 1, 49))
3693        yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}),
3694                         error_regex='pad must be non-negative')
3695
3696        # error inputs when pad > kernel_size / 2
3697        yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}),
3698                         error_regex='pad should be at most half of effective kernel size')
3699
3700        # error inputs when pad > ((kernel_size - 1) * dilation + 1) / 2, when dilation is not default
3701        yield ErrorInput(SampleInput(x,
3702                         kwargs={'kernel_size': 3, 'dilation': 2, 'stride': 1, 'padding': 3, 'return_indices': True}),
3703                         error_regex='pad should be at most half of effective kernel size')
3704
3705        # error inputs for input tensor
3706        error_msg = r'Expected 2D or 3D \(batch mode\) tensor with optional 0 dim batch size for input'
3707        yield ErrorInput(SampleInput(make_arg((), requires_grad=requires_grad), kwargs={'kernel_size': 1}),
3708                         error_regex=error_msg)
3709
3710        # error inputs for empty input
3711        yield ErrorInput(SampleInput(torch.tensor([], device=device, requires_grad=requires_grad),
3712                                     kwargs={'kernel_size': 1}),
3713                         error_regex=error_msg)
3714
3715        # error: unbatched input with 0 sized non-batch dims.
3716        yield ErrorInput(SampleInput(make_arg((0, 10), requires_grad=requires_grad),
3717                                     kwargs={'kernel_size': 1}),
3718                         error_regex=error_msg)
3719
3720        # error: batched input with 0 sized non-batch dims.
3721        yield ErrorInput(SampleInput(make_arg((1, 10, 0), requires_grad=requires_grad),
3722                                     kwargs={'kernel_size': 1}),
3723                         error_regex=error_msg)
3724
3725        # error inputs for empty input with stride=0
3726        error_msg = 'stride must be greater than zero, but got 0'
3727        yield ErrorInput(SampleInput(make_arg((3, 3, 3)), kwargs={'kernel_size': 1, 'stride': 0}),
3728                         error_regex=error_msg)
3729
3730        # error inputs for empty input with dilation=0
3731        error_msg = 'dilation must be greater than zero, but got 0'
3732        yield ErrorInput(SampleInput(make_arg((3, 3, 3)),
3733                                     kwargs={'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 0}),
3734                         error_regex=error_msg)
3735
3736        # error inputs for invalid output size
3737        error_msg = 'Invalid computed output size: -2'
3738        yield ErrorInput(SampleInput(make_arg((2, 2, 2)),
3739                                     kwargs={'kernel_size': 5, 'stride': 1, 'padding': 0, 'dilation': 1}),
3740                         error_regex=error_msg)
3741
3742        # error inputs when kernel_size=0
3743        error_msg = 'kernel_size must be greater than zero'
3744        yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 0}),
3745                         error_regex=error_msg)
3746
3747        # error inputs for strides > 0
3748        error_msg = 'stride must be greater than zero'
3749        yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 0}),
3750                         error_regex=error_msg)
3751
3752
3753def error_inputs_max_pool2d(op_info, device, **kwargs):
3754    make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
3755    # error inputs when pad is negative
3756    x = make_arg((0, 1, 49))
3757    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}),
3758                     error_regex='pad must be non-negative')
3759    # 2-dimensional kernel
3760    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1, 'return_indices': True}),
3761                     error_regex='pad must be non-negative')
3762
3763    # error inputs when pad > kernel_size / 2 (kernel_size : int)
3764    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}),
3765                     error_regex='pad should be at most half of effective kernel size')
3766
3767    # error inputs when pad > kernel_size / 2 (kernel_size : tuple)
3768    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4, 'return_indices': True}),
3769                     error_regex='pad should be at most half of effective kernel size')
3770
3771    # error: unbatched input with 0 sized non-batch dims.
3772    err_msg = r'Expected 3D or 4D \(batch mode\) tensor with optional 0 dim batch size for input'
3773    yield ErrorInput(SampleInput(make_arg((1, 0, 10)),
3774                                 kwargs={'kernel_size': 1}),
3775                     error_regex=err_msg)
3776
3777    # error: batched input with 0 sized non-batch dims.
3778    yield ErrorInput(SampleInput(make_arg((2, 1, 10, 0)),
3779                                 kwargs={'kernel_size': 1}),
3780                     error_regex=err_msg)
3781
3782
3783def error_inputs_max_pool3d(op_info, device, **kwargs):
3784    make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
3785    # error inputs when pad is negative
3786    x = make_arg((0, 1, 49, 50))
3787    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}),
3788                     error_regex='pad must be non-negative')
3789    # 3-dimensional kernel
3790    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50,
3791                                            'padding': -1, 'return_indices': True}),
3792                     error_regex='pad must be non-negative')
3793
3794    # error inputs when pad > kernel_size / 2 (kernel_size: int)
3795    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}),
3796                     error_regex='pad should be at most half of effective kernel size')
3797
3798    # error inputs when pad > kernel_size / 2 (kernel_size: tuple)
3799    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50,
3800                                            'padding': 4, 'return_indices': True}),
3801                     error_regex='pad should be at most half of effective kernel size')
3802
3803    # error: unbatched input with 0 sized non-batch dims.
3804    err_msg = r'Expected input\'s non-batch dimensions to have positive length'
3805    yield ErrorInput(SampleInput(make_arg((0, 1, 2, 10)),
3806                                 kwargs={'kernel_size': 1}),
3807                     error_regex=err_msg)
3808
3809    # error: batched inputs with 0 sized non-batch dims.
3810    yield ErrorInput(SampleInput(make_arg((2, 1, 0, 1, 2)),
3811                                 kwargs={'kernel_size': 1}),
3812                     error_regex=err_msg)
3813
3814
3815def sample_inputs_normalize(self, device, dtype, requires_grad, **kwargs):
3816    make_arg = partial(make_tensor, low=-1, high=1, device=device, dtype=dtype, requires_grad=requires_grad)
3817
3818    cases: Tuple[Tuple[int], dict] = (  # type: ignore[assignment]
3819                                     ((2, 1, 4, 5), {'p': 1., 'dim': 2}),
3820                                     ((2, 3, 4, 5), {'p': 2., 'dim': 1}),
3821                                     ((1, 2, 4, 5), {'p': 0.5, 'dim': 0}),
3822                                     ((1, 3, 4, 5), {'p': -1., 'dim': 1}),
3823                                     ((1, 3, 4, 5), {'p': 0., 'dim': -1}),
3824                                     ((), {'p': 1.2, 'dim': 0}),
3825                                     ((2, 3, 4, 5), {}),
3826                                     ((2, 3, 4, 5), {'eps': 1e-4}))
3827
3828    for input_shape, kwargs in cases:
3829        yield SampleInput(make_arg(input_shape), kwargs=kwargs)
3830
3831
3832def complex_conv(fn, input_size, weight, grad_output, stride, padding, dilation, groups):
3833    # conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
3834    # a = conv(Wr, xr, br),
3835    # b = conv(Wi, xi, 0),
3836    # c = conv(Wr + Wi, xr + xi, br + bi)
3837    # conv(W, x, b) = a - b + i(c - a - b)
3838
3839    grad_output_ = torch.view_as_real(grad_output)
3840    grad_output_r = grad_output_[..., 0]
3841    grad_output_i = grad_output_[..., 1]
3842
3843    weight_ = torch.view_as_real(weight)
3844    weight_r = weight_[..., 0]
3845    weight_i = weight_[..., 1]
3846
3847    a = fn(input_size, weight_r, grad_output_r, stride, padding, dilation, groups)
3848    b = fn(input_size, weight_i, grad_output_i, stride, padding, dilation, groups)
3849    c = fn(input_size, weight_r + weight_i, grad_output_r + grad_output_i, stride, padding, dilation, groups)
3850
3851    return (a - b) + 1j * (c - a - b)
3852
3853
3854def conv_transpose_ref(input, weight, bias, stride=1, padding=0,
3855                       output_padding=0, dilation=1, groups=1,
3856                       fn=None):
3857    # Derivative of `conv` is `conv_transpose`.
3858    # To verify the correctness of `conv_transpose`,
3859    # we rely `torch.nn.grad` implementation (which is tested in test_nn.py)
3860    # for floating dtypes.
3861
3862    assert fn is not None
3863
3864    grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input,
3865                   torch.nn.functional.conv_transpose2d: torch.nn.grad.conv2d_input,
3866                   torch.nn.functional.conv_transpose3d: torch.nn.grad.conv3d_input}
3867    batched_dim_map = {torch.nn.functional.conv_transpose1d: 3,
3868                       torch.nn.functional.conv_transpose2d: 4,
3869                       torch.nn.functional.conv_transpose3d: 5}
3870
3871    # Input for `ref` is ndarray.
3872    input, weight = torch.from_numpy(input), torch.from_numpy(weight)
3873
3874    is_batched = len(input.shape) == batched_dim_map[fn]
3875    if not is_batched:
3876        input = input.unsqueeze(0)
3877
3878    if bias is not None:
3879        bias = torch.from_numpy(bias)
3880        unsqueeze_dims = input.ndim - 2
3881        for _ in range(unsqueeze_dims):
3882            bias = bias.unsqueeze(1)
3883
3884    grad_output = input
3885    # Get the input shape for grad_fn.
3886    conv_transpose_output = fn(grad_output.to('meta'), weight.to('meta'), None,
3887                               stride=stride, padding=padding, output_padding=output_padding,
3888                               groups=groups, dilation=dilation)
3889    input_size = conv_transpose_output.shape
3890
3891    grad_fn = grad_fn_map[fn]
3892    if weight.dtype.is_complex:
3893        out = complex_conv(grad_fn, input_size, weight, grad_output, stride, padding, dilation, groups)
3894    else:  # Floating
3895        out = grad_fn(input_size, weight, grad_output, stride, padding, dilation, groups)
3896
3897    if bias is not None:
3898        out = out + bias
3899
3900    return out.squeeze(0) if not is_batched else out
3901
3902
3903def sample_inputs_conv_transpose1d(op_info, device, dtype, requires_grad, **kwargs):
3904    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3905
3906    # Ordered as shapes for input, weight, bias
3907    # and a dict of values of (stride, padding, output_padding, groups, dilation)
3908    cases: Tuple[Tuple[int], Tuple[int], Tuple[int], dict] = (  # type: ignore[assignment]
3909        ((1, 3, 4), (3, 3, 3), (3,),
3910         {'stride': (2,), 'padding': 2, 'output_padding': (1,), 'groups': 1}),
3911        ((2, 2, 4), (2, 2, 4), (4,),
3912         {'stride': (3,), 'padding': (1,), 'output_padding': (2,), 'groups': 2, 'dilation': (4,)}),
3913        ((1, 1, 4), (1, 1, 4), (1,),
3914         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2,)}),
3915        ((1, 1, 4), (1, 2, 3), None,
3916         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
3917        ((1, 4, 5), (4, 8, 3), None,
3918         {})
3919    )
3920
3921    for input_shape, weight, bias, kwargs in cases:
3922        # Batched
3923        yield SampleInput(make_arg(input_shape), args=(
3924            make_arg(weight),
3925            make_arg(bias) if bias is not None else bias
3926        ), kwargs=kwargs)
3927        # Unbatched
3928        yield SampleInput(make_arg(input_shape[1:]), args=(
3929            make_arg(weight),
3930            make_arg(bias) if bias is not None else bias
3931        ), kwargs=kwargs)
3932
3933
3934def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwargs):
3935    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3936
3937    # Ordered as shapes for input, weight, bias
3938    # and a dict of values of (stride, padding, output_padding, groups, dilation)
3939    cases: Tuple[Tuple[int], Tuple[int], Tuple[int], dict] = (  # type: ignore[assignment]
3940        ((1, 3, 4, 4), (3, 3, 3, 3), (3,),
3941         {'stride': (2, 2), 'padding': 2, 'output_padding': (1, 1), 'groups': 1}),
3942        ((2, 2, 4, 4), (2, 2, 4, 5), (4,),
3943         {'stride': (3, 2), 'padding': (1, 2), 'output_padding': (2, 3), 'groups': 2, 'dilation': (4, 4)}),
3944        ((1, 1, 4, 5), (1, 1, 4, 3), (1,),
3945         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3)}),
3946        ((1, 1, 4, 3), (1, 2, 3, 4), None,
3947         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
3948        ((2, 4, 4, 4), (4, 1, 3, 3), None, {'groups': 4}),
3949        ((1, 2, 5, 5), (2, 4, 3, 3), None, {})
3950    )
3951
3952    for input_shape, weight, bias, kwargs in cases:
3953        # Batched
3954        yield SampleInput(make_arg(input_shape), args=(
3955            make_arg(weight),
3956            make_arg(bias) if bias is not None else bias
3957        ), kwargs=kwargs)
3958        # Unbatched
3959        yield SampleInput(make_arg(input_shape[1:]), args=(
3960            make_arg(weight),
3961            make_arg(bias) if bias is not None else bias
3962        ), kwargs=kwargs)
3963
3964def sample_inputs_conv_transpose3d(op_info, device, dtype, requires_grad, **kwargs):
3965    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3966
3967    # Ordered as shapes for input, weight, bias
3968    # and a dict of values of (stride, padding, output_padding, groups, dilation)
3969    cases: Tuple[Tuple[int], Tuple[int], Tuple[int], dict] = (  # type: ignore[assignment]
3970        ((1, 3, 4, 4, 4), (3, 3, 3, 3, 3), (3,),
3971         {'stride': (2, 2, 2), 'padding': 2, 'output_padding': (1, 1, 1), 'groups': 1}),
3972        ((2, 2, 4, 4, 4), (2, 2, 4, 5, 6), (4,),
3973         {'stride': (3, 2, 1), 'padding': (1, 2, 3), 'output_padding': (2, 3, 1), 'groups': 2, 'dilation': (4, 4, 4)}),
3974        ((1, 1, 4, 5, 2), (1, 1, 4, 3, 1), (1,),
3975         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3, 2)}),
3976        ((1, 1, 4, 3, 4), (1, 2, 3, 4, 5), None,
3977         {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}),
3978        ((1, 4, 5, 5, 5), (4, 8, 3, 3, 3), None,
3979         {})
3980    )
3981
3982    for input_shape, weight, bias, kwargs in cases:
3983        # Batched
3984        yield SampleInput(make_arg(input_shape), args=(
3985            make_arg(weight),
3986            make_arg(bias) if bias is not None else bias
3987        ), kwargs=kwargs)
3988        # Unbatched
3989        yield SampleInput(make_arg(input_shape[1:]), args=(
3990            make_arg(weight),
3991            make_arg(bias) if bias is not None else bias
3992        ), kwargs=kwargs)
3993
3994
3995def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs):
3996    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3997
3998    # Ordered as shapes for input, weight, bias,
3999    # and a dict of values of (stride, padding, dilation, groups)
4000    cases: Tuple = (
4001        ((1, 3, 4), (3, 3, 3), (3,), {'stride': (2,), 'padding': 2, 'groups': 1}),
4002        ((2, 4, 8), (2, 2, 3), (2,), {'stride': 3, 'padding': 1, 'groups': 2, 'dilation': 2}),
4003        ((1, 4, 5), (1, 4, 3), None, {'stride': (2,), 'padding': 'valid'}),
4004        ((2, 2, 4), (2, 1, 4), (2,), {'stride': (1,), 'padding': 'same', 'groups': 2, 'dilation': (2,)}),
4005        # With defaults
4006        ((1, 4, 5), (3, 4, 3), None, {}),
4007    )
4008
4009    for input_shape, weight, bias, kwargs in cases:
4010        # Batched
4011        yield SampleInput(make_arg(input_shape), args=(
4012            make_arg(weight),
4013            make_arg(bias) if bias is not None else bias
4014        ), kwargs=kwargs)
4015        # Unbatched
4016        yield SampleInput(make_arg(input_shape[1:]), args=(
4017            make_arg(weight),
4018            make_arg(bias) if bias is not None else bias
4019        ), kwargs=kwargs)
4020
4021
4022def error_inputs_conv1d(opinfo, device, **kwargs):
4023    make_arg = partial(make_tensor, device=device, dtype=torch.float64)
4024    make_int_arg = partial(make_tensor, device=device, dtype=torch.int64)
4025    make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128)
4026
4027    # error inputs for different dtypes of input tensor and bias
4028    yield ErrorInput(
4029        SampleInput(make_int_arg((1, 1, 4)), args=(make_int_arg((1, 1, 2)), make_arg((1,)))),
4030        error_regex="should be the same")
4031
4032    # error inputs for different dtypes of input tensor and bias
4033    yield ErrorInput(
4034        SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_complex_arg((1,)))),
4035        error_regex="should be the same")
4036
4037    # error inputs for negative strides
4038    yield ErrorInput(
4039        SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))),
4040                    kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported")
4041
4042    # error inputs for negative padding
4043    yield ErrorInput(
4044        SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))),
4045                    kwargs={'padding': (-1,)}), error_regex="negative padding is not supported")
4046
4047    # error inputs for negative dilation
4048    yield ErrorInput(
4049        SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_arg((1,))),
4050                    kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero")
4051
4052    # FIXME: https://github.com/pytorch/pytorch/issues/85656
4053    # error inputs for bias shape not equal to the output channels
4054    # yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 3)), make_arg((2,)))),
4055    #                  error_regex="expected bias to be 1-dimensional with 1 elements")
4056
4057    # error inputs for input.ndim != weight.ndim
4058    yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2)), make_arg((1,)))),
4059                     error_regex="weight should have at least three dimensions")
4060
4061    # error inputs for the weight[0] are less than the number of groups
4062    yield ErrorInput(
4063        SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))),
4064                    kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0")
4065
4066    # error inputs for the weight[0] are less than the number of groups
4067    yield ErrorInput(
4068        SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))),
4069                    kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0")
4070
4071    # error inputs for invalid groups
4072    yield ErrorInput(
4073        SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))),
4074                    kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported")
4075
4076    # error inputs for invalid groups
4077    yield ErrorInput(
4078        SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))),
4079                    kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported")
4080
4081
4082def error_inputs_conv2d(opinfo, device, **kwargs):
4083    make_arg = partial(make_tensor, device=device, dtype=torch.float64)
4084    make_int_arg = partial(make_tensor, device=device, dtype=torch.int64)
4085    make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128)
4086
4087    # error inputs for different dtypes of input tensor and bias
4088    yield ErrorInput(
4089        SampleInput(make_int_arg((2, 4, 4)), args=(make_int_arg((3, 2, 3, 3)), make_arg((3,)))),
4090        error_regex="should be the same")
4091
4092    # error inputs for different dtypes of input tensor and bias
4093    yield ErrorInput(
4094        SampleInput(make_arg((2, 4, 4)), args=(make_arg((3, 2, 3, 3)), make_complex_arg((3,)))),
4095        error_regex="should be the same")
4096
4097    # error inputs for negative strides
4098    yield ErrorInput(
4099        SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 2, 2, 3)), make_arg((1,))),
4100                    kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported")
4101
4102    # error inputs for negative padding
4103    yield ErrorInput(
4104        SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2, 4)), make_arg((1,))),
4105                    kwargs={'padding': (-1,)}), error_regex="negative padding is not supported")
4106
4107    # error inputs for negative dilation
4108    yield ErrorInput(
4109        SampleInput(make_arg((1, 1, 4, 2)), args=(make_arg((1, 1, 2, 5)), make_arg((1,))),
4110                    kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero")
4111
4112    # FIXME: https://github.com/pytorch/pytorch/issues/85656
4113    # error inputs for bias shape not equal to the output channels
4114    # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 1, 3, 2)), make_arg((2,)))),
4115    #                  error_regex="expected bias to be 1-dimensional with 1 elements")
4116
4117    # error inputs for input.ndim != weight.ndim
4118    yield ErrorInput(
4119        SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2)), make_arg((1,))),
4120                    kwargs={'padding': 'same'}), error_regex="Expected 3-dimensional input for 3-dimensional weight")
4121
4122    # error inputs for the weight[0] are less than the number of groups
4123    yield ErrorInput(
4124        SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))),
4125                    kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0")
4126
4127    # error inputs for groups the weight[0] are less than the number of groups
4128    yield ErrorInput(
4129        SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))),
4130                    kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0")
4131
4132    # error inputs for invalid groups
4133    yield ErrorInput(
4134        SampleInput(make_arg((2, 2, 4, 5)), args=(make_arg((2, 2, 1, 4)), make_arg((2,))),
4135                    kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported")
4136
4137    # error inputs for invalid groups
4138    yield ErrorInput(
4139        SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 4, 3)), make_arg((2,))),
4140                    kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported")
4141
4142
4143def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=False, **kwargs):
4144    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4145
4146    # Ordered as shapes for input, weight, bias
4147    # and a dict of values of (stride, padding, groups, dilation)
4148    cases: Tuple = (
4149        ((1, 3, 4, 4), (3, 3, 3, 3), (3,),
4150            {'stride': (2, 2), 'padding': 2, 'groups': 1}),
4151        ((2, 4, 8, 8), (2, 2, 3, 3), (2,),
4152            {'stride': (3, 2), 'padding': (2, 1), 'groups': 2, 'dilation': (4, 4)}),
4153        ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
4154            {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}),
4155        ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
4156            {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}),
4157        ((1, 2, 4, 3), (4, 2, 3, 4), None,
4158            {'stride': 2, 'padding': 1, 'groups': 1}),
4159        ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
4160            {'stride': 2, 'padding': "valid"}),
4161        ((1, 4, 5, 5), (1, 4, 2, 3), (1,),
4162            {'stride': 1, 'padding': "same", 'dilation': 3}),
4163        # Below are the group related samples from common_nn.py
4164        ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4}),
4165        ((2, 4, 6, 6), (8, 1, 3, 3), (8,), {'groups': 4}),
4166        ((2, 4, 6, 6), (8, 1, 3, 3), None, {'groups': 4}),
4167        ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'stride': (3, 2)}),
4168        ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'padding': (1, 1)}),
4169        ((2, 4, 5, 5), (4, 1, 2, 2), (4,), {'groups': 4, 'dilation': (2, 2)}),
4170        ((2, 4, 6, 5), (6, 2, 3, 2), (6,), {'groups': 2}),
4171        # With defaults
4172        ((1, 4, 5, 5), (3, 4, 3, 3), None, {}),
4173    )
4174
4175    for input_shape, weight, bias, kwargs in cases:
4176        # Batched
4177        yield SampleInput(make_arg(input_shape), args=(
4178            make_arg(weight),
4179            make_arg(bias) if bias is not None else bias
4180        ), kwargs=kwargs)
4181        # Unbatched
4182        yield SampleInput(make_arg(input_shape[1:]), args=(
4183            make_arg(weight),
4184            make_arg(bias) if bias is not None else bias
4185        ), kwargs=kwargs)
4186
4187
4188def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs):
4189    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4190
4191    # Ordered as shapes for input, weight, bias
4192    # and dict of values of (stride, padding, dilation, groups)
4193    cases: Tuple = (
4194        ((1, 1, 4, 4, 4), (1, 1, 1, 1, 1), (1,), {'padding': 'same'}),
4195        ((1, 1, 4, 4, 4), (1, 1, 4, 4, 4), (1,), {'stride': (2, 2, 2)}),
4196        ((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1,), {'dilation': 2}),
4197        ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}),
4198        ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same'}),
4199        ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same', 'dilation': 2}),
4200        ((1, 1, 10, 11, 12), (1, 1, 4, 4, 4), None, {'padding': 'same', 'dilation': 3}),
4201        ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}),
4202        ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'groups': 3}),
4203        ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'stride': (2, 2, 2), 'dilation': 1, 'groups': 3}),
4204    )
4205
4206    for input_shape, weight, bias, kwargs in cases:
4207        # Batched
4208        yield SampleInput(make_arg(input_shape), args=(
4209            make_arg(weight),
4210            make_arg(bias) if bias is not None else bias
4211        ), kwargs=kwargs)
4212        # Unbatched
4213        yield SampleInput(make_arg(input_shape[1:]), args=(
4214            make_arg(weight),
4215            make_arg(bias) if bias is not None else bias
4216        ), kwargs=kwargs)
4217
4218
4219def error_inputs_conv3d(opinfo, device, **kwargs):
4220    make_arg = partial(make_tensor, device=device, dtype=torch.float64)
4221    make_int_arg = partial(make_tensor, device=device, dtype=torch.int64)
4222    make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128)
4223
4224    # error inputs for different dtypes of input tensor and bias
4225    yield ErrorInput(
4226        SampleInput(make_int_arg((1, 1, 4, 4, 4)), args=(make_int_arg((1, 1, 2, 2, 2)), make_arg((1,)))),
4227        error_regex="should be the same")
4228
4229    # error inputs for different dtypes of input tensor and bias
4230    yield ErrorInput(
4231        SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_complex_arg((1,)))),
4232        error_regex="should be the same")
4233
4234    # error inputs for negative strides
4235    yield ErrorInput(
4236        SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))),
4237                    kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported")
4238
4239    # error inputs for negative padding
4240    yield ErrorInput(
4241        SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))),
4242                    kwargs={'padding': (-1,)}), error_regex="negative padding is not supported")
4243
4244    # error inputs for negative dilation
4245    yield ErrorInput(
4246        SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))),
4247                    kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero")
4248
4249    # FIXME: https://github.com/pytorch/pytorch/issues/85656
4250    # error inputs for bias shape not equal to the output channels
4251    # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 3, 3, 3)), make_arg((2,)))),
4252    #                  error_regex="expected bias to be 1-dimensional with 1 elements")
4253
4254    # error inputs for input.ndim != weight.ndim
4255    yield ErrorInput(
4256        SampleInput(make_arg((1, 1, 3, 4, 5)), args=(make_arg((1, 1, 4, 3)), make_arg((1,))),
4257                    kwargs={'padding': 'same'}), error_regex="Expected 4-dimensional input for 4-dimensional weight")
4258
4259    # error inputs for the weight[0] are less than the number of groups
4260    yield ErrorInput(
4261        SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)),
4262                    make_arg((2,))), kwargs={'groups': 3}),
4263        error_regex="expected weight to be at least 3 at dimension 0")
4264
4265    # error inputs for the weight[0] are less than the number of groups
4266    yield ErrorInput(
4267        SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)),
4268                    make_arg((2,))), kwargs={'padding': 'same', 'groups': 3}),
4269        error_regex="expected weight to be at least 3 at dimension 0")
4270
4271    # error inputs for invalid groups
4272    yield ErrorInput(
4273        SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)),
4274                    make_arg((2,))), kwargs={'padding': 'same', 'groups': 0}),
4275        error_regex="non-positive groups is not supported")
4276
4277    # error inputs for padding='same' not supported by strided convolutions
4278    yield ErrorInput(
4279        SampleInput(make_arg((18, 27, 9, 1, 9)), args=(make_arg((9, 9, 9, 1, 9)),
4280                    make_arg((9,))), kwargs={'stride': 2, 'padding': 'same', 'groups': 3}),
4281        error_regex="padding='same' is not supported for strided convolutions")
4282
4283
4284def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs):
4285    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4286
4287    # Ordered as input shape, num groups, and kwargs for eps
4288    cases: Tuple[Tuple[int], int, float] = (  # type: ignore[assignment]
4289        ((1, 6, 3), 2, {'eps' : 0.5}),
4290        ((2, 6, 3), 2, {'eps' : -0.5}),
4291        ((1, 3), 1, {'eps' : 1e-5}),
4292        ((0, 2), 1, {'eps' : 1e-5}),
4293        ((S, S, S), 1, {'eps' : 0.5}),
4294    )
4295
4296    # num_channels is inferred to be input.shape[1] dimension
4297    for input_shape, num_groups, kwargs in cases:
4298        # Shape of weight and bias should be the same as num_channels
4299        channels = input_shape[1] if len(input_shape) > 1 else 0
4300        weight_tensor = make_arg(channels)
4301        bias_tensor = make_arg(channels)
4302
4303        # Checking for permutations of weights and biases as `None`
4304        weights = [weight_tensor, None]
4305        biases = [bias_tensor, None]
4306        for weight, bias in itertools.product(weights, biases):
4307            kwargs = {
4308                'weight': weight,
4309                'bias': bias,
4310                **kwargs
4311            }
4312            yield SampleInput(make_arg(input_shape), num_groups, **kwargs)
4313
4314    # Without any optional args
4315    yield SampleInput(make_arg((1, 2)), args=(1,))
4316
4317def reference_inputs_group_norm(op_info, device, dtype, requires_grad, **kwargs):
4318    yield from sample_inputs_group_norm(
4319        op_info, device, dtype, requires_grad, **kwargs)
4320
4321    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4322
4323    # Ordered as input shape, num groups, and kwargs for eps
4324    cases: Tuple[Tuple[int], int, float] = (  # type: ignore[assignment]
4325        ((20, 6, 10, 10), 3, {'eps' : 1e-5}),
4326        # equivalent with InstanceNorm
4327        # GroupNorm(C, num_groups=C) == InstanceNorm(num_features=C)
4328        ((20, 6, 10, 10), 6, {'eps' : 1e-5}),
4329        # equivalent with LayerNorm
4330        # GroupNorm(C, num_groups=1, affine=False) == LayerNorm(normalized_shape=[C, H, W], elementwise_affine=False)
4331        ((20, 6, 10, 10), 1, {'eps' : 1e-5}),
4332    )
4333
4334    # num_channels is inferred to be input.shape[1] dimension
4335    for input_shape, num_groups, kwargs in cases:
4336        # Shape of weight and bias should be the same as num_channels
4337        channels = input_shape[1] if len(input_shape) > 1 else 0
4338        input_tensor = make_arg(input_shape)
4339        weight_tensor = make_arg(channels)
4340        bias_tensor = make_arg(channels)
4341
4342        # Checking for permutations of weights and biases as `None`
4343        weights = [weight_tensor, None]
4344        biases = [bias_tensor, None]
4345        for weight, bias in itertools.product(weights, biases):
4346            kwargs = {
4347                'weight': weight,
4348                'bias': bias,
4349                **kwargs
4350            }
4351            yield SampleInput(input_tensor, num_groups, **kwargs)
4352
4353
4354def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs):
4355    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4356    make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
4357
4358    # Ordered as: input shape, kwargs for momentum, eps
4359    cases: Tuple[Tuple[int], dict] = (  # type: ignore[assignment]
4360        ((S, S, S), {'momentum': 0.5, 'eps': 0.6}),
4361        ((S, S, S), {'momentum': 0.5, 'eps': 0.6, 'use_input_stats': True}),
4362        ((3, 2, 4), {'momentum': -1.2}),
4363        ((3, 2, 4), {'momentum': 0.0}),
4364        ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}),
4365        ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}),
4366    )
4367
4368    for input_shape, kwargs in cases:
4369        # args: running mean, running var, weight and bias should necessarily be of shape: (channels,)
4370        channels = input_shape[1]
4371        weight = make_arg(channels)
4372        bias = make_arg(channels)
4373        running_mean = make_arg_without_requires_grad(channels, low=0)
4374        running_var = make_arg_without_requires_grad(channels, low=0)
4375        new_kwargs = {
4376            'running_mean': running_mean,
4377            'running_var': running_var,
4378            'weight': weight,
4379            'bias': bias,
4380            **kwargs
4381        }
4382
4383        yield SampleInput(
4384            make_arg(input_shape),
4385            args=(),
4386            kwargs=new_kwargs
4387        )
4388
4389    # Checking for permutations of weights and biases as `None`
4390    # instance_norm assumes that if there's a bias, there's a weight
4391    weights = [channels, None]
4392    biases = [None, None]
4393
4394    for weight_channels, bias_channels in zip(weights, biases):
4395        running_mean = make_arg_without_requires_grad(channels, low=0)
4396        running_var = make_arg_without_requires_grad(channels, low=0)
4397        yield SampleInput(
4398            make_arg(input_shape),
4399            args=(),
4400            kwargs={
4401                'running_mean': running_mean,
4402                'running_var': running_var,
4403                'weight': make_arg(weight_channels) if weight_channels is not None else None,
4404                'bias': make_arg(bias_channels) if bias_channels is not None else None
4405            }
4406        )
4407
4408    # Test case for no optional kwargs
4409    yield SampleInput(make_arg((1, 2, 3)), kwargs={})
4410
4411def sample_inputs_safe_softmax(opinfo, device, dtype, requires_grad, **kwargs):
4412    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
4413
4414    def make_bool_mask(*shape):
4415        return torch.randint(0, 2, shape, device=device, dtype=torch.bool)
4416
4417    def mask_two_rows(rows, cols):
4418        mask_two_rows = torch.ones((rows, cols), dtype=torch.bool, device=device)
4419        mask_two_rows[rows - 1] = False
4420        mask_two_rows[rows - 3] = False
4421        return mask_two_rows
4422
4423    def convert_to_float_mask(mask: torch.Tensor) -> torch.Tensor:
4424        return torch.where(~mask, float('-inf'), 0.0)
4425
4426    def with_requires_grad(tensor):
4427        return tensor.requires_grad_(requires_grad)
4428
4429    def generate_input_from_mask(mask_shape, dim):
4430        mask = make_bool_mask(*mask_shape)
4431        input_tensor = make_arg(mask_shape)
4432        masked_input = input_tensor + convert_to_float_mask(mask)
4433        return SampleInput(with_requires_grad(masked_input), kwargs={'dim': dim})
4434
4435    samples = [
4436        # Basic 3D tensor with mask
4437        generate_input_from_mask((2, 3, 4), dim=1),
4438        # 2D tensor with mask, testing different dim
4439        generate_input_from_mask((5, 5), dim=0),
4440        # 4D tensor, testing with a different dim
4441        generate_input_from_mask((2, 3, 4, 5), dim=2),
4442        # Edge case: 1D tensor
4443        generate_input_from_mask((10,), dim=0),
4444        # Edge case: tensor with one dimension of size 1
4445        generate_input_from_mask((1, 5, 5), dim=1),
4446        # Testing with all elements masked
4447        SampleInput(
4448            with_requires_grad(
4449                make_arg((3, 3))
4450                + convert_to_float_mask(
4451                    torch.zeros((3, 3), dtype=torch.bool, device=device)
4452                )
4453            ),
4454            kwargs={"dim": 1},
4455        ),
4456        # Testing with no elements masked
4457        SampleInput(
4458            with_requires_grad(
4459                make_arg((3, 3))
4460                + convert_to_float_mask(
4461                    torch.ones((3, 3), dtype=torch.bool, device=device)
4462                )
4463            ),
4464            kwargs={"dim": 1},
4465        ),
4466        # Testing with two rows masked
4467        SampleInput(
4468            with_requires_grad(
4469                make_arg((6, 3)) + convert_to_float_mask(mask_two_rows(6, 3))
4470            ),
4471            kwargs={"dim": 1},
4472        ),
4473    ]
4474    yield from samples
4475
4476def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs):
4477    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4478
4479    # Ordered as input shape, normalized_shape and a kwarg dict for eps
4480    cases: Tuple[Tuple[int], Tuple[int], dict] = (  # type: ignore[assignment]
4481        ((1, 2, 3), (1, 2, 3), {'eps': 0.5}),
4482        ((2, 2, 3), (2, 3), {'eps': -0.5}),
4483        ((1,), (1,), {}),
4484        ((1, 2), (2,), {}),
4485        ((0, 1), (1,), {}),
4486    )
4487
4488    for input_shape, normalized_shape, kwargs in cases:
4489        # Shape of weight and bias should be the same as normalized_shape
4490        weight = make_arg(normalized_shape)
4491        bias = make_arg(normalized_shape)
4492        yield SampleInput(
4493            make_arg(input_shape),
4494            args=(normalized_shape, weight, bias),
4495            kwargs=kwargs
4496        )
4497    # Without any optional args
4498    yield SampleInput(make_arg((1, 2)), args=((2,),))
4499
4500    # TODO: @krshrimali, once to_numpy method in SampleInput class is modified to take None inputs,
4501    # enable these inputs; see https://github.com/pytorch/pytorch/pull/63276#discussion_r691950400
4502
4503    # With weight and a `None` bias
4504    # yield SampleInput(make_arg((1, 2)), args=((2,), make_arg((2,)), None))
4505
4506    # With `None` weight and bias (tests failing for this, see the link above)
4507    # yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,))))
4508
4509
4510def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwargs):
4511    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4512
4513    # Ordered as input shape, normalized_shape, eps
4514    cases: Tuple[Tuple[int], Tuple[int], float] = (  # type: ignore[assignment]
4515        ((1, 2, 3), (1, 2, 3), 0.5),
4516        ((2, 2, 3), (2, 3), -0.5),
4517        ((1,), (1,), 1e-5),
4518        ((1, 2), (2,), 1e-5),
4519        ((0, 1), (1,), 1e-5),
4520    )
4521
4522    for input_shape, normalized_shape, eps in cases:
4523        # Shape of weight and bias should be the same as normalized_shape
4524        weight = make_arg(normalized_shape)
4525        bias = make_arg(normalized_shape)
4526        yield SampleInput(
4527            make_arg(input_shape),
4528            args=(normalized_shape, weight, bias, eps),
4529        )
4530        yield SampleInput(
4531            make_arg(input_shape),
4532            args=(normalized_shape, None, bias, eps),
4533        )
4534        yield SampleInput(
4535            make_arg(input_shape),
4536            args=(normalized_shape, weight, None, eps),
4537        )
4538        yield SampleInput(
4539            make_arg(input_shape),
4540            args=(normalized_shape, None, None, eps),
4541        )
4542
4543def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs):
4544    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4545
4546    # Ordered as input shape, normalized_shape and a kwarg dict for eps
4547    cases: Tuple[Tuple[int], Tuple[int], dict] = (  # type: ignore[assignment]
4548        ((1, 2, 3), (1, 2, 3), {'eps': 0.5}),
4549        ((2, 2, 3), (2, 3), {'eps': -0.5}),
4550        ((1,), (1,), {}),
4551        ((1, 2), (2,), {}),
4552        ((0, 1), (1,), {}),
4553    )
4554
4555    for input_shape, normalized_shape, kwargs in cases:
4556        # Shape of weight and bias should be the same as normalized_shape
4557        weight = make_arg(normalized_shape)
4558        yield SampleInput(
4559            make_arg(input_shape),
4560            args=(normalized_shape, weight),
4561            kwargs=kwargs
4562        )
4563    # Without any optional args
4564    yield SampleInput(make_arg((1, 2)), args=((2,),))
4565
4566def error_inputs_group_norm(opinfo, device, **kwargs):
4567    make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
4568
4569    # check that input has minimum number of dimensions
4570    err_msg1 = "Expected at least 2 dimensions for input tensor but received"
4571    s1 = SampleInput(make_arg(1), args=(1,))
4572    yield ErrorInput(s1, error_regex=err_msg1)
4573
4574    # check that the channels dimension is compatible with number of groups
4575    err_msg2 = "Expected number of channels in input to be divisible by num_groups, but got input of shape"
4576    s2 = SampleInput(make_arg((2, 7, 4)), args=(2,))
4577    yield ErrorInput(s2, error_regex=err_msg2)
4578
4579def error_inputs_native_layer_norm(opinfo, device, **kwargs):
4580    make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
4581    input_shape = (1, 2, 3)
4582
4583    err_msg1 = "Expected normalized_shape to be at least 1-dimensional"
4584    s1 = SampleInput(
4585        make_arg(input_shape), args=((), None, None, 1e-5)
4586    )
4587    yield ErrorInput(s1, error_regex=err_msg1)
4588
4589    normalized_shape = (1, 2, 3)
4590    weight = make_arg((1, 2))
4591    err_msg2 = "Expected weight to be of same shape as normalized_shape"
4592    s2 = SampleInput(
4593        make_arg(input_shape), args=(normalized_shape, weight, None, 1e-5)
4594    )
4595    yield ErrorInput(s2, error_regex=err_msg2)
4596
4597    bias = make_arg((1, 2))
4598    err_msg3 = "Expected bias to be of same shape as normalized_shape"
4599    s3 = SampleInput(
4600        make_arg(input_shape), args=(normalized_shape, None, bias, 1e-5)
4601    )
4602    yield ErrorInput(s3, error_regex=err_msg3)
4603
4604    err_msg4 = "Given normalized_shape="
4605    s4 = SampleInput(
4606        make_arg((2, 2, 3)), args=((2, 2), None, None, 1e-5)
4607    )
4608    yield ErrorInput(s4, error_regex=err_msg4)
4609
4610def error_inputs_rms_norm(opinfo, device, **kwargs):
4611    make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False)
4612    input_shape = (1, 2, 3)
4613
4614    err_msg1 = "Expected normalized_shape to be at least 1-dimensional"
4615    s1 = SampleInput(
4616        make_arg(input_shape), args=((), None, 1e-5)
4617    )
4618    yield ErrorInput(s1, error_regex=err_msg1)
4619
4620    normalized_shape = (1, 2, 3)
4621    weight = make_arg((1, 2))
4622    err_msg2 = "Expected weight to be of same shape as normalized_shape"
4623    s2 = SampleInput(
4624        make_arg(input_shape), args=(normalized_shape, weight, 1e-5)
4625    )
4626    yield ErrorInput(s2, error_regex=err_msg2)
4627
4628
4629    err_msg4 = "Given normalized_shape="
4630    s4 = SampleInput(
4631        make_arg((2, 2, 3)), args=((2, 2), None, 1e-5)
4632    )
4633    yield ErrorInput(s4, error_regex=err_msg4)
4634
4635
4636def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs):
4637    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4638
4639    # Ordered as input shape, size and a kwarg dict for alpha, beta, and k
4640    cases: Tuple[Tuple[int], Tuple[int], dict] = (  # type: ignore[assignment]
4641        ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
4642        ((1, 6, 3), 2, {'beta': 0.5, 'k': 1.25}),
4643        ((1, 6, 3), 2, {'alpha': 3e-05, 'k': 1.25}),
4644        ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5}),
4645        ((1, 6, 3), 2, {'alpha': 3e-05}),
4646        ((1, 6, 3), 2, {'beta': 0.5}),
4647        ((1, 6, 3), 2, {'k': 1.25}),
4648        ((1, 6, 3), 2, {}),
4649        ((2, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
4650        ((1, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
4651        ((0, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}),
4652    )
4653
4654    for input_shape, size, kwargs in cases:
4655        yield SampleInput(make_arg(input_shape), args=(size,), kwargs=kwargs)
4656
4657def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs):
4658    N = 5
4659    # make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ?
4660    make_arg = partial(make_tensor, device=device, dtype=dtype,
4661                       requires_grad=requires_grad, low=-5, high=5)
4662    return (SampleInput(make_arg((N * 2, N * 2))) for _ in range(1, N))
4663
4664def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs):
4665    features_options = [[3, 4], [8, 8]]
4666    batch_options: List[List[int]] = [
4667        [],  # no batch
4668        [0],
4669        [8],
4670        [2, 3],
4671    ]
4672    create_tensor = partial(make_tensor, device=device, dtype=dtype,
4673                            requires_grad=requires_grad, low=-2, high=2)
4674
4675    for has_bias, (in_feat, out_feat), batch_shape in \
4676            itertools.product([True, False], features_options, batch_options):
4677        input_tensor = create_tensor(batch_shape + [in_feat])
4678        weight = create_tensor([out_feat, in_feat])
4679        if not has_bias:
4680            yield SampleInput(input_tensor, weight)
4681            continue
4682
4683        bias = create_tensor([out_feat])
4684        yield SampleInput(input_tensor, weight, bias)
4685
4686    # 5D tensor, used to crash on MPS, see https://github.com/pytorch/pytorch/issues/114942
4687    yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2))
4688    yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2), create_tensor(4))
4689
4690def sample_inputs_bilinear(self, device, dtype, requires_grad, **kwargs):
4691    features_options = [[3, 4, 5], [8, 8, 8]]
4692    batch_options: List[List[int]] = [
4693        [],  # no batch
4694        [0],
4695        [8],
4696        [2, 3],
4697    ]
4698    create_tensor = partial(make_tensor, device=device, dtype=dtype,
4699                            requires_grad=requires_grad, low=-2, high=2)
4700
4701    for has_bias, (in_feat1, in_feat2, out_feat), batch_shape in \
4702            itertools.product([True, False], features_options, batch_options):
4703        input_tensor1 = create_tensor(batch_shape + [in_feat1])
4704        input_tensor2 = create_tensor(batch_shape + [in_feat2])
4705        weight = create_tensor([out_feat, in_feat1, in_feat2])
4706        if not has_bias:
4707            yield SampleInput(input_tensor1, input_tensor2, weight)
4708            continue
4709        bias = create_tensor([out_feat])
4710        yield SampleInput(input_tensor1, input_tensor2, weight, bias)
4711
4712def sample_inputs_glu(self, device, dtype, requires_grad, **kwargs):
4713    features_options = [[2], [2, 4], [8, 8], [3, 6, 8], [1, 4, 6, 7]]
4714    batch_options: List[List[int]] = [
4715        [],  # no batch
4716        [0],
4717        [8],
4718        [2, 3],
4719    ]
4720    create_tensor = partial(make_tensor, device=device, dtype=dtype,
4721                            requires_grad=requires_grad, low=-2, high=2)
4722
4723    for features, batch_shape in itertools.product(features_options, batch_options):
4724        ndim = len(features) + len(batch_shape)
4725        for dim in range(ndim):
4726            input_tensor = create_tensor(batch_shape + features)
4727            dim_size = input_tensor.size(dim)
4728            if dim_size > 0 and dim_size % 2 == 0:
4729                yield SampleInput(input_tensor, dim)
4730
4731def sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs):
4732    N, C = 2, 3
4733    D = 4
4734    S = 3
4735    L = 5
4736
4737    align_corners_options: Tuple[Any, ...] = (None,)
4738    if mode in ('linear', 'bilinear', 'bicubic', 'trilinear'):
4739        align_corners_options = (True, False, None)
4740    ranks_for_mode = {
4741        'nearest': [1, 2, 3],
4742        'nearest-exact': [1, 2, 3],
4743        'linear': [1],
4744        'bilinear': [2],
4745        'bicubic': [2],
4746        'trilinear': [3],
4747        'area': [1, 2, 3]
4748    }
4749
4750    def shape(size, rank, with_batch_channel=True):
4751        if with_batch_channel:
4752            return tuple([N, C] + ([size] * rank))
4753        return tuple([size] * rank)
4754
4755    if mode in ('bilinear', 'bicubic') and dtype == torch.uint8:
4756        make_arg = partial(
4757            make_tensor,
4758            device=device,
4759            dtype=dtype,
4760            requires_grad=requires_grad,
4761            # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype
4762            high=256 if dtype == torch.uint8 else None,
4763        )
4764        # provide few samples for a more close to typical image processing usage
4765        rank = 2
4766        for memory_format in [torch.contiguous_format, torch.channels_last]:
4767            yield SampleInput(
4768                make_arg(shape(270, rank), memory_format=memory_format),
4769                shape(130, rank, False),
4770                scale_factor=None,
4771                mode=mode,
4772                align_corners=False,
4773            )
4774
4775    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4776
4777    for align_corners in align_corners_options:
4778        for rank in ranks_for_mode[mode]:
4779            yield SampleInput(
4780                make_arg(shape(D, rank)),
4781                shape(S, rank, False),
4782                scale_factor=None,
4783                mode=mode,
4784                align_corners=align_corners,
4785            )
4786            yield SampleInput(
4787                make_arg(shape(D, rank)),
4788                shape(L, rank, False),
4789                scale_factor=None,
4790                mode=mode,
4791                align_corners=align_corners,
4792            )
4793            for recompute_scale_factor in [False, True]:
4794                for scale_factor in [1.7, 0.6]:
4795                    yield SampleInput(
4796                        make_arg(shape(D, rank)),
4797                        size=None,
4798                        scale_factor=scale_factor,
4799                        mode=mode,
4800                        align_corners=align_corners,
4801                        recompute_scale_factor=recompute_scale_factor,
4802                    )
4803
4804def reference_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs):
4805    yield from sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs)
4806
4807    if mode in ('bilinear', 'bicubic'):
4808        make_arg = partial(
4809            make_tensor,
4810            device=device,
4811            dtype=dtype,
4812            requires_grad=requires_grad,
4813            # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype
4814            high=256 if dtype == torch.uint8 else None,
4815        )
4816        # provide few samples for more typical image processing usage
4817        for memory_format in [torch.contiguous_format, torch.channels_last]:
4818            for aa in [True, False]:
4819                yield SampleInput(
4820                    make_arg((2, 3, 345, 456), memory_format=memory_format),
4821                    (270, 270),
4822                    scale_factor=None,
4823                    mode=mode,
4824                    align_corners=False,
4825                    antialias=aa,
4826                )
4827
4828def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
4829    N, C = 2, 3
4830    D = 4
4831    S = 3
4832    L = 5
4833
4834    ranks_for_mode = {
4835        'nearest': [1, 2, 3],
4836        'bilinear': [2],
4837    }
4838
4839    def shape(size, rank, with_batch_channel=True):
4840        if with_batch_channel:
4841            return torch.Size([N, C] + ([size] * rank))
4842        return torch.Size([size] * rank)
4843
4844    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4845
4846    for rank in ranks_for_mode[mode]:
4847        yield SampleInput(make_arg(shape(D, rank)), size=shape(S, rank, False))
4848        yield SampleInput(make_arg(shape(D, rank)), size=shape(L, rank, False))
4849        yield SampleInput(make_arg(shape(D, rank)), scale_factor=1.7)
4850        yield SampleInput(make_arg(shape(D, rank)), scale_factor=0.6)
4851
4852def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs):
4853    yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs)
4854
4855    if mode in ('bilinear', ):
4856        make_arg = partial(
4857            make_tensor,
4858            device=device,
4859            dtype=dtype,
4860            requires_grad=requires_grad,
4861            # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype
4862            high=256 if dtype == torch.uint8 else None,
4863        )
4864        # provide a single sample for more typical image processing usage
4865        for memory_format in [torch.contiguous_format, torch.channels_last]:
4866            yield SampleInput(
4867                make_arg((2, 3, 345, 456), memory_format=memory_format),
4868                (270, 270),
4869            )
4870
4871def sample_inputs_upsample_aa(mode, self, device, dtype, requires_grad, **kwargs):
4872    N = 6
4873    C = 3
4874    H = 10
4875    W = 20
4876    S = 3
4877    L = 5
4878
4879    input_tensor = make_tensor(torch.Size([N, C, H, W]), device=device, dtype=dtype, requires_grad=requires_grad)
4880
4881    yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scale_factors=None)
4882    yield SampleInput(input_tensor, output_size=torch.Size([L, L]), align_corners=False, scale_factors=None)
4883    yield SampleInput(input_tensor, output_size=None, align_corners=False, scale_factors=[1.7, 0.9])
4884    yield SampleInput(input_tensor, output_size=None, align_corners=True, scale_factors=[0.8, 1.0])
4885
4886    yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=None, scales_w=None)
4887    yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=1.7, scales_w=0.9)
4888    yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=True, scales_h=1.7, scales_w=0.9)
4889
4890def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs):
4891    N = 5
4892    for _ in range(1, N):
4893        for approximate in ['none', 'tanh']:
4894            yield SampleInput(
4895                make_tensor((N * 2, N * 2), device=device, dtype=dtype,
4896                            requires_grad=requires_grad, low=-3, high=3),
4897                approximate=approximate)
4898
4899
4900def error_inputs_gelu(op, device, **kwargs):
4901    # Tests that gelu errors out when passed an approximation we don't know.
4902    yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device), kwargs={"approximate": "asdf"}),
4903                     error_regex="approximate argument must be either")
4904
4905
4906def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs):
4907    inputs = []
4908    args_for_reduction_with_dim = (
4909        ((S, S, S), (1,),),
4910        ((S, S, S), (1, True, ),),
4911        ((), (0,),),
4912        ((), (0, True,),),
4913    )
4914    return ((SampleInput(make_tensor(input_tensor, dtype=dtype, device=device,
4915                                     low=None, high=None,
4916                                     requires_grad=requires_grad),
4917                         *args))
4918            for input_tensor, args in args_for_reduction_with_dim)
4919
4920def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad, **kwargs):
4921    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
4922    yield SampleInput(make_arg((S, S, S)))
4923    yield SampleInput(make_arg(()))
4924
4925def _generate_nan_reduction_inputs(device, dtype, requires_grad, **kwargs):
4926    yield from _generate_reduction_inputs(device, dtype, requires_grad)
4927    # NaN only exists for floating point numbers
4928    if dtype.is_complex or dtype.is_floating_point:
4929        yield torch.tensor([2, torch.nan, -1], device=device, dtype=dtype, requires_grad=requires_grad)
4930        yield torch.tensor([[torch.nan, 2], [0, 1]], device=device, dtype=dtype, requires_grad=requires_grad)
4931
4932def sample_inputs_nan_reduction(supports_multiple_dims):
4933    # Generates sample inputs for reduction ops that contain the input tensor
4934    # and dim and keepdim kwargs. If a reduction op needs to test additional
4935    # args/kwargs then create a separate sample_inputs function
4936    def fn(op_info, device, dtype, requires_grad, **kwargs):
4937        for t in _generate_nan_reduction_inputs(device, dtype, requires_grad):
4938            # Add case without dim and keepdim kwargs
4939            yield SampleInput(t.clone().requires_grad_(requires_grad))
4940            for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims):
4941                yield SampleInput(t.clone().requires_grad_(requires_grad), **kwargs)
4942
4943    return fn
4944
4945def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad, **kwargs):
4946    test_quantiles = (0.5, make_tensor((2,), dtype=dtype, device=device, low=0, high=1, requires_grad=requires_grad))
4947    test_interpolations = ['linear', 'midpoint']
4948
4949    for quantiles in test_quantiles:
4950        for t in _generate_reduction_inputs(device, dtype, requires_grad):
4951            # Add case without dim and keepdim kwargs
4952            input = t.clone().requires_grad_(requires_grad)
4953            yield SampleInput(input, quantiles)
4954            for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims=False):
4955                # Interpolation kwarg for now is only supported when providing both dim and keepdim
4956                kwargs.setdefault('dim', 0)
4957                kwargs.setdefault('keepdim', False)
4958                for interpolation in test_interpolations:
4959                    kwargs['interpolation'] = interpolation
4960                    input = t.clone().requires_grad_(requires_grad)
4961                    yield SampleInput(input, quantiles, **kwargs)
4962
4963def sample_inputs_reduction_count_nonzero(*args, **kwargs):
4964    """Sample inputs for count_nonzero"""
4965    # count_nonzero does not support keepdim yet
4966    for sample in sample_inputs_reduction(*args, **kwargs):
4967        sample.kwargs.pop('keepdim', None)
4968        yield sample
4969
4970def sample_inputs_leaky_relu(op_info, device, dtype, requires_grad, **kwargs):
4971    N = 10
4972    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4973    return (SampleInput(make_arg((N, N))) for _ in range(1, N))
4974
4975def sample_inputs_fractional_max_pool2d(op_info, device, dtype, requires_grad, **kwargs):
4976    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
4977
4978    # Order: input_shape, kernel_size
4979    cases = (((1, 3, 9, 9), 3),
4980             ((1, 3, 9, 9), (4, 4)),
4981             ((1, 3, 9, 9), (6, 6)),
4982             ((2, 3, 9, 9), (3, 3)),
4983             ((1, 1, 4, 4), (2, 2)),
4984             ((1, 2, 6, 6), (4, 4)))
4985
4986    for input_shape, kernel_size in cases:
4987        for return_indices in [False, True]:
4988            # test case passing a single output size
4989            yield SampleInput(
4990                make_arg(input_shape),
4991                kernel_size,
4992                output_size=2,
4993                return_indices=return_indices,
4994            )
4995
4996            # test case passing a tuple output size
4997            yield SampleInput(
4998                make_arg(input_shape),
4999                kernel_size,
5000                output_size=(2, 3),
5001                return_indices=return_indices,
5002            )
5003
5004            # test case passing an output ratio
5005            yield SampleInput(
5006                make_arg(input_shape),
5007                kernel_size,
5008                output_ratio=(0.5, 0.5),
5009                return_indices=return_indices,
5010            )
5011
5012def sample_inputs_fractional_max_pool3d(op_info, device, dtype, requires_grad, **kwargs):
5013    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5014
5015    # Order: input_shape, kernel_size
5016    cases = (((2, 3, 5, 5, 5), (2, 2, 2)),
5017             ((1, 2, 6, 5, 4), 2),
5018             ((1, 2, 5, 6, 5), (2, 3, 2)),
5019             ((1, 2, 6, 6, 6), (2, 3, 2)),
5020             ((1, 1, 7, 6, 7), (2, 3, 4)),
5021             ((1, 1, 4, 5, 4), (2, 2, 1)),
5022             ((1, 1, 8, 7, 6), (4, 3, 2)),
5023             ((0, 1, 4, 5, 4), (2, 2, 1)))
5024
5025    for input_shape, kernel_size in cases:
5026        for return_indices in [False, True]:
5027            # test case passing a single output size
5028            yield SampleInput(
5029                make_arg(input_shape),
5030                kernel_size,
5031                output_size=2,
5032                return_indices=return_indices,
5033            )
5034
5035            # test case passing a tuple output size
5036            yield SampleInput(
5037                make_arg(input_shape),
5038                kernel_size,
5039                output_size=(2, 3, 2),
5040                return_indices=return_indices,
5041            )
5042
5043            # test case passing an output ratio
5044            yield SampleInput(
5045                make_arg(input_shape),
5046                kernel_size,
5047                output_ratio=(0.5, 0.5, 0.5),
5048                return_indices=return_indices,
5049            )
5050
5051def sample_inputs_avgpool2d(op_info, device, dtype, requires_grad, **kwargs):
5052    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5053
5054    # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
5055    cases = (((1, 3, 9, 9), 3, 1, 1, True, False, 2),
5056             ((1, 3, 9, 9), (4, 4), (2, 3), 1, True, False, 2),
5057             ((1, 3, 9, 9), (6, 6), (3, 3), (2, 3), True, True, 2),
5058             ((2, 3, 9, 9), (3, 3), (1, 1), (1, ), True, False, 2),
5059             ((1, 1, 4, 4), (2, 2), (), (0, ), False, True, -2),
5060             ((1, 2, 6, 6), (4, 4), (2, 2), (2, ), True, True, None))
5061
5062    for input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override in cases:
5063        yield SampleInput(make_arg(input_shape),
5064                          args=(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override))
5065    # Case with just input_shape and kernel_size
5066    yield SampleInput(make_arg((1, 3, 9, 9)), args=((3, 3)))
5067
5068def sample_inputs_avgpool1d(op_info, device, dtype, requires_grad, **kwargs):
5069    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5070
5071    # Order: input_shape, kernel_size, kwargs
5072    cases: List[Tuple[Tuple[int, ...], Union[int, Tuple[int, ...]], Dict]] = [
5073        ((2, 3, 9), (3,), {}),
5074        ((1, 3, 9), 3, dict(stride=1, padding=1, ceil_mode=True, count_include_pad=False)),
5075        ((1, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=True, count_include_pad=True)),
5076        ((2, 3, 9), (3,), dict(stride=(1,), padding=(1,), ceil_mode=False, count_include_pad=True)),
5077        ((0, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=False, count_include_pad=True)),
5078        ((1, 2, 9), (7,), dict(stride=(3,), padding=(2,), ceil_mode=False)),
5079        ((1, 2, 9), (7,), dict(stride=(3,), padding=(3,), ceil_mode=True)),
5080        ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=False)),
5081        ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=True)),
5082    ]
5083
5084    for input_shape, kernel_size, kwargs in cases:
5085        yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs)
5086
5087def sample_inputs_avgpool3d(op_info, device, dtype, requires_grad, **kwargs):
5088    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5089
5090    # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
5091    cases: List[Tuple[Tuple[int, ...], Union[int, Tuple[int, ...]], Dict]] = [
5092        ((2, 3, 3, 4, 4), (2, 2, 2), {}),
5093        ((1, 2, 4, 4, 4), 2, dict(stride=1, padding=1, ceil_mode=True,
5094                                  count_include_pad=False, divisor_override=2)),
5095        ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=True,
5096                                          count_include_pad=True, divisor_override=2)),
5097        ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=False)),
5098        ((1, 1, 7, 5, 7), (6, 3, 4), dict(stride=(2, 3, 2), padding=(3, 1, 0), ceil_mode=False,
5099                                          count_include_pad=False, divisor_override=2)),
5100        ((1, 1, 4, 5, 4), (2, 2, 3), dict(stride=(2, 2, 1), padding=0, ceil_mode=False,
5101                                          count_include_pad=True, divisor_override=-2)),
5102        ((1, 1, 6, 5, 6), (4, 5, 6), dict(stride=(2, 3, 2), padding=2, ceil_mode=True,
5103                                          count_include_pad=True, divisor_override=None)),
5104        ((0, 1, 4, 5, 4), (2, 3, 1), dict(stride=(2, 1, 2), padding=0, ceil_mode=False,
5105                                          count_include_pad=True, divisor_override=None)),
5106    ]
5107
5108    for input_shape, kernel_size, kwargs in cases:
5109        yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs)
5110
5111def error_inputs_avg_pool1d(op_info, device, **kwargs):
5112    # error inputs when pad is negative
5113    x = torch.rand([0, 1, 49], dtype=torch.float32)
5114    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}),
5115                     error_regex='pad must be non-negative')
5116
5117    # error inputs when pad > kernel_size / 2
5118    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}),
5119                     error_regex='pad should be at most half of effective kernel size')
5120
5121def error_inputs_avg_pool2d(op_info, device, **kwargs):
5122    # error inputs when pad is negative
5123    x = torch.rand([0, 1, 49], dtype=torch.float32)
5124    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}),
5125                     error_regex='pad must be non-negative')
5126    # 2-dimensional kernel
5127    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1}),
5128                     error_regex='pad must be non-negative')
5129
5130    # error inputs when pad > kernel_size / 2
5131    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}),
5132                     error_regex='pad should be at most half of effective kernel size')
5133    # 2-dimensional kernel
5134    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4}),
5135                     error_regex='pad should be at most half of effective kernel size')
5136
5137    # error inputs for zero divisor
5138    x = torch.zeros(3, 3, 3)
5139    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2), 'divisor_override': 0}),
5140                     error_regex='divisor must be not zero')
5141
5142def error_inputs_avg_pool3d(op_info, device, **kwargs):
5143    # error inputs when pad is negative
5144    x = torch.rand([0, 1, 49, 50], dtype=torch.float32)
5145    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}),
5146                     error_regex='pad must be non-negative')
5147    # 3-dimensional kernel
5148    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': -1}),
5149                     error_regex='pad must be non-negative')
5150
5151    # error inputs when pad > kernel_size / 2
5152    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}),
5153                     error_regex='pad should be at most half of effective kernel size')
5154    # 3-dimensional kernel
5155    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': 4}),
5156                     error_regex='pad should be at most half of effective kernel size')
5157
5158    # error inputs for zero divisor
5159    x = torch.zeros(3, 3, 3, 3)
5160    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2, 2), 'divisor_override': 0}),
5161                     error_regex='divisor must be not zero')
5162
5163    # error inputs for invalid input dimension
5164    x = torch.rand([0, 1, 49], dtype=torch.float32)
5165    yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 0}),
5166                     error_regex='non-empty 4D or 5D')
5167
5168
5169def sample_inputs_to(op_info, device, dtype, requires_grad, **kwargs):
5170    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5171    # test_multiple_devices_to_cuda would fail if we use a different device than given
5172    devices = [device]
5173    if torch.device(device).type == 'cpu':
5174        devices = [torch.device('cpu'), torch.device('cuda:0')] if torch.cuda.is_available() else devices
5175    memory_formats = [torch.preserve_format, torch.channels_last]
5176
5177    # TODO: can't switch `to.device` overload to use positional arguments
5178    # https://github.com/pytorch/pytorch/issues/84265
5179    # to.device overload
5180    for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats):
5181        kwargs = {
5182            "memory_format": mem_f,
5183        }
5184        yield SampleInput(make_arg((S, S, S, S)), args=(device, torch.float64, nb, cp), kwargs=kwargs)
5185
5186    # to.dtype overload
5187    for nb, cp, mem_f in product([True, False], [True, False], memory_formats):
5188        kwargs = {
5189            "memory_format": mem_f,
5190        }
5191        yield SampleInput(make_arg((S, S, S, S)), args=(torch.float64, nb, cp), kwargs=kwargs)
5192
5193    # to.other overload
5194    for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats):
5195        kwargs = {
5196            "memory_format": mem_f,
5197        }
5198        other = make_arg((S, S, S, S), dtype=torch.float64, device=device)
5199        yield SampleInput(make_arg((S, S, S, S)), args=(other, nb, cp), kwargs=kwargs)
5200
5201
5202def sample_inputs_topk(op_info, device, dtype, requires_grad, **kwargs):
5203    def get_tensor_input(size):
5204        return make_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad)
5205
5206    yield SampleInput(get_tensor_input((S, M, S)), 3)
5207    yield SampleInput(get_tensor_input((S, M, S)), 3, 1)
5208    yield SampleInput(get_tensor_input((S, M, S)), 3, -2)
5209    yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True)
5210    yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True)
5211    yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True, True)
5212    yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True, True)
5213
5214    yield SampleInput(get_tensor_input(()), 1)
5215    yield SampleInput(get_tensor_input(()), 1, 0)
5216    yield SampleInput(get_tensor_input(()), 1, -1)
5217    yield SampleInput(get_tensor_input(()), 1, 0, True)
5218    yield SampleInput(get_tensor_input(()), 1, -1, True)
5219    yield SampleInput(get_tensor_input(()), 1, 0, True, True)
5220    yield SampleInput(get_tensor_input(()), 1, -1, True, True)
5221
5222def sample_inputs_outer(op_info, device, dtype, requires_grad, **kwargs):
5223    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5224    yield SampleInput(make_arg(S), make_arg(M))
5225
5226def sample_inputs_dist(op_info, device, dtype, requires_grad, **kwargs):
5227    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5228    sizes = ((S, S, S), (S,), (S, 1, S), (), (S, S))
5229    ps = (2, 4)
5230
5231    for size_x, size_y, p in product(sizes, sizes, ps):
5232        yield SampleInput(make_arg(size_x), args=(make_arg(size_y), p))
5233
5234# Missing to test the nondeterminism of the operation
5235# https://github.com/pytorch/pytorch/issues/53352
5236def sample_inputs_index(op_info, device, dtype, requires_grad, reference=False, **kwargs):
5237    # target.index_select(dim, idx)
5238    select = "index_select" in op_info.name
5239    # target.index_add(dim, idx, source, *, alpha=1)
5240    add = "index_add" in op_info.name
5241    # target.index_copy(dim, idx, source)
5242    copy = "index_copy" in op_info.name
5243    # target.index_fill(dim, idx, value)
5244    fill = "index_fill" in op_info.name
5245
5246    # Extended reference inputs. We generate that exercise atomic adds / writing
5247    # several times to one location
5248    if reference:
5249        make_arg = partial(torch.ones, device=device, dtype=dtype, requires_grad=requires_grad)
5250        make_idx = partial(torch.zeros, device=device, dtype=torch.int64)
5251    else:
5252        make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5253        # idx They need to be different for copy and add to be deterministic
5254        if copy or add:
5255            make_idx = partial(torch.randperm, device=device, dtype=torch.int64)
5256        else:
5257            def make_idx(n):
5258                return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=n)
5259
5260    shapes = [(), (1,), (S, S)]
5261    # extra parameter for add
5262    if add:
5263        if dtype == torch.bool:
5264            alphas = (True, False)
5265        else:
5266            alphas = (-1, 0, 2)
5267    else:
5268        alphas = (None,)
5269
5270    if fill:
5271        # A weird number to catch errors.
5272        # The former one tests `index_fill.int_Scalar`, and the latter one tests `index_fill.int_Tensor`.
5273        values = (make_arg((1,)).item(), make_arg(()))
5274    else:
5275        values = (None,)
5276
5277    for shape, alpha, value in product(shapes, alphas, values):
5278        t = make_arg(shape)
5279        args = []
5280
5281        # dim. We handle the scalar case
5282        dim = -1 if t.ndim == 2 else 0
5283        args.append(dim)
5284
5285        idx = make_idx(t.shape[dim] if t.ndim != 0 else 1)
5286        args.append(idx)
5287
5288        # source
5289        if copy or add:
5290            args.append(make_arg(shape))
5291        elif fill:
5292            args.append(value)
5293
5294        args = tuple(args)
5295        kwargs = {} if alpha is None else {"alpha": alpha}
5296
5297        yield SampleInput(t, args=args, kwargs=kwargs)
5298
5299def sample_inputs_index_reduce(op_info, device, dtype, requires_grad, **kwargs):
5300    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5301
5302    def make_idx(n, m):
5303        return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m)
5304
5305    shapes = [((), ()), ((1,), (1,)), ((S, S), (S, M)), ((S, S, S), (S, M, S))]
5306    include_selfs = (True, False)
5307    reduce = op_info.variant_test_name
5308    assert reduce in ('prod', 'mean', 'amin', 'amax')
5309
5310    for shape, include_self in product(shapes, include_selfs):
5311        self_shape, src_shape = shape
5312        # dim. We handle the scalar case
5313        dim = 1 if len(self_shape) >= 2 else 0
5314        idx = make_idx(src_shape[dim] if len(src_shape) != 0 else 1,
5315                       self_shape[dim] if len(self_shape) != 0 else 1)
5316        args = (dim, idx, make_arg(src_shape), reduce)
5317        yield SampleInput(make_arg(self_shape),
5318                          args=args,
5319                          kwargs={'include_self' : include_self})
5320
5321    # Sample inputs to test edge cases for backward
5322    if requires_grad and reduce == 'prod':
5323        # Check that gradients are propagated correctly for prod when zeros in self/src are reduced
5324        # This sample tests gradients for the following cases
5325        # (a) 1 zero reduced (from source (self[0, 1]), from self (self[0, 0]))
5326        # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0], self[1, 1])
5327        # (c) no zeros reduced (self[2, 1], self[2, 2])
5328        # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py
5329        #     test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad
5330        input = torch.tensor([[0, 13], [0, 0], [15, 19]], dtype=dtype, device=device, requires_grad=requires_grad)
5331        src = torch.tensor([[2, 0], [0, 0], [2, 3], [2, 2]], dtype=dtype, device=device, requires_grad=requires_grad)
5332        idx = torch.tensor([0, 1, 2, 0], dtype=torch.long, device=device)
5333
5334        yield SampleInput(input,
5335                          args=(0, idx, src, reduce),
5336                          kwargs={'include_self': True})
5337
5338def sample_inputs__unsafe_masked_index(op_info, device, dtype, requires_grad, **kwargs):
5339    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5340
5341    def make_idx(n, m, dim, d):
5342        view_shape = [1] * dim
5343        view_shape[d] = n
5344        return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape)
5345
5346    cases = [
5347        ((S, S), S, M),
5348        ((S, S), M, S),
5349        ((S, S, S), S, M),
5350    ]
5351
5352    fill_value = make_tensor([], dtype=dtype, device="cpu").item()
5353
5354    for c in cases:
5355        self_shape, high, idx_size = c
5356        dim = len(self_shape)
5357        indices = [make_idx(idx_size, high, dim, d) for d in range(dim)]
5358        masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None]
5359        mask = functools.reduce(torch.logical_and, masks)
5360        yield SampleInput(make_arg(self_shape), mask, indices, fill_value)
5361
5362        masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None]
5363        mask = functools.reduce(torch.logical_and, masks)
5364        yield SampleInput(make_arg(self_shape), mask, indices, fill_value)
5365
5366def sample_inputs__unsafe_masked_index_put_accumulate(op_info, device, dtype, requires_grad, **kwargs):
5367    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5368
5369    def make_idx(n, m, dim, d):
5370        view_shape = [1] * dim
5371        view_shape[d] = n
5372        return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape)
5373
5374    cases = [
5375        ((S, S), S, (M, M)),
5376        ((S, S), M, (S, S + 1)),
5377        ((S, S, S), S, (M, M - 1, M + 1)),
5378    ]
5379
5380    fill_value = make_tensor([], dtype=dtype, device="cpu").item()
5381
5382    for c in cases:
5383        self_shape, high, idx_sizes = c
5384        dim = len(self_shape)
5385        indices = [make_idx(idx_sizes[d], high, dim, d) for d in range(dim)]
5386        masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None]
5387        mask = functools.reduce(torch.logical_and, masks)
5388        values = make_arg(idx_sizes)
5389        yield SampleInput(make_arg(self_shape), mask, indices, values)
5390
5391        masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None]
5392        mask = functools.reduce(torch.logical_and, masks)
5393        yield SampleInput(make_arg(self_shape), mask, indices, values)
5394
5395
5396def sample_inputs_mode(op_info, device, dtype, requires_grad, **kwargs):
5397    args = (
5398        ((S, S, S), (),),
5399        ((S, S, S), (1, ),),
5400        ((S, S, S), (1, True, ),),
5401        ((), (),),
5402        ((), (0,),),
5403        ((), (0, True,),),
5404        # Non-fused mode kernel on CUDA
5405        ((3000,), ()),
5406    )
5407    make_arg = partial(make_tensor, dtype=dtype, device=device,
5408                       requires_grad=requires_grad, low=None, high=None)
5409    return (SampleInput(make_arg(input_tensor), *args)
5410            for input_tensor, args in args)
5411
5412# Missing to test the nondeterminism of the operation
5413# https://github.com/pytorch/pytorch/issues/53352
5414def sample_inputs_put(op_info, device, dtype, requires_grad, **kwargs):
5415    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
5416    make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False)
5417
5418    S = 3
5419
5420    # Generic inputs
5421    idx = torch.randperm(S * S, device=device, dtype=torch.int64)[:S]
5422    idx_list = [idx, -idx - 1]
5423    for idx, acc in product(idx_list, (True, False)):
5424        yield SampleInput(input=make_arg((S, S)),
5425                          args=(idx.clone(),
5426                                make_arg((S,)),
5427                                acc))
5428
5429    # Scalar cases
5430    scalar_sizes = [(), (1,)]
5431    tgt_gen = (make_arg(size) for size in scalar_sizes)
5432    idx_gen = (make_idx(size, high=1) for size in scalar_sizes)
5433    src_gen = (make_arg(size) for size in scalar_sizes)
5434    for tgt, idx, src, acc in product(tgt_gen, idx_gen, src_gen, (True, False)):
5435        yield SampleInput(input=tgt.clone().requires_grad_(requires_grad),
5436                          args=(idx.clone(),
5437                                src.clone().requires_grad_(requires_grad),
5438                                acc))
5439
5440    # Empty cases
5441    tgt_sizes = [(0,), (), (1,), (3, 2)]
5442    tgt_gen = (make_arg(size) for size in tgt_sizes)
5443    idx = make_idx((0,), high=1)
5444    src = make_arg((0,))
5445    for tgt, acc in product(tgt_gen, (True, False)):
5446        yield SampleInput(input=tgt.clone().requires_grad_(requires_grad),
5447                          args=(idx.clone(),
5448                                src.clone().requires_grad_(requires_grad),
5449                                acc))
5450
5451def sample_inputs_take(op_info, device, dtype, requires_grad, **kwargs):
5452    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
5453    make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False)
5454
5455    S = 3
5456
5457    # Generic inputs: take S elements out of S * S
5458    index = make_idx((S,), high=(S * S))
5459    for idx in (index, -index - 1):
5460        yield SampleInput(input=make_arg((S, S)), args=(idx,))
5461
5462    # Scalar cases
5463    scalar_sizes = [(), (1,)]
5464    src_gen = (make_arg(size) for size in scalar_sizes)
5465    idx_gen = (make_idx(size, high=1) for size in scalar_sizes)
5466    for src, idx in product(src_gen, idx_gen):
5467        yield SampleInput(input=src.clone().requires_grad_(requires_grad),
5468                          args=(idx.clone(),))
5469
5470    # Empty cases
5471    src_sizes = [(0,), (), (1,), (3, 2)]
5472    src_gen = (make_arg(size) for size in src_sizes)
5473
5474    idx = make_idx((0,), high=1)
5475    for src in src_gen:
5476        yield SampleInput(input=src.clone().requires_grad_(requires_grad),
5477                          args=(idx.clone(),))
5478
5479def sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs):
5480    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
5481    yield SampleInput(make_arg((4, 3, 2, 1)), [0, 1, 2, 3], [3, 2, 1, 0])
5482    yield SampleInput(make_arg((4, 3, 2, 1)), [0, -1, -2, -3], [-3, -2, -1, -0])
5483
5484def reference_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs):
5485    yield from sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs)
5486
5487    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5488
5489    # shape, source, destination
5490    args = (
5491        # empty inputs
5492        ((), (), ()),
5493        # int inputs, negative
5494        ((3, 5, 7, 2), -2, 1),
5495        # swap bounds
5496        ((3, 5, 7, 2), (-1, 0), (0, -1)),
5497        # non-sequential, negative
5498        ((2, 3, 4, 5, 6), (3, -3, 4), (1, 0, -1)),
5499        # idempotence, negative
5500        ((2, 3, 4, 5, 6), (-3, 4, 3, 1), (-3, 4, 3, 1)),
5501        # reverse, sequential, positive
5502        ((6, 2, 3, 5, 4), (4, 3, 2, 1, 0), (0, 1, 2, 3, 4)),
5503        # reverse, non-sequential
5504        ((6, 2, 3, 5, 4), (-3, -2, -4, -5, -1), (2, 1, 3, 4, 0)),
5505        # reverse, sequential, negative
5506        ((6, 2, 3, 5, 4), (4, -2, 2, -4, -5), (-5, 1, 2, -2, -1)),
5507    )
5508
5509    for shape, source, destination in args:
5510        yield SampleInput(make_arg(shape), args=(source, destination))
5511
5512def error_movedim_moveaxis(op_info, device, **kwargs):
5513    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
5514
5515    # source length < destination length
5516    yield ErrorInput(
5517        SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3), (1, 0, -1))),
5518        error_regex=(r"movedim: Invalid source or destination dims: source "
5519                     r"\(\[3, -3\] dims\) should contain the same number of "
5520                     r"dims as destination \(\[1, 0, -1\] dims\)"),
5521    )
5522
5523    # source length > destination length
5524    yield ErrorInput(
5525        SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3, 4), (1, 0))),
5526        error_regex=(r"movedim: Invalid source or destination dims: source "
5527                     r"\(\[3, -3, 4\] dims\) should contain the same number of "
5528                     r"dims as destination \(\[1, 0\] dims\)"),
5529    )
5530
5531    # repeated source dim, with negative indices
5532    yield ErrorInput(
5533        SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 4, -5), (1, 0, 2))),
5534        error_regex=r"movedim: repeated dim in `source` \(\[0, 4, -5\]\)",
5535    )
5536
5537    # repeated destination dim, with negative indices
5538    yield ErrorInput(
5539        SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, 2), (0, 4, -5))),
5540        error_regex=r"movedim: repeated dim in `destination` \(\[0, 4, -5\]\)",
5541    )
5542
5543    # repeated dim (both), with negative indices
5544    yield ErrorInput(
5545        SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, -4), (0, 4, -5))),
5546        error_regex=r"movedim: repeated dim in `source` \(\[1, 0, -4\]\)",
5547    )
5548
5549    # out of bounds source inputs, with negative indices
5550    yield ErrorInput(
5551        SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 1, -6), (1, 4, 2))),
5552        error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
5553        error_type=IndexError,
5554    )
5555
5556    # out of bounds destination inputs, with negative indices
5557    yield ErrorInput(
5558        SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 4, 2), (0, 1, -6))),
5559        error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
5560        error_type=IndexError,
5561    )
5562
5563    # out of bounds source input, int
5564    yield ErrorInput(
5565        SampleInput(make_arg(2, 3, 4, 5, 6), args=(-6, 1)),
5566        error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
5567        error_type=IndexError,
5568    )
5569
5570    # out of bounds destination input, int
5571    yield ErrorInput(
5572        SampleInput(make_arg(2, 3, 4, 5, 6), args=(3, -6)),
5573        error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)",
5574        error_type=IndexError,
5575    )
5576
5577def sample_repeat_tile(op_info, device, dtype, requires_grad, **kwargs):
5578    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
5579    rep_dims = ((), (0, ), (1, ), (0, 2), (1, 1), (2, 3), (2, 3, 2), (0, 2, 3), (2, 1, 1, 1),)
5580    shapes = ((), (0,), (2,), (3, 0), (3, 2), (3, 0, 1))
5581
5582    if requires_grad:
5583        # Tests for variant_consistency_jit, grad, gradgrad
5584        # are slower. Use smaller bags of `rep_dims` and `shapes`
5585        # in this case.
5586        rep_dims = ((), (0, ), (0, 2), (1, 1), (2, 3), (1, 3, 2), (3, 1, 1))  # type: ignore[assignment]
5587        shapes = ((), (0,), (2,), (3, 2))  # type: ignore[assignment]
5588
5589    is_repeat_op = op_info.name in ['repeat', '_refs.repeat']
5590    for rep_dim, shape in product(rep_dims, shapes):
5591        # `torch.repeat` errors for `len(rep_dims) < t.dim()`,
5592        # so we filter such combinations.
5593        if is_repeat_op and len(rep_dim) < len(shape):
5594            continue
5595        yield SampleInput(make_arg(shape), rep_dim)
5596
5597
5598def sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs):
5599    shapes_and_args = (
5600        ((S, S, S), 1, 2, 2),
5601        ((S, S, S), -1, 2, 2),
5602        ((S, S, S), 1, 0, 0),
5603        ((S, S, S), -1, 0, 0),
5604        ((S, S, S), 2, 1, 2),
5605    )
5606
5607    for shape, dim, start, length in shapes_and_args:
5608        tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
5609                             requires_grad=requires_grad)
5610        yield SampleInput(tensor, dim, start, length)
5611        # narrow also accepts the start argument being a Tensor
5612        if is_narrow:
5613            yield SampleInput(tensor, dim, torch.tensor(start), length)
5614
5615def reference_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs):
5616    yield from sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, is_narrow=is_narrow, **kwargs)
5617
5618    shapes_and_args = (
5619        # 1-dim
5620        ((M,), 0, 0, 0),    # 0 elems from the left
5621        ((M,), -1, -1, 0),  # 0 elems from the right
5622        ((M,), 0, 5, 3),    # 3 elems from the left
5623        ((M,), 0, -5, 2),   # 2 elems from the right
5624        ((M,), -1, 0, M),   # M elems from the left
5625        ((M,), 0, -M, M),   # M elems from the right
5626
5627        # 2-dim
5628        ((M, S), 1, 0, 0),    # dim 1, 0 elems from the left
5629        ((S, M), -2, -1, 0),  # dim 0, 0 elems from the right
5630        ((L, S), 1, 2, 3),    # dim 1, 3 elems from the left
5631        ((L, S), -1, 3, 2),   # dim 1, 2 elems from the left
5632        ((M, L), 0, 0, M),    # dim 0, M elems from the left
5633        ((M, L), -1, -L, L),  # dim 1, L elems from the right
5634
5635        # 3-dim
5636        ((L, M, S), 2, 0, 0),    # dim 2, 0 elems from the left
5637        ((M, S, L), -1, -1, 0),  # dim 2, 0 elems from the right
5638        ((S, L, M), 2, 0, M),    # dim 2, M elems from the left
5639        ((L, S, M), -1, -M, M),  # dim 2, M elems from the right
5640        ((S, L, M), 1, 0, 0),    # dim 1, 0 elems from the left
5641        ((S, L, M), 0, 2, 1),    # dim 0, 1 elem from the left
5642        ((M, S, M), -1, -5, 4),  # dim 2, 4 elems from the right
5643    )
5644
5645    for shape, dim, start, length in shapes_and_args:
5646        tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
5647                             requires_grad=requires_grad)
5648        yield SampleInput(tensor, dim, start, length)
5649        # narrow also accepts the start argument being a Tensor
5650        if is_narrow:
5651            yield SampleInput(tensor, dim, torch.tensor(start), length)
5652
5653def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref):
5654    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
5655
5656    # 0-dim
5657    yield ErrorInput(SampleInput(make_arg(()), 0, 0, 1),
5658                     error_type=RuntimeError,
5659                     error_regex=r"narrow\(\) cannot be applied to a 0-dim tensor\.")
5660
5661    # out of bounds dim
5662    if not is_narrow and not is_ref and torch.device(device).type == 'cpu':
5663        # narrow_copy_dense_cpu_out
5664        yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0),
5665                         error_type=RuntimeError,
5666                         error_regex=r"Expected dim < static_cast<int64_t>\(self_sizes.size\(\)\) to be true, but got false\.")
5667    else:
5668        yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0),
5669                         error_type=IndexError,
5670                         error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got 3\)")
5671    # out of bounds dim (negative)
5672    yield ErrorInput(SampleInput(make_arg((L, S, M)), -4, 0, 0),
5673                     error_type=IndexError,
5674                     error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)")
5675
5676    # out of bounds start
5677    yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0),
5678                     error_type=IndexError,
5679                     error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got 11\)")
5680    # out of bounds start (negative)
5681    yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0),
5682                     error_type=IndexError,
5683                     error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got -11\)")
5684
5685    # out of bounds length
5686    yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1),
5687                     error_type=RuntimeError,
5688                     error_regex=r"start \(0\) \+ length \(11\) exceeds dimension size \(10\)\.")
5689    # out of bounds length (negative)
5690    if not is_narrow and not is_ref and torch.device(device).type == 'cpu':
5691        # narrow_copy_dense_cpu_out
5692        yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1),
5693                         error_type=RuntimeError,
5694                         error_regex=r"start \(0\) \+ length \(-1\) exceeds dimension size \(10\)\.")
5695    else:
5696        yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1),
5697                         error_type=RuntimeError,
5698                         error_regex=r"narrow\(\): length must be non-negative\.")
5699
5700    # Test Tensor overload that was added for XLA. Start must be an 0-dim
5701    # integral Tensor. narrow_copy doesn't have this overload.
5702    # https://github.com/pytorch/pytorch/issues/31558
5703    if is_narrow:
5704        # *1-dim* integral Tensor
5705        yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, make_arg(S, dtype=torch.int), 2),
5706                         error_type=RuntimeError,
5707                         error_regex=r"start must be an 0-dim integral Tensor\.")
5708
5709        # 0-dim *bool* Tensor (bools are not allowed)
5710        yield ErrorInput(SampleInput(make_arg((L, M, S)), -3, make_arg((), dtype=torch.bool), 3),
5711                         error_type=RuntimeError,
5712                         error_regex=r"start must be an 0-dim integral Tensor\.")
5713
5714
5715def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs):
5716    y_shape_x_shape_and_kwargs = [
5717        ((2, 3), (2, 3), {}),
5718        ((2, 3), (2, 3), {'dim': 1}),
5719        ((6,), (6,), {}),
5720        ((6,), None, {}),
5721        # When 'trapezoid' is called with an empty input, it does not produce an output with requires_grad
5722        # See Issue #{61619}
5723        # ((6,0), (6,0), {}),
5724        ((2, 3), (1, 3), {}),
5725        ((3, 3), (3, 3), {}),
5726        ((3, 3), (3, 3), {'dim': -2}),
5727        ((5,), None, {'dx': 2.0}),
5728        ((2, 2), None, {'dx': 3.0})
5729    ]
5730    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None,
5731                       requires_grad=requires_grad)
5732    for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs:
5733        y_tensor = make_arg(y_shape)
5734        if x_shape is not None:
5735            x_tensor = make_arg(x_shape)
5736            yield SampleInput(y_tensor, x_tensor, **kwarg)
5737        else:
5738            yield SampleInput(y_tensor, **kwarg)
5739
5740def sample_cumulative_trapezoid(op_info, device, dtype, requires_grad, **kwargs):
5741
5742    y_shape_x_shape_and_kwargs = [
5743        ((2, 3), (2, 3), {}),
5744        ((2, 3), (2, 3), {'dim': 1}),
5745        ((6,), (6,), {}),
5746        ((6,), None, {}),
5747        # When 'cumulative_trapezoid' is called with an empty input, it does not produce an output with requires_grad
5748        # See Issue #{61619}
5749        # ((6,0), (6,0), {}),
5750        ((2, 3), (1, 3), {}),
5751        ((3, 3), (3, 3), {}),
5752        ((3, 3), (3, 3), {'dim': -2}),
5753        ((5,), None, {'dx': 2.0}),
5754        ((2, 2), None, {'dx': 3.0})
5755    ]
5756    make_arg = partial(make_tensor, device=device, dtype=dtype,
5757                       requires_grad=requires_grad, low=None, high=None)
5758    for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs:
5759        y_tensor = make_arg(y_shape)
5760        if x_shape is not None:
5761            x_tensor = make_arg(x_shape)
5762            yield SampleInput(y_tensor, x_tensor, **kwarg)
5763        else:
5764            yield SampleInput(y_tensor, **kwarg)
5765
5766def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
5767    shapes_and_axes = [
5768        ((3, 4, 5), 0),
5769        ((3, 4, 5), 1),
5770        ((3, 4, 5), 3),
5771        ((3, 4, 5), -1),
5772        ((3, 4, 5), -3),
5773        ((), 0),
5774        ((), -1),
5775        ((1,), 0),
5776        ((1,), -1),
5777    ]
5778
5779    for shape, axis in shapes_and_axes:
5780        tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
5781                             requires_grad=requires_grad)
5782        yield SampleInput(tensor, axis)
5783
5784
5785def sample_inputs_nn_unfold(op_info, device, dtype, requires_grad, **kwargs):
5786    shapes = ((0, 1, 5, 5), (2, 3, 5, 5))
5787    kernel_sizes = (2, (2, 2), (2, 3))
5788    dilations = (1, 2, (1, 2))
5789    paddings = (0, 1, (1, 2))
5790    strides = (1, 2, (1, 2))
5791
5792    cases = product(shapes, kernel_sizes, dilations, paddings, strides)
5793    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5794    for shape, kernel_size, dilation, padding, stride in cases:
5795        tensor = make_arg(shape)
5796        yield SampleInput(tensor, kernel_size, dilation, padding, stride)
5797
5798    # With default args
5799    yield SampleInput(make_arg((1, 1, 5, 5)), (3, 3))
5800
5801
5802def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs):
5803    shapes_and_args = (
5804        ((S, 1, S, 1), ()),
5805        ((1, 1, 1, 1), ()),
5806        ((1, 1, 1, 1), (0,)),
5807        ((S, 1, S, 1), (1,)),
5808        ((S, 1, S, 1), (-1,)),
5809        ((S, 1, S, 1), (2,)),
5810        ((S, 1, S, 1), (-2,)),
5811        ((), (0, )),
5812    )
5813
5814    for shape, args in shapes_and_args:
5815        tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
5816                             requires_grad=requires_grad)
5817
5818        yield SampleInput(tensor, args=args)
5819
5820
5821def sample_inputs_squeeze_multiple(op_info, device, dtype, requires_grad, **kwargs):
5822    shapes_and_args = (
5823        ((1, 1, 1, 1), ()),
5824        ((S, 1, S, 1), (1,)),
5825        ((S, 1, S, 1), (-1,)),
5826        ((S, 1, S, 1), (1, 3)),
5827        ((S, 1, S, 1), (1, 2,)),
5828        ((), (0,)),
5829    )
5830
5831    for shape, dims in shapes_and_args:
5832        tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None,
5833                             requires_grad=requires_grad)
5834
5835        yield SampleInput(tensor, dims)
5836
5837
5838def _squeeze_ref(x, axis=None):
5839    # NumPy doesn't allow squeezing scalars
5840    if x.ndim == 0:
5841        return x
5842
5843    if isinstance(axis, Sequence):
5844        # Numpy doesn't allow specifying non-singular dimensions
5845        axis = tuple(a for a in axis if x.shape[a] == 1)
5846
5847    if isinstance(axis, int) and x.shape[axis] != 1:
5848        return x
5849
5850    return np.squeeze(x, axis)
5851
5852def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs):
5853    assert mode in ('constant', 'reflect', 'replicate', 'circular')
5854    if mode in ['reflect', 'replicate']:
5855        cases: tuple = (  # ignore
5856            ((1, 3), (1, 2)),
5857            ((1, 3), (0, 1)),
5858            ((0, 3, 3), (1, 2)),
5859            ((0, 3, 3), (0, 1)),
5860            ((1, 3, 3), (1, 2)),
5861            ((1, 3, 3), (0, 1)),
5862            ((1, 3, 3), (0, 2, 0, 1)),
5863            ((0, 3, 3, 3), (0, 2, 0, 1)),
5864            ((3, 3, 5, 5), (0, 2, 0, 1)),
5865            ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)),
5866            ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
5867            ((1, 3, 4, 4), (-1, 1, -2, 1)),
5868        )
5869    elif mode == 'constant':
5870        cases = (
5871            ((1, 3), (1, 2)),
5872            ((1, 3), (0, 1)),
5873            ((1, 3), (0, 2, 0, 1)),
5874            ((0, 3, 3), (1, 2)),
5875            ((0, 3, 3), (0, 1)),
5876            ((0, 3, 3), (0, 2, 0, 1)),
5877            ((0, 3, 3), (1, 1, 1, 1, 1, 1)),
5878            ((1, 3, 3), (1, 2)),
5879            ((1, 3, 3), (0, 1)),
5880            ((1, 3, 3), (0, 2, 0, 1)),
5881            ((1, 3, 3), (1, 1, 1, 1, 1, 1)),
5882            ((0, 3, 3, 3), (1, 2)),
5883            ((0, 3, 3, 3), (0, 1)),
5884            ((0, 3, 3, 3), (0, 2, 0, 1)),
5885            ((0, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
5886            ((3, 3, 5, 5), (1, 2)),
5887            ((3, 3, 5, 5), (0, 1)),
5888            ((3, 3, 5, 5), (0, 2, 0, 1)),
5889            ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)),
5890            ((1, 3, 3, 3, 3), (1, 2)),
5891            ((1, 3, 3, 3, 3), (0, 1)),
5892            ((1, 3, 3, 3, 3), (0, 2, 0, 1)),
5893            ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
5894            ((1, 3, 4, 4), (-1, 1, -2, 1)),
5895        )
5896    else:  # mode == 'circular'
5897        if dtype == torch.bool:
5898            # test_dtypes fails on ASAN with for the case ab
5899            # runtime error: load of value 190, which is not a valid value for type 'bool'
5900            # Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562
5901            # Reference Issue: https://github.com/pytorch/pytorch/issues/63034
5902            cases = (
5903                ((2, 3, 3), (1, 2)),
5904                ((1, 3, 3), (1, 2)),
5905            )
5906        else:
5907            cases = (
5908                ((0, 3, 3), (1, 2)),
5909                ((0, 3, 3), (0, 1)),
5910                ((1, 3, 3), (1, 2)),
5911                ((1, 3, 3), (0, 1)),
5912                ((0, 3, 3, 3), (0, 2, 0, 1)),
5913                ((3, 3, 5, 5), (0, 2, 0, 1)),
5914                ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)),
5915                ((1, 3, 4, 4), (-1, 1, -2, 1)),
5916            )
5917
5918    make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5919
5920    if mode == 'constant':
5921        # Default args
5922        yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),))
5923
5924    if mode in ['reflect', 'replicate', 'circular']:
5925        for shape, pad in cases:
5926            yield SampleInput(make_inp(shape), args=(pad, mode))
5927    else:  # mode == 'constant'
5928        for pad_value in (1., 2.):
5929            for shape, pad in cases:
5930                yield SampleInput(make_inp(shape), args=(pad, mode, pad_value))
5931
5932def sample_inputs_nn_pad_replicate_negative(op_info, device, dtype, requires_grad, **kwargs):
5933    cases: tuple = (
5934        ((5, 3, 4, 4), (-4, 5, 0, 0)),
5935        ((6, 2, 4, 4), (0, 0, 2, -4)),
5936        ((5, 6, 4, 4), (5, -4, -4, 3)),
5937        ((4, 2, 5, 5), (-2, -1, 4, 6)),
5938        ((2, 6, 5, 5), (8, -1, -1, -3)),
5939        ((8, 1, 5, 5), (-2, -1, -1, -3)),
5940    )
5941    make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5942
5943    for shape, pad in cases:
5944        yield SampleInput(make_inp(shape), args=(pad, 'replicate'))
5945
5946def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs):
5947    # Inherit sample inputs from nn.pad, but transform them to fit
5948    # constant_pad_nd's interface
5949    nn_samples = sample_inputs_nn_pad(op_info, device, dtype, *args,
5950                                      mode='constant', **kwargs)
5951
5952    # NOTE: primTorch is more strict about the type of the fill value argument
5953    # So we must cast it to the correct dtype
5954    from torch._prims_common import dtype_to_type
5955    scalar_type = dtype_to_type(dtype)
5956
5957    def drop_mode_argument(input, pad, mode=None, value=None):
5958        if value is None:
5959            return SampleInput(input, args=(pad,))
5960        else:
5961            return SampleInput(input, args=(pad, scalar_type(value)))
5962
5963    for sample in nn_samples:
5964        yield drop_mode_argument(sample.input, *sample.args, **sample.kwargs)
5965
5966def sample_inputs_repeat_interleave(op_info, device, dtype, requires_grad, **kwargs):
5967    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
5968
5969    yield SampleInput(make_input(()), repeats=2)
5970    yield SampleInput(make_input((2, 3, 4)), repeats=2)
5971    yield SampleInput(make_input((2, 3, 4)), repeats=2, dim=1)
5972    yield SampleInput(make_input((2, 3, 4)), repeats=torch.arange(3, device=device), dim=1)
5973
5974
5975def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs):
5976    def mt(shape, **kwargs):
5977        return make_tensor(shape, device=device, dtype=dtype,
5978                           requires_grad=requires_grad, **kwargs)
5979
5980    yield SampleInput(mt(100), n_fft=10, return_complex=True)
5981    yield SampleInput(mt(100), n_fft=10, return_complex=False)
5982    if dtype.is_complex:
5983        yield SampleInput(mt(100), n_fft=10)
5984
5985    for center in [False, True]:
5986        yield SampleInput(mt(10), n_fft=7, center=center, return_complex=True)
5987        yield SampleInput(mt((10, 100)), n_fft=16, hop_length=4,
5988                          center=center, return_complex=True)
5989
5990    window = mt(16, low=.5, high=2.0)
5991    yield SampleInput(
5992        mt((2, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center))
5993    yield SampleInput(
5994        mt((3, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center))
5995    if not dtype.is_complex:
5996        yield SampleInput(
5997            mt((10, 100)), n_fft=16, window=window, onesided=False,
5998            return_complex=True)
5999
6000
6001def sample_inputs_istft(op_info, device, dtype, requires_grad, **kwargs):
6002    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6003
6004    def mt(shape, **kwargs):
6005        real_shape = shape if dtype.is_complex else shape + (2,)
6006        return make_arg(real_shape, **kwargs)
6007
6008    yield SampleInput(mt((10, 2)), kwargs=dict(n_fft=10))
6009    yield SampleInput(mt((6, 3)), kwargs=dict(n_fft=6, onesided=False))
6010    yield SampleInput(mt((6, 4)), kwargs=dict(n_fft=10, onesided=True))
6011
6012    for center in [False, True]:
6013        yield SampleInput(mt((10, 10, 6)), kwargs=dict(n_fft=10, center=center))
6014        yield SampleInput(mt((1, 9, 10)), kwargs=dict(n_fft=16, hop_length=4, center=center))
6015
6016    window = make_arg(10, low=.5, high=2.0)
6017    yield SampleInput(mt((10, 10, 6)), kwargs=dict(
6018        n_fft=10, window=window, center=center, return_complex=dtype.is_complex))
6019    yield SampleInput(mt((10, 10, 10)), kwargs=dict(
6020        n_fft=10, window=window[:8], win_length=8, center=center, return_complex=True))
6021
6022    real_window = window if not dtype.is_complex else window.real
6023    yield SampleInput(mt((10, 5, 6)), kwargs=dict(n_fft=8, window=real_window[:8], center=center))
6024
6025def sample_inputs_ormqr(op_info, device, dtype, requires_grad, **kwargs):
6026    # create a helper function wrapping `make_tensor`
6027    make_input = partial(make_tensor, dtype=dtype, device=device, low=-1, high=1)
6028
6029    batches = [(), (0, ), (2, ), (2, 1)]
6030    ns = [5, 2, 0]
6031    tf = [True, False]
6032    for batch, (m, n), left, transpose in product(batches, product(ns, ns), tf, tf):
6033        input = make_input((*batch, m, n))
6034        reflectors, tau = torch.geqrf(input)
6035        reflectors.requires_grad_(requires_grad)
6036        tau.requires_grad_(requires_grad)
6037        other_matrix_shape = (m, n) if left else (n, m)
6038        other = make_input((*batch, *other_matrix_shape), requires_grad=requires_grad)
6039        yield SampleInput(reflectors, tau, other, left=left, transpose=transpose)
6040
6041
6042def sample_inputs_cholesky_solve(op_info, device, dtype, requires_grad=False, **kwargs):
6043    cholesky_inverse_samples = sample_inputs_linalg_cholesky_inverse(
6044        op_info, device, dtype, requires_grad=False
6045    )
6046
6047    for sample in cholesky_inverse_samples:
6048        psd_matrix = sample.input
6049        sample.input = make_tensor(psd_matrix.shape, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None)
6050        sample.args = (psd_matrix.requires_grad_(requires_grad),)
6051        yield sample
6052
6053
6054def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs):
6055    make_arg = partial(make_fullrank_matrices_with_distinct_singular_values,
6056                       dtype=dtype, device=device, requires_grad=requires_grad)
6057
6058    # not needed once OpInfo tests support Iterables
6059    batch_shapes = ((), (3,), (3, 3))
6060    for batch_shape, get_infos, size_delta in product(batch_shapes, (True, False), (-2, -1, 0, +1, +2)):
6061        shape = batch_shape + (S + size_delta, S)
6062        input = make_arg(*shape)
6063        yield SampleInput(input, args=(True, get_infos))
6064
6065
6066def sample_inputs_lu_unpack(op_info, device, dtype, requires_grad=False, **kwargs):
6067    def out_fn(output):
6068        return output[1], output[2]
6069
6070    for lu_sample in sample_inputs_linalg_lu(op_info, device, dtype, requires_grad, **kwargs):
6071        lu_data, pivots = torch.linalg.lu_factor(lu_sample.input)
6072        lu_data.requires_grad_(requires_grad)
6073        yield SampleInput(lu_data, pivots).with_metadata(output_process_fn_grad=out_fn)
6074
6075
6076def sample_inputs_roll(op_info, device, dtype, requires_grad=False, **kwargs):
6077    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6078
6079    args = ((0, 0), (1, 2), (0, 2), (2, 0), (-1, 0), (10000, 1), (2,), ((1, 2, -1), (0, 1, 2)))
6080
6081    for arg in args:
6082        yield SampleInput(make_arg((0, 0, 0)), args=arg)
6083        yield SampleInput(make_arg((S, S, S)), args=arg)
6084
6085    # Scalar tensor
6086    yield SampleInput(make_arg(()), args=(10, ))
6087
6088def error_inputs_roll(op_info, device, **kwargs):
6089    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
6090    err_msg1 = "`shifts` required"
6091    s1 = SampleInput(make_arg((S,)), ())
6092    yield ErrorInput(s1, error_regex=err_msg1)
6093
6094    err_msg2 = ("shifts and dimensions must align")
6095    s2 = SampleInput(make_arg((S, S)), (2, 1), 0)
6096    yield ErrorInput(s2, error_regex=err_msg2)
6097
6098    err_msg3 = ("out of range")
6099    s3 = SampleInput(make_arg((S, )), 0, 2)
6100    yield ErrorInput(s3, error_regex=err_msg3, error_type=IndexError)
6101
6102    err_msg4 = ("Dimension specified as 0")
6103    s4 = SampleInput(make_arg(()), 0, 0)
6104    yield ErrorInput(s4, error_regex=err_msg4, error_type=IndexError)
6105
6106def sample_inputs_rot90(op_info, device, dtype, requires_grad=False, **kwargs):
6107    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6108
6109    args = itertools.product(range(-5, 6), [(0, 1), (1, 2), (1, -1)])
6110
6111    yield SampleInput(make_arg((S, S, S)))
6112    for arg in args:
6113        yield SampleInput(make_arg((S, S, S)), args=arg)
6114
6115
6116def error_inputs_rot90(op_info, device, **kwargs):
6117    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
6118    err_msg1 = "expected total rotation dims"
6119    s1 = SampleInput(make_arg((S, S)), dims=(0,))
6120    yield ErrorInput(s1, error_regex=err_msg1)
6121
6122    err_msg2 = "expected total dims >= 2"
6123    s2 = SampleInput(make_arg((S,)))
6124    yield ErrorInput(s2, error_regex=err_msg2)
6125
6126    err_msg3 = "expected rotation dims to be different"
6127    s3 = SampleInput(make_arg((S, S)), dims=(1, 1))
6128    yield ErrorInput(s3, error_regex=err_msg3)
6129
6130
6131def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs):
6132    tensor_nd = partial(make_tensor, (S, S, S), device=device, dtype=dtype,
6133                        requires_grad=requires_grad)
6134    tensor_1d = partial(make_tensor, (S,), device=device, dtype=dtype,
6135                        requires_grad=requires_grad)
6136
6137    yield SampleInput(tensor_nd())
6138    yield SampleInput(tensor_nd(), dim=1)
6139    yield SampleInput(tensor_nd(), dim=1, unbiased=True, keepdim=True)
6140    yield SampleInput(tensor_1d(), dim=0, unbiased=True, keepdim=True)
6141    yield SampleInput(tensor_1d(), dim=0, unbiased=False, keepdim=False)
6142
6143    yield SampleInput(tensor_nd(), dim=(1,), correction=1.3)
6144    yield SampleInput(tensor_nd(), dim=(1,), correction=S // 2)
6145    yield SampleInput(tensor_nd(), dim=None, correction=0, keepdim=True)
6146    yield SampleInput(tensor_nd(), dim=None, correction=None)
6147    yield SampleInput(tensor_nd(), correction=0, keepdim=True)
6148    yield SampleInput(make_tensor(3, 4, 5, device=device, dtype=dtype, requires_grad=requires_grad), dim=-3)
6149
6150
6151def sample_inputs_std_var_unbiased(op_info, device, dtype, requires_grad, **kwargs):
6152    make_arg = partial(make_tensor, device=device, dtype=dtype,
6153                       requires_grad=requires_grad)
6154
6155    # Test var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
6156    yield SampleInput(make_arg((S, S)), True)
6157    yield SampleInput(make_arg((S,)), False)
6158
6159
6160def _generate_correlation_inputs(device, dtype, requires_grad, **kwargs):
6161    shapes = [(2,), (1, 2), (3, 2), (2, 3)]
6162    for shape in shapes:
6163        yield make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
6164
6165
6166def sample_inputs_corrcoef(op_info, device, dtype, requires_grad, **kwargs):
6167    return (SampleInput(t) for t in _generate_correlation_inputs(device, dtype, requires_grad))
6168
6169
6170def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs):
6171    for t in _generate_correlation_inputs(device, dtype, requires_grad):
6172        yield SampleInput(t)
6173        num_observations = t.numel() if t.ndimension() < 2 else t.size(1)
6174        fweights = make_tensor((num_observations,), dtype=torch.int, device=device, low=1, high=10)
6175        aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=0, high=1, requires_grad=requires_grad)
6176        for correction, fw, aw in product(range(num_observations), [None, fweights], [None, aweights]):
6177            yield SampleInput(t.clone().requires_grad_(requires_grad),
6178                              correction=correction, fweights=fw, aweights=aw)
6179
6180
6181def error_inputs_cov(op_info, device, **kwargs):
6182    a = torch.rand(S, device=device)
6183    yield ErrorInput(
6184        SampleInput(torch.rand(S, S, S, device=device)),
6185        error_regex="expected input to have two or fewer dimensions")
6186    yield ErrorInput(
6187        SampleInput(a, fweights=torch.rand(S, S, device=device)),
6188        error_regex="expected fweights to have one or fewer dimensions")
6189    yield ErrorInput(
6190        SampleInput(a, aweights=torch.rand(S, S, device=device)),
6191        error_regex="expected aweights to have one or fewer dimensions")
6192    yield ErrorInput(
6193        SampleInput(a, fweights=torch.rand(S, device=device)),
6194        error_regex="expected fweights to have integral dtype")
6195    yield ErrorInput(
6196        SampleInput(a, aweights=torch.tensor([1, 1], device=device)),
6197        error_regex="expected aweights to have floating point dtype")
6198    yield ErrorInput(
6199        SampleInput(a, fweights=torch.tensor([1], device=device)),
6200        error_regex="expected fweights to have the same numel")
6201    yield ErrorInput(
6202        SampleInput(a, aweights=torch.rand(1, device=device)),
6203        error_regex="expected aweights to have the same numel")
6204    yield ErrorInput(
6205        SampleInput(a, fweights=torch.tensor([-1, -2, -3, -4 , -5], device=device)),
6206        error_regex="fweights cannot be negative")
6207    yield ErrorInput(
6208        SampleInput(a, aweights=torch.tensor([-1., -2., -3., -4., -5.], device=device)),
6209        error_regex="aweights cannot be negative")
6210
6211
6212def sample_inputs_permute(op_info, device, dtype, requires_grad, **kwargs):
6213    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6214
6215    cases = [((1, 2, 3, 4), (0, 2, 3, 1)),
6216             ((1, 2, 3, 4), (0, -2, -1, 1)),
6217             ((), ()),
6218             ((1, 2, 3, 4), (2, 1, 3, 0))]
6219
6220    for shape, args in cases:
6221        yield SampleInput(make_arg(shape), args=(args,))
6222
6223def reference_inputs_permute(op, device, dtype, requires_grad, **kwargs):
6224    yield from sample_inputs_permute(op, device, dtype, requires_grad, **kwargs)
6225
6226    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6227
6228    cases = (
6229        ((), ()),
6230        ((1,), (0,)),
6231        ((2, 2), (1, 0)),
6232        ((2, 2), (0, 1)),
6233        ((2, 0, 1), (0, 2, 1)),
6234        ((3, 4, 2), (2, 1, 0)),
6235        ((3, 4, 2), (1, 0, 2)),
6236        ((3, 4, 2), (0, 1, 2)),
6237    )
6238
6239    # Adds tricky permutations and permutations with noncontiguity
6240    for shape, permutation in cases:
6241        for p in itertools.permutations(permutation):
6242            a = make_arg(shape).permute(p)
6243            yield SampleInput(a, args=(permutation,))
6244
6245            a = make_arg(shape, noncontiguous=True).permute(p)
6246            yield SampleInput(a, args=(permutation,))
6247
6248def error_inputs_softshrink(op, device, **kwargs):
6249    yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"lambd": -0.5}),
6250                     error_regex="lambda must be greater or equal to 0, but found to be -0.5")
6251
6252def sample_inputs_softshrink(op_info, device, dtype, requires_grad=False, **kwargs):
6253    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6254
6255    # The additional sample is to check additional values of lambd beyond the default
6256    # value (what is already checked by sample_inputs_elementwise_unary)
6257    for lbda in (0., 0.5):
6258        yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda})
6259
6260    yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
6261
6262def sample_inputs_hardshrink(op_info, device, dtype, requires_grad=False, **kwargs):
6263    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6264
6265    # The additional sample is to check additional values of lambd beyond the default
6266    # value (what is already checked by sample_inputs_elementwise_unary)
6267    # Note that unlike softshrink, lambd is allowed to be negative for hardshrink
6268    for lbda in (-0.5, 0., 0.5):
6269        yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda})
6270
6271    yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
6272
6273
6274def sample_inputs_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs):
6275    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6276
6277    # The additional sample is to check additional values of min_val and max_val beyond the default
6278    # value (what is already checked by sample_inputs_elementwise_unary)
6279    for max_val, min_val in ((0.5, -0.5), (0., 0.)):
6280        yield SampleInput(make_arg(S, S), kwargs={"min_val": min_val, "max_val": max_val})
6281
6282    yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
6283
6284def error_inputs_hardtanh(op_info, device, **kwargs):
6285    # Tests that hardtanh errors out when passed min_val > max_val.
6286    yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"min_val": 0.5, "max_val": -0.5}),
6287                     error_type=ValueError, error_regex="min_val cannot be greater than max_val")
6288
6289def sample_inputs_einsum(op_info, device, dtype, requires_grad=False, **kwargs):
6290    def c(t):
6291        return t.clone().requires_grad_(requires_grad)
6292
6293    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6294    x = make_arg((3,))
6295    y = make_arg((4,))
6296    A = make_arg((2, 3,))
6297    B = make_arg((1, 3,))
6298    C = make_arg((1, 2, 3,))
6299    D = make_arg((1, 3, 4,))
6300    E = make_arg((4, 4,))
6301    H = make_arg((3, 3,))
6302    I = make_arg((1, 3, 1,))
6303
6304    # Vector operations
6305    yield SampleInput([c(x)], 'i->')                      # sum
6306    yield SampleInput([c(x), c(y)], 'i,j->ij')            # outer
6307
6308    # Matrix operations
6309    yield SampleInput([c(A)], "ij->i")                    # col sum
6310    yield SampleInput([c(A), c(B)], "ij,kj->ik")          # matmul
6311    yield SampleInput([c(A), c(E)], "ij,Ab->ijAb")        # matrix outer product
6312
6313    # Tensor operations
6314    yield SampleInput([c(C), c(D)], "aij,ajk->aik")       # batch matmul
6315    yield SampleInput([c(D), c(E)], "aij,jk->aik")        # tensor matrix contraction
6316    yield SampleInput([c(C), c(B)], "ijk,ik->j")          # non contiguous
6317
6318    # Test diagonals
6319    yield SampleInput([c(I)], 'iji->j')                   # non-contiguous trace
6320
6321    # Test ellipsis
6322    yield SampleInput([c(H)], "i...->...")
6323    yield SampleInput([c(C), c(x)], '...ik, ...j -> ij')
6324
6325
6326def sample_inputs_flip(op_info, device, dtype, requires_grad, **kwargs):
6327    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
6328    sizes = ((S, M, S), (S, 0, M))
6329    all_dims = ((0, 1, 2), (0,), (0, 2), (-1,), ())
6330
6331    for size, dims in product(sizes, all_dims):
6332        yield SampleInput(make_arg(size), kwargs={"dims": dims})
6333
6334def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad, **kwargs):
6335    shapes = [
6336        (S, M, S),
6337        (S, 0, M),
6338    ]
6339    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
6340    return (SampleInput(make_arg(shape, low=None, high=None)) for shape in shapes)
6341
6342def error_inputs_fliplr(op, device, **kwargs):
6343    yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device)),
6344                     error_regex="Input must be >= 2-d.")
6345
6346def error_inputs_flipud(op, device, **kwargs):
6347    yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device)),
6348                     error_regex="Input must be >= 1-d.")
6349
6350def sample_inputs_clamp(op_info, device, dtype, requires_grad, **kwargs):
6351    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
6352    make_integral_arg = partial(make_tensor, dtype=torch.int32, device=device, low=None, high=None, requires_grad=False)
6353    shape = (S, M, S)
6354
6355    yield SampleInput(make_arg(shape), args=(make_arg(shape), make_arg(shape)))
6356    yield SampleInput(make_arg(shape), args=(make_arg(shape[1:]), make_arg(shape[1:])))
6357    yield SampleInput(make_arg(shape), args=(make_arg((S, 1, S)),))
6358    yield SampleInput(make_arg(shape), args=(None, make_arg(shape)))
6359    yield SampleInput(make_arg(shape), args=(make_arg(shape), None))
6360    # test type promotion
6361    yield SampleInput(make_arg(shape), args=(make_integral_arg(shape), None))
6362    yield SampleInput(make_arg(shape), args=(make_arg(shape), make_integral_arg(shape)))
6363
6364def reference_inputs_elementwise_ternary(op, device, dtype, requires_grad, *, sample_inputs_func, supports_scalars=False, **kwargs):
6365    yield from sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
6366
6367    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6368    make_scalar_tensor = partial(make_tensor, (), device='cpu', dtype=dtype, requires_grad=requires_grad)
6369    supported_dtypes = op.supported_dtypes(device)
6370
6371    # broadcasting and oncontiguous cases
6372    cases = (
6373        ((4, 4), (4, 4), (4, 4)),
6374        ((4, 4), (1, 4, 4), (4, 4)),
6375        ((4, 4), (1, 4, 4), (4, 1, 4)),
6376        ((4, 4, 1), (1, 4, 4), (4, 4)),
6377        ((4, 1), (1, 4, 4), (1, 4)),
6378        ((4, 4), (), (4, 4)),
6379        ((4, 4), (), ()),
6380        ((), (4, 4), (1, 4, 4)),
6381    )
6382
6383    for a, b, c in cases:
6384        yield SampleInput(make_arg(a), args=(make_arg(b), make_arg(c)))
6385        yield SampleInput(make_arg(a, noncontiguous=True),
6386                          args=(make_arg(b).transpose(0, -1), make_arg(c, noncontiguous=True).transpose(0, -1)))
6387
6388    # scalar cases
6389    if supports_scalars:
6390        cases = [
6391            ((), 1, 2,),
6392            ((), 1., 2),
6393            ((4, 4), 1., 2,),
6394            ((3, 4), make_scalar_tensor(), make_scalar_tensor()),
6395        ]
6396
6397        if torch.complex64 in supported_dtypes:
6398            cases.extend([
6399                ((3, 1, 4), complex(1, 2), 3.),
6400            ])
6401
6402        for a, b, c in cases:
6403            yield SampleInput(make_arg(a), args=(b, c))
6404
6405    # type promotion cases
6406    # int x float
6407    if torch.float in supported_dtypes and torch.long in supported_dtypes:
6408        a = make_arg((), dtype=torch.long)
6409        b = make_arg((1, 4), dtype=torch.float)
6410        c = make_arg((3, 4))
6411
6412        cases = (
6413            (a, b, c),
6414            (c, a, b),
6415        )
6416
6417        for a, b, c in cases:
6418            yield SampleInput(a, args=(b, c))
6419
6420    # NaN propagation
6421    if dtype.is_floating_point or dtype.is_complex:
6422        nan = float('nan') if dtype.is_floating_point else complex(float('nan'), float('nan'))
6423
6424        a = make_arg((12,))
6425        a[4] = nan
6426        a[7] = nan
6427        b = make_arg((12,))
6428        b[1] = nan
6429        b[7] = nan
6430        c = make_arg((12,))
6431        c[9] = nan
6432
6433        yield SampleInput(a, args=(b, c))
6434
6435
6436def _clamp_min_numpy(a, min=None):
6437    return np.maximum(a, min)
6438
6439
6440def _clamp_max_numpy(a, max=None):
6441    return np.minimum(a, max)
6442
6443
6444def _clamp_numpy(a, min=None, max=None):
6445    if min is None:
6446        return np.minimum(a, max)
6447    if max is None:
6448        return np.maximum(a, min)
6449
6450    return np.minimum(max, np.maximum(a, min))
6451
6452
6453def sample_inputs_cumprod(op_info, device, dtype, requires_grad, **kwargs):
6454    def make_arg(shape):
6455        # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
6456        return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad)
6457
6458    def prod_zeros(dim_select):
6459        assert len(dim_select) == 2
6460        result = make_arg(3 * (S,))
6461        result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_()
6462        result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_()
6463        result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_()
6464        return result
6465
6466    for dim in range(3):
6467        yield SampleInput(make_arg((S, S, S)), args=(dim,))
6468    # Scalar tensors and empty tensor
6469    for size in [(), (1,), (0,)]:
6470        yield SampleInput(make_arg(size), args=(0,))
6471
6472    yield SampleInput(prod_zeros([0, 1]), args=(1,))
6473    yield SampleInput(prod_zeros([0, 2]), args=(1,))
6474    yield SampleInput(prod_zeros([1, 2]), args=(1,))
6475
6476    # test dtype kwarg
6477    yield SampleInput(prod_zeros([1, 2]), args=(1,), kwargs={'dtype': dtype})
6478
6479def sample_inputs_view_as_complex(op_info, device, dtype, requires_grad, **kwargs):
6480    yield SampleInput(make_tensor((S, 2), dtype=dtype, device=device, requires_grad=requires_grad))
6481
6482def sample_inputs_view_as_real(op_info, device, dtype, requires_grad, **kwargs):
6483    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6484    sizes = ((S, S), ())
6485    return (SampleInput(make_arg(size)) for size in sizes)
6486
6487def error_inputs_complex(op_info, device, is_ref=False, **kwargs):
6488    make_arg = partial(make_tensor, dtype=torch.float32, device=device)
6489
6490    if is_ref:
6491        error_float = "Expected both inputs to be Half, Float or Double tensors but got torch.float32 and torch.int32"
6492        error_dtype = "Expected object of scalar type torch.float32 but got scalar type torch.float64 for second argument"
6493        error_out = "Expected out tensor to have dtype torch.complex128 but got torch.complex64 instead"
6494    else:
6495        error_float = "Expected both inputs to be Half, Float or Double tensors but got Float and Int"
6496        error_dtype = "Expected object of scalar type Float but got scalar type Double for second argument"
6497        error_out = "Expected object of scalar type ComplexDouble but got scalar type ComplexFloat for argument 'out'"
6498
6499    yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.int)),
6500                     error_type=RuntimeError, error_regex=error_float)
6501
6502    yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.float64)),
6503                     error_type=RuntimeError, error_regex=error_dtype)
6504
6505    yield ErrorInput(SampleInput(make_arg(M, S, dtype=torch.float64), make_arg(M, S, dtype=torch.float64),
6506                                 out=make_arg(M, S, dtype=torch.complex64)),
6507                     error_type=RuntimeError, error_regex=error_out)
6508
6509def sample_inputs_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
6510    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6511    shape = (S, S)
6512    yield SampleInput(make_arg(shape), make_arg(shape))
6513
6514def sample_inputs_prod(op_info, device, dtype, requires_grad, **kwargs):
6515    def make_arg(shape):
6516        # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
6517        return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad)
6518
6519    def prod_single_zero():
6520        result = make_arg(2 * (S,))
6521        result[0, 1] = 0
6522        return result
6523
6524    for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad):
6525        # only Tensor, ignore other inputs
6526        yield SampleInput(sample.input.clone().requires_grad_(requires_grad))
6527        yield sample
6528
6529    # Generates samples with keepdim = True
6530    for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad):
6531        sample.kwargs['keepdim'] = True
6532        yield sample
6533
6534    yield SampleInput(prod_single_zero())
6535    yield SampleInput(make_arg((3, 3, 3)), args=(1,))
6536    yield SampleInput(make_arg((3, 3, 3)), args=(1,), kwargs={'keepdim': True})
6537
6538    yield SampleInput(make_arg((3, 0)), args=(1,))
6539    yield SampleInput(make_arg((3, 0)), args=(1,), kwargs={'keepdim': True})
6540    yield SampleInput(torch.tensor([2., 3, 0, 0], dtype=dtype, device=device, requires_grad=requires_grad))
6541
6542    # test zero scalar tensor
6543    zero = make_arg(())
6544    zero.zero_()
6545    yield SampleInput(zero.clone().requires_grad_(requires_grad))
6546    yield SampleInput(zero.clone().requires_grad_(requires_grad), args=(0,))
6547    yield SampleInput(zero.clone().requires_grad_(requires_grad),
6548                      args=(0,),
6549                      kwargs={'keepdim': True})
6550
6551def error_inputs_neg(op_info, device, **kwargs):
6552    si = SampleInput(torch.tensor((False, True), device=device))
6553    msg = ("Negation, the `\\-` operator, on a bool tensor is not supported."
6554           " If you are trying to invert a mask, use the `\\~` or"
6555           " `logical_not\\(\\)` operator instead.")
6556    yield ErrorInput(si, error_regex=msg)
6557
6558def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs):
6559    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
6560    yield SampleInput(make_arg(M))
6561
6562    tensors = (
6563        make_arg((M, M)),
6564        make_arg((3, 5)),
6565        make_arg((5, 3)),
6566    )
6567
6568    args = ((), (2,), (-2,), (1,), (2,))
6569
6570    for tensor, arg in product(tensors, args):
6571        yield SampleInput(tensor.clone().requires_grad_(requires_grad), *arg)
6572
6573def reference_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs):
6574    yield from sample_inputs_diagonal_diag_embed(
6575        op_info, device, dtype, requires_grad, **kwargs)
6576
6577    make_arg = partial(
6578        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6579
6580    shapes1d = ((0,), (1,))
6581    shapes2d = ((L, M),)
6582    shapes3d = ((L, M, S),)
6583
6584    kwargs1d = {}
6585
6586    kwargs2d = (
6587        # dim1 > dim2 is allowed
6588        dict(dim1=1, dim2=0),
6589        # negative dims are allowed
6590        dict(dim1=-2, dim2=-1),
6591        # one dim negative and the other nonnegative is allowed
6592        dict(dim1=-1, dim2=0),
6593        # out of bounds offset should return an empty tensor in diagonal and
6594        # offset the diagonal in diag_embed
6595        dict(offset=100),
6596    )
6597
6598    kwargs3d = kwargs2d + (
6599        # make sure we can use non-sequential dims
6600        dict(offset=-1, dim1=0, dim2=2),
6601    )
6602
6603    samples1d = product(shapes1d, kwargs1d)
6604    samples2d = product(shapes2d, kwargs2d)
6605    samples3d = product(shapes3d, kwargs3d)
6606
6607    for shape, kwargs in chain(samples1d, samples2d, samples3d):
6608        if 'diagonal' in op_info.name:
6609            # these are error inputs for diagonal
6610            if shape in ((0,), (1,)):
6611                continue
6612        yield SampleInput(input=make_arg(shape), kwargs=kwargs)
6613
6614
6615def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs):
6616    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
6617
6618    # Shapes for 2D Tensors
6619    shapes_2d = ((M, M), (3, 5), (5, 3))
6620
6621    # Shapes for 3D Tensors
6622    shapes_3d = ((M, M, M),)
6623
6624    args_2d = ((), (2,), (-2,), (1,))
6625    args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1))
6626
6627    for input_shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)):
6628        input_ = make_arg(input_shape)
6629        # We can programmatically figure out the right shape for src:
6630        # It should be the same size as input.diagonal(other_args...)
6631        if not isinstance(arg, tuple):
6632            arg_tuple = (arg,)
6633        else:
6634            arg_tuple = arg
6635        src_shape = input_.diagonal(*arg_tuple).size()
6636        src = make_arg(src_shape)
6637        yield SampleInput(input_, args=(src, *arg_tuple))
6638
6639
6640def sample_inputs_to_sparse(op_info, device, dtype, requires_grad, **kwargs):
6641    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6642
6643    yield SampleInput(make_arg((S, S))).with_metadata(output_process_fn_grad=lambda x: x.to_dense())
6644    yield SampleInput(make_arg((S, S)), 1).with_metadata(output_process_fn_grad=lambda x: x.to_dense())
6645
6646def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs):
6647    batch_size, num_classes = shape = (2, 3)
6648    reductions = ("mean", "sum", "none")
6649
6650    input_shape_and_kwargs: List[Tuple[Tuple[int, ...], Dict[str, Any]]] = [
6651        (shape, {}),
6652        ((*shape, 1), {}),
6653        ((*shape, 1, 2), {}),
6654        ((*shape, 1, 2, 3), {}),
6655        *[(shape, dict(reduction=reduction)) for reduction in reductions],
6656        *[
6657            (
6658                shape,
6659                dict(
6660                    weight=make_tensor((num_classes,), device=device, dtype=dtype),
6661                    reduction=reduction,
6662                ),
6663            )
6664            for reduction in reductions
6665        ],
6666        (shape, dict(ignore_index=1)),
6667    ]
6668
6669    for (input_shape, kwargs), probabilities_target in itertools.product(input_shape_and_kwargs, (False, True)):
6670        input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad)
6671
6672        if probabilities_target:
6673            # ignore_index is not supported for probabilities target
6674            if "ignore_index" in kwargs:
6675                continue
6676
6677            target = make_tensor(
6678                input_shape,
6679                low=0,
6680                high=1,
6681                device=device,
6682                dtype=dtype,
6683                requires_grad=requires_grad,
6684            )
6685        else:
6686            target = make_tensor(
6687                (batch_size, *input_shape[2:]),
6688                low=0,
6689                high=num_classes,
6690                device=device,
6691                dtype=torch.long,
6692            )
6693
6694            if "ignore_index" in kwargs and torch.all(target == kwargs["ignore_index"]):
6695                # make sure at least one item in target is not ignored
6696                target[0] = random.sample(sorted(set(range(num_classes)) - {kwargs["ignore_index"]}), 1)[0]
6697
6698        yield SampleInput(input, target, **kwargs)
6699
6700
6701def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs):
6702    low, high = op_info.domain
6703
6704    # Note: Operator is very sensitive at points near the
6705    # start and end of domain and leads to NaN for float16
6706    # if domain_eps is 1e-5.
6707    if dtype.is_floating_point or dtype.is_complex:
6708        domain_eps = op_info._domain_eps if dtype != torch.float16 else 3e-2
6709
6710        low = low + domain_eps
6711        high = high - domain_eps
6712
6713    make_arg = partial(make_tensor, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
6714
6715    yield SampleInput(make_arg((S, S, S)))
6716    yield SampleInput(make_arg((S, S, S)), 0.2)
6717    yield SampleInput(make_arg(()))
6718    yield SampleInput(make_arg(()), 0.2)
6719
6720def sample_inputs_isin(op_info, device, dtype, requires_grad, **kwargs):
6721    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6722    # isin has two paths based on the size of elements and test_elements.
6723    # if elements.numel() < 10 * pow(test_elements.numel(), 0.145):
6724    yield SampleInput(make_arg((L,)), args=(make_arg((S,)),))
6725    # else:
6726    yield SampleInput(make_arg((S,)), args=(make_arg((L,)),))
6727
6728def sample_inputs_masked_scatter(op_info, device, dtype, requires_grad, **kwargs):
6729    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6730
6731    yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg((S, S))))
6732    yield SampleInput(make_arg((S, S)), args=(torch.randn((S,), device=device) > 0, make_arg((S, S))))
6733    yield SampleInput(make_arg((S, S)), args=(bernoulli_scalar().to(device), make_arg((S, S))))
6734    yield SampleInput(make_arg((S,)),
6735                      args=(torch.randn(S, S, device=device) > 0, make_arg((S, S))),
6736                      broadcasts_input=True)
6737
6738def error_inputs_masked_scatter(op_info, device, **kwargs):
6739    make_arg = partial(make_tensor, device=device, dtype=torch.float)
6740    for mask_dtype in [torch.float, torch.uint8]:
6741        yield ErrorInput(SampleInput(make_arg(1, 3), args=(torch.ones(1, 3, device=device, dtype=mask_dtype),
6742                                                           make_arg(3, 4))),
6743                         error_regex=r"masked_scatter_ only supports boolean masks")
6744
6745def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs):
6746    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6747
6748    yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10))
6749    yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg(())))
6750    yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10))
6751    yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10))
6752    yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, make_arg(())))
6753    yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10))
6754
6755    yield SampleInput(make_arg((S,)),
6756                      args=(torch.randn(S, S, device=device) > 0, make_arg(())),
6757                      broadcasts_input=True)
6758    yield SampleInput(make_arg((S,)),
6759                      args=(torch.randn(S, S, device=device) > 0, 10),
6760                      broadcasts_input=True)
6761
6762    if torch.device(device).type == 'cuda':
6763        # `self` and `mask` on CUDA but `value` is a CPU scalar tensor.
6764        yield SampleInput(make_arg((S, S)),
6765                          args=(torch.randn(S, S, device=device) > 0,
6766                                make_tensor((), device="cpu", dtype=dtype)))
6767
6768def error_inputs_masked_fill(op_info, device, **kwargs):
6769    make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
6770    # `value` is not a 0-D tensor.
6771    yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, make_arg((1,)))),
6772                     error_regex="only supports a 0-dimensional value tensor, but got tensor with 1 dimension")
6773    # downcasting complex value (scalar overload)
6774    yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, 1j)),
6775                     error_regex=r"value cannot be converted to type .* without overflow")
6776    # downcasting complex value (tensor overload)
6777    yield ErrorInput(SampleInput(torch.ones(2, dtype=torch.long, device=device),
6778                                 args=(make_arg(()) > 0, torch.tensor(1j, device=device))),
6779                     error_regex=r"value cannot be converted to type .* without overflow")
6780
6781    if torch.device(device).type == 'cuda':
6782        # `self` and `mask` on CPU but `value` is a CUDA scalar tensor.
6783        yield ErrorInput(SampleInput(torch.randn((S, S), device='cpu'),
6784                                     args=(torch.randn(S, S, device='cpu') > 0,
6785                                           torch.randn((), device='cuda'))),
6786                         error_regex=r"to be on same device")
6787
6788
6789def sample_inputs_masked_select(op_info, device, dtype, requires_grad, **kwargs):
6790    make_arg = partial(
6791        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None)
6792
6793    yield SampleInput(make_arg((M, M)), torch.randn(M, M, device=device) > 0)
6794
6795    yield SampleInput(make_arg((M, M)), torch.randn((M,), device=device) > 0)
6796    yield SampleInput(make_arg((M,)), torch.randn((M, M), device=device) > 0)
6797
6798    yield SampleInput(make_arg((M, 1, M)), torch.randn((M, M), device=device) > 0)
6799
6800    yield SampleInput(make_arg(()), torch.tensor(1, device=device, dtype=torch.bool))
6801
6802    yield SampleInput(make_arg((M, M)), torch.tensor(1, device=device, dtype=torch.bool))
6803
6804    yield SampleInput(make_arg(()), torch.randn((M, M), device=device) > 0)
6805
6806def sample_inputs_matrix_exp(op_info, device, dtype, requires_grad, **kwargs):
6807    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6808    yield SampleInput(make_arg((S, S)))
6809    yield SampleInput(make_arg((S, S, S)))
6810
6811def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False, **kwargs):
6812    make_arg = partial(make_tensor, dtype=dtype, device=device, low=None,
6813                       high=None, requires_grad=requires_grad)
6814    test_cases = (((L,), (L,)),
6815                  ((S, M), (M,)),
6816                  ((M,), (M, S)),
6817                  ((S, M), (M, S)),
6818                  ((S, 0), (0, M)),
6819                  ((S, S, M), (M,)),
6820                  ((S, S, M), (M, S)),
6821                  ((S, S, 0), (0, S)),
6822                  ((M,), (S, M, S)),
6823                  ((S, M), (S, M, S)),
6824                  ((0, 0), (S, 0, 0)),
6825                  ((S, S, M, M), (S, S, M, S)),
6826                  ((S, S, M, M), (M,)),
6827                  ((M,), (S, S, M, S)),
6828                  ((S, S, S), (1, S, S))
6829                  )
6830    for lhs_shape, rhs_shape in test_cases:
6831        lhs = make_arg(lhs_shape)
6832        rhs = make_arg(rhs_shape)
6833        if not is_rmatmul:
6834            yield SampleInput(lhs, rhs)
6835        else:
6836            yield SampleInput(rhs, lhs)
6837
6838
6839def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.dtype,
6840                           requires_grad: bool,
6841                           *, variant: str, **kwargs) -> List[SampleInput]:
6842    if variant == 'variadic':
6843        def make_inputs(
6844                tensors: List[torch.Tensor]) -> Tuple[Union[torch.Tensor,
6845                                                            List[torch.Tensor]],
6846                                                      Tuple[torch.Tensor, ...]]:
6847            return tensors
6848    elif variant == 'list':
6849        def make_inputs(
6850                tensors: List[torch.Tensor]) -> Tuple[Union[torch.Tensor,
6851                                                            List[torch.Tensor]],
6852                                                      Tuple[torch.Tensor, ...]]:
6853            return [tensors]
6854    else:
6855        raise ValueError(
6856            'Unsupported variant, must be one of {"variadic", "list"}. '
6857            f'Got "{variant}".')
6858
6859    SCALAR = torch.Size([])
6860    VECTOR = torch.Size([3])
6861    test_cases: List[List[torch.Size]] = [
6862        [SCALAR],
6863        [VECTOR],
6864        [VECTOR, SCALAR],
6865        [VECTOR, SCALAR, VECTOR],
6866        [VECTOR, SCALAR, VECTOR, SCALAR],
6867    ]
6868
6869    for shapes, indexing in itertools.product(test_cases, {'xy', 'ij'}):
6870        args = make_inputs(
6871            [make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad)
6872             for shape in shapes])
6873        yield SampleInput(*args, indexing=indexing)
6874
6875
6876def sample_inputs_mvlgamma(op_info, device, dtype, requires_grad, **kwargs):
6877    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6878    tensor_shapes = ((S, S), ())
6879    ns = (1, 2, 3, 4, 5)
6880
6881    # Since the accepted lower bound for input
6882    # to mvlgamma depends on `p` argument,
6883    # the following function computes the lower bound
6884    # which we pass to `make_tensor`.
6885    def compute_min_val(p):
6886        return (p - 1.) / 2
6887
6888    for shape, n in product(tensor_shapes, ns):
6889        min_val = compute_min_val(n)
6890        if not dtype.is_floating_point:
6891            # Round-up minimum value for integral dtypes
6892            min_val += 1
6893        else:
6894            min_val += 2 * torch.finfo(dtype).eps
6895        yield SampleInput(make_arg(shape, low=min_val), args=(n,))
6896
6897
6898# Since `mvlgamma` has multiple entries,
6899# there are multiple common skips for the additional
6900# entries. Following function is a helper to that end.
6901def skips_mvlgamma(skip_redundant=False):
6902    skips = (
6903        # outside domain values are hard error for mvlgamma op.
6904        DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_float_domains'),
6905        DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs',
6906                     'test_reference_numerics_extremal'),
6907        DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
6908                     'test_reference_numerics_large',
6909                     dtypes=(torch.float16, torch.int8)),
6910        DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
6911                     'test_reference_numerics_small',
6912                     dtypes=(torch.int8,)),
6913    )
6914    if skip_redundant:
6915        # Redundant tests
6916        skips = skips + (  # type: ignore[assignment]
6917            DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
6918            DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
6919            DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
6920            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
6921        )
6922    return skips
6923
6924
6925# To test reference numerics against multiple values of argument `p`,
6926# we make multiple OpInfo entries with each entry corresponding to different value of p.
6927# We run the op tests from test_ops.py only for `p=1` to avoid redundancy in testing.
6928def make_mvlgamma_opinfo(variant_test_name, domain, skips, sample_kwargs):
6929    return UnaryUfuncInfo('mvlgamma',
6930                          ref=reference_mvlgamma if TEST_SCIPY else None,
6931                          aliases=('special.multigammaln',),
6932                          variant_test_name=variant_test_name,
6933                          domain=domain,
6934                          decorators=(precisionOverride({torch.float16: 5e-2}),),
6935                          dtypes=all_types_and(torch.half, torch.bfloat16),
6936                          dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
6937                          sample_inputs_func=sample_inputs_mvlgamma,
6938                          supports_forward_ad=True,
6939                          supports_fwgrad_bwgrad=True,
6940                          promotes_int_to_float=True,
6941                          skips=skips,
6942                          sample_kwargs=sample_kwargs)
6943
6944
6945def sample_inputs_cumulative_ops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs):
6946    def _make_tensor_helper(shape, low=None, high=None):
6947        return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
6948
6949    yield SampleInput(_make_tensor_helper((S, S, S)), 0)
6950    yield SampleInput(_make_tensor_helper((S, S, S)), 1)
6951    yield SampleInput(_make_tensor_helper(()), 0)
6952
6953    if supports_dtype_kwargs:
6954        # NOTE: if `dtype` is not same as input, then inplace variants fail with
6955        # `provided dtype must match the dtype of self tensor in cumsum`
6956        yield SampleInput(_make_tensor_helper((S, S, S)), 1, dtype=dtype)
6957
6958
6959def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs):
6960    test_cases = (
6961        ((), (0, 1, 1)),
6962        ((S, S, S, S), (0, 3, 1)),
6963        ((S, S, S, S), (1, 3, 1)),
6964        ((S, S, S, S), (2, 3, 1)),
6965        ((S, S, S, S), (3, 3, 1)),
6966        ((S, S, S, S), (0, 3, 2)),
6967        ((S, S, S, S), (1, 3, 2)),
6968        ((S, S, S, S), (2, 3, 2)),
6969        ((S, S, S, S), (3, 3, 2)),
6970        ((S, S, S, S), (0, 4, 1)),
6971        ((S, S, S, S), (1, 4, 1)),
6972        ((S, S, S, S), (2, 4, 1)),
6973        ((S, S, S, S), (3, 4, 1)),
6974        ((M,), (0, 3, 1)),
6975        ((M,), (0, 3, 2)),
6976        ((M,), (0, 3, 3)),
6977        ((1000,), (0, 3, 11)),
6978        ((1000,), (0, 2, 27)),
6979        ((10, 10), (0, 1, 2)),
6980        ((10, 10), (1, 2, 3)),
6981        ((10, 10), (1, 2, 2)),
6982        ((S, S, S), (2, 3, 2)),
6983    )
6984
6985    for shape, arguments in test_cases:
6986        yield SampleInput(make_tensor(shape, dtype=dtype, device=device,
6987                                      low=None, high=None,
6988                                      requires_grad=requires_grad),
6989                          *arguments)
6990
6991def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=False, **kwargs):
6992    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
6993
6994    if list_args:
6995        cases = (
6996            ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)),
6997            ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), 2),),
6998            ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), -2),)
6999        )
7000    else:
7001        cases = (  # type: ignore[assignment]
7002            ((S, S, S), (2,)),
7003            ((S, S, S), (S, 1)),
7004        )
7005
7006    for shape, args in cases:
7007        yield SampleInput(make_arg(shape), args=args)
7008
7009
7010def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs):
7011    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
7012
7013    cases = (((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)),
7014             ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3), 0]),)),
7015             ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), 2)),
7016             ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), -2)),
7017             )
7018
7019    for shape, args in cases:
7020        yield SampleInput(make_arg(shape), args=args)
7021
7022
7023def sample_inputs_msort(op_info, device, dtype, requires_grad, **kwargs):
7024    def apply_grad(t):
7025        if dtype in floating_types_and(torch.float16, torch.bfloat16):
7026            t.requires_grad_(requires_grad)
7027
7028    def large_1d_unique(dtype, device):
7029        res = torch.randperm(L * L * L, dtype=torch.int64, device=device)
7030        res = res.to(dtype)
7031        apply_grad(res)
7032        return res
7033
7034    # Test case for large tensor.
7035    yield SampleInput(large_1d_unique(dtype, device))
7036
7037    yield SampleInput(make_tensor((S, M, S), dtype=dtype, device=device,
7038                                  low=None, high=None,
7039                                  requires_grad=requires_grad))
7040
7041def sample_inputs_lerp(op_info, device, dtype, requires_grad, **kwargs):
7042    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7043
7044    # no broadcast
7045    yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4)
7046    # broadcast rhs
7047    yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4)
7048    # scalar tensor
7049    yield SampleInput(make_arg(()), make_arg(()), 0.4)
7050    # broadcast rhs scalar-tensor
7051    yield SampleInput(make_arg((S, S)), make_arg(()), 0.4)
7052    # broadcast rhs with weight tensor
7053    yield SampleInput(make_arg((S, S)), make_arg((S,)), make_arg((S, S)))
7054    # broadcast rhs and weight tensor
7055    yield SampleInput(make_arg((S, S)), make_arg((S, 1)), make_arg((S,)))
7056    # broadcast lhs
7057    yield SampleInput(make_arg((S,)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True)
7058    # scalar broadcast_lhs
7059    yield SampleInput(make_arg(()), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True)
7060    # broadcast all
7061    yield SampleInput(make_arg((S, 1)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True)
7062    # tensor broadcast all
7063    yield SampleInput(make_arg((S, 1)), make_arg((S, S)), make_arg((S, 1))).with_metadata(
7064        broadcasts_input=True)
7065    # no broadcast with weight tensor
7066    yield SampleInput(make_arg((S, S)), make_arg((S, S)), make_arg((S, S)))
7067    # broadcast lhs with weight tensor
7068    yield SampleInput(make_arg((S,)), make_arg((S, S)), make_arg((S, S))).with_metadata(
7069        broadcasts_input=True)
7070    # broadcast lhs and weight tensor
7071    yield SampleInput(make_arg((S,)), make_arg((S, S, S)), make_arg((S, S))).with_metadata(
7072        broadcasts_input=True)
7073    # broadcast lhs and weight tensor variant
7074    yield SampleInput(make_arg((S, S)), make_arg((S, S, S)), make_arg((S,))).with_metadata(
7075        broadcasts_input=True)
7076
7077    if dtype.is_complex:
7078        # no broadcast
7079        yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4j)
7080        yield SampleInput(make_arg((S, S)), make_arg((S, S)), 1.2 + 0.1j)
7081        # broadcast rhs
7082        yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4j)
7083        yield SampleInput(make_arg((S, S)), make_arg((S, S)), 5.4 + 9j)
7084        # scalar tensor
7085        yield SampleInput(make_arg(()), make_arg(()), 0.4j)
7086        yield SampleInput(make_arg(()), make_arg(()), 6.1 + 0.004j)
7087        # broadcast rhs scalar-tensor
7088        yield SampleInput(make_arg((S, S)), make_arg(()), 0.4j)
7089        yield SampleInput(make_arg((S, S)), make_arg(()), 1 + 2j)
7090
7091def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs):
7092    cases = (
7093        ((2, 2, 2), (2, 2, 2), (2)),
7094        ((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])),
7095    )
7096    for first_shape, second_shape, dims in cases:
7097        yield SampleInput(make_tensor(first_shape, dtype=dtype, device=device,
7098                                      requires_grad=requires_grad),
7099                          make_tensor(second_shape, dtype=dtype, device=device,
7100                                      requires_grad=requires_grad),
7101                          dims=dims)
7102
7103def sample_inputs_kron(op_info, device, dtype, requires_grad, **kwargs):
7104    make_arg = partial(
7105        make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None)
7106    test_cases = (
7107        ((S, S), (M, L)),
7108    )
7109
7110    for input_shape, other_shape in test_cases:
7111        input = make_arg(input_shape)
7112        other = make_arg(other_shape)
7113        yield SampleInput(input, other)
7114
7115def sample_inputs_inner(self, device, dtype, requires_grad, **kwargs):
7116    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7117    yield SampleInput(make_arg(S), make_arg(S))
7118    yield SampleInput(make_arg(), make_arg(S, S))
7119
7120def sample_inputs_scatter(op_info, device, dtype, requires_grad, **kwargs):
7121    def _tensor(shape, dtype=dtype, low=None, high=None):
7122        return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
7123
7124    def _gather(shape, index_dim, max_indices):
7125        return gather_variable(shape, index_dim, max_indices, device=device)
7126
7127    zero = torch.tensor(0, dtype=torch.long, device=device)
7128    test_cases = (
7129        (_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))),
7130        (_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))),
7131        (_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))),
7132        (_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))),
7133        (_tensor((M, S)), (1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))),
7134        (_tensor((M, S)), (-1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))),
7135        (_tensor(()), (0, zero.clone().detach(), _tensor(()))),
7136        (_tensor(()), (0, zero.clone().detach(), 2.5)),
7137    )
7138
7139    for tensor, args in test_cases:
7140        yield SampleInput(tensor, *args)
7141
7142        if not requires_grad:
7143            yield SampleInput(tensor.clone().detach(), *args, reduce='add')
7144
7145            if dtype.is_floating_point:
7146                yield SampleInput(tensor.clone().detach(), *args, reduce='multiply')
7147
7148def sample_inputs_scatter_add(op_info, device, dtype, requires_grad, **kwargs):
7149    def _tensor(shape, dtype=dtype, low=None, high=None):
7150        return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
7151
7152    def _gather(shape, index_dim, max_indices):
7153        return gather_variable(shape, index_dim, max_indices, device=device)
7154
7155    zero = torch.tensor(0, dtype=torch.long, device=device)
7156    yield SampleInput(_tensor((M, S)), 0, _gather((S, S), 1, M), _tensor((S, S)))
7157    yield SampleInput(_tensor((M, S)), 1, _gather((S, S), 0, S), _tensor((S, S)))
7158    yield SampleInput(_tensor((M, S)), -1, _gather((S, S), 0, S), _tensor((S, S)))
7159    yield SampleInput(_tensor((M, S)), 0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))
7160    yield SampleInput(_tensor((M, S)), 1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))
7161    yield SampleInput(_tensor((M, S)), -1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))
7162    yield SampleInput(_tensor(()), 0, zero.clone().detach(), _tensor(()))
7163
7164def sample_inputs_scatter_reduce(op_info, device, dtype, requires_grad, **kwargs):
7165    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7166    gather = partial(gather_variable, device=device)
7167
7168    zero = torch.tensor(0, dtype=torch.long, device=device)
7169    test_cases = (
7170        ((M, S), 0, gather((S, S), 1, M), (S, S)),
7171        ((M, S), 1, gather((S, S), 0, S), (S, S)),
7172        ((M, S), -1, gather((S, S), 0, S), (S, S)),
7173        ((M, S), 0, gather((M, S // 2), 1, M), (M, S // 2)),
7174        ((M, S), 1, gather((M, S // 2), 0, S), (M, S // 2)),
7175        ((M, S), -1, gather((M, S // 2), 0, S), (M, S // 2)),
7176        ((), 0, zero.clone().detach(), ()),
7177    )
7178
7179    reduce = op_info.variant_test_name
7180    for (inp_shape, dim, index, src_shape), include_self in product(test_cases, [False, True, False]):
7181        yield SampleInput(make_arg(inp_shape),
7182                          args=(dim, index, make_arg(src_shape), reduce),
7183                          kwargs={'include_self': include_self})
7184
7185
7186    # Sample inputs to test edge cases for backward
7187    # Check that gradients are propagated correctly for prod when zeros in self/src are reduced
7188    if requires_grad and reduce == 'prod':
7189        # This sample tests gradients for the following cases
7190        # (a) 1 zero reduced (from src (self[0, 1], self[1, 1]), from self (self[0, 0], self[2, 0]))
7191        # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0])
7192        # (c) no zeros reduced (self([2, 1]))
7193        # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py
7194        #     test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad
7195        input = torch.tensor([[0, 13], [0, 17], [0, 19]], dtype=dtype, device=device, requires_grad=requires_grad)
7196        src = torch.tensor([[0, 1, 2, 3], [0, 4, 0, 1], [2, 3, 5, 6]], dtype=dtype, device=device, requires_grad=requires_grad)
7197        idx = torch.tensor([[1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.long, device=device)
7198
7199        yield SampleInput(input,
7200                          args=(1, idx, src, reduce),
7201                          kwargs={'include_self': True})
7202
7203def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode='lengths', **kwargs):
7204    def _tensor(shape, dtype=dtype, low=None, high=None):
7205        return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
7206
7207    zero = torch.tensor(0, dtype=torch.long, device=device)
7208    test_cases = (
7209        # inp_shape, dim, lengths, unsafe
7210        ((S,), 0, [0, 1, 2, 2], False),
7211        ((S,), 0, [0, 1, 2, 2], True),
7212        ((S,), 0, [2, 0, 3, 0], False),
7213        ((S, S), 0, [0, 1, 2, 2], False),
7214        # test when lengths do not sum to dim size
7215        ((M, S, S), 0, [1, 2, 0, 6, 0], True),
7216        # test for higher dimensions
7217        ((S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
7218        ((S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
7219        ((S, S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
7220        ((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
7221    )
7222
7223    reductions = ["max", "mean", "min", "sum", "prod"]
7224    for args, reduce, initial in product(test_cases, reductions, [1, 2]):
7225        inp_shape, dim, lengths, unsafe = args
7226        lengths_t = torch.tensor(lengths, dtype=torch.long, device=device)
7227        sample_input_kwargs = {'axis': dim, 'unsafe': unsafe, 'initial': initial}
7228        if mode == 'lengths':
7229            sample_input_kwargs['lengths'] = lengths_t
7230        elif mode == 'offsets':
7231            zeros_shape = list(lengths_t.shape)
7232            zeros_shape[dim] = 1
7233            offsets_t = torch.cat((lengths_t.new_zeros(zeros_shape), lengths_t), dim).cumsum_(dim)
7234            sample_input_kwargs['offsets'] = offsets_t
7235        else:
7236            raise RuntimeError(f"mode most be one of 'offsets' or 'lengths' got '{mode}'.")
7237        yield SampleInput(_tensor(inp_shape),
7238                          args=(reduce,),
7239                          kwargs=sample_input_kwargs)
7240
7241
7242def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs):
7243    make_arg = partial(make_tensor, dtype=dtype, device=device,
7244                       low=None, high=None, requires_grad=requires_grad)
7245    yield SampleInput(make_arg((S, S, S)))
7246    yield SampleInput(make_arg(()))
7247    yield SampleInput(make_arg((S, S, S), noncontiguous=True))
7248
7249def sample_inputs_unravel_index(op_info, device, dtype, requires_grad, **kwargs):
7250    make_arg = partial(make_tensor, dtype=dtype, device=device,
7251                       low=None, high=None, requires_grad=requires_grad)
7252    yield SampleInput(
7253        torch.tensor(
7254            [[3, 8, 13], [0, 5, 10]],
7255            device=device,
7256            dtype=dtype),
7257        (4, 5))
7258    yield SampleInput(
7259        torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype),
7260        (4, 2**30))
7261    yield SampleInput(
7262        torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype),
7263        (2**30, 4))
7264    yield SampleInput(
7265        torch.tensor(2, device=device, dtype=dtype),
7266        (2, 2))
7267    max_val = 2**(8 * dtype.itemsize - (1 if dtype.is_signed else 0)) - 1
7268    yield SampleInput(
7269        torch.tensor(max_val - 1, device=device, dtype=dtype),
7270        (1, max_val))
7271    yield SampleInput(
7272        torch.tensor([22, 41, 37], device=device, dtype=dtype),
7273        (7, 6))
7274    yield SampleInput(
7275        torch.tensor(min(1621, max_val), device=device, dtype=dtype),
7276        (6, 7, 8, 9))
7277    yield SampleInput(
7278        torch.tensor([], device=device, dtype=dtype),
7279        (10, 3, 5))
7280    yield SampleInput(
7281        torch.tensor(
7282            [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0]],
7283            device=device,
7284            dtype=dtype),
7285        (5, 8))
7286    yield SampleInput(
7287        torch.tensor(
7288            [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0], [1, 3, 1, 0, 9, 5]],
7289            device=device,
7290            dtype=dtype),
7291        (5, 8, 10))
7292    yield SampleInput(
7293        torch.tensor(0, device=device, dtype=dtype),
7294        ())
7295
7296    a = np.array([[2, 4, 5, 6], [7, 8, 1, 15]])
7297    b = np.array([[3, 2, 7, 6], [10, 12, 8, 9]])
7298    _, i1, i2 = np.intersect1d(a, b, assume_unique=True, return_indices=True)
7299    yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape)
7300    yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape)
7301
7302    a = np.array([[2, 4, 5, 6, 6], [4, 7, 8, 7, 2]])
7303    b = np.array([[3, 2, 7, 7], [10, 12, 8, 7]])
7304    _, i1, i2 = np.intersect1d(a, b, return_indices=True)
7305    yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape)
7306    yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape)
7307
7308
7309def sample_inputs_tril_triu(op_info, device, dtype, requires_grad, **kwargs):
7310    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7311    cases = (((M, M), ()),
7312             ((M, M), (2,),),
7313             ((M, S), ()),
7314             ((M, S), (-1,)),
7315             ((M, M), (2,),),
7316             ((S, M, S), ()),
7317             ((S, M, S), (2,)),
7318             ((3, 3, S, S), ()),)
7319
7320    for shape, args in cases:
7321        yield SampleInput(make_arg(shape), args=args)
7322
7323def error_inputs_tril_triu(opinfo, device, **kwargs):
7324    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
7325
7326    # error inputs for input.ndim <= 2
7327    yield ErrorInput(SampleInput(make_arg((4,))), error_regex="input tensor must have at least 2 dimensions")
7328
7329def sample_inputs_trilu_indices(op_info, device, dtype, requires_grad, **kwargs):
7330    # (row, col, offset)
7331    args_list = ((0, 0),
7332                 (20, 0),
7333                 (0, 20),
7334                 (20, 21, 0),
7335                 (20, 21, 7),
7336                 (20, 21, -7),
7337                 # Large test cases below are deliberately commented out to speed up CI
7338                 # tests and to avoid OOM error. When modifying implementations of
7339                 # tril_indices and triu_indices, please enable these tests and make sure
7340                 # they pass.
7341                 # (2, 68435455, 3),
7342                 # (5000, 5000),
7343                 # (5000, 5000, 1234),
7344                 # (5000, 5000, -1233),
7345                 )
7346    for args in args_list:
7347        yield SampleInput(args[0], args=args[1:], kwargs={"dtype": dtype, "device": device})
7348
7349def sample_inputs_clone_contiguous(op_info, device, dtype, requires_grad, **kwargs):
7350    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7351
7352    yield SampleInput(make_arg((S, M, S)))
7353    yield SampleInput(make_arg(()))
7354
7355def reference_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs):
7356    # NOTE: the default memory format for clone is torch.preserve_format, for contiguous it's torch.contiguous_format
7357    # This exploits that default to test torch.preserve_format for clone, without causing an error when testing contiguous
7358    yield from sample_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs)
7359
7360    shapes = (
7361        (3, 5, 6),
7362        (1, 1, 3, 5, 6),
7363        (1, 1, 3, 5, 6, 1, 1),
7364        (1, 0, 3, 5, 0, 2),
7365        (1, 0, 3, 5, 0, 0, 1, 1, 2),
7366        (),
7367    )
7368
7369    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7370    for shape in shapes:
7371        yield SampleInput(make_arg(shape))
7372        yield SampleInput(make_arg(shape).transpose(0, -1))
7373        yield SampleInput(make_arg(shape, noncontiguous=True))
7374        yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1))
7375
7376        yield SampleInput(make_arg(shape), kwargs={'memory_format': torch.contiguous_format})
7377        yield SampleInput(make_arg(shape).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format})
7378        yield SampleInput(make_arg(shape, noncontiguous=True), kwargs={'memory_format': torch.contiguous_format})
7379        yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format})
7380
7381    # shape, strides, offset
7382    strided_cases = (
7383        ((5, 6, 2), (1, 1, 7), 2),
7384        ((5, 5, 4), (1, 1, 7), 2),
7385        ((5, 5, 2), (4, 5, 7), 3),
7386        ((5, 5, 2), (5, 5, 7), 3),
7387        ((5, 5, 2), (5, 5, 5), 3),
7388        ((9, 5, 2), (0, 1, 7), 3),
7389    )
7390
7391    for shape, strides, offset in strided_cases:
7392        yield SampleInput(make_arg(500,).as_strided(shape, strides, offset))
7393        yield SampleInput(make_arg(500,).as_strided(shape, strides, offset), kwargs={'memory_format': torch.contiguous_format})
7394
7395    # channels last 2D
7396    yield SampleInput(make_arg((2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last})
7397    a = make_arg((2, 2, 2, 2)).permute(0, 3, 1, 2)
7398    yield SampleInput(a, kwargs={'memory_format': torch.channels_last})
7399
7400    # channels last 3D
7401    yield SampleInput(make_arg((2, 2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last_3d})
7402    a = make_arg((2, 2, 2, 2, 2)).permute(0, 4, 1, 2, 3)
7403    yield SampleInput(a, kwargs={'memory_format': torch.channels_last_3d})
7404
7405
7406def sample_inputs_sum_to_size(op_info, device, dtype, requires_grad, **kwargs):
7407    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7408
7409    # list of tuples (shape, shape) defining the shapes of the input and output tensors
7410    sample_shapes = [
7411        ((), ()),
7412        ((S,), (1,)),
7413        ((S, S), (1, 1)),
7414        ((S, S), (1, S)),
7415        ((S, S), (S, S)),
7416        ((S, S, S), (S, 1, S)),
7417    ]
7418
7419    for input_shape, output_shape in sample_shapes:
7420        yield SampleInput(make_arg(input_shape), args=(output_shape,))
7421        if output_shape == ():
7422            continue
7423        yield SampleInput(make_arg(input_shape), args=(list(output_shape),))
7424        yield SampleInput(make_arg(input_shape), args=(*output_shape,))
7425
7426
7427def error_inputs_sum_to_size(op_info, device, **kwargs):
7428    shape = (M, S, M)
7429    err_msg = "is not expandable to size"
7430    si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M, M))
7431    yield ErrorInput(si, error_regex=err_msg)
7432
7433    shape = (M + 1, S, S, M)
7434    err_msg = "is not expandable to size"
7435    si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M + 1, 1))
7436    yield ErrorInput(si, error_regex=err_msg)
7437
7438
7439def sample_inputs_resize_ops(op_info, device, dtype, requires_grad, **kwargs):
7440    make_arg = partial(make_tensor, dtype=dtype, device=device)
7441    cases = (((S, S, S), (S * S, S)),
7442             ((), ()),
7443             ((), (1, 1, 1)),
7444             )
7445
7446    for shape, args_or_shape in cases:
7447        # Update `args` based on operator
7448        if op_info.name == 'resize_':
7449            # resize_ takes shape/tuple of ints,
7450            args = (args_or_shape, )
7451        elif op_info.name == 'resize_as_':
7452            # resize_as_ takes another tensor
7453            args = (make_arg(shape, requires_grad=False), )  # type:ignore[assignment]
7454        else:
7455            raise ValueError("sample_inputs_resize_ops is being used with incorrect operator")
7456
7457        yield SampleInput(make_arg(shape, requires_grad=requires_grad), args=args)
7458
7459def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs):
7460    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7461
7462    cases = (
7463        # a, b, is_tensor_supported
7464        ((S, S, S), (S * S, S), True),
7465        ((S * S, S), (S, S, S), True),
7466        ((S * S, S), (S, -1, S), False),  # neg index
7467        ((S * S * 2, S), (S, -1), False),  # neg index
7468        ((S,), (S,), True),
7469        ((), (), False),  # empty
7470        ((), (1,), True),
7471    )
7472
7473    for a, b, is_tensor_supported in cases:
7474        # skip unsupported cases
7475        if kwargs.get("tensor_arg") and not is_tensor_supported:
7476            continue
7477
7478        # convert to tensor
7479        if kwargs.get("tensor_arg"):
7480            b = make_arg(b, requires_grad=False)
7481
7482        yield SampleInput(make_arg(a), args=(b,))
7483
7484def reference_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs):
7485    yield from sample_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs)
7486
7487    cases = (
7488        # a, b, is_tensor_supported
7489        ((125,), (25, 5), True),
7490        ((25, 25), (1, 5, 5, 1, 5, 1, 5, 1), True),
7491        ((16, 32), (2, 4, 1, 4, 4, 1, 4), True),
7492        ((16, 12), (12, 16), True),
7493        ((1, 16, 12), (12, 16), True),
7494        ((1, 5, 1, 5), (25, 1), True),
7495        ((2, 4, 2), (4, 4), True),
7496        ((1, 4), (1, 1, 2, 1, 2), True),
7497        ((3, 5, 7), (7, 5, 3), True),
7498        ((1,), (), False),  # empty
7499        ((5, 0, 2, 3), (5, 0, 2, 3), True),
7500        ((2, 1, 0, 3, 1), (5, 0), True),
7501        ((1,), (), False),  # empty
7502        ((4, 5, 6), (4, 5, 6, 1, 1, 1), True),
7503        ((), (1, 1, 1, 1), False),  # empty
7504    )
7505
7506    irreversible_cases = (
7507        ((), (-1,), False),  # neg index, empty
7508        ((4, 7, 9, 1, 1), (1, 4, 3, -1, 1), False),  # neg index
7509    )
7510
7511    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7512    for a, b, is_tensor_supported in cases:
7513        # skip unsupported cases
7514        if kwargs.get("tensor_arg") and not is_tensor_supported:
7515            continue
7516
7517        if kwargs.get("tensor_arg"):
7518            # convert to tensor
7519            yield SampleInput(make_arg(a), args=(make_arg(b, requires_grad=False),))
7520            yield SampleInput(make_arg(b), args=(make_arg(a, requires_grad=False),))
7521        else:
7522            yield SampleInput(make_arg(a), args=(b,))
7523            yield SampleInput(make_arg(b), args=(a,))
7524
7525    for a, b, is_tensor_supported in irreversible_cases:
7526        # skip unsupported cases
7527        if kwargs.get("tensor_arg") and not is_tensor_supported:
7528            continue
7529
7530        # convert to tensor
7531        if kwargs.get("tensor_arg"):
7532            b = make_arg(b, requires_grad=False)
7533
7534        yield SampleInput(make_arg(a), args=(b,))
7535
7536def error_inputs_view_reshape(op, device, **kwargs):
7537
7538    cases = (
7539        # a, b, is_tensor_supported
7540        # Reshape to different numel
7541        ((2,), (), False),  # empty
7542        ((1, 3, 0), (), False),  # empty
7543        ((4, 3), (4, 2), True),
7544        ((1, 3, 5), (5, 2, 2), True),
7545        # No valid inference
7546        ((1, 3, 5), (5, -1, 2), False),  # neg index
7547        # Two inferred shapes
7548        ((1, 3, 5), (5, -1, -1), False),  # neg index
7549        ((1), (0, -1), False),  # neg index
7550        ((0, 5), (0, -1), False),  # neg index
7551    )
7552
7553    make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False)
7554    for a, b, is_tensor_supported in cases:
7555        # skip unsupported cases
7556        if kwargs.get("tensor_arg") and not is_tensor_supported:
7557            continue
7558
7559        if b == (5, -1, -1):
7560            error_regex = "only one dimension can be inferred"
7561        elif a == (0, 5):
7562            error_regex = (r"cannot reshape tensor of 0 elements into shape "
7563                           r"\[0, -1\] because the unspecified dimension size "
7564                           r"-1 can be any value and is ambiguous")
7565        else:
7566            # to avoid having issues with a regex
7567            shape = ', '.join(map(str, b))
7568            size = a if type(a) is int else functools.reduce(operator.mul, a, 1)
7569            error_regex = rf"shape '\[{shape}\]' is invalid for input of size {size}"
7570
7571        # convert to tensor
7572        if kwargs.get("tensor_arg"):
7573            b = make_arg(b, requires_grad=False)
7574
7575        yield ErrorInput(SampleInput(make_arg(a), args=(b,)), error_type=Exception,
7576                         error_regex=error_regex)
7577
7578
7579def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs):
7580    input_list = []
7581    shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),)
7582    make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7583    for shape in shapes:
7584        yield SampleInput(make_tensor_partial(shape))
7585    yield SampleInput([make_tensor_partial(shape) for shape in shapes])
7586
7587def sample_inputs_column_stack(op_info, device, dtype, requires_grad, **kwargs):
7588    cases: Tuple[tuple, tuple] = (  # type: ignore[assignment]
7589        ((S, 2, 1), (S, 3, 1)),
7590        ((S), (S, 5)), ((), (1, S))
7591    )
7592    make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7593    for shape1, shape2 in cases:
7594        yield SampleInput([make_tensor_partial(shape1), make_tensor_partial(shape2)])
7595
7596def sample_inputs_flatten(op_info, device, dtype, requires_grad, **kwargs):
7597    shapes = ((S, S, S), (S, S), (S, ), (),)
7598    make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7599    for shape in shapes:
7600        yield SampleInput(make_tensor_partial(shape))
7601        if len(shape) > 1:
7602            yield SampleInput(make_tensor_partial(shape), start_dim=1, end_dim=-1)
7603
7604def reference_inputs_flatten(op, device, dtype, requires_grad, **kwargs):
7605    yield from sample_inputs_flatten(op, device, dtype, requires_grad, **kwargs)
7606
7607    # shape x start_dim x end_dim
7608    cases = (
7609        ((5, 4, 0, 1, 3, 7), 1, 3),
7610        ((5, 4, 0, 1, 3, 7), 4, 5),
7611        ((5, 4, 1, 1, 3, 7), 2, 3),
7612        ((), 0, -1),
7613        ((1,), 0, -1),
7614        ((3, 7, 5), 1, 2),
7615        ((4, 5), 1, 1),
7616        ((1, 5, 5, 1, 5, 1, 5, 1), 0, 2),
7617        ((1, 5, 5, 1, 5, 1, 5, 1), 3, -1),
7618        ((1, 5, 5, 1, 5, 7, 5, 1), -2, -1),
7619        ((2, 4, 2), 0, 1),
7620        ((4, 2, 2), 1, 2),
7621        ((0, 3, 4, 5), 1, 3),
7622    )
7623
7624    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7625    for shape, start, end in cases:
7626        yield SampleInput(make_arg(shape), args=(start, end,))
7627        yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), args=(start, end,))
7628        yield SampleInput(make_arg(shape).transpose(0, -1), args=(start, end,))
7629
7630def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs):
7631    # in_shape, dim, sizes
7632    args = (((8,), 0, (8,)),
7633            ((8,), 0, (4, 2)),
7634            ((8,), -1, (2, 2, 2)),
7635            ((8,), -1, (-1, 2)),
7636            ((3, 6, 2), 1, (2, 3)),
7637            ((3, 6, 2), -2, (2, 3)),
7638            ((3, 6, 2), -2, (-1, 3)),
7639            ((3, 2, 12), 2, (3, 2, 2)),
7640            ((4, 0), 0, (2, 2)),
7641            ((4, 0), 1, (2, 0, 0, 0)),
7642            )
7643    make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7644    for in_shape, dim, sizes in args:
7645        yield SampleInput(make_tensor_partial(in_shape), args=(dim, sizes))
7646
7647
7648def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
7649    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7650
7651    cases = (((S, S, S), (1, 2)),
7652             ((S, S, S), (-1, 2)),
7653             ((S, S, S), (-1, -1)),
7654             ((S, S, S), (1, -1)),
7655             ((S,), (0, 2))
7656             )
7657
7658    for shape, args in cases:
7659        yield SampleInput(make_arg(shape), args=args)
7660
7661
7662def sample_inputs_select_scatter(op_info, device, dtype, requires_grad, **kwargs):
7663    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7664
7665    cases = (((S, S, S), (S, S), (1, 2)),
7666             ((S, S, S), (S, S), (-1, 2)),
7667             ((S, S, S), (S, S), (-1, -1)),
7668             ((S, S, S), (S, S), (1, -1)),
7669             ((S,), (), (0, 2))
7670             )
7671
7672    for input_shape, src_shape, args in cases:
7673        input_ = make_arg(input_shape)
7674        src = make_arg(src_shape)
7675        yield SampleInput(input_, args=(src, *args))
7676
7677
7678def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs):
7679    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7680
7681    cases = (((L, L, L), (L, L, L,), (0, 0, L, 1)),
7682             ((L, L, L), (L // 2, L, L,), (0, L // 2, L, 1)),
7683             ((L, L, L), (L // 4, L, L,), (0, L // 2, L, 2)),
7684             ((L, L, L), (L, L, L,), (1, 0, L, 1)),
7685             ((L, L, L), (L, L // 2, L,), (1, L // 2, L, 1)),
7686             ((L, L, L), (L, L // 4, L,), (1, L // 2, L, 2)),
7687             ((L, L, L), (L, L, L,), (2, 0, L, 1)),
7688             ((L, L, L), (L, L, L // 2,), (2, L // 2, L, 1)),
7689             ((L, L, L), (L, L, L // 4,), (2, L // 2, L, 2)),
7690             )
7691
7692    for input_shape, src_shape, args in cases:
7693        input_ = make_arg(input_shape)
7694        src = make_arg(src_shape)
7695        yield SampleInput(input_, args=(src, *args))
7696
7697def sample_inputs_expand(op_info, device, dtype, requires_grad, **kwargs):
7698    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7699
7700    cases = (((S, 1, 1), (S, S, S)),
7701             ((S, 1, S), (S, S, S)),
7702             ((S, 1, S), (-1, S, -1)),
7703             ((S, 1, S), (-1, S, S)),
7704             ((S, 1), (S, S, S)),
7705             ((1,), (S, S, S)),
7706             ((1, S), (1, 1, S)),
7707             ((), ()),
7708             ((), (1, 3, 2)),
7709             )
7710
7711    for case in cases:
7712        shape, args = case
7713        yield SampleInput(make_arg(shape), args=(args,))
7714
7715def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs):
7716    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7717
7718    shapes = ((),
7719              (2, 3))
7720    memory_format_options = [None, torch.contiguous_format]
7721
7722    for shape, memory_format in itertools.product(shapes, memory_format_options):
7723        yield SampleInput(make_arg(shape),
7724                          kwargs={'memory_format': memory_format} if memory_format else {})
7725    yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last})
7726
7727def sample_inputs_byte(op_info, device, dtype, requires_grad, **kwargs):
7728    make_arg = partial(make_tensor, dtype=dtype, device=device, low=0, high=255, requires_grad=requires_grad)
7729
7730    shapes = ((),
7731              (2, 3))
7732    memory_format_options = [None, torch.contiguous_format]
7733
7734    for shape, memory_format in itertools.product(shapes, memory_format_options):
7735        yield SampleInput(make_arg(shape),
7736                          kwargs={'memory_format': memory_format} if memory_format else {})
7737    yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last})
7738
7739def sample_inputs_expand_as(op_info, device, dtype, requires_grad, **kwargs):
7740    make_arg = partial(make_tensor, dtype=dtype, device=device)
7741
7742    cases = (((S, 1, 1), (S, S, S)),
7743             ((), ()),
7744             ((), (1, 1)),
7745             )
7746
7747    for shape, shape_other in cases:
7748        yield SampleInput(make_arg(shape, requires_grad=requires_grad),
7749                          args=(make_arg(shape_other, requires_grad=False),))
7750
7751
7752def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
7753    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7754
7755    def make_bool_mask(shape):
7756        # Make sure atleast one element is nonzero,
7757        # except for empty tensor
7758        mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False)
7759
7760        if mask_t.numel() == 0:
7761            return mask_t
7762        elif mask_t.numel() == 1:
7763            mask_t.fill_(True)
7764            return mask_t
7765
7766        if mask_t.sum() == 0:
7767            def random_index(shape):
7768                return tuple(random.randrange(0, max_idx) for max_idx in shape)
7769
7770            mask_t[random_index(mask_t.shape)] = True
7771            return mask_t
7772
7773        return mask_t
7774
7775    cases = (((M, M), (M, M), (M, M), False),
7776             ((M, 1, M), (M, M), (M, M, 1), True),
7777             ((), (), (), False),
7778             ((M, 1, M), (), (M, M, 1), True),
7779             ((), (M, M), (), True),
7780             ((), (2), (1, 1), True),
7781             )
7782
7783    for shape, mask_shape, other_shape, broadcasts_input in cases:
7784        yield SampleInput(make_arg(shape),
7785                          args=(make_bool_mask(mask_shape), make_arg(other_shape)),
7786                          broadcasts_input=broadcasts_input)
7787
7788# TODO: add reference inputs for where(condition) signature
7789def reference_inputs_where(op, device, dtype, requires_grad, **kwargs):
7790    yield from sample_inputs_where(op, device, dtype, requires_grad, **kwargs)
7791
7792    make_cond = partial(make_tensor, dtype=torch.bool, device=device, requires_grad=requires_grad)
7793    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7794
7795    # noncontiguous
7796    c = make_cond((10, 3), noncontiguous=True)
7797    a = make_arg((10, 1), noncontiguous=True)
7798    b = make_arg((3, 10, 3)).transpose(0, -1)
7799
7800    # NOTE that the OpInfo for where takes samples of the form a, cond, b
7801    yield SampleInput(a, args=(c, b))
7802
7803    # type promoting
7804    other_dtype = torch.double if dtype is not torch.double else torch.long
7805    c = make_cond((10, 3), noncontiguous=True)
7806    a = make_arg((10, 1), dtype=torch.long)
7807    b = make_arg((10, 1))
7808
7809    yield SampleInput(a, args=(c, b))
7810
7811    # two python scalars
7812    c = make_cond((10, 3), noncontiguous=True)
7813    a = make_arg((1,)).item()
7814    b = make_arg((1,)).item()
7815
7816    yield SampleInput(a, args=(c, b))
7817
7818    # NaN propagation
7819    if dtype.is_floating_point or dtype.is_complex:
7820        if dtype.is_floating_point:
7821            nan = float('nan')
7822        else:
7823            # dtype.is_complex
7824            nan = complex(float('nan'), float('nan'))
7825        c = make_cond((1, 10, 3))
7826        a = make_arg((10, 3), noncontiguous=True)
7827        a[2, 1] = nan
7828        b = make_arg((1, 3))
7829        b[0, 2] = nan
7830
7831        yield SampleInput(a, args=(c, b))
7832
7833    # Python scalars type promotion
7834    for scalar in (0, 0.0, 2j, False):
7835        yield SampleInput(scalar, args=(c, b))
7836        yield SampleInput(a, args=(c, scalar))
7837
7838
7839def error_inputs_where(op_info, device, **kwargs):
7840    shape = (S,)
7841    err_msg = "Expected all tensors to be on the same device"
7842    for devices in product(('cpu', device), repeat=3):
7843        if len(set(devices)) == 2:
7844            si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32),
7845                             args=(make_tensor(shape, dtype=torch.bool, device=devices[1]),
7846                             make_tensor(shape, device=devices[2], dtype=torch.float32)))
7847            yield ErrorInput(si, error_regex=err_msg)
7848
7849def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs):
7850    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7851
7852    sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
7853
7854    inputs = []
7855    for shape in sizes:
7856        # construct input without any non-zero elements
7857        zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad)
7858        inputs.append(zeros)
7859
7860        # construct input with mixed zero and non-zero elements
7861        mixed = make_arg(shape).requires_grad_(False)
7862        mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False)
7863        mixed[mask_t] = 0
7864        inputs.append(mixed)
7865
7866    for input_t, as_tuple in product(inputs, [False, True]):
7867        yield SampleInput(input_t.clone().requires_grad_(requires_grad),
7868                          kwargs=dict(as_tuple=as_tuple))
7869
7870def sample_inputs_nonzero_static(op_info, device, dtype, requires_grad, **kwargs):
7871    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7872
7873    sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S))
7874
7875    inputs = []
7876    for shape in sizes:
7877        # construct input without any non-zero elements
7878        zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad)
7879        inputs.append(zeros)
7880
7881        # construct input with mixed zero and non-zero elements
7882        mixed = make_arg(shape).requires_grad_(False)
7883        mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False)
7884        mixed[mask_t] = 0
7885        inputs.append(mixed)
7886
7887    nonzero_sizes = [0, 1, XS, S, M]
7888
7889    for input_t, nonzero_size in product(inputs, nonzero_sizes):
7890        yield SampleInput(input_t.clone().requires_grad_(requires_grad),
7891                          kwargs=dict(size=nonzero_size))
7892
7893def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs):
7894    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7895
7896    cases = (((S, S, S), (2,)),
7897             ((S, S, S), (S, 1)),
7898             ((S, S, S), (S, -1)))
7899
7900    for case in cases:
7901        shape, args = case
7902        yield SampleInput(make_arg(shape), args=args)
7903
7904def reference_inputs_chunk(op, device, dtype, requires_grad, **kwargs):
7905    yield from sample_inputs_chunk(op, device, dtype, requires_grad, **kwargs)
7906
7907    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
7908
7909    # shape x chunks x dim
7910    cases = (
7911        ((13, 9, 11), 17, -1),
7912        ((13, 9, 11), 11, -1),
7913        ((13,), 12, -1),
7914        ((15,), 12, -1),
7915        ((15,), 7, 0),
7916        ((15,), 9, 0),
7917        ((3, 7), 9, 1),
7918        ((3, 7), 9, 0),
7919        ((3, 7), 2, 0),
7920        ((3, 7), 3, 0),
7921        ((3, 7), 1, 0),
7922        ((3, 7), 1, 1),
7923        ((4, 4), 2, 0),
7924    )
7925
7926    for shape, chunks, dim in cases:
7927        yield SampleInput(make_arg(shape), args=(chunks, dim))
7928
7929def sample_inputs_kthvalue(op_info, device, dtype, requires_grad, **kwargs):
7930    def _tensor(shape, dtype=dtype, low=None, high=None):
7931        return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
7932
7933    test_cases = [
7934        ((S, S, S), (2,)),
7935        ((S, S, S), (2, 1,)),
7936        ((S, S, S), (2, -1,)),
7937        ((S, S, S), (2, 1, True,)),
7938        ((S, S, S), (2, -1, True,)),
7939        ((S,), (2, 0,)),
7940        ((S,), (2, 0, True,)),
7941        ((), (1,)),
7942        ((), (1, 0,)),
7943        ((), (1, 0, True)),
7944    ]
7945
7946    yield from (SampleInput(_tensor(tensor), *args) for tensor, args in test_cases)
7947
7948def error_inputs_kthvalue(op_info, device, **kwargs):
7949    # tests overlapping output fails
7950    t = make_tensor(10, dtype=torch.float32, device=device)
7951    indices = torch.empty((), device=device, dtype=torch.long)
7952    yield ErrorInput(SampleInput(t, 5, out=(t, indices)),
7953                     error_regex="unsupported operation")
7954
7955    k_out_of_range_err = "selected number k out of range for dimension"
7956    yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3, 0),
7957                     error_regex=k_out_of_range_err)
7958    yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3),
7959                     error_regex=k_out_of_range_err)
7960    yield ErrorInput(SampleInput(torch.tensor(2, device=device), 3),
7961                     error_regex=k_out_of_range_err)
7962
7963def sample_inputs_dropout(op_info, device, dtype, requires_grad, *,
7964                          train=None, valid_input_dim=None, **kwargs):
7965    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
7966
7967    if valid_input_dim:
7968        cases = ((S,) * i for i in valid_input_dim)
7969    else:
7970        cases = ((S, S), (S,), ())
7971    p_vals = [0.0, 0.5, 1.0]
7972    # This is to handle special case for feature_alpha_dropout which has different
7973    # supported dtypes depending on `train` parameter
7974    training_vals = [train] if train is not None else [True, False]
7975
7976    for case, p, training in product(cases, p_vals, training_vals):
7977        yield SampleInput(make_arg(case), p=p, training=training)
7978    yield SampleInput(make_arg(case))
7979
7980def sample_inputs_dropout_backward(op_info, device, dtype, requires_grad, **kwargs):
7981    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
7982    make_mask = partial(make_tensor, device=device, dtype=torch.bool, requires_grad=False)
7983
7984    cases = ((S, S, S, S), (S,), ())
7985    scale_vals = [0.0, 1.0, 2.0]
7986
7987    for case, scale in product(cases, scale_vals):
7988        yield SampleInput(make_arg(case), make_mask(case), scale)
7989
7990def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs):
7991    def make_input(shape):
7992        return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
7993
7994    def make_long_input(shape, *, low, high, noncontiguous=False):
7995        return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high,
7996                           noncontiguous=noncontiguous)
7997
7998    def make_per_sample_weight(flag, idx):
7999        # a tensor of float / double weights, or None
8000        # to indicate all weights should be taken to be 1
8001        if flag:
8002            return make_input(idx.shape)
8003        return None
8004
8005    offsets = torch.tensor([0, 3], device=device, dtype=torch.long)
8006    for generate_per_sample_weight in (True, False):
8007        for mode in ('sum', 'mean', 'max'):
8008            # per_sample_weights is only supported for mode='sum' (got mode='****')
8009            if generate_per_sample_weight and mode in ('mean', 'max'):
8010                continue
8011
8012            # 1-D index tensor
8013            idx = make_long_input((S,), low=0, high=M)
8014            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8015            yield SampleInput(make_input((M, S)), args=(idx,),
8016                              kwargs={'offsets': offsets, 'mode': mode,
8017                                      'per_sample_weights': per_sample_weights})
8018
8019            idx = make_long_input((S,), low=0, high=M, noncontiguous=True)
8020            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8021            yield SampleInput(make_input((M, S)), args=(idx,),
8022                              kwargs={'offsets': offsets, 'mode': mode,
8023                                      'per_sample_weights': per_sample_weights})
8024
8025            # bag with zero length
8026            idx = make_long_input((S,), low=0, high=M, noncontiguous=True)
8027            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8028            yield SampleInput(make_input((M, S)), args=(idx,),
8029                              kwargs={'offsets': torch.tensor([0, 0, 3], device=device, dtype=torch.long),
8030                                      'mode': mode,
8031                                      'per_sample_weights': per_sample_weights})
8032
8033            # 2-D index tensor
8034            idx = make_long_input((S, S), low=0, high=M)
8035            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8036            yield SampleInput(make_input((M, S)), args=(idx,),
8037                              kwargs={'mode': mode, 'per_sample_weights': per_sample_weights})
8038
8039            idx = make_long_input((S, S), low=0, high=M, noncontiguous=True)
8040            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8041            yield SampleInput(make_input((M, S)), args=(idx,),
8042                              kwargs={'mode': mode, 'per_sample_weights': per_sample_weights})
8043
8044            # The gradient vector at `padding_idx` is not updated.
8045            # Negative padding_idx
8046            idx = make_long_input((6,), low=0, high=S)
8047            idx[0] = 4
8048            idx[4] = 4
8049            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8050            yield SampleInput(make_input((S, S)), args=(idx,),
8051                              kwargs={'padding_idx': -1, 'offsets': offsets,
8052                                      'mode': mode, 'per_sample_weights': per_sample_weights},)
8053
8054            idx = make_long_input((3, 3), low=0, high=S)
8055            # Positive padding_idx
8056            idx[0, 0] = 2
8057            idx[1, 1] = 2
8058            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8059            yield SampleInput(make_input((S, S)), args=(idx,),
8060                              kwargs={'padding_idx': 2, 'mode': mode,
8061                                      'per_sample_weights': per_sample_weights},)
8062
8063            idx = make_long_input((6, ), low=0, high=S)
8064            weights = make_input((S, S))
8065            offsets_ = torch.tensor([0, 3, 6], device=device, dtype=torch.long)
8066            per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8067            yield SampleInput(weights, args=(idx,),
8068                              kwargs={'mode': mode, 'offsets': offsets_, 'include_last_offset': True},)
8069
8070            if not requires_grad:
8071                # Following inputs return different gradient from the numerical gradient.
8072                # This is expected and relevant tests are present in `test_nn.py`.
8073
8074                # Due to inplace renorming of weight, the numerical gradient doesn't match the
8075                # analytical gradient.
8076                idx = make_long_input((2, 2), low=0, high=S)
8077                weights = make_input((S, S)) * 2
8078                per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8079                yield SampleInput(weights, args=(idx,),
8080                                  kwargs={'max_norm': 1., 'mode': mode,
8081                                          'per_sample_weights': per_sample_weights},)
8082
8083                idx = make_long_input((6, ), low=0, high=S)
8084                weights = make_input((S, S)) * 2
8085                per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8086                yield SampleInput(weights, args=(idx,),
8087                                  kwargs={'max_norm': 1., 'norm_type': 1.0,
8088                                          'mode': mode, 'offsets': offsets,
8089                                          'per_sample_weights': per_sample_weights},)
8090
8091                if mode != 'max':
8092                    # Scale the gradient based on the inverse frequency of a particular index.
8093                    # Note : smax mode does not support sparse weights
8094                    idx = make_long_input((2, 2), low=0, high=S)
8095                    idx[0, 0] = 1
8096                    idx[0, 1] = 1
8097                    weights = make_input((S, S))
8098                    per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8099                    yield SampleInput(weights, args=(idx,),
8100                                      kwargs={'scale_grad_by_freq': True, 'mode': mode,
8101                                              'per_sample_weights': per_sample_weights},)
8102
8103                    # gradcheck not implemented for sparse tensors.
8104                    # Note : max mode does not support sparse weights
8105                    idx = make_long_input((6, ), low=0, high=S)
8106                    weights = make_input((S, S))
8107                    per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8108                    yield SampleInput(weights, args=(idx,),
8109                                      kwargs={'sparse': True, 'offsets': offsets,
8110                                              'mode': mode, 'per_sample_weights': per_sample_weights})
8111
8112                    idx = make_long_input((6, ), low=0, high=S)
8113                    idx[0] = 1  # freq more than 1
8114                    idx[1] = 1  # freq more than 1
8115                    idx[3] = 0  # padding_idx
8116                    weights = make_input((S, S)) * 2
8117                    per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx)
8118                    yield SampleInput(weights, args=(idx,),
8119                                      kwargs={'sparse': True, 'scale_grad_by_freq': True, 'padding_idx': 0,
8120                                              'max_norm': 1., 'offsets': offsets,
8121                                              'mode': mode, 'per_sample_weights': per_sample_weights})
8122
8123
8124def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs):
8125    def make_input(shape):
8126        return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
8127
8128    def make_long_input(shape, *, low, high):
8129        return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high)
8130
8131    # 0-D index tensor
8132    idx = make_long_input((), low=0, high=M)
8133    yield SampleInput(make_input((M, S)), args=(idx,),)
8134
8135    # 1-D index tensor
8136    idx = make_long_input((S,), low=0, high=M)
8137    yield SampleInput(make_input((M, S)), args=(idx,),)
8138
8139    # 2-D index tensor
8140    idx = make_long_input((S, S), low=0, high=M)
8141    yield SampleInput(make_input((M, S)), args=(idx,),)
8142
8143    if not requires_grad:
8144        # Following inputs return different gradient from the numerical gradient.
8145        # This is expected and relevant tests are present in `test_nn.py`.
8146
8147        # The gradient vector at `padding_idx` is not updated.
8148        idx = make_long_input((2, 2), low=0, high=S)
8149        idx[0, 0] = 2
8150        idx[1, 1] = 2
8151        yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': 2},)
8152
8153        idx = make_long_input((2, 2), low=0, high=S)
8154        idx[0, 0] = 4
8155        idx[1, 1] = 4
8156        yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': -1},)
8157
8158        # Due to inplace renorming of weight, the numerical gradient doesn't match the
8159        # analytical gradient.
8160        idx = make_long_input((2, 2), low=0, high=S)
8161        weights = make_input((S, S)) * 2
8162        yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1.},)
8163
8164        idx = make_long_input((2, 2), low=0, high=S)
8165        weights = make_input((S, S)) * 2
8166        yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1., 'norm_type': 1.0},)
8167
8168        # Scale the gradient based on the inverse frequency of a particular index.
8169        idx = make_long_input((2, 2), low=0, high=S)
8170        idx[0, 0] = 1
8171        idx[0, 1] = 1
8172        weights = make_input((S, S))
8173        yield SampleInput(weights, args=(idx,), kwargs={'scale_grad_by_freq': True},)
8174
8175        # gradcheck not implemented for sparse tensors.
8176        idx = make_long_input((2, 2), low=0, high=S)
8177        weights = make_input((S, S))
8178        yield SampleInput(weights, args=(idx,), kwargs={'sparse': True})
8179
8180        idx = make_long_input((3, 3), low=0, high=S)
8181        idx[0, 0] = 1  # freq more than 1
8182        idx[0, 1] = 1  # freq more than 1
8183        idx[1, 0] = 0  # padding_idx
8184        weights = make_input((S, S)) * 2
8185        yield SampleInput(weights, args=(idx,),
8186                          kwargs={'sparse': True, 'scale_grad_by_freq': True,
8187                                  'padding_idx': 0, 'max_norm': 1.})
8188
8189
8190def sample_inputs_one_hot(op_info, device, dtype, requires_grad, **kwargs):
8191    def make_input(shape, *, low, high):
8192        return make_tensor(shape, device=device, dtype=dtype, low=low, high=high, requires_grad=requires_grad)
8193
8194    shapes = ((), (S,), (L, M, S))
8195    num_classess = (-1, 10)
8196
8197    return (
8198        SampleInput(
8199            make_input(
8200                shape,
8201                low=0,
8202                high=10 if num_classes == -1 else num_classes // 2,
8203            ),
8204            kwargs=dict(num_classes=num_classes),
8205        )
8206        for shape, num_classes in itertools.product(shapes, num_classess)
8207    )
8208
8209
8210def sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs):
8211    rhs_requires_grad = kwargs.get('rhs_requires_grad', requires_grad)
8212    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8213
8214    # Although most losses also support the reduce and size_average combination instead of reduce, the former is
8215    # deprecated since 0.4.1 and thus is not tested
8216    shapes_and_kwargs = (
8217        ((), None),
8218        ((S,), dict(reduction="mean")),
8219        ((S,), dict(reduction="sum")),
8220        ((S,), dict(reduction="none")),
8221        ((S, S), None),
8222        ((S, S, S), None),
8223    )
8224
8225    for shape, kwargs in shapes_and_kwargs:
8226        yield SampleInput(_make_tensor(shape),
8227                          args=(_make_tensor(shape, requires_grad=rhs_requires_grad),),
8228                          kwargs=kwargs)
8229
8230def sample_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs):
8231    # We get better tests if we change the range of the values to something like [-2,2]
8232    # because for grid (second tensor argument) the "useful" range is [-1,1] and this way
8233    # you get a better combination of out-of-range and in-range test cases
8234    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad,
8235                           low=-2, high=2)
8236
8237    batch_size = 2
8238    num_channels = 3
8239    modes = ("bilinear", "nearest")
8240    align_cornerss = (False, True)
8241    padding_modes = ("zeros", "border", "reflection")
8242
8243    for dim in (2, 3):
8244
8245        modes_ = (*modes, "bicubic") if dim == 2 else modes
8246
8247        for mode, padding_mode, align_corners in itertools.product(modes_, padding_modes, align_cornerss):
8248            yield SampleInput(
8249                _make_tensor((batch_size, num_channels, *[S] * dim)),
8250                _make_tensor((batch_size, *[S] * dim, dim)),
8251                mode=mode,
8252                padding_mode=padding_mode,
8253                align_corners=align_corners,
8254            )
8255
8256def reference_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs):
8257
8258    batch_size = 2
8259    num_channels = 3
8260    height = 345
8261    width = 456
8262    modes = ("bilinear", "nearest", "bicubic")
8263    align_cornerss = (False, True)
8264    padding_modes = ('zeros', 'border', 'reflection')
8265
8266    # Create an affine transformation matrix
8267    a = torch.deg2rad(torch.tensor(45.0))
8268    ca, sa = torch.cos(a), torch.sin(a)  # rotation angles
8269    s1, s2 = 1.23, 1.34  # scales
8270
8271    theta = torch.tensor([[
8272        [ca / s1, sa, 0.0],
8273        [-sa, ca / s2, 0.0],
8274    ]], dtype=dtype, device=device)
8275    theta = theta.expand(batch_size, 2, 3).contiguous()
8276
8277    x = torch.arange(batch_size * num_channels * height * width, device=device)
8278    x = x.reshape(batch_size, num_channels, height, width).to(torch.uint8)
8279    x = x.to(dtype=dtype)
8280    x.requires_grad_(requires_grad)
8281
8282    for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss):
8283        grid = torch.nn.functional.affine_grid(
8284            theta, size=(batch_size, num_channels, height, width), align_corners=align_corners
8285        )
8286        yield SampleInput(
8287            x,
8288            grid,
8289            mode,
8290            padding_mode,
8291            align_corners,
8292        )
8293
8294def sample_inputs_grid_sampler_2d(op_info, device, dtype, requires_grad, **kwargs):
8295    # We get better tests if we change the range of the values to something like [-2,2]
8296    # because for grid (second tensor argument) the "useful" range is [-1,1] and this way
8297    # you get a better combination of out-of-range and in-range test cases
8298    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad,
8299                           low=-2, high=2)
8300
8301    batch_size = 2
8302    num_channels = 3
8303    modes = (0, 1, 2)
8304    align_cornerss = (False, True)
8305    padding_modes = (0, 1, 2)
8306
8307    for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss):
8308        yield SampleInput(
8309            _make_tensor((batch_size, num_channels, S, L)),
8310            _make_tensor((batch_size, M + 3, M, 2)),
8311            mode,
8312            padding_mode,
8313            align_corners,
8314        )
8315
8316def sample_inputs_cosine_embedding_loss(op_info, device, dtype, requires_grad, **kwargs):
8317    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8318
8319    def make_target(shape):
8320        shape = () if len(shape) == 1 else (shape[0], )
8321        t = torch.randint(0, 2, shape, device=device, dtype=torch.long)
8322        # Label with -1 or 1
8323        t = t * 2 - 1
8324        target = t.to(dtype=dtype).detach_().requires_grad_(requires_grad)
8325        return target
8326
8327    shapes = ((S, S), (S,))
8328    reductions = ('none', 'mean', 'sum')
8329    for s, r in product(shapes, reductions):
8330        yield SampleInput(
8331            make_input(s),
8332            args=(make_input(s), make_target(s)),
8333            kwargs=dict(reduction=r, margin=random.uniform(-1, 1))
8334        )
8335
8336def sample_inputs_ctc_loss(op_info, device, dtype, requires_grad, **kwargs):
8337    input_length = 50
8338    batch = 16
8339    num_char = 20
8340    target_length = 30
8341
8342    def make_log_probs(s):
8343        t = make_tensor(s, device=device, dtype=dtype)
8344        log_probs = t.log_softmax(2).to(device=device, dtype=dtype).detach().requires_grad_(requires_grad=requires_grad)
8345        return log_probs
8346
8347    reductions = ('none', 'mean', 'sum')
8348    zero_inf = (True, False)
8349    lengths_type = (list, torch.Tensor)
8350    for r, z, lt in product(reductions, zero_inf, lengths_type):
8351        log_probs = make_log_probs((input_length, batch, num_char))
8352        targets = torch.randint(1, num_char, (batch, target_length), dtype=torch.long, device=device)
8353        input_lengths = torch.full((batch, ), input_length, dtype=torch.long, device=device)
8354        target_lengths = torch.randint(10, target_length, (batch, ), dtype=torch.long, device=device)
8355
8356        # Dont generate int[] types if reduction = "Mean" since this results in non composite compliant calls
8357        # to ctc_loss.IntList since a tensor needs to be created from the target lengths.
8358        # Creating such a tensor requires the use of pointers to copy data from int[] -> torch.Tensor
8359        # e.g. via std::copy. Similarly symbolic/real tracing with fx will also not work
8360        if lt is list and r in ["none", "sum"]:
8361            input_lengths = input_lengths.tolist()
8362            target_lengths = target_lengths.tolist()
8363
8364        yield SampleInput(log_probs, args=(targets, input_lengths, target_lengths,), kwargs=dict(reduction=r, zero_infinity=z))
8365
8366def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs):
8367    shape = (2, 3)
8368    num_classes = shape[1]
8369    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8370    # FIXME: Derivative wrt. weight not implemented
8371    make_weight = partial(make_tensor, num_classes, device=device, dtype=dtype, requires_grad=False)
8372
8373    def make_target(shape, zeros=False):
8374        s = (shape[0], *shape[2:]) if len(shape) > 1 else ()
8375        if zeros:
8376            return torch.zeros(s, device=device, dtype=torch.long)
8377        else:
8378            return make_tensor(s,
8379                               low=0,
8380                               high=shape[1] if len(shape) > 1 else shape[0],
8381                               device=device,
8382                               dtype=torch.long)
8383
8384
8385    def gen_shape_kwargs():
8386        # Batched, non-batched and 2d
8387        shapes = (shape, (num_classes,), shape + (2, 2))
8388        reductions = ('none', 'mean', 'sum')
8389        for reduction, s in product(reductions, shapes):
8390            yield make_input(s), make_target(s), dict(reduction=reduction)
8391            yield make_input(s), make_target(s), dict(weight=make_weight(), reduction=reduction)
8392            yield make_input(s), make_target(s), dict(weight=make_weight(low=0), reduction=reduction)
8393            yield make_input(s), make_target(s), dict(weight=make_weight(high=0), reduction=reduction)
8394            t = make_target(s)
8395            ignore = num_classes // 2
8396            # If "mean", nll returns NaN, so it's not differentiable at those points
8397            if t.eq(ignore).all() and reduction == "mean":
8398                t.fill_(0)
8399            yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction)
8400            yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction, weight=make_weight())
8401            # Test ignoring all the targets
8402            # If "mean", nll returns NaN, so it's not differentiable at those points
8403            if reduction != "mean":
8404                yield make_input(s), make_target(s, zeros=True), dict(ignore_index=0, reduction=reduction)
8405
8406    for input, target, kwargs in gen_shape_kwargs():
8407        yield SampleInput(input, args=(target,), kwargs=kwargs)
8408
8409    target = torch.tensor([-1, 2], device=device, dtype=torch.long)
8410    yield SampleInput(make_input(shape), args=(target,), kwargs={'ignore_index': -1})
8411
8412
8413def sample_inputs_binary_cross_entropy_with_logits(
8414    op_info, device, dtype, requires_grad, **kwargs
8415):
8416    make = partial(make_tensor, device=device, dtype=dtype)
8417    make_prob = partial(make, low=0, high=1)
8418    reductions = ("mean", "sum", "none")
8419
8420    def make_weight_shape_kwargs():
8421        kwargs = []
8422        for shape in ((1,), (1, S), (S), (S, S)):
8423            kwargs.extend([((S, S), dict(reduction=reduction, weight=make(shape))) for reduction in reductions])
8424        return kwargs
8425
8426    shapes_and_kwargs = [
8427        *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))],
8428        *[((S, S), dict(reduction=reduction)) for reduction in reductions],
8429        *make_weight_shape_kwargs(),
8430        *[((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions],
8431        *[((S, S), dict(reduction=reduction, weight=make((S, S)), pos_weight=make((S,), low=0))) for reduction in reductions],
8432    ]
8433
8434    for shape, kwargs in shapes_and_kwargs:
8435        yield SampleInput(
8436            make(shape, requires_grad=requires_grad),
8437            args=(make_prob(shape, requires_grad=requires_grad),),
8438            kwargs=kwargs,
8439        )
8440
8441def sample_inputs_argwhere(op_info, device, dtype, requires_grad, **kwargs):
8442    yield SampleInput(torch.tensor([1, 0, 2, 0], dtype=dtype, device=device, requires_grad=requires_grad))
8443    mask = torch.tensor([[0, 1, 0, 1, 0],
8444                         [1, 1, 1, 1, 0],
8445                         [0, 0, 0, 1, 0],
8446                         [1, 0, 1, 1, 0],
8447                         [1, 0, 0, 1, 0]], dtype=torch.bool, device=device)
8448    t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad)
8449    t[mask] = 0
8450    yield SampleInput(t)
8451
8452    t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True)
8453    t[mask] = 0
8454    yield SampleInput(t)
8455
8456    t = make_tensor((S, 0), dtype=dtype, device=device, requires_grad=requires_grad)
8457    yield SampleInput(t)
8458
8459    yield SampleInput(torch.zeros((S,), dtype=dtype, device=device, requires_grad=requires_grad))
8460    yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad))
8461
8462def _generate_sample_shape_reduction():
8463    shapes = ((S,), (S, S), (S, S, S))
8464    reductions = ('none', 'mean', 'sum')
8465    yield from product(shapes, reductions)
8466
8467def sample_inputs_gaussian_nll_loss(op_info, device, dtype, requires_grad, **kwargs):
8468    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8469    # Set low slightly above 0 so gradcheck doesn't accidentally dip below 0
8470    make_var = partial(make_tensor, low=0.1, device=device, dtype=dtype, requires_grad=requires_grad)
8471
8472    def gen_shape(shape):
8473        yield shape
8474        # Broadcast
8475        yield (*shape[:-1], 1)
8476        yield shape[:-1]
8477
8478    def gen_shape_kwargs():
8479        for s, r in _generate_sample_shape_reduction():
8480            for t_s, v_s in product(gen_shape(s), gen_shape(s)):
8481                yield _make_tensor(s), _make_tensor(t_s), make_var(v_s), dict(reduction=r)
8482                yield (
8483                    _make_tensor(s), _make_tensor(t_s), make_var(v_s),
8484                    dict(full=True, reduction=r)
8485                )
8486                yield (
8487                    _make_tensor(s), _make_tensor(t_s), make_var(v_s),
8488                    dict(eps=random.uniform(1e-6, 1e-3), reduction=r)
8489                )
8490                yield (
8491                    _make_tensor(s), _make_tensor(t_s), make_var(v_s),
8492                    dict(full=True, eps=random.uniform(1e-6, 1e-3), reduction=r)
8493                )
8494
8495    for input, target, var, kwargs in gen_shape_kwargs():
8496        yield SampleInput(input, args=(target, var, ), kwargs=kwargs)
8497
8498def error_inputs_gaussian_nll_loss(op_info, device, **kwargs):
8499    _make = partial(make_tensor, device=device, dtype=torch.float32)
8500
8501    # invalid reduction value
8502    yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 3), low=0), reduction="abc"),
8503                     error_type=ValueError, error_regex="abc is not valid")
8504
8505    # var is of incorrect shape
8506    yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 2), low=0)),
8507                     error_type=ValueError, error_regex="var is of incorrect size")
8508
8509    # target is of incorrect shape
8510    yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 2), _make((10, 2, 3), low=0)),
8511                     error_type=RuntimeError,
8512                     error_regex=(r"The size of tensor a \(3\) must match the size of tensor b \(2\) "
8513                                  r"at non-singleton dimension 2"))
8514
8515def _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs):
8516    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8517
8518    for s, r in _generate_sample_shape_reduction():
8519        yield _make_tensor(s), _make_tensor(s), dict(reduction=r)
8520
8521def sample_inputs_hinge_embedding_loss(op_info, device, dtype, requires_grad, **kwargs):
8522    for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs):
8523        # target should contain either 1 or -1 as per docs
8524        mask = torch.rand_like(target) > 0.5
8525        target[mask] = 1
8526        target[~mask] = -1
8527        d['margin'] = random.uniform(-9, 9)
8528        yield SampleInput(input, args=(target, ), kwargs=d)
8529
8530    # scalar input and target.
8531    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8532    yield SampleInput(_make_tensor(()), args=(_make_tensor(()), ))
8533
8534def error_inputs_hinge_embedding_loss(op, device, **kwargs):
8535    make_input = partial(make_tensor, device=device, dtype=torch.float32)
8536    # invalid reduction value
8537    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}),
8538                     error_type=ValueError, error_regex='is not a valid value')
8539
8540def reference_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs):
8541    yield from sample_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs)
8542    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8543
8544    for reduction in ('sum', 'mean', 'none'):
8545        if dtype.is_floating_point:  # only supports ints and floats
8546            # NaN propagation
8547            inp = make_input((10, ))
8548            inp[2] = float('nan')
8549            target = make_input((10, ))
8550            # target should contain either 1 or -1 as per docs
8551            mask = torch.rand_like(target) > 0.5
8552            target[mask] = -1
8553            target[~mask] = 1
8554            yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction})
8555
8556            # Inf Handling
8557            inp = make_input((10, ))
8558            inp[4] = float('inf')
8559            target = make_input((10, ))
8560            mask = torch.rand_like(target) > 0.5
8561            target[mask] = -1
8562            target[~mask] = 1
8563            yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction})
8564
8565        # Broadcasting
8566        inp = make_input((5, 5))
8567        target = make_input((1, 5))
8568        mask = torch.rand_like(target) > 0.5
8569        target[mask] = -1
8570        target[~mask] = 1
8571        yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction})
8572
8573def sample_inputs_huber_loss(op_info, device, dtype, requires_grad, **kwargs):
8574    for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs):
8575        d['delta'] = random.uniform(1e-3, 9)
8576        yield SampleInput(input, args=(target, ), kwargs=d)
8577
8578def error_inputs_huber_loss(op, device, **kwargs):
8579    make_input = partial(make_tensor, device=device, dtype=torch.float32)
8580    # invalid reduction value
8581    err = 'is not a valid value for reduction'
8582    yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}),
8583                     error_type=ValueError, error_regex=err)
8584    # delta <= 0
8585    for delta in (0, -1):
8586        err = 'huber_loss does not support non-positive values for delta.'
8587        yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'delta': delta}),
8588                         error_type=RuntimeError, error_regex=err)
8589
8590def sample_inputs_poisson_nll_loss(op_info, device, dtype, requires_grad, **kwargs):
8591    _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8592
8593    def gen_shape_kwargs():
8594        for s, r in _generate_sample_shape_reduction():
8595            for li in (True, False):
8596                for f in (True, False):
8597                    i1 = _make_tensor(s)
8598                    i2 = _make_tensor(s)
8599                    # For Poisson NLL Loss,
8600                    # target is assumed to be from
8601                    # Poisson Distribution which
8602                    # always has positive samples
8603                    t1 = _make_tensor(s, low=0)
8604                    t2 = _make_tensor(s, low=0)
8605
8606                    if not li:
8607                        i1.abs_()
8608                        i2.abs_()
8609                    t1.abs_()
8610                    t2.abs_()
8611
8612                    yield (
8613                        i1, t1,
8614                        dict(log_input=li, full=f, reduction=r)
8615                    )
8616                    yield (
8617                        i2, t2,
8618                        dict(log_input=li, full=f,
8619                             eps=random.uniform(1e-8, 1e-3),
8620                             reduction=r)
8621                    )
8622
8623    for input, target, kwargs in gen_shape_kwargs():
8624        yield SampleInput(input, args=(target, ), kwargs=kwargs)
8625
8626    # test INT_TO_FLOAT promotion
8627    if dtype.is_complex:
8628        for d in (torch.bool, torch.int64):
8629            yield SampleInput(_make_tensor(dtype=dtype), args=(_make_tensor(dtype=d),))
8630            yield SampleInput(_make_tensor(dtype=d), args=(_make_tensor(dtype=dtype),))
8631
8632def error_inputs_poisson_nll_loss(op_info, device, **kwargs):
8633    make = partial(make_tensor, device=device, dtype=torch.float32)
8634
8635    # invalid reduction value
8636    yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),),
8637                     kwargs={'reduction': 'abc'}),
8638                     error_type=ValueError,
8639                     error_regex='abc is not a valid value for reduction')
8640    # invalid input shapes
8641    yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)),
8642                     error_regex=(r'(Attempting to broadcast a dimension of length|'
8643                                  r'The size of tensor a \(5\) must match the '
8644                                  r'size of tensor b \(4\) at non-singleton '
8645                                  r'dimension 1)'))
8646
8647def error_inputs_soft_margin_loss(op_info, device, **kwargs):
8648    make = partial(make_tensor, device=device, dtype=torch.float32)
8649
8650    # invalid reduction value
8651    yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),),
8652                     kwargs={'reduction': 'abc'}),
8653                     error_type=ValueError,
8654                     error_regex='abc is not a valid value for reduction')
8655    # invalid input shapes
8656    yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)),
8657                     error_regex=(r'(Attempting to broadcast a dimension of length|'
8658                                  r'The size of tensor a \(4\) must match the '
8659                                  r'size of tensor b \(5\) at non-singleton '
8660                                  r'dimension 1)'))
8661
8662def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, with_distance=False, **kwargs):
8663    make = partial(make_tensor, (S, M), device=device, dtype=dtype, requires_grad=requires_grad)
8664
8665    kwargss = (
8666        *[dict(margin=margin) for margin in (1e-6, 1.0, 10.0)],
8667        dict(swap=True),
8668        *[dict(reduction=reduction) for reduction in ("mean", "sum", "none")],
8669    )
8670
8671    for kwargs in kwargss:
8672        input = make()
8673        args = (make(), make())
8674        if with_distance:
8675            kwargs["distance_function"] = torch.nn.PairwiseDistance()
8676        yield SampleInput(input, args=args, kwargs=kwargs)
8677
8678def error_inputs_triplet_margin_loss(op_info, device, **kwargs):
8679    make_input = partial(make_tensor, device=device, dtype=torch.float32)
8680
8681    samples = (
8682        # input, args, kwargs, error_type, error_regex
8683        # invalid reduction
8684        (make_input(3, 4), (make_input(3, 4), make_input(3, 4)),
8685         dict(reduction="abc"),
8686         ValueError, "abc is not a valid value for reduction"),
8687
8688        # invalid margin
8689        (make_input(3, 4), (make_input(3, 4), make_input(3, 4)),
8690         dict(margin=-1.0),
8691         ValueError, "margin must be greater than 0, got -1.0"),
8692
8693        # shape mismatch
8694        (make_input(3, 5), (make_input(3, 4), make_input(3, 4)),
8695         {},
8696         RuntimeError,
8697         (r'(Attempting to broadcast a dimension of length|'
8698          r"The size of tensor a \(5\) must match the size of tensor b \(4\) "
8699          r"at non-singleton dimension 1)")),
8700        (make_input(3, 4), (make_input(3, 5), make_input(3, 4)),
8701         {},
8702         RuntimeError,
8703         (r'(Attempting to broadcast a dimension of length|'
8704          r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
8705          r"at non-singleton dimension 1)")),
8706        (make_input(3, 4), (make_input(3, 4), make_input(3, 5)),
8707         {},
8708         RuntimeError,
8709         (r'(Attempting to broadcast a dimension of length|'
8710          r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
8711          r"at non-singleton dimension 1)")),
8712
8713        # different dimensions
8714        (make_input(3,), (make_input(3, 4), make_input(3, 4)),
8715         {},
8716         RuntimeError,
8717         (r"The anchor, positive, and negative tensors are expected to have "
8718          r"the same number of dimensions, but got: anchor 1D, positive 2D, "
8719          r"and negative 2D inputs")),
8720        (make_input(3, 4), (make_input(3,), make_input(3, 4)),
8721         {},
8722         RuntimeError,
8723         (r"The anchor, positive, and negative tensors are expected to have "
8724          r"the same number of dimensions, but got: anchor 2D, positive 1D, "
8725          r"and negative 2D inputs")),
8726        (make_input(3, 4), (make_input(3, 4), make_input(3,)),
8727         {},
8728         RuntimeError,
8729         (r"The anchor, positive, and negative tensors are expected to have "
8730          r"the same number of dimensions, but got: anchor 2D, positive 2D, "
8731          r"and negative 1D inputs")),
8732    )
8733
8734    for input, args, kwargs, error_type, error_regex in samples:
8735        yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
8736                         error_type=error_type, error_regex=error_regex)
8737
8738def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
8739    make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad)
8740    make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad)
8741    make_scale = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
8742    M, N, K = 15, 32, 16
8743    samples = []
8744    # two e4m3
8745    mat1 = make_mat_e4m3((M, K))
8746    mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
8747    scale1 = make_scale((1,))
8748    scale2 = make_scale((1,))
8749    samples.append(SampleInput(mat1, mat2, scale1, scale2))
8750    # mat1 e4m3 mat2 e5m2
8751    mat1 = make_mat_e4m3((M, K))
8752    mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
8753    scale1 = make_scale((1,))
8754    scale2 = make_scale((1,))
8755    samples.append(SampleInput(mat1, mat2, scale1, scale2))
8756    # mat1 e5m2 mat2 e4m3
8757    mat1 = make_mat_e5m2((M, K))
8758    mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
8759    scale1 = make_scale((1,))
8760    scale2 = make_scale((1,))
8761    samples.append(SampleInput(mat1, mat2, scale1, scale2))
8762
8763    yield from samples
8764
8765def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
8766    make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8767    batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8
8768    num_heads_q_gqa, num_heads_kv_gqa = 32, 8
8769
8770    dim_3_q_shape = (batch, seq_q, head_dim)
8771    dim_3_kv_shape = (batch, seq_kv, head_dim)
8772    dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
8773    dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
8774
8775    broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim))
8776
8777    qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
8778    samples = []
8779    gqa_options = [False] if TEST_WITH_ROCM else [True, False]  # TODO: GQA support
8780    if TEST_WITH_ROCM and dtype == torch.float32:
8781        causal_options = [False]  # FIXME: Large errors with causal+fp32
8782    else:
8783        causal_options = [True, False]
8784    for qkv_shape, is_causal, dropout_p, enable_gqa in product(
8785            qkv_shapes, causal_options, [0.0, 0.5], gqa_options):
8786        shape_q, shape_kv = qkv_shape
8787        samples.append(SampleInput(
8788            make(shape_q),
8789            make(shape_kv),
8790            make(shape_kv),
8791            is_causal=is_causal,
8792            dropout_p=dropout_p
8793        ))
8794
8795    # Add non standard shapes
8796    diff_v_head_dim = SampleInput(
8797        make((batch, num_heads, seq_q, head_dim)),
8798        make((batch, num_heads, seq_kv, head_dim)),
8799        make((batch, num_heads, seq_kv, head_dim + 8)),
8800        is_causal=is_causal,
8801        dropout_p=dropout_p
8802    )
8803
8804    # Add an attn_mask
8805    samples.append(
8806        SampleInput(
8807            make((batch, num_heads, seq_q, head_dim)),
8808            make((batch, num_heads, seq_kv, head_dim)),
8809            make((batch, num_heads, seq_kv, head_dim)),
8810            attn_mask=make((seq_q, seq_kv)),
8811            is_causal=False,
8812            dropout_p=0.0)
8813    )
8814
8815    if not TEST_WITH_ROCM:
8816        samples.append(
8817            SampleInput(
8818                make((batch, num_heads_q_gqa, seq_q, head_dim)),
8819                make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
8820                make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
8821                enable_gqa=True
8822            )
8823        )
8824
8825    yield from samples
8826
8827
8828def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_grad, **kwargs):
8829    make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8830    batch, num_heads, head_dim = 4, 4, 8
8831    seq_q = 11
8832    seq_kv = 32
8833
8834    dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
8835    dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
8836
8837    qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
8838    samples = []
8839    mask_types = [1, 2]  # UpperLeft, LowerRight
8840    scales = [None, 1.0]
8841
8842    for qkv_shape, is_causal, dropout_p, mask_type, scale in product(
8843            qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales):
8844        shape_q, shape_kv = qkv_shape
8845        samples.append(SampleInput(
8846            make(shape_q).transpose(1, 2),
8847            make(shape_kv).transpose(1, 2),
8848            make(shape_kv).transpose(1, 2),
8849            bias=None,
8850            cu_seqlens_q=None,
8851            cu_seqlens_k=None,
8852            max_seqlen_q=None,
8853            max_seqlen_k=None,
8854            dropout_p=dropout_p,
8855            custom_mask_type=mask_type,
8856            compute_log_sumexp=requires_grad,
8857            scale=scale,
8858            seqlen_k=None
8859        ))
8860
8861    # Add non standard shapes
8862    diff_v_head_dim = SampleInput(
8863        make((batch, seq_q, num_heads, head_dim)),
8864        make((batch, seq_kv, num_heads, head_dim)),
8865        make((batch, seq_kv, num_heads, head_dim + 8)),
8866        bias=None,
8867        cu_seqlens_q=None,
8868        cu_seqlens_k=None,
8869        max_seqlen_q=None,
8870        max_seqlen_k=None,
8871        dropout_p=dropout_p,
8872        custom_mask_type=0,  # No Mask
8873        compute_log_sumexp=requires_grad,
8874        scale=None,
8875        seqlen_k=None
8876    )
8877
8878    # Add an attn_mask
8879    samples.append(
8880        SampleInput(
8881            make((batch, seq_q, num_heads, head_dim)),
8882            make((batch, seq_kv, num_heads, head_dim)),
8883            make((batch, seq_kv, num_heads, head_dim)),
8884            bias=make(batch, num_heads, seq_q, seq_kv),
8885            cu_seqlens_q=None,
8886            cu_seqlens_k=None,
8887            max_seqlen_q=None,
8888            max_seqlen_k=None,
8889            dropout_p=dropout_p,
8890            custom_mask_type=0,  # No Mask
8891            compute_log_sumexp=requires_grad,
8892            scale=None,
8893            seqlen_k=None
8894        )
8895    )
8896
8897    # jagged (with query/keys offsets)
8898    cu_seqlens_k = torch.arange(-1, 32 * 2 + 1, 2, dtype=torch.int32, device=device)
8899    cu_seqlens_k[-1] = 62
8900    cu_seqlens_k[0] = 0
8901    samples.append(
8902        SampleInput(
8903            make((32, 2, 64)).view(-1, 8, 8).unsqueeze(0),
8904            make((64, 64)).view(-1, 8, 8).unsqueeze(0),
8905            make((64, 64)).view(-1, 8, 8).unsqueeze(0),
8906            bias=None,
8907            cu_seqlens_q=torch.arange(0, 32 * 2 + 2, 2, dtype=torch.int32, device=device),
8908            cu_seqlens_k=cu_seqlens_k,
8909            max_seqlen_q=2,
8910            max_seqlen_k=2,
8911            dropout_p=0.0,
8912            custom_mask_type=0,  # No Mask
8913            compute_log_sumexp=requires_grad,
8914            scale=None,
8915            seqlen_k=None,
8916        )
8917    )
8918
8919    yield from samples
8920
8921def sample_inputs_flash_attention_forward(op_info, device, dtype, requires_grad, **kwargs):
8922    make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8923    batch, num_heads, head_dim = 4, 4, 8
8924    seq_q = 11
8925    seq_kv = 32
8926
8927    dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
8928    dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
8929
8930    qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]
8931    samples = []
8932    scales = [None, 1.0]
8933
8934    for qkv_shape, is_causal, dropout_p, scale in product(
8935            qkv_shapes, [True, False], [0.0, 0.5], scales):
8936        shape_q, shape_kv = qkv_shape
8937        samples.append(SampleInput(
8938            make(shape_q).transpose(1, 2),
8939            make(shape_kv).transpose(1, 2),
8940            make(shape_kv).transpose(1, 2),
8941            cum_seq_q=None,
8942            cum_seq_k=None,
8943            max_q=seq_q,
8944            max_k=seq_kv,
8945            dropout_p=dropout_p,
8946            is_causal=is_causal,
8947            return_debug_mask=False,
8948            scale=scale,
8949        ))
8950
8951    yield from samples
8952
8953def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs):
8954    make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8955
8956    shape = (3,)
8957    batched_shape = (2, *shape)
8958    shapes_and_kwargs = [
8959        (shape, None),
8960        (batched_shape, None),
8961        (shape, dict(keepdim=True)),
8962        (batched_shape, dict(keepdim=True)),
8963        (shape, dict(p=5.0)),
8964        (shape, dict(p=-1.0)),
8965        (shape, dict(eps=1.0)),
8966    ]
8967
8968    return (
8969        SampleInput(make(shape), args=(make(shape),), kwargs=kwargs) for shape, kwargs in shapes_and_kwargs
8970    )
8971
8972def sample_inputs_pixel_shuffle(op_info, device, dtype, requires_grad, **kwargs):
8973    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8974    yield from (
8975        SampleInput(make_arg((1, 9, 2, 2)), upscale_factor=upscale_factor)
8976        for upscale_factor in (1, 3)
8977    )
8978    yield from (
8979        SampleInput(make_arg(shape), upscale_factor=1)
8980        for shape in [
8981            (1, 0, 1, 1),
8982            (1, 1, 0, 1),
8983            (1, 1, 1, 0),
8984        ]
8985    )
8986
8987def sample_inputs_pixel_unshuffle(op_info, device, dtype, requires_grad, **kwargs):
8988    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
8989    yield from (
8990        SampleInput(make_arg((1, 1, 6, 6)), downscale_factor=downscale_factor)
8991        for downscale_factor in (1, 3)
8992    )
8993    yield from (
8994        SampleInput(make_arg(shape), downscale_factor=1)
8995        for shape in [
8996            (1, 0, 1, 1),
8997            (1, 1, 0, 1),
8998            (1, 1, 1, 0),
8999        ]
9000    )
9001
9002def sample_inputs_channel_shuffle(op_info, device, dtype, requires_grad, **kwargs):
9003    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
9004
9005    shapes_groups = [
9006        ((1, 4, 10, 10), 2),
9007        ((2, 6, 8, 8), 3),
9008        ((2, 8, 5, 5), 4),
9009    ]
9010
9011    yield from (
9012        SampleInput(make_arg(shape), args=(groups,))
9013        for shape, groups in shapes_groups
9014    )
9015
9016def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs):
9017    make = partial(make_tensor, device=device, dtype=dtype)
9018    # Lower bounds must be greater than 'eps' defined in gradcheck.py::gradgradcheck() -> eps
9019    # otherwise perturbation calculation causes Tensor value to become negative triggering
9020    # a device-side hardware assertion
9021    make_prob = partial(make, low=1e-6, high=1)
9022
9023    reductions = ("mean", "sum", "none")
9024
9025    shapes_and_kwargs = [
9026        *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))],
9027        *[((S, S), dict(reduction=reduction)) for reduction in reductions],
9028        *[((S, S), dict(reduction=reduction, weight=make((S, S)))) for reduction in reductions],
9029    ]
9030
9031    if logits:
9032        shapes_and_kwargs.extend(
9033            [((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions]
9034        )
9035
9036    for shape, kwargs in shapes_and_kwargs:
9037        yield SampleInput(
9038            (make if logits else make_prob)(shape, requires_grad=requires_grad),
9039            args=(make_prob(shape, requires_grad=requires_grad),),
9040            kwargs=kwargs,
9041        )
9042
9043def sample_inputs_allclose(op_info, device, dtype, requires_grad, **kwargs):
9044    sample_shapes = [(), (S), (S, S, S)]
9045    atols = [1e-2, 1e-16]
9046    rtols = [1e-1, 0.5]
9047    eps = 1e-8
9048    for s, rtol, atol in product(sample_shapes, rtols, atols):
9049        # close sample
9050        t = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
9051        close = (t + atol).detach().requires_grad_(requires_grad)
9052        yield SampleInput(t, close, rtol=rtol, atol=atol)
9053
9054        # random sample
9055        a = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
9056        b = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)
9057        yield SampleInput(a, b, rtol=rtol, atol=atol)
9058
9059
9060def sample_inputs_l1_loss(op_info, device, dtype, requires_grad, **kwargs):
9061    yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs)
9062
9063    # test COMPLEX_TO_FLOAT promotion
9064    if dtype.is_complex:
9065        make = partial(make_tensor, (), device=device, requires_grad=requires_grad)
9066        yield SampleInput(make(dtype=dtype), args=(make(dtype=torch.double),))
9067        yield SampleInput(make(dtype=torch.double), args=(make(dtype=dtype),))
9068
9069def error_inputs_l1_loss(op_info, device, **kwargs):
9070    make = partial(make_tensor, device=device, dtype=torch.float32)
9071
9072    # invalid reduction value
9073    yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),),
9074                     kwargs={'reduction': 'abc'}),
9075                     error_type=ValueError,
9076                     error_regex='abc is not a valid value for reduction')
9077    # invalid input shapes
9078    yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)),
9079                     error_regex=(r'(Attempting to broadcast a dimension of length|'
9080                                  r'The size of tensor a \(4\) must match the '
9081                                  r'size of tensor b \(5\) at non-singleton '
9082                                  r'dimension 1)')
9083                     )
9084
9085def sample_inputs_smooth_l1_loss(op_info, device, dtype, requires_grad, **kwargs):
9086    yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs)
9087
9088    make = partial(make_tensor, (S, S), device=device, dtype=dtype, requires_grad=requires_grad)
9089
9090    # This test case always triggers the smooth condition, since absolute difference of input and target
9091    # is smaller than beta
9092    yield SampleInput(make(low=0, high=2), args=(make(low=-2, high=0),), kwargs=dict(beta=5))
9093    yield SampleInput(make(), args=(make(),), kwargs=dict(beta=0))
9094
9095def sample_inputs_kl_div(op_info, device, dtype, requires_grad, **kwargs):
9096    # kl_div works with inputs in [0, 1] (aka the pdf of a probability measure)
9097    # Then log [0, 1] = (-inf, 0], so this is the log space
9098    make_arg = partial(make_tensor, low=0., device=device, dtype=dtype, requires_grad=requires_grad)
9099
9100    def make_log(shape):
9101        out = torch.nn.functional.log_softmax(make_arg(shape), -1)
9102        out.requires_grad_(requires_grad)
9103        return out
9104
9105    def make_prob(shape):
9106        out = torch.nn.functional.softmax(make_arg(shape), -1)
9107        out.requires_grad_(requires_grad)
9108        return out
9109
9110    shapes = ((2,), (2, 3))
9111    reductions = ("none", "mean", "batchmean", "sum")
9112    for shape, reduction, log_target in product(shapes, reductions, (True, False)):
9113        input = make_log(shape)
9114        target = make_log(shape) if log_target else make_prob(shape)
9115        yield SampleInput(input, args=(target,), kwargs=dict(reduction=reduction, log_target=log_target))
9116
9117def sample_inputs_pdist(op_info, device, dtype, requires_grad, **kwargs):
9118    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
9119
9120    yield from (SampleInput(make_input((n, m))) for n, m in itertools.product((1, S), repeat=2))
9121    yield from (SampleInput(make_input((S, S)), kwargs=dict(p=p)) for p in (0.0, 1.0, 2.0, 10.0, float("inf")))
9122
9123def reference_pdist(input, p=2):
9124    pdist = scipy.spatial.distance.pdist
9125    if p == 0:
9126        output = pdist(input, "hamming") * input.shape[1]
9127    elif p == float("inf"):
9128        output = pdist(input, lambda x, y: np.abs(x - y).max())
9129    else:
9130        output = pdist(input, "minkowski", p=p)
9131    return output.astype(input.dtype)
9132
9133def sample_inputs_diagflat(op_info, device, dtype, requires_grad, **kwargs):
9134    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
9135
9136    yield SampleInput(make_input(()))
9137    yield SampleInput(make_input((2,)))
9138    yield SampleInput(make_input((2, 2)))
9139    yield SampleInput(make_input((2,)), offset=1)
9140    yield SampleInput(make_input((2,)), offset=-1)
9141
9142def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
9143    unpool_name_to_pool_method_dict = {
9144        'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d,
9145        'nn.functional.max_unpool2d': torch.nn.functional.max_pool2d,
9146        'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d
9147    }
9148
9149    unpool_name_to_dim = {
9150        'nn.functional.max_unpool1d': 1,
9151        'nn.functional.max_unpool2d': 2,
9152        'nn.functional.max_unpool3d': 3
9153    }
9154
9155    unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()}
9156
9157    pool_dim = unpool_name_to_dim[op_info.name]
9158    pool_method = unpool_name_to_pool_method_dict[op_info.name]
9159
9160    pool_op_info = copy.copy(op_info)
9161    pool_op_info.name = unpool_to_pool_name_dict[op_info.name]
9162
9163    for sample in sample_inputs_max_pool(pool_op_info, device, dtype, requires_grad, **kwargs):
9164        # shapes (C, ...) do not work as of now,
9165        # see https://github.com/pytorch/pytorch/issues/68337
9166        # TODO: remove once the issue is resolved
9167        if sample.input.dim() != pool_dim + 2:
9168            continue
9169
9170        # No dilation > 1 for max_unpool,
9171        # see https://github.com/pytorch/pytorch/issues/68420
9172        if sample.kwargs['dilation'] != 1:
9173            continue
9174
9175        # Can't unpool without indices
9176        if sample.kwargs['return_indices']:
9177            pool, indices = pool_method(sample.input, **sample.kwargs)
9178            # arg has to be a leaf
9179            arg = pool.detach().requires_grad_(requires_grad)
9180            sample_kwargs = {
9181                'kernel_size': sample.kwargs['kernel_size'],
9182                'stride': sample.kwargs['stride'],
9183                'padding': sample.kwargs['padding'],
9184                # output_size could be None but we specify it explicitly
9185                # to compensate for the information lose in pool due
9186                # to the floor/ceil operation used to compute the shapes
9187                'output_size': sample.input.size()
9188            }
9189
9190            yield SampleInput(arg, args=(indices,), kwargs=sample_kwargs)
9191
9192def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwargs):
9193    for sample in sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs):
9194        indices = sample.args[0]
9195        # The samples for max_unpool are generated with max_pool.
9196        # It could be that a single element from the max_pool's
9197        # input is mapped to several locations in its output.
9198        # This situation leads to failed gradchecks because
9199        # the finite difference algorithm perturbs the elements
9200        # of the output one by one, and not in classes of
9201        # equivalences determined by whether two elements
9202        # in the output are coming from the same location in the
9203        # input (simply put, they have the same corresponding index).
9204        # So, there are two ways to resolve this issue:
9205        # 1. Extract a perturbation for one element and apply it all
9206        #    the elements from the same equivalence class, or
9207        # 2. Make sure that the equivalence classes are all singletons,
9208        # i.e. the index tensor has to be comprised of only unique
9209        # indices.
9210        # Here we go with the solution 2, the easiest of all.
9211        if indices.unique().numel() == indices.numel():
9212            yield sample
9213
9214def sample_inputs_multi_head_attention_forward(opinfo, device, dtype, requires_grad, **kwargs):
9215    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
9216
9217    if requires_grad:
9218        # backward tests would take too long to complete, causing the job timeout.
9219        bsz = 2
9220        is_batcheds = (True,)
9221        use_separate_proj_weights = (False,)
9222        emb_sizes = (2,)
9223        src_lens = (XS,)
9224        tgt_lens = (XS,)
9225        heads = (2,)
9226        dropouts = (0.5,)
9227        mask_types = ("2d",)
9228    else:
9229        bsz = 2
9230        is_batcheds = (False, True)
9231        use_separate_proj_weights = (False, True)
9232        emb_sizes = (2, 4)
9233        src_lens = (XS,)
9234        tgt_lens = (XS, S)
9235        heads = (1, 2)
9236        dropouts = (0.0, 0.5)
9237        mask_types = (None, "2d", "3d")
9238
9239    for is_batched, use_separate_proj_weight, mask_type, emb_size, src_len, tgt_len, num_heads, dropout_p in itertools.product(
9240        is_batcheds, use_separate_proj_weights, mask_types, emb_sizes, src_lens, tgt_lens, heads, dropouts
9241    ):
9242        attn_mask = None
9243        if mask_type == "2d":
9244            attn_mask = make_input(src_len, tgt_len)
9245        elif mask_type == "3d":
9246            attn_mask = make_input((bsz if is_batched else 1) * num_heads, src_len, tgt_len)
9247
9248        if is_batched:
9249            q = make_input(src_len, bsz, emb_size)
9250            k = make_input(tgt_len, bsz, emb_size)
9251            v = make_input(tgt_len, bsz, emb_size)
9252        else:
9253            q = make_input(src_len, emb_size)
9254            k = make_input(tgt_len, emb_size)
9255            v = make_input(tgt_len, emb_size)
9256        if use_separate_proj_weight:
9257            in_proj_weight = None
9258            q_proj_weight = make_input(emb_size, emb_size)
9259            k_proj_weight = make_input(emb_size, emb_size)
9260            v_proj_weight = make_input(emb_size, emb_size)
9261        else:
9262            in_proj_weight = make_input(emb_size * 3, emb_size)
9263            q_proj_weight = None
9264            k_proj_weight = None
9265            v_proj_weight = None
9266
9267        bias_k = make_input(emb_size)
9268        bias_v = make_input(emb_size)
9269        in_proj_bias = make_input(emb_size * 3)
9270        out_proj_weight = make_input(emb_size, emb_size)
9271        out_proj_bias = make_input(emb_size)
9272        sample_args = (
9273            k, v, emb_size, num_heads, in_proj_weight,
9274            in_proj_bias, bias_k, bias_v, False,
9275            dropout_p, out_proj_weight, out_proj_bias
9276        )
9277        sample_kwargs = {
9278            "q_proj_weight" : q_proj_weight,
9279            "k_proj_weight" : k_proj_weight,
9280            "v_proj_weight" : v_proj_weight,
9281            "attn_mask" : attn_mask,
9282            "training" : True if dropout_p > 0.0 else False,
9283            "use_separate_proj_weight" : use_separate_proj_weight
9284        }
9285
9286        yield SampleInput(q, args=sample_args, kwargs=sample_kwargs)
9287
9288
9289# Includes some values such that N * N won't be a multiple of 4,
9290# which should ensure we test the vectorized and non-vectorized
9291# kernel code paths.
9292NUM_SIZE0_TENSORS = 10000
9293foreach_num_tensors = [20, 23] if not TEST_WITH_SLOW else [23, 30, 300]
9294_foreach_inputs_default_kwargs = {"noncontiguous": False, "same_size": False, "low": None, "high": None}
9295
9296
9297class ForeachRightmostArgType(enum.Enum):
9298    TensorList = enum.auto()
9299    ScalarList = enum.auto()
9300    Scalar = enum.auto()
9301    Tensor = enum.auto()
9302
9303
9304class ForeachSampleInput(SampleInput):
9305    # For TensorList <op> Scalar/Tensor, we compute the reference
9306    # by converting it into TensorList <op> ScalarList/TensorList and
9307    # then converting into multiple Tensor <op> Scalar/Tensor.
9308    # ref_args contains the args converted to TensorList <op> ScalarList/TensorList
9309    ref_args: Any
9310    disable_fastpath: bool
9311
9312    def __init__(self, *args, disable_fastpath=False, ref_args=None, **kwargs):
9313        super().__init__(*args, **kwargs)
9314        self.ref_args = ref_args or self.args
9315        self.disable_fastpath = disable_fastpath
9316
9317
9318class foreach_inputs_sample_func:
9319    def __init__(
9320        self,
9321        arity: int,
9322        rightmost_supports_scalar: bool,
9323        rightmost_supports_scalarlist: bool,
9324        rightmost_supports_tensor: bool = False,
9325    ) -> None:
9326        self.arity = arity
9327        self._set_rightmost_arg_types(
9328            rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor,
9329        )
9330        self._intersperse_empty = (True, False)
9331
9332    def _set_rightmost_arg_types(
9333        self,
9334        rightmost_supports_scalar: bool,
9335        rightmost_supports_scalarlist: bool,
9336        rightmost_supports_tensor: bool,
9337    ) -> None:
9338        self._rightmost_arg_types = [ForeachRightmostArgType.TensorList]
9339        if self.arity > 1:
9340            if rightmost_supports_scalar:
9341                self._rightmost_arg_types.append(ForeachRightmostArgType.Scalar)
9342            if rightmost_supports_scalarlist:
9343                self._rightmost_arg_types.append(ForeachRightmostArgType.ScalarList)
9344            if rightmost_supports_tensor:
9345                self._rightmost_arg_types.append(ForeachRightmostArgType.Tensor)
9346
9347    def _sample_rightmost_arg(
9348        self,
9349        opinfo,
9350        rightmost_arg_type,
9351        device,
9352        dtype,
9353        num_tensors,
9354        allow_higher_dtype_scalars,
9355        **_foreach_inputs_kwargs,
9356    ):
9357        if rightmost_arg_type == ForeachRightmostArgType.TensorList:
9358            return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)]
9359        if rightmost_arg_type == ForeachRightmostArgType.Tensor:
9360            return [make_tensor(
9361                (), device=device, dtype=dtype,
9362                noncontiguous=_foreach_inputs_kwargs["noncontiguous"],
9363                requires_grad=_foreach_inputs_kwargs.get("requires_grad", False),
9364            )]
9365        should_use_simpler_scalars = opinfo.name == "_foreach_pow" and dtype in (torch.float16, torch.bfloat16)
9366
9367        def sample_float():
9368            s = random.random()
9369            if should_use_simpler_scalars:
9370                return 1.0 if s > 0.5 else 2.0
9371            else:
9372                return 1.0 - s
9373
9374        high = 2 if should_use_simpler_scalars else 9
9375        if rightmost_arg_type == ForeachRightmostArgType.ScalarList:
9376            scalarlist_list = []
9377            scalarlist_list.append([random.randint(0, high) + 1 for _ in range(num_tensors)])
9378
9379            if allow_higher_dtype_scalars or dtype.is_floating_point:
9380                scalarlist_list.append([sample_float() for _ in range(num_tensors)])
9381            if allow_higher_dtype_scalars or dtype.is_complex:
9382                scalarlist_list.append([complex(sample_float(), sample_float()) for _ in range(num_tensors)])
9383                scalarlist_list.append([1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)])
9384                scalarlist_list.append([True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)])
9385            return scalarlist_list
9386        if rightmost_arg_type == ForeachRightmostArgType.Scalar:
9387            scalars = []
9388            scalars.append(random.randint(1, high + 1))
9389            if allow_higher_dtype_scalars or dtype.is_floating_point:
9390                scalars.append(sample_float())
9391            if allow_higher_dtype_scalars or dtype.is_complex:
9392                scalars.append(complex(sample_float(), sample_float()))
9393            scalars.append(True)
9394            return scalars
9395        raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}")
9396
9397    def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
9398        if self.arity == 1:
9399            if "foreach_abs" in opinfo.name and dtype in complex_types():
9400                return True
9401            # unary
9402            if opinfo.ref in (torch.abs, torch.neg):
9403                return False
9404            if opinfo.ref_inplace in (torch.Tensor.zero_,):
9405                return False
9406            return dtype in integral_types_and(torch.bool)
9407        if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor:
9408            return None
9409        if "foreach_pow" in opinfo.name and dtype in integral_types_and(torch.bool):
9410            return True
9411        if any(
9412                foreach_name in opinfo.name
9413                for foreach_name in ("foreach_clamp_max", "foreach_clamp_min", "foreach_maximum", "foreach_minimum")
9414        ) and dtype in integral_types_and(torch.bool):
9415            return True
9416        if rightmost_arg_type == ForeachRightmostArgType.TensorList:
9417            disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool)
9418            if "foreach_add" in opinfo.name and dtype == torch.bool:
9419                disable_fastpath = True
9420            return disable_fastpath
9421        elif rightmost_arg_type == ForeachRightmostArgType.Scalar:
9422            disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool)
9423            if isinstance(rightmost_arg, bool):
9424                disable_fastpath |= dtype == torch.bool
9425                if opinfo.ref in (torch.add, torch.mul):
9426                    disable_fastpath = False
9427            elif isinstance(rightmost_arg, int):
9428                disable_fastpath |= dtype == torch.bool
9429            elif isinstance(rightmost_arg, float):
9430                disable_fastpath |= dtype in integral_types_and(torch.bool)
9431            elif isinstance(rightmost_arg, complex):
9432                disable_fastpath |= dtype not in complex_types()
9433            else:
9434                raise AssertionError(f"Invalid scalar of type {rightmost_arg_type} - {rightmost_arg}")
9435            return disable_fastpath
9436        elif rightmost_arg_type == ForeachRightmostArgType.ScalarList:
9437            disable_fastpath = opinfo.ref == torch.div and dtype in integral_types_and(torch.bool)
9438            elmt_t = type(rightmost_arg[0])
9439            has_same_type = all(isinstance(v, elmt_t) for v in rightmost_arg)
9440            if not has_same_type:
9441                return dtype not in complex_types()
9442            if isinstance(rightmost_arg[0], bool):
9443                if ("foreach_add" in opinfo.name or "foreach_mul" in opinfo.name) and dtype == torch.bool:
9444                    disable_fastpath = False
9445            elif isinstance(rightmost_arg[0], int):
9446                disable_fastpath |= dtype == torch.bool
9447            elif isinstance(rightmost_arg[0], float):
9448                disable_fastpath |= dtype in integral_types_and(torch.bool)
9449            elif isinstance(rightmost_arg[0], complex):
9450                disable_fastpath |= dtype not in complex_types()
9451            else:
9452                raise AssertionError(f"Invalid scalarlist of {rightmost_arg}")
9453            return disable_fastpath
9454        else:
9455            raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}")
9456
9457    def _sample_kwargs(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
9458        kwargs = {}
9459        if rightmost_arg_type == ForeachRightmostArgType.TensorList and opinfo.supports_alpha_param:
9460            if dtype in integral_types_and(torch.bool):
9461                kwargs["alpha"] = 3
9462            elif dtype.is_complex:
9463                kwargs["alpha"] = complex(3, 3)
9464            else:
9465                kwargs["alpha"] = 3.14
9466        if self.arity > 1:
9467            kwargs["disable_fastpath"] = self._should_disable_fastpath(opinfo, rightmost_arg, rightmost_arg_type, dtype)
9468        return kwargs
9469
9470    def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
9471        assert "num_input_tensors" not in kwargs
9472        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
9473        _foreach_inputs_kwargs["requires_grad"] = requires_grad
9474        allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False)
9475        for rightmost_arg_type in self._rightmost_arg_types:
9476            zero_size_foreach_inputs_kwargs = copy.deepcopy(_foreach_inputs_kwargs)
9477            zero_size_foreach_inputs_kwargs["zero_size"] = True
9478            input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs)
9479            if self.arity > 1:
9480                args = [
9481                    sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs)
9482                    for _ in range(self.arity - 2)
9483                ]
9484                args.append(
9485                    self._sample_rightmost_arg(
9486                        opinfo,
9487                        ForeachRightmostArgType.TensorList,
9488                        device,
9489                        dtype,
9490                        NUM_SIZE0_TENSORS,
9491                        allow_higher_dtype_scalars=allow_higher_dtype_scalars,
9492                        **zero_size_foreach_inputs_kwargs,
9493                    )[0])
9494                kwargs = self._sample_kwargs(
9495                    opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype)
9496            else:
9497                args = []
9498                kwargs = {}
9499                if opinfo.ref in (torch.abs, torch.neg):
9500                    kwargs["disable_fastpath"] = False
9501                else:
9502                    kwargs["disable_fastpath"] = dtype in integral_types_and(torch.bool)
9503            yield ForeachSampleInput(input, *args, **kwargs)
9504
9505    def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
9506        num_input_tensors_specified = "num_input_tensors" in kwargs
9507        num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors
9508        assert isinstance(num_input_tensors, list)
9509        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
9510        _foreach_inputs_kwargs["requires_grad"] = requires_grad
9511        _foreach_inputs_kwargs["zero_size"] = False
9512        allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False)
9513
9514        # add empty tensor interspersion to test fully fixing #100701
9515        for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product(
9516                num_input_tensors, self._rightmost_arg_types, self._intersperse_empty):
9517            if intersperse_empty_tensors and (num_tensors != max(num_input_tensors) or str(device) == 'cpu'):
9518                # generate interspersed empty tensors for only 1 N on non-cpu device to lessen redundancy
9519                continue
9520            _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors
9521            input = sample_inputs_foreach(
9522                None, device, dtype, num_tensors, **_foreach_inputs_kwargs)
9523            args = []
9524            if self.arity > 1:
9525                args = [
9526                    sample_inputs_foreach(
9527                        None, device, dtype, num_tensors, **_foreach_inputs_kwargs)
9528                    for _ in range(self.arity - 2)
9529                ]
9530                rightmost_arg_list = self._sample_rightmost_arg(
9531                    opinfo, rightmost_arg_type, device, dtype, num_tensors, allow_higher_dtype_scalars,
9532                    **_foreach_inputs_kwargs)
9533                for rightmost_arg in rightmost_arg_list:
9534                    args.append(rightmost_arg)
9535                    kwargs = self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype)
9536                    ref_args = args
9537                    if rightmost_arg_type in (ForeachRightmostArgType.Scalar, ForeachRightmostArgType.Tensor):
9538                        ref_args = args[:-1] + [[args[-1] for _ in range(num_tensors)]]
9539                    sample = ForeachSampleInput(input, *args, ref_args=ref_args, **kwargs)
9540                    yield sample
9541                    args.pop()
9542            else:
9543                yield ForeachSampleInput(
9544                    input,
9545                    *args,
9546                    disable_fastpath=self._should_disable_fastpath(opinfo, None, None, dtype),
9547                )
9548
9549
9550class foreach_max_sample_func(foreach_inputs_sample_func):
9551    def __init__(
9552        self,
9553        arity: int,
9554        rightmost_supports_scalar: bool,
9555        rightmost_supports_scalarlist: bool,
9556        rightmost_supports_tensor: bool = False,
9557    ) -> None:
9558        super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor)
9559        self._intersperse_empty = (False,)
9560
9561    def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
9562        return []
9563
9564    def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
9565        return False
9566
9567
9568class foreach_norm_sample_func(foreach_inputs_sample_func):
9569    def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
9570        assert "num_input_tensors" not in kwargs
9571        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
9572        _foreach_inputs_kwargs["requires_grad"] = requires_grad
9573        for ord in (0, 1, 2, -1, -2, float('inf'), float('-inf')):
9574            input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs)
9575            disable_fastpath = True
9576            if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
9577                disable_fastpath = False
9578            yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath)
9579
9580    def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
9581        num_input_tensors = kwargs.pop("num_input_tensors", foreach_num_tensors)
9582        assert isinstance(num_input_tensors, list)
9583        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
9584        _foreach_inputs_kwargs["requires_grad"] = requires_grad
9585        _allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False)
9586
9587        for num_tensors, ord, out_dtype in product(
9588            num_input_tensors,
9589            (0, 1, 2, -1, -2, float('inf'), float('-inf')),
9590            (None,) + (torch.complex128,) if dtype in complex_types() else (torch.float64,),
9591        ):
9592            input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs)
9593            disable_fastpath = True
9594            if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
9595                disable_fastpath = False
9596            yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath, dtype=out_dtype)
9597
9598        # Also test nan propagation with a single tensor, but skip autograd testing
9599        if not requires_grad:
9600            nan_inputs = [
9601                [float('nan')],
9602                [float('nan'), 1.0],
9603                [1.0, float('nan')],
9604                [1.0, 2.0, 3.0, float('nan'), float('nan'), 7.0, float('nan'), float('nan'), -1.5, 6.0],
9605                [7.0, 3.0, float('nan'), float('nan'), -1.5, 6.0],
9606                [3.0, float('nan'), float('nan'), -1.5, 6.0],
9607            ]
9608            for input in nan_inputs:
9609                x = torch.tensor(input, device=device)
9610                disable_fastpath = True
9611                if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16):
9612                    disable_fastpath = False
9613                yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath)
9614
9615
9616class foreach_pointwise_sample_func(foreach_inputs_sample_func):
9617
9618    def __init__(
9619        self,
9620        arity: int = 3,
9621        rightmost_supports_scalar: bool = False,
9622        rightmost_supports_scalarlist: bool = False,
9623    ):
9624        super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist)
9625
9626    def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
9627        return dtype in integral_types_and(torch.bool) and opinfo.ref in (torch.addcmul,)
9628
9629    def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs):
9630        assert "num_input_tensors" not in kwargs
9631        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
9632        _foreach_inputs_kwargs["requires_grad"] = requires_grad
9633        # zero_size tensor
9634        input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs)
9635        args = [
9636            sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs)
9637            for _ in range(2)
9638        ]
9639        if "scalars" in kwargs:
9640            del kwargs["scalars"]
9641        kwargs.update(self._sample_kwargs(opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype))
9642        yield ForeachSampleInput(input, *args, **kwargs)
9643
9644    def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):
9645        num_input_tensors_specified = "num_input_tensors" in kwargs
9646        num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors
9647        assert isinstance(num_input_tensors, list)
9648        _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()}
9649        _foreach_inputs_kwargs["requires_grad"] = requires_grad
9650        allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False)
9651
9652        for num_tensors, rightmost_arg_type in itertools.product(num_input_tensors, self._rightmost_arg_types):
9653            input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs)
9654            args = [
9655                sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs)
9656                for _ in range(2 - int(rightmost_arg_type == ForeachRightmostArgType.TensorList))
9657            ]
9658            rightmost_arg_list = self._sample_rightmost_arg(
9659                opinfo,
9660                rightmost_arg_type,
9661                device,
9662                dtype,
9663                num_tensors,
9664                zero_size=False,
9665                allow_higher_dtype_scalars=allow_higher_dtype_scalars,
9666                **_foreach_inputs_kwargs,
9667            )
9668            for rightmost_arg in rightmost_arg_list:
9669                kwargs = {}
9670                if rightmost_arg_type == ForeachRightmostArgType.TensorList:
9671                    args.append(rightmost_arg)
9672                elif rightmost_arg_type in [ForeachRightmostArgType.Tensor, ForeachRightmostArgType.ScalarList]:
9673                    kwargs["scalars"] = rightmost_arg
9674                else:
9675                    kwargs["value"] = rightmost_arg
9676                kwargs.update(self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype))
9677                assert len(args) == 2, f"{len(args)=}"
9678                sample = ForeachSampleInput(input, *args, **kwargs)
9679                yield sample
9680                if rightmost_arg_type == ForeachRightmostArgType.TensorList:
9681                    args.pop()
9682
9683
9684foreach_unary_op_db: List[OpInfo] = [
9685    ForeachFuncInfo(
9686        'exp',
9687        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
9688        backward_requires_result=True,
9689        supports_autograd=True,
9690        supports_inplace_autograd=True,
9691        supports_forward_ad=True,
9692        decorators=(
9693            DecorateInfo(
9694                unittest.expectedFailure,
9695                "TestMeta",
9696                "test_dispatch_meta_inplace",
9697                dtypes=integral_types_and(torch.bool,),
9698            ),
9699            DecorateInfo(
9700                unittest.expectedFailure,
9701                "TestMeta",
9702                "test_dispatch_symbolic_meta_inplace",
9703                dtypes=integral_types_and(torch.bool,),
9704            ),
9705            DecorateInfo(
9706                unittest.expectedFailure,
9707                "TestMeta",
9708                "test_meta_inplace",
9709                dtypes=integral_types_and(torch.bool,),
9710            ),
9711        ),
9712    ),
9713    ForeachFuncInfo(
9714        'acos',
9715        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
9716        supports_autograd=True,
9717        supports_inplace_autograd=True,
9718        supports_forward_ad=True,
9719        decorators=(
9720            DecorateInfo(
9721                unittest.expectedFailure,
9722                "TestMeta",
9723                "test_dispatch_meta_inplace",
9724                dtypes=integral_types_and(torch.bool,),
9725            ),
9726            DecorateInfo(
9727                unittest.expectedFailure,
9728                "TestMeta",
9729                "test_dispatch_symbolic_meta_inplace",
9730                dtypes=integral_types_and(torch.bool,),
9731            ),
9732            DecorateInfo(
9733                unittest.expectedFailure,
9734                "TestMeta",
9735                "test_meta_inplace",
9736                dtypes=integral_types_and(torch.bool,),
9737            ),
9738        ),
9739    ),
9740    ForeachFuncInfo(
9741        'asin',
9742        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
9743        supports_autograd=True,
9744        supports_inplace_autograd=True,
9745        supports_forward_ad=True,
9746        decorators=(
9747            DecorateInfo(
9748                unittest.expectedFailure,
9749                "TestMeta",
9750                "test_dispatch_meta_inplace",
9751                dtypes=integral_types_and(torch.bool,),
9752            ),
9753            DecorateInfo(
9754                unittest.expectedFailure,
9755                "TestMeta",
9756                "test_dispatch_symbolic_meta_inplace",
9757                dtypes=integral_types_and(torch.bool,),
9758            ),
9759            DecorateInfo(
9760                unittest.expectedFailure,
9761                "TestMeta",
9762                "test_meta_inplace",
9763                dtypes=integral_types_and(torch.bool,),
9764            ),
9765        ),
9766    ),
9767    ForeachFuncInfo(
9768        'atan',
9769        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
9770        supports_autograd=True,
9771        supports_inplace_autograd=True,
9772        supports_forward_ad=True,
9773        decorators=(
9774            DecorateInfo(
9775                unittest.expectedFailure,
9776                "TestMeta",
9777                "test_dispatch_meta_inplace",
9778                dtypes=integral_types_and(torch.bool,),
9779            ),
9780            DecorateInfo(
9781                unittest.expectedFailure,
9782                "TestMeta",
9783                "test_dispatch_symbolic_meta_inplace",
9784                dtypes=integral_types_and(torch.bool,),
9785            ),
9786            DecorateInfo(
9787                unittest.expectedFailure,
9788                "TestMeta",
9789                "test_meta_inplace",
9790                dtypes=integral_types_and(torch.bool,),
9791            ),
9792        ),
9793    ),
9794    ForeachFuncInfo(
9795        'cos',
9796        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
9797        supports_autograd=True,
9798        supports_inplace_autograd=True,
9799        supports_forward_ad=True,
9800        decorators=(
9801            DecorateInfo(
9802                unittest.expectedFailure,
9803                "TestMeta",
9804                "test_dispatch_meta_inplace",
9805                dtypes=integral_types_and(torch.bool,),
9806            ),
9807            DecorateInfo(
9808                unittest.expectedFailure,
9809                "TestMeta",
9810                "test_dispatch_symbolic_meta_inplace",
9811                dtypes=integral_types_and(torch.bool,),
9812            ),
9813            DecorateInfo(
9814                unittest.expectedFailure,
9815                "TestMeta",
9816                "test_meta_inplace",
9817                dtypes=integral_types_and(torch.bool,),
9818            ),
9819        ),
9820    ),
9821    ForeachFuncInfo(
9822        'cosh',
9823        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
9824        supports_autograd=True,
9825        supports_inplace_autograd=True,
9826        supports_forward_ad=True,
9827        decorators=(
9828            DecorateInfo(
9829                unittest.expectedFailure,
9830                "TestMeta",
9831                "test_dispatch_meta_inplace",
9832                dtypes=integral_types_and(torch.bool,),
9833            ),
9834            DecorateInfo(
9835                unittest.expectedFailure,
9836                "TestMeta",
9837                "test_dispatch_symbolic_meta_inplace",
9838                dtypes=integral_types_and(torch.bool,),
9839            ),
9840            DecorateInfo(
9841                unittest.expectedFailure,
9842                "TestMeta",
9843                "test_meta_inplace",
9844                dtypes=integral_types_and(torch.bool,),
9845            ),
9846        ),
9847    ),
9848    ForeachFuncInfo(
9849        'log',
9850        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
9851        supports_autograd=True,
9852        supports_inplace_autograd=True,
9853        supports_forward_ad=True,
9854        decorators=(
9855            DecorateInfo(
9856                unittest.expectedFailure,
9857                "TestMeta",
9858                "test_dispatch_meta_inplace",
9859                dtypes=integral_types_and(torch.bool,),
9860            ),
9861            DecorateInfo(
9862                unittest.expectedFailure,
9863                "TestMeta",
9864                "test_dispatch_symbolic_meta_inplace",
9865                dtypes=integral_types_and(torch.bool,),
9866            ),
9867            DecorateInfo(
9868                unittest.expectedFailure,
9869                "TestMeta",
9870                "test_meta_inplace",
9871                dtypes=integral_types_and(torch.bool,),
9872            ),
9873        ),
9874    ),
9875    ForeachFuncInfo(
9876        'log10',
9877        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
9878        supports_autograd=True,
9879        supports_inplace_autograd=True,
9880        supports_forward_ad=True,
9881        decorators=(
9882            DecorateInfo(
9883                unittest.expectedFailure,
9884                "TestMeta",
9885                "test_dispatch_meta_inplace",
9886                dtypes=integral_types_and(torch.bool,),
9887            ),
9888            DecorateInfo(
9889                unittest.expectedFailure,
9890                "TestMeta",
9891                "test_dispatch_symbolic_meta_inplace",
9892                dtypes=integral_types_and(torch.bool,),
9893            ),
9894            DecorateInfo(
9895                unittest.expectedFailure,
9896                "TestMeta",
9897                "test_meta_inplace",
9898                dtypes=integral_types_and(torch.bool,),
9899            ),
9900        ),
9901    ),
9902    ForeachFuncInfo(
9903        'log2',
9904        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
9905        supports_autograd=True,
9906        supports_inplace_autograd=True,
9907        supports_forward_ad=True,
9908        decorators=(
9909            DecorateInfo(
9910                unittest.expectedFailure,
9911                "TestMeta",
9912                "test_dispatch_meta_inplace",
9913                dtypes=integral_types_and(torch.bool,),
9914            ),
9915            DecorateInfo(
9916                unittest.expectedFailure,
9917                "TestMeta",
9918                "test_dispatch_symbolic_meta_inplace",
9919                dtypes=integral_types_and(torch.bool,),
9920            ),
9921            DecorateInfo(
9922                unittest.expectedFailure,
9923                "TestMeta",
9924                "test_meta_inplace",
9925                dtypes=integral_types_and(torch.bool,),
9926            ),
9927        ),
9928    ),
9929    ForeachFuncInfo(
9930        'tan',
9931        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
9932        backward_requires_result=True,
9933        supports_autograd=True,
9934        supports_inplace_autograd=True,
9935        supports_forward_ad=True,
9936        decorators=(
9937            # due to https://github.com/pytorch/pytorch/pull/102427 enabling jiterator for complex
9938            DecorateInfo(
9939                unittest.expectedFailure,
9940                "TestMeta",
9941                "test_dispatch_meta_inplace",
9942                dtypes=integral_types_and(torch.bool,),
9943            ),
9944            DecorateInfo(
9945                unittest.expectedFailure,
9946                "TestMeta",
9947                "test_dispatch_symbolic_meta_inplace",
9948                dtypes=integral_types_and(torch.bool,),
9949            ),
9950            DecorateInfo(
9951                unittest.expectedFailure,
9952                "TestMeta",
9953                "test_meta_inplace",
9954                dtypes=integral_types_and(torch.bool,),
9955            ),
9956            DecorateInfo(
9957                toleranceOverride(
9958                    {
9959                        torch.complex64: tol(atol=3e-04, rtol=2e-05)
9960                    }
9961                ),
9962                'TestForeach',
9963                'test_parity',
9964                device_type='cuda'
9965            ),
9966        ),
9967    ),
9968    ForeachFuncInfo(
9969        'tanh',
9970        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
9971        backward_requires_result=True,
9972        supports_autograd=True,
9973        supports_inplace_autograd=True,
9974        supports_forward_ad=True,
9975        decorators=(
9976            DecorateInfo(
9977                unittest.expectedFailure,
9978                "TestMeta",
9979                "test_dispatch_meta_inplace",
9980                dtypes=integral_types_and(torch.bool,),
9981            ),
9982            DecorateInfo(
9983                unittest.expectedFailure,
9984                "TestMeta",
9985                "test_dispatch_symbolic_meta_inplace",
9986                dtypes=integral_types_and(torch.bool,),
9987            ),
9988            DecorateInfo(
9989                unittest.expectedFailure,
9990                "TestMeta",
9991                "test_meta_inplace",
9992                dtypes=integral_types_and(torch.bool,),
9993            ),
9994            DecorateInfo(
9995                toleranceOverride(
9996                    {torch.complex64: tol(atol=5e-03, rtol=1e-04)}
9997                ),
9998                'TestForeach',
9999                'test_parity',
10000                device_type='cuda'
10001            ),
10002        ),
10003    ),
10004    ForeachFuncInfo(
10005        'sin',
10006        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10007        supports_autograd=True,
10008        supports_inplace_autograd=True,
10009        supports_forward_ad=True,
10010        decorators=(
10011            DecorateInfo(
10012                unittest.expectedFailure,
10013                "TestMeta",
10014                "test_dispatch_meta_inplace",
10015                dtypes=integral_types_and(torch.bool,),
10016            ),
10017            DecorateInfo(
10018                unittest.expectedFailure,
10019                "TestMeta",
10020                "test_dispatch_symbolic_meta_inplace",
10021                dtypes=integral_types_and(torch.bool,),
10022            ),
10023            DecorateInfo(
10024                unittest.expectedFailure,
10025                "TestMeta",
10026                "test_meta_inplace",
10027                dtypes=integral_types_and(torch.bool,),
10028            ),
10029        ),
10030    ),
10031    ForeachFuncInfo(
10032        'sinh',
10033        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10034        supports_autograd=True,
10035        supports_inplace_autograd=True,
10036        supports_forward_ad=True,
10037        decorators=(
10038            DecorateInfo(
10039                unittest.expectedFailure,
10040                "TestMeta",
10041                "test_dispatch_meta_inplace",
10042                dtypes=integral_types_and(torch.bool),
10043            ),
10044            DecorateInfo(
10045                unittest.expectedFailure,
10046                "TestMeta",
10047                "test_dispatch_symbolic_meta_inplace",
10048                dtypes=integral_types_and(torch.bool),
10049            ),
10050            DecorateInfo(
10051                unittest.expectedFailure,
10052                "TestMeta",
10053                "test_meta_inplace",
10054                dtypes=integral_types_and(torch.bool),
10055            ),
10056        ),
10057    ),
10058    ForeachFuncInfo(
10059        'neg',
10060        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10061        supports_autograd=True,
10062        supports_inplace_autograd=True,
10063        supports_forward_ad=True,
10064        decorators=(
10065            DecorateInfo(
10066                unittest.expectedFailure,
10067                "TestMeta",
10068                "test_dispatch_meta_inplace",
10069                dtypes=(torch.bool,),
10070            ),
10071            DecorateInfo(
10072                unittest.expectedFailure,
10073                "TestMeta",
10074                "test_dispatch_meta_outplace",
10075                dtypes=(torch.bool,),
10076            ),
10077            DecorateInfo(
10078                unittest.expectedFailure,
10079                "TestMeta",
10080                "test_dispatch_symbolic_meta_inplace",
10081                dtypes=(torch.bool,),
10082            ),
10083            DecorateInfo(
10084                unittest.expectedFailure,
10085                "TestMeta",
10086                "test_dispatch_symbolic_meta_outplace",
10087                dtypes=(torch.bool,),
10088            ),
10089            DecorateInfo(
10090                unittest.expectedFailure,
10091                "TestMeta",
10092                "test_meta_inplace",
10093                dtypes=(torch.bool,),
10094            ),
10095            DecorateInfo(
10096                unittest.expectedFailure,
10097                "TestMeta",
10098                "test_meta_outplace",
10099                dtypes=(torch.bool,),
10100            ),
10101            DecorateInfo(
10102                unittest.expectedFailure,
10103                "TestForeach",
10104                "test_unary_op_tensors_on_different_devices",
10105                device_type="cuda",
10106                dtypes=(torch.bool,),
10107            ),
10108        ),
10109    ),
10110    ForeachFuncInfo(
10111        'sqrt',
10112        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10113        supports_autograd=True,
10114        supports_inplace_autograd=True,
10115        supports_forward_ad=True,
10116        backward_requires_result=True,
10117        decorators=(
10118            DecorateInfo(
10119                unittest.expectedFailure,
10120                "TestMeta",
10121                "test_dispatch_meta_inplace",
10122                dtypes=integral_types_and(torch.bool),
10123            ),
10124            DecorateInfo(
10125                unittest.expectedFailure,
10126                "TestMeta",
10127                "test_dispatch_symbolic_meta_inplace",
10128                dtypes=integral_types_and(torch.bool),
10129            ),
10130            DecorateInfo(
10131                unittest.expectedFailure,
10132                "TestMeta",
10133                "test_meta_inplace",
10134                dtypes=integral_types_and(torch.bool),
10135            ),
10136        ),
10137    ),
10138    ForeachFuncInfo(
10139        'ceil',
10140        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10141        supports_autograd=True,
10142        supports_inplace_autograd=True,
10143        supports_forward_ad=True,
10144        decorators=(
10145            DecorateInfo(
10146                unittest.expectedFailure,
10147                "TestMeta",
10148                "test_dispatch_meta_inplace",
10149                dtypes=complex_types_and(torch.bool),
10150            ),
10151            DecorateInfo(
10152                unittest.expectedFailure,
10153                "TestMeta",
10154                "test_dispatch_meta_outplace",
10155                dtypes=complex_types_and(torch.bool),
10156            ),
10157            DecorateInfo(
10158                unittest.expectedFailure,
10159                "TestMeta",
10160                "test_dispatch_symbolic_meta_inplace",
10161                dtypes=complex_types_and(torch.bool),
10162            ),
10163            DecorateInfo(
10164                unittest.expectedFailure,
10165                "TestMeta",
10166                "test_dispatch_symbolic_meta_outplace",
10167                dtypes=complex_types_and(torch.bool),
10168            ),
10169            DecorateInfo(
10170                unittest.expectedFailure,
10171                "TestMeta",
10172                "test_meta_inplace",
10173                dtypes=complex_types_and(torch.bool),
10174            ),
10175            DecorateInfo(
10176                unittest.expectedFailure,
10177                "TestMeta",
10178                "test_meta_outplace",
10179                dtypes=complex_types_and(torch.bool),
10180            ),
10181            DecorateInfo(
10182                unittest.expectedFailure,
10183                "TestForeach",
10184                "test_autodiff",
10185                device_type="cuda",
10186                dtypes=(torch.complex128,),
10187            ),
10188        ),
10189    ),
10190    ForeachFuncInfo(
10191        'erf',
10192        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10193        supports_autograd=True,
10194        supports_inplace_autograd=True,
10195        supports_forward_ad=True,
10196        decorators=(
10197            DecorateInfo(
10198                unittest.expectedFailure,
10199                "TestMeta",
10200                "test_dispatch_meta_inplace",
10201                dtypes=integral_types_and(torch.bool) + complex_types(),
10202            ),
10203            DecorateInfo(
10204                unittest.expectedFailure,
10205                "TestMeta",
10206                "test_dispatch_meta_outplace",
10207                dtypes=complex_types(),
10208            ),
10209            DecorateInfo(
10210                unittest.expectedFailure,
10211                "TestMeta",
10212                "test_dispatch_symbolic_meta_inplace",
10213                dtypes=integral_types_and(torch.bool) + complex_types(),
10214            ),
10215            DecorateInfo(
10216                unittest.expectedFailure,
10217                "TestMeta",
10218                "test_dispatch_symbolic_meta_outplace",
10219                dtypes=complex_types(),
10220            ),
10221            DecorateInfo(
10222                unittest.expectedFailure,
10223                "TestMeta",
10224                "test_meta_inplace",
10225                dtypes=integral_types_and(torch.bool) + complex_types(),
10226            ),
10227            DecorateInfo(
10228                unittest.expectedFailure,
10229                "TestMeta",
10230                "test_meta_outplace",
10231                dtypes=complex_types(),
10232            ),
10233            DecorateInfo(
10234                unittest.expectedFailure,
10235                "TestForeach",
10236                "test_autodiff",
10237                device_type="cuda",
10238                dtypes=(torch.complex128,),
10239            ),
10240        ),
10241    ),
10242    ForeachFuncInfo(
10243        'erfc',
10244        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10245        supports_autograd=True,
10246        supports_inplace_autograd=True,
10247        supports_forward_ad=True,
10248        decorators=(
10249            DecorateInfo(
10250                unittest.expectedFailure,
10251                "TestMeta",
10252                "test_dispatch_meta_inplace",
10253                dtypes=integral_types_and(torch.bool) + complex_types(),
10254            ),
10255            DecorateInfo(
10256                unittest.expectedFailure,
10257                "TestMeta",
10258                "test_dispatch_meta_outplace",
10259                dtypes=complex_types(),
10260            ),
10261            DecorateInfo(
10262                unittest.expectedFailure,
10263                "TestMeta",
10264                "test_dispatch_symbolic_meta_inplace",
10265                dtypes=integral_types_and(torch.bool) + complex_types(),
10266            ),
10267            DecorateInfo(
10268                unittest.expectedFailure,
10269                "TestMeta",
10270                "test_dispatch_symbolic_meta_outplace",
10271                dtypes=complex_types(),
10272            ),
10273            DecorateInfo(
10274                unittest.expectedFailure,
10275                "TestMeta",
10276                "test_meta_inplace",
10277                dtypes=integral_types_and(torch.bool) + complex_types(),
10278            ),
10279            DecorateInfo(
10280                unittest.expectedFailure,
10281                "TestMeta",
10282                "test_meta_outplace",
10283                dtypes=complex_types(),
10284            ),
10285            DecorateInfo(
10286                unittest.expectedFailure,
10287                "TestForeach",
10288                "test_autodiff",
10289                device_type="cuda",
10290                dtypes=(torch.complex128,),
10291            ),
10292        ),
10293    ),
10294    ForeachFuncInfo(
10295        'expm1',
10296        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10297        supports_autograd=True,
10298        supports_inplace_autograd=True,
10299        supports_forward_ad=True,
10300        backward_requires_result=True,
10301        decorators=(
10302            DecorateInfo(
10303                unittest.expectedFailure,
10304                "TestMeta",
10305                "test_dispatch_meta_inplace",
10306                dtypes=integral_types_and(torch.bool),
10307            ),
10308            DecorateInfo(
10309                unittest.expectedFailure,
10310                "TestMeta",
10311                "test_dispatch_symbolic_meta_inplace",
10312                dtypes=integral_types_and(torch.bool),
10313            ),
10314            DecorateInfo(
10315                unittest.expectedFailure,
10316                "TestMeta",
10317                "test_meta_inplace",
10318                dtypes=integral_types_and(torch.bool),
10319            ),
10320        ),
10321    ),
10322    ForeachFuncInfo(
10323        'floor',
10324        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10325        supports_autograd=True,
10326        supports_inplace_autograd=True,
10327        supports_forward_ad=True,
10328        decorators=(
10329            DecorateInfo(
10330                unittest.expectedFailure,
10331                "TestMeta",
10332                "test_dispatch_meta_inplace",
10333                dtypes=complex_types_and(torch.bool),
10334            ),
10335            DecorateInfo(
10336                unittest.expectedFailure,
10337                "TestMeta",
10338                "test_dispatch_meta_outplace",
10339                dtypes=complex_types_and(torch.bool),
10340            ),
10341            DecorateInfo(
10342                unittest.expectedFailure,
10343                "TestMeta",
10344                "test_dispatch_symbolic_meta_inplace",
10345                dtypes=complex_types_and(torch.bool),
10346            ),
10347            DecorateInfo(
10348                unittest.expectedFailure,
10349                "TestMeta",
10350                "test_dispatch_symbolic_meta_outplace",
10351                dtypes=complex_types_and(torch.bool),
10352            ),
10353            DecorateInfo(
10354                unittest.expectedFailure,
10355                "TestMeta",
10356                "test_meta_inplace",
10357                dtypes=complex_types_and(torch.bool),
10358            ),
10359            DecorateInfo(
10360                unittest.expectedFailure,
10361                "TestMeta",
10362                "test_meta_outplace",
10363                dtypes=complex_types_and(torch.bool),
10364            ),
10365            DecorateInfo(
10366                unittest.expectedFailure,
10367                "TestForeach",
10368                "test_autodiff",
10369                device_type="cuda",
10370                dtypes=(torch.complex128,),
10371            ),
10372        ),
10373    ),
10374    ForeachFuncInfo(
10375        'log1p',
10376        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10377        supports_autograd=True,
10378        supports_inplace_autograd=True,
10379        supports_forward_ad=True,
10380        decorators=(
10381            DecorateInfo(
10382                unittest.expectedFailure,
10383                "TestMeta",
10384                "test_dispatch_meta_inplace",
10385                dtypes=integral_types_and(torch.bool),
10386            ),
10387            DecorateInfo(
10388                unittest.expectedFailure,
10389                "TestMeta",
10390                "test_dispatch_symbolic_meta_inplace",
10391                dtypes=integral_types_and(torch.bool),
10392            ),
10393            DecorateInfo(
10394                unittest.expectedFailure,
10395                "TestMeta",
10396                "test_meta_inplace",
10397                dtypes=integral_types_and(torch.bool),
10398            ),
10399        ),
10400    ),
10401    ForeachFuncInfo(
10402        'round',
10403        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10404        supports_autograd=True,
10405        supports_inplace_autograd=True,
10406        supports_forward_ad=True,
10407        decorators=(
10408            DecorateInfo(
10409                unittest.expectedFailure,
10410                "TestMeta",
10411                "test_dispatch_meta_inplace",
10412                dtypes=complex_types_and(torch.bool),
10413            ),
10414            DecorateInfo(
10415                unittest.expectedFailure,
10416                "TestMeta",
10417                "test_dispatch_meta_outplace",
10418                dtypes=complex_types_and(torch.bool),
10419            ),
10420            DecorateInfo(
10421                unittest.expectedFailure,
10422                "TestMeta",
10423                "test_dispatch_symbolic_meta_inplace",
10424                dtypes=complex_types_and(torch.bool),
10425            ),
10426            DecorateInfo(
10427                unittest.expectedFailure,
10428                "TestMeta",
10429                "test_dispatch_symbolic_meta_outplace",
10430                dtypes=complex_types_and(torch.bool),
10431            ),
10432            DecorateInfo(
10433                unittest.expectedFailure,
10434                "TestMeta",
10435                "test_meta_inplace",
10436                dtypes=complex_types_and(torch.bool),
10437            ),
10438            DecorateInfo(
10439                unittest.expectedFailure,
10440                "TestMeta",
10441                "test_meta_outplace",
10442                dtypes=complex_types_and(torch.bool),
10443            ),
10444            DecorateInfo(
10445                unittest.expectedFailure,
10446                "TestForeach",
10447                "test_autodiff",
10448                device_type="cuda",
10449                dtypes=(torch.complex128,),
10450            ),
10451        ),
10452    ),
10453    ForeachFuncInfo(
10454        'frac',
10455        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10456        supports_autograd=True,
10457        supports_inplace_autograd=True,
10458        supports_forward_ad=True,
10459        decorators=(
10460            DecorateInfo(
10461                unittest.expectedFailure,
10462                "TestMeta",
10463                "test_dispatch_meta_inplace",
10464                dtypes=integral_types_and(torch.bool) + complex_types(),
10465            ),
10466            DecorateInfo(
10467                unittest.expectedFailure,
10468                "TestMeta",
10469                "test_dispatch_meta_outplace",
10470                dtypes=integral_types_and(torch.bool) + complex_types(),
10471            ),
10472            DecorateInfo(
10473                unittest.expectedFailure,
10474                "TestMeta",
10475                "test_dispatch_symbolic_meta_inplace",
10476                dtypes=integral_types_and(torch.bool) + complex_types(),
10477            ),
10478            DecorateInfo(
10479                unittest.expectedFailure,
10480                "TestMeta",
10481                "test_dispatch_symbolic_meta_outplace",
10482                dtypes=integral_types_and(torch.bool) + complex_types(),
10483            ),
10484            DecorateInfo(
10485                unittest.expectedFailure,
10486                "TestMeta",
10487                "test_meta_inplace",
10488                dtypes=integral_types_and(torch.bool) + complex_types(),
10489            ),
10490            DecorateInfo(
10491                unittest.expectedFailure,
10492                "TestMeta",
10493                "test_meta_outplace",
10494                dtypes=integral_types_and(torch.bool) + complex_types(),
10495            ),
10496            DecorateInfo(
10497                unittest.expectedFailure,
10498                "TestForeach",
10499                "test_autodiff",
10500                device_type="cuda",
10501                dtypes=(torch.complex128,),
10502            ),
10503        ),
10504    ),
10505    ForeachFuncInfo(
10506        'reciprocal',
10507        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10508        supports_autograd=True,
10509        supports_inplace_autograd=True,
10510        supports_forward_ad=True,
10511        backward_requires_result=True,
10512        decorators=(
10513            DecorateInfo(
10514                unittest.expectedFailure,
10515                "TestMeta",
10516                "test_dispatch_meta_inplace",
10517                dtypes=integral_types_and(torch.bool),
10518            ),
10519            DecorateInfo(
10520                unittest.expectedFailure,
10521                "TestMeta",
10522                "test_dispatch_symbolic_meta_inplace",
10523                dtypes=integral_types_and(torch.bool),
10524            ),
10525            DecorateInfo(
10526                unittest.expectedFailure,
10527                "TestMeta",
10528                "test_meta_inplace",
10529                dtypes=integral_types_and(torch.bool),
10530            ),
10531        ),
10532    ),
10533    ForeachFuncInfo(
10534        'sigmoid',
10535        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10536        supports_autograd=True,
10537        supports_inplace_autograd=True,
10538        supports_forward_ad=True,
10539        backward_requires_result=True,
10540        decorators=(
10541            DecorateInfo(
10542                unittest.expectedFailure,
10543                "TestMeta",
10544                "test_dispatch_meta_inplace",
10545                dtypes=integral_types_and(torch.bool),
10546            ),
10547            DecorateInfo(
10548                unittest.expectedFailure,
10549                "TestMeta",
10550                "test_dispatch_symbolic_meta_inplace",
10551                dtypes=integral_types_and(torch.bool),
10552            ),
10553            DecorateInfo(
10554                unittest.expectedFailure,
10555                "TestMeta",
10556                "test_meta_inplace",
10557                dtypes=integral_types_and(torch.bool),
10558            ),
10559        ),
10560    ),
10561    ForeachFuncInfo(
10562        'trunc',
10563        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10564        supports_autograd=True,
10565        supports_inplace_autograd=True,
10566        supports_forward_ad=True,
10567        decorators=(
10568            DecorateInfo(
10569                unittest.expectedFailure,
10570                "TestMeta",
10571                "test_dispatch_meta_inplace",
10572                dtypes=complex_types_and(torch.bool),
10573            ),
10574            DecorateInfo(
10575                unittest.expectedFailure,
10576                "TestMeta",
10577                "test_dispatch_meta_outplace",
10578                dtypes=complex_types_and(torch.bool),
10579            ),
10580            DecorateInfo(
10581                unittest.expectedFailure,
10582                "TestMeta",
10583                "test_dispatch_symbolic_meta_inplace",
10584                dtypes=complex_types_and(torch.bool),
10585            ),
10586            DecorateInfo(
10587                unittest.expectedFailure,
10588                "TestMeta",
10589                "test_dispatch_symbolic_meta_outplace",
10590                dtypes=complex_types_and(torch.bool),
10591            ),
10592            DecorateInfo(
10593                unittest.expectedFailure,
10594                "TestMeta",
10595                "test_meta_inplace",
10596                dtypes=complex_types_and(torch.bool),
10597            ),
10598            DecorateInfo(
10599                unittest.expectedFailure,
10600                "TestMeta",
10601                "test_meta_outplace",
10602                dtypes=complex_types_and(torch.bool),
10603            ),
10604            DecorateInfo(
10605                unittest.expectedFailure,
10606                "TestForeach",
10607                "test_autodiff",
10608                device_type="cuda",
10609                dtypes=(torch.complex128,),
10610            ),
10611        ),
10612    ),
10613    ForeachFuncInfo(
10614        'abs',
10615        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10616        supports_autograd=True,
10617        supports_inplace_autograd=True,
10618        supports_forward_ad=True,
10619        supports_fwgrad_bwgrad=True,
10620        decorators=(
10621            DecorateInfo(
10622                unittest.expectedFailure,
10623                "TestMeta",
10624                "test_dispatch_symbolic_meta_inplace",
10625                dtypes=complex_types(),
10626            ),
10627            DecorateInfo(
10628                unittest.expectedFailure,
10629                "TestMeta",
10630                "test_dispatch_meta_inplace",
10631                dtypes=complex_types(),
10632            ),
10633            DecorateInfo(
10634                unittest.expectedFailure,
10635                "TestMeta",
10636                "test_dispatch_meta_outplace",
10637                device_type="cpu",
10638                dtypes=(torch.bool,),
10639            ),
10640            DecorateInfo(
10641                unittest.expectedFailure,
10642                "TestMeta",
10643                "test_dispatch_symbolic_meta_inplace",
10644                device_type="cpu",
10645                dtypes=(torch.bool,),
10646            ),
10647            DecorateInfo(
10648                unittest.expectedFailure,
10649                "TestMeta",
10650                "test_dispatch_symbolic_meta_outplace",
10651                device_type="cpu",
10652                dtypes=(torch.bool,),
10653            ),
10654            DecorateInfo(
10655                unittest.expectedFailure,
10656                "TestMeta",
10657                "test_meta_inplace",
10658                device_type="cpu",
10659                dtypes=(torch.bool,),
10660            ),
10661            DecorateInfo(
10662                unittest.expectedFailure,
10663                "TestMeta",
10664                "test_meta_outplace",
10665                device_type="cpu",
10666                dtypes=(torch.bool,),
10667            ),
10668            DecorateInfo(
10669                unittest.expectedFailure,
10670                "TestMeta",
10671                "test_dispatch_meta_inplace",
10672                device_type="cpu",
10673                dtypes=(torch.bool,),
10674            ),
10675            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=complex_types()),
10676        ),
10677    ),
10678    ForeachFuncInfo(
10679        'zero',
10680        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10681        supports_autograd=True,
10682        supports_inplace_autograd=True,
10683        supports_forward_ad=True,
10684        supports_out=False,
10685    ),
10686    ForeachFuncInfo(
10687        'sign',
10688        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10689        supports_autograd=True,
10690        supports_inplace_autograd=True,
10691        supports_forward_ad=True,
10692        decorators=(
10693            DecorateInfo(
10694                unittest.expectedFailure,
10695                "TestMeta",
10696                "test_dispatch_meta_inplace",
10697                dtypes=complex_types(),
10698            ),
10699            DecorateInfo(
10700                unittest.expectedFailure,
10701                "TestMeta",
10702                "test_dispatch_meta_outplace",
10703                dtypes=complex_types(),
10704            ),
10705            DecorateInfo(
10706                unittest.expectedFailure,
10707                "TestMeta",
10708                "test_dispatch_symbolic_meta_inplace",
10709                dtypes=complex_types(),
10710            ),
10711            DecorateInfo(
10712                unittest.expectedFailure,
10713                "TestMeta",
10714                "test_dispatch_symbolic_meta_outplace",
10715                dtypes=complex_types(),
10716            ),
10717            DecorateInfo(
10718                unittest.expectedFailure,
10719                "TestMeta",
10720                "test_meta_inplace",
10721                dtypes=complex_types(),
10722            ),
10723            DecorateInfo(
10724                unittest.expectedFailure,
10725                "TestMeta",
10726                "test_meta_outplace",
10727                dtypes=complex_types(),
10728            ),
10729            DecorateInfo(
10730                unittest.expectedFailure,
10731                "TestForeach",
10732                "test_autodiff",
10733                device_type="cuda",
10734                dtypes=(torch.complex128,),
10735            ),
10736        ),
10737    ),
10738    ForeachFuncInfo(
10739        'lgamma',
10740        sample_inputs_func=foreach_inputs_sample_func(1, False, False),
10741        supports_autograd=True,
10742        supports_inplace_autograd=True,
10743        supports_forward_ad=True,
10744        decorators=(
10745            DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
10746                         "test_dispatch_symbolic_meta_inplace", dtypes=integral_types_and(torch.bool)),
10747            # DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
10748            #              "test_dispatch_meta_inplace", dtypes=integral_types_and(torch.bool)),
10749            DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta",
10750                         "test_meta_inplace", dtypes=integral_types_and(torch.bool)),
10751            DecorateInfo(
10752                unittest.expectedFailure,
10753                "TestMeta",
10754                "test_dispatch_meta_inplace",
10755                dtypes=complex_types() + integral_types_and(torch.bool),
10756            ),
10757            DecorateInfo(
10758                unittest.expectedFailure,
10759                "TestMeta",
10760                "test_dispatch_meta_outplace",
10761                dtypes=complex_types(),
10762            ),
10763            DecorateInfo(
10764                unittest.expectedFailure,
10765                "TestMeta",
10766                "test_dispatch_symbolic_meta_inplace",
10767                dtypes=complex_types() + integral_types_and(torch.bool),
10768            ),
10769            DecorateInfo(
10770                unittest.expectedFailure,
10771                "TestMeta",
10772                "test_dispatch_symbolic_meta_outplace",
10773                dtypes=complex_types(),
10774            ),
10775            DecorateInfo(
10776                unittest.expectedFailure,
10777                "TestMeta",
10778                "test_meta_inplace",
10779                dtypes=complex_types() + integral_types_and(torch.bool),
10780            ),
10781            DecorateInfo(
10782                unittest.expectedFailure,
10783                "TestMeta",
10784                "test_meta_outplace",
10785                dtypes=complex_types(),
10786            ),
10787            DecorateInfo(
10788                unittest.expectedFailure,
10789                "TestForeach",
10790                "test_autodiff",
10791                device_type="cuda",
10792                dtypes=(torch.complex128,),
10793            ),
10794        ),
10795    ),
10796]
10797
10798foreach_binary_op_db: List[OpInfo] = [
10799    ForeachFuncInfo(
10800        "add",
10801        sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
10802        supports_alpha_param=True,
10803        supports_autograd=True,
10804        supports_inplace_autograd=True,
10805        supports_forward_ad=True,
10806        decorators=(
10807            # These tests fail with aten._local_scalar_dense not being implemented.
10808            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
10809            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
10810                         dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)),
10811            # Samples have complex types and inplace only works if the dtype is complex.
10812            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
10813                         dtypes=(torch.bool,)),
10814            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
10815                         dtypes=(torch.bool,)),
10816            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
10817                         dtypes=integral_types() + complex_types_and(torch.bool, torch.bfloat16, torch.float16, torch.float64)),
10818        ),
10819    ),
10820    ForeachFuncInfo(
10821        "sub",
10822        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
10823        supports_alpha_param=True,
10824        supports_autograd=True,
10825        supports_inplace_autograd=True,
10826        supports_forward_ad=True,
10827        decorators=(
10828            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
10829            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
10830            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
10831            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"),
10832            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"),
10833            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"),
10834            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
10835            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
10836        ),
10837    ),
10838    ForeachFuncInfo(
10839        "mul",
10840        sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
10841        supports_autograd=True,
10842        supports_inplace_autograd=True,
10843        supports_forward_ad=True,
10844        decorators=(
10845            # Samples have complex types and inplace only works if the dtype is complex.
10846            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
10847                         dtypes=(torch.bool,)),
10848            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
10849                         dtypes=(torch.bool,)),
10850            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)),
10851            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
10852                         dtypes=(torch.bool,)),
10853        ),
10854    ),
10855    ForeachFuncInfo(
10856        "div",
10857        sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
10858        supports_autograd=True,
10859        supports_inplace_autograd=True,
10860        supports_forward_ad=True,
10861        decorators=(
10862            # Samples have complex types and inplace only works if the dtype is complex.
10863            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
10864                         dtypes=integral_types_and(torch.bool)),
10865            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
10866                         dtypes=integral_types_and(torch.bool)),
10867            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
10868                         dtypes=integral_types_and(torch.bool)),
10869            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
10870                         dtypes=integral_types_and(torch.bool)),
10871        ),
10872    ),
10873    ForeachFuncInfo(
10874        "clamp_min",
10875        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
10876        supports_autograd=True,
10877        supports_inplace_autograd=True,
10878        supports_forward_ad=True,
10879        decorators=(
10880            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
10881                         dtypes=complex_types_and(torch.bool)),
10882            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
10883                         dtypes=complex_types_and(torch.bool)),
10884            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
10885                         dtypes=complex_types_and(torch.bool)),
10886            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
10887                         dtypes=complex_types_and(torch.bool)),
10888            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
10889                         dtypes=complex_types_and(torch.bool)),
10890            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
10891                         dtypes=complex_types_and(torch.bool)),
10892            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
10893                         dtypes=complex_types_and(torch.bool)),
10894            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
10895                         dtypes=complex_types_and(torch.bool)),
10896            DecorateInfo(
10897                unittest.expectedFailure,
10898                "TestForeach",
10899                "test_autodiff",
10900                device_type="cuda",
10901                dtypes=(torch.complex128,),
10902            ),
10903            DecorateInfo(
10904                unittest.expectedFailure,
10905                "TestForeach",
10906                "test_binary_op_scalar_with_overlapping_tensors",
10907                dtypes=complex_types(),
10908            ),
10909        ),
10910    ),
10911    ForeachFuncInfo(
10912        "clamp_max",
10913        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
10914        supports_autograd=True,
10915        supports_inplace_autograd=True,
10916        supports_forward_ad=True,
10917        decorators=(
10918            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
10919                         dtypes=complex_types_and(torch.bool)),
10920            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
10921                         dtypes=complex_types_and(torch.bool)),
10922            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
10923                         dtypes=complex_types_and(torch.bool)),
10924            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
10925                         dtypes=complex_types_and(torch.bool)),
10926            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
10927                         dtypes=complex_types_and(torch.bool)),
10928            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
10929                         dtypes=complex_types_and(torch.bool)),
10930            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
10931                         dtypes=complex_types_and(torch.bool)),
10932            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
10933                         dtypes=complex_types_and(torch.bool)),
10934            DecorateInfo(
10935                unittest.expectedFailure,
10936                "TestForeach",
10937                "test_autodiff",
10938                device_type="cuda",
10939                dtypes=(torch.complex128,),
10940            ),
10941            DecorateInfo(
10942                unittest.expectedFailure,
10943                "TestForeach",
10944                "test_binary_op_scalar_with_overlapping_tensors",
10945                dtypes=complex_types(),
10946            ),
10947        ),
10948    ),
10949    # note(crcrpar): forward ad not implemented.
10950    ForeachFuncInfo(
10951        "minimum",
10952        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
10953        supports_autograd=True,
10954        supports_inplace_autograd=False,
10955        supports_forward_ad=False,
10956        decorators=(
10957            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
10958                         dtypes=complex_types_and(torch.bool)),
10959            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
10960                         dtypes=complex_types_and(torch.bool)),
10961            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
10962                         dtypes=complex_types_and(torch.bool)),
10963            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
10964                         dtypes=complex_types_and(torch.bool)),
10965            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
10966                         dtypes=complex_types_and(torch.bool)),
10967            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
10968                         dtypes=complex_types_and(torch.bool)),
10969            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
10970                         dtypes=complex_types_and(torch.bool)),
10971            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
10972                         dtypes=complex_types_and(torch.bool)),
10973            DecorateInfo(
10974                unittest.expectedFailure,
10975                "TestForeach",
10976                "test_autodiff",
10977                device_type="cuda",
10978                dtypes=(torch.complex128,),
10979            ),
10980            DecorateInfo(
10981                unittest.expectedFailure,
10982                "TestForeach",
10983                "test_binary_op_scalar_with_overlapping_tensors",
10984                dtypes=complex_types(),
10985            ),
10986        ),
10987    ),
10988    # note(crcrpar): forward ad not implemented.
10989    ForeachFuncInfo(
10990        "maximum",
10991        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
10992        supports_autograd=True,
10993        supports_forward_ad=False,
10994        supports_inplace_autograd=False,
10995        decorators=(
10996            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
10997                         dtypes=complex_types_and(torch.bool)),
10998            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
10999                         dtypes=complex_types_and(torch.bool)),
11000            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
11001                         dtypes=complex_types_and(torch.bool)),
11002            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
11003                         dtypes=complex_types_and(torch.bool)),
11004            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
11005                         dtypes=complex_types_and(torch.bool)),
11006            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
11007                         dtypes=complex_types_and(torch.bool)),
11008            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
11009                         dtypes=complex_types_and(torch.bool)),
11010            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
11011                         dtypes=complex_types_and(torch.bool)),
11012            DecorateInfo(
11013                unittest.expectedFailure,
11014                "TestForeach",
11015                "test_autodiff",
11016                device_type="cuda",
11017                dtypes=(torch.complex128,),
11018            ),
11019            DecorateInfo(
11020                unittest.expectedFailure,
11021                "TestForeach",
11022                "test_binary_op_scalar_with_overlapping_tensors",
11023                dtypes=complex_types(),
11024            ),
11025        ),
11026    ),
11027    ForeachFuncInfo(
11028        "pow",
11029        supports_alpha_param=False,
11030        supports_scalar_self_arg=True,
11031        sample_inputs_func=foreach_inputs_sample_func(2, True, True),
11032        supports_autograd=True,
11033        supports_inplace_autograd=True,
11034        supports_forward_ad=True,
11035        decorators=(
11036            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)),
11037            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
11038                         dtypes=(torch.bool,)),
11039            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)),
11040            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)),
11041            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
11042                         dtypes=(torch.bool,)),
11043            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
11044                         dtypes=(torch.bool,)),
11045            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
11046                         dtypes=(torch.bool,)),
11047            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
11048                         dtypes=(torch.bool,),),
11049            DecorateInfo(unittest.skip("flaky"), "TestForeach", "test_parity", device_type="cpu", dtypes=(torch.complex64,)),
11050            DecorateInfo(
11051                unittest.skip("failed starting on ROCm 6.2"),
11052                "TestForeach",
11053                "test_parity",
11054                device_type="cuda",
11055                dtypes=(torch.complex64,),
11056                active_if=TEST_WITH_ROCM),
11057            DecorateInfo(
11058                unittest.expectedFailure,
11059                "TestForeach",
11060                "test_binary_op_with_scalar_self_support",
11061                device_type="cuda",
11062                dtypes=(torch.bool,),
11063                active_if=lambda kwargs: kwargs["is_fastpath"],
11064            ),
11065        ),
11066        backward_requires_result=True,
11067    ),
11068    ForeachFuncInfo(
11069        "copy",
11070        sample_inputs_func=foreach_inputs_sample_func(2, False, False),
11071        supports_out=False,
11072        supports_forward_ad=False,
11073        supports_autograd=False,
11074        supports_inplace_autograd=False,
11075    )
11076]
11077
11078foreach_pointwise_op_db: List[ForeachFuncInfo] = [
11079    ForeachFuncInfo(
11080        "addcmul",
11081        sample_inputs_func=foreach_pointwise_sample_func(4, True, True),
11082        supports_autograd=True,
11083        supports_inplace_autograd=True,
11084        supports_forward_ad=True,
11085        decorators=(
11086            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bool,)),
11087            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
11088                         dtypes=(torch.bool,)),
11089            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)),
11090            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
11091                         dtypes=(torch.bool,)),
11092            # # Samples have complex types and inplace only works if the dtype is complex.
11093            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)),
11094            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
11095                         dtypes=(torch.bool,)),
11096            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)),
11097            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
11098                         dtypes=integral_types() + complex_types_and(torch.bool)),
11099        ),
11100    ),
11101    ForeachFuncInfo(
11102        "addcdiv",
11103        sample_inputs_func=foreach_pointwise_sample_func(4, True, True),
11104        supports_autograd=True,
11105        supports_inplace_autograd=True,
11106        supports_forward_ad=True,
11107        decorators=(
11108            # Samples have complex types and inplace only works if the dtype is complex.
11109            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace",
11110                         dtypes=integral_types_and(torch.bool)),
11111            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace",
11112                         dtypes=integral_types_and(torch.bool)),
11113            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace",
11114                         dtypes=integral_types_and(torch.bool)),
11115            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides",
11116                         dtypes=integral_types() + complex_types_and(torch.bool)),
11117            # fails with div_cpu is not implemented with ComplexHalf
11118            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace",
11119                         dtypes=integral_types_and(torch.bool)),
11120            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace",
11121                         dtypes=integral_types_and(torch.bool)),
11122            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace",
11123                         dtypes=integral_types_and(torch.bool)),
11124            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
11125                         dtypes=integral_types() + complex_types_and(torch.bool)),
11126        ),
11127    ),
11128]
11129
11130foreach_reduce_op_db: List[ForeachFuncInfo] = [
11131    ForeachFuncInfo(
11132        "max",
11133        sample_inputs_func=foreach_max_sample_func(1, False, False),
11134        supports_autograd=True,
11135        supports_inplace_autograd=True,
11136        supports_forward_ad=True,
11137        decorators=(
11138            # no complex support for ordering ops like max
11139            DecorateInfo(
11140                unittest.expectedFailure,
11141                "TestForeach",
11142                "test_autodiff",
11143                dtypes=(torch.complex128, torch.complex64),
11144            ),
11145            DecorateInfo(
11146                unittest.expectedFailure,
11147                "TestForeach",
11148                "test_foreach_reduce_large_input",
11149                dtypes=(torch.complex128, torch.complex64),
11150            ),
11151            DecorateInfo(
11152                unittest.expectedFailure,
11153                "TestMeta",
11154                "test_dispatch_symbolic_meta_outplace",
11155                dtypes=(torch.complex128, torch.complex64),
11156            ),
11157            DecorateInfo(
11158                unittest.expectedFailure,
11159                "TestMeta",
11160                "test_meta_outplace",
11161                dtypes=(torch.complex128, torch.complex64),
11162            ),
11163            DecorateInfo(
11164                unittest.expectedFailure,
11165                "TestMeta",
11166                "test_dispatch_meta_outplace",
11167                dtypes=(torch.complex128, torch.complex64),
11168            ),
11169        ),
11170    ),
11171    ForeachFuncInfo(
11172        "norm",
11173        sample_inputs_func=foreach_norm_sample_func(1, False, False),
11174        supports_autograd=True,
11175        supports_inplace_autograd=True,
11176        supports_forward_ad=True,
11177        decorators=(
11178            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"),
11179            DecorateInfo(
11180                unittest.expectedFailure,
11181                "TestMeta",
11182                "test_dispatch_meta_outplace",
11183                dtypes=integral_types_and(torch.bool),
11184            ),
11185            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"),
11186            DecorateInfo(
11187                unittest.expectedFailure,
11188                "TestMeta",
11189                "test_dispatch_symbolic_meta_outplace",
11190                dtypes=integral_types_and(torch.bool),
11191            ),
11192            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"),
11193            DecorateInfo(
11194                unittest.expectedFailure,
11195                "TestMeta",
11196                "test_meta_outplace",
11197                dtypes=integral_types_and(torch.bool),
11198            ),
11199            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"),
11200            DecorateInfo(
11201                unittest.expectedFailure,
11202                "TestForeach",
11203                "test_foreach_reduce_large_input",
11204                device_type="cuda",
11205                dtypes=integral_types_and(torch.bool),
11206            ),
11207        ),
11208    ),
11209]
11210
11211foreach_other_op_db: List[ForeachFuncInfo] = [
11212    ForeachFuncInfo(
11213        "lerp",
11214        sample_inputs_func=foreach_inputs_sample_func(3, True, False),
11215        supports_autograd=True,
11216        supports_inplace_autograd=True,
11217        supports_forward_ad=True,
11218        decorators=(
11219            DecorateInfo(
11220                unittest.expectedFailure,
11221                "TestMeta",
11222                "test_dispatch_meta_inplace",
11223                dtypes=integral_types_and(torch.bool),
11224            ),
11225            DecorateInfo(
11226                unittest.expectedFailure,
11227                "TestMeta",
11228                "test_dispatch_meta_outplace",
11229                dtypes=integral_types_and(torch.bool),
11230            ),
11231            DecorateInfo(
11232                unittest.expectedFailure,
11233                "TestMeta",
11234                "test_dispatch_symbolic_meta_outplace",
11235                dtypes=integral_types_and(torch.bool),
11236            ),
11237            DecorateInfo(
11238                unittest.expectedFailure,
11239                "TestMeta",
11240                "test_dispatch_symbolic_meta_inplace",
11241                dtypes=integral_types_and(torch.bool),
11242            ),
11243            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=integral_types_and(torch.bool)),
11244            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=integral_types_and(torch.bool)),
11245        ),
11246    ),
11247]
11248
11249def reference_sign(x):
11250    if x.dtype == np.bool_:
11251        # `np.sign` doesn't support `bool`.
11252        # >>> np.sign(True)
11253        # ufunc 'sign' did not contain a loop
11254        # with signature matching types dtype('bool') -> dtype('bool')
11255        return np.sign(x, dtype=np.uint8).astype(np.bool_)
11256    return np.sign(x)
11257
11258
11259def reference_sgn(x):
11260    # NumPy doesn't have an equivalent to `torch.sgn` when the dtype is complex.
11261    # For complex inputs, `np.sign` returns sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j.
11262    # while `torch.sgn` returns, 0 if abs(input) == 0 else input/abs(input)
11263    if x.dtype not in [np.complex64, np.complex128]:
11264        return reference_sign(x)
11265
11266    out = (x / np.abs(x))
11267    if out.ndim == 0:
11268        # Handle x == 0 case
11269        if (x == 0):
11270            # Can't assign to np.complex object
11271            # So make a new one.
11272            return np.array(complex(0, 0), dtype=x.dtype)
11273        return out
11274
11275    # Handle x == 0 case
11276    mask = (x == 0)
11277    out[mask] = complex(0, 0)
11278    return out
11279
11280
11281def reference_sigmoid(x):
11282    # 'scipy.special.expit' not supported for the input types
11283    if x.dtype in [np.complex64, np.complex128]:
11284        return (1 / (1 + np.exp(-x)))
11285    return scipy.special.expit(x)
11286
11287
11288def reference_logsigmoid(x):
11289    return np.where(
11290        x < 0,
11291        x - np.log1p(np.exp(x)),
11292        -np.log1p(np.exp(-x)))
11293
11294
11295def reference_hardsigmoid(x):
11296    intermediate = x / 6 + 0.5
11297    y = np.clip(intermediate, 0, None)
11298    return np.where(y > 1, 1, y).astype(x.dtype)
11299
11300
11301def reference_lgamma(x):
11302    # scipy.special.gammaln returns `-inf` when input is `-inf`.
11303    # While Pytorch, C and C++, all return `inf` when input is `-inf`.
11304    # Reference:
11305    # https://en.cppreference.com/w/cpp/numeric/math/lgamma
11306    # https://en.cppreference.com/w/c/numeric/math/lgamma
11307
11308    # To handle the above discrepancy,
11309    # we replace -inf with inf so values
11310    # that were originally -inf map to inf as expected
11311    if x.dtype.kind == 'f':
11312        x = np.where(x == float('-inf'), np.array(float('inf'), dtype=x.dtype), x)
11313
11314    out = scipy.special.gammaln(x)
11315
11316    if x.dtype == np.float16:
11317        # `scipy.special.gammaln` returns output of float32 when input is float16,
11318        # while `torch.lgamma` preserves `float16`. But due to smaller range of float16,
11319        # Pytorch version outputs `inf` while SciPy returns finite values.
11320        out = out.astype(np.float16)
11321
11322    return out
11323
11324
11325def reference_mvlgamma(x, d):
11326    if x.dtype == np.float16:
11327        return scipy.special.multigammaln(x, d).astype(np.float16)
11328
11329    return scipy.special.multigammaln(x, d)
11330
11331def reference_softplus(input, beta=1, threshold=20):
11332    non_linear = input * beta <= threshold
11333    output = input.copy()
11334    output[non_linear] = np.log(1 + np.exp(beta * input[non_linear])) / beta
11335    return output
11336
11337def reference_gelu(X, *, approximate='none'):
11338    def _gelu_ref(X):
11339        return X * stats.norm.cdf(X)
11340
11341    def _tanh_gelu_ref(X):
11342        M_SQRT_2_PI = math.sqrt(2 / math.pi)
11343        Z = M_SQRT_2_PI * (X + 0.044715 * np.power(X, 3.0))
11344        return 0.5 * X * (1.0 + np.tanh(Z))
11345
11346    if approximate == 'tanh':
11347        return _tanh_gelu_ref(X)
11348    else:
11349        return _gelu_ref(X)
11350
11351
11352def reference_one_hot(a: np.ndarray, num_classes: int = -1) -> np.ndarray:
11353    if num_classes == -1:
11354        num_classes = int(np.amax(a) + 1)
11355
11356    idcs = a.reshape(-1) + np.arange(0, a.size, dtype=np.int64) * num_classes
11357    one_hot = np.zeros((a.size, num_classes), dtype=a.dtype)
11358    np.put(one_hot, idcs, 1)
11359    return one_hot.reshape(*a.shape, -1)
11360
11361
11362def reference_mse_loss(input, target, reduction="mean"):
11363    se = (input - target) ** 2
11364    if reduction == "mean":
11365        return np.mean(se)
11366    elif reduction == "sum":
11367        return np.sum(se)
11368    else:  # reduction == "none"
11369        return se
11370
11371
11372def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5):
11373    return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0]
11374
11375
11376def reference_native_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight, bias, eps):
11377    feature_size = np.prod(normalized_shape)
11378    inp_view = inp.reshape(-1, feature_size)  # type: ignore[call-overload]
11379    mean = inp_view.mean(axis=-1, keepdims=True)
11380    var = inp_view.var(axis=-1, ddof=0, keepdims=True)
11381    Y = (inp_view - mean) / np.sqrt(var + eps)
11382    if weight is None and bias is not None:
11383        Y = Y + bias.reshape(-1)
11384    elif weight is not None and bias is None:
11385        Y = Y * weight.reshape(-1)
11386    elif weight is not None and bias is not None:
11387        Y = Y * weight.reshape(-1) + bias.reshape(-1)
11388    axis = inp.ndim - len(normalized_shape)
11389    stat_shape = inp.shape[:axis] + (1,) * len(normalized_shape)
11390    return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape)
11391
11392
11393def reference_rms_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, eps=None):
11394    if eps is None:
11395        eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps
11396    feature_size = np.prod(normalized_shape)
11397    inp_view = inp.reshape(-1, feature_size)  # type: ignore[call-overload]
11398    rms = np.sqrt((inp_view**2).mean(axis=-1, keepdims=True) + eps)
11399    Y = inp_view / rms
11400    if weight is not None:
11401        Y = Y * weight.reshape(-1)
11402    return Y.reshape(*inp.shape)
11403
11404
11405def reference_group_norm(inp: np.ndarray, num_groups: int, weight=None, bias=None, eps=1e-5):
11406    inp_view = inp
11407    if np.prod(inp.shape) != 0:
11408        inp_view = inp.reshape((inp.shape[0], num_groups, -1))
11409    mean = inp_view.mean(axis=-1, keepdims=True)
11410    var = inp_view.var(axis=-1, ddof=0, keepdims=True)
11411    Y = (inp_view - mean) / np.sqrt(var + eps)
11412    Y = Y.reshape(inp.shape)
11413    if weight is not None:
11414        # weight is a vector of length equal to the channel
11415        if len(Y.shape) > 2:
11416            weight = np.expand_dims(weight, [0] + [idx + 2 for idx in range(inp.ndim - 2)])
11417        Y = Y * weight
11418    if bias is not None:
11419        # bias is a vector of length equal to the channel
11420        if len(Y.shape) > 2:
11421            bias = np.expand_dims(bias, [0] + [idx + 2 for idx in range(inp.ndim - 2)])
11422        Y = Y + bias
11423    return Y
11424
11425
11426# using a custom reference function since numpy only has a string side arg (instead of right and side) and doesn't
11427# have an out_int32 arg. Additionally, numpy doesn't support searchsorted with ND arrays, so this splits those into
11428# stacked 1D cases
11429def reference_searchsorted(sorted_sequence, boundary, out_int32=False, right=False, side='left', sorter=None):
11430    side = 'right' if (right or side == 'right') else 'left'
11431    if len(sorted_sequence.shape) == 1 :
11432        ret = np.searchsorted(sorted_sequence, boundary, side=side, sorter=sorter)
11433        return ret.astype(np.int32) if out_int32 else ret
11434    elif sorted_sequence.shape[0] == 0:
11435        if sorter is not None:
11436            sorter = sorter.flatten()
11437        ret = np.searchsorted(sorted_sequence.flatten(), boundary.flatten(), side=side, sorter=sorter)
11438        ret = ret.astype(np.int32) if out_int32 else ret
11439        return ret.reshape(boundary.shape)
11440    else:
11441        # numpy searchsorted only supports 1D inputs so we split up ND inputs
11442        orig_shape = boundary.shape
11443        num_splits = np.prod(sorted_sequence.shape[:-1])
11444        splits = range(0, num_splits)
11445        sorted_sequence, boundary = sorted_sequence.reshape(num_splits, -1), boundary.reshape(num_splits, -1)
11446        if sorter is not None:
11447            sorter = sorter.reshape(num_splits, -1)
11448
11449        split_sequence = [sorted_sequence[i] for i in splits]
11450        split_boundary = [boundary[i] for i in splits]
11451        split_sorter = [sorter[i] if (sorter is not None) else None for i in splits]
11452
11453        split_ret = [np.searchsorted(s_seq, b, side=side, sorter=s_sort)
11454                     for (s_seq, b, s_sort) in zip(split_sequence, split_boundary, split_sorter)]
11455        split_ret = [i.astype(np.int32) for i in split_ret] if out_int32 else split_ret
11456        return np.stack(split_ret).reshape(orig_shape)
11457
11458def loss_reference_reduction_wrapper(fn):
11459    def wrapper(input, target, *, size_average=None, reduce=None, reduction="mean", **other_kwargs):
11460        if size_average is not None or reduce is not None:
11461            raise RuntimeError(
11462                "The keyword arguments 'size_average' and 'reduce' are deprecated and not supported by this wrapper"
11463            )
11464        output = fn(input, target, **other_kwargs)
11465        if reduction == "mean":
11466            return np.mean(output)
11467        elif reduction == "sum":
11468            return np.sum(output)
11469        else:  # reduction == "none"
11470            return output
11471
11472    return wrapper
11473
11474@loss_reference_reduction_wrapper
11475def reference_smooth_l1_loss(input, target, beta=1.0):
11476    diff = input - target
11477    abs_diff = np.abs(diff)
11478    above_threshold = abs_diff >= beta
11479
11480    loss = np.empty_like(input)
11481    loss[above_threshold] = abs_diff[above_threshold] - 0.5 * beta
11482    loss[~above_threshold] = diff[~above_threshold] ** 2 / (2 * beta)
11483
11484    return loss
11485
11486def reference_std_var(f):
11487    """Forwards unbiased/correction kwargs as NumPy's equivalent ddof"""
11488    g = reference_reduction_numpy(f)
11489
11490    @wraps(g)
11491    def wrapper(x: np.ndarray, *args, **kwargs):
11492        assert not ('unbiased' in kwargs and 'correction' in kwargs)
11493
11494        if 'unbiased' in kwargs:
11495            kwargs['ddof'] = int(kwargs.pop('unbiased'))
11496        elif 'correction' in kwargs:
11497            kwargs['ddof'] = kwargs.pop('correction')
11498
11499        return g(x, *args, **kwargs)
11500
11501    return wrapper
11502
11503def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
11504    """Generates unbiased/correction kwargs for std/var operators"""
11505    yield ((), {'unbiased': True})
11506    yield ((), {'unbiased': False})
11507
11508    # Currently, calling std with correction is only enabled when
11509    # both dim and keepdim are provided.
11510    if 'dim' in kwargs and 'keepdim' in kwargs:
11511        yield ((), {'correction': 0})
11512        yield ((), {'correction': 1})
11513
11514        numel = torch.tensor(t.shape)[kwargs.get('dim')].prod()
11515        yield ((), {'correction': numel // 2})
11516
11517def error_inputs_mean(op_info, device, is_ref=False, **kwargs):
11518    if is_ref:
11519        err_msg1 = (r"mean\(\): could not infer output dtype. "
11520                    r"Input dtype must be either a floating point or complex dtype. "
11521                    r"Got: torch.int64")
11522    else:
11523        err_msg1 = (r"mean\(\): could not infer output dtype. "
11524                    r"Input dtype must be either a floating point or complex dtype. "
11525                    r"Got: Long")
11526    yield ErrorInput(
11527        SampleInput(make_tensor((3, 4, 5), dtype=torch.int64, device=device), []),
11528        error_regex=err_msg1,
11529    )
11530
11531    if is_ref:
11532        err_msg2 = (r"mean\(\): could not infer output dtype. "
11533                    r"Optional dtype must be either a floating point or complex dtype. "
11534                    r"Got: torch.int64")
11535    else:
11536        err_msg2 = (r"mean\(\): could not infer output dtype. "
11537                    r"Optional dtype must be either a floating point or complex dtype. "
11538                    r"Got: Long")
11539    yield ErrorInput(
11540        SampleInput(
11541            make_tensor((3, 4, 5), dtype=torch.float32, device=device),
11542            [],
11543            dtype=torch.int64),
11544        error_regex=err_msg2
11545    )
11546
11547    if is_ref:
11548        err_msg3 = "Expected out tensor to have dtype torch.float64, but got torch.float32 instead"
11549    else:
11550        err_msg3 = "Expected out tensor to have dtype double, but got float instead"
11551    yield ErrorInput(
11552        SampleInput(
11553            make_tensor((3, 4, 5), dtype=torch.int64, device=device),
11554            [],
11555            dtype=torch.float64,
11556            out=make_tensor([], dtype=torch.float32, device=device),
11557        ),
11558        error_regex=err_msg3
11559    )
11560
11561# numpy implementation of torch.flatten
11562# unfortunately there's no np.flatten. we figure out the desired shape and call np.reshape
11563def reference_flatten(input, start_dim=0, end_dim=-1):
11564    in_shape = input.shape
11565    in_rank = len(in_shape)
11566    for d in start_dim, end_dim:
11567        if not ((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank):
11568            raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank - 1}], but got {d}")
11569    end_dim = end_dim if end_dim >= 0 else in_rank + end_dim
11570    start_dim = start_dim if start_dim >= 0 else in_rank + start_dim
11571    if in_rank == 0:
11572        end_dim = start_dim
11573    if end_dim < start_dim:
11574        raise RuntimeError("flatten() has invalid args: start_dim cannot come after end_dim")
11575    flatten_bit_dim = functools.reduce(operator.mul, in_shape[start_dim:end_dim + 1], 1)
11576    out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:]
11577    return np.reshape(input, out_shape)
11578
11579
11580def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
11581    yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad))
11582    yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad))
11583
11584
11585# Operator database (sorted alphabetically)
11586op_db: List[OpInfo] = [
11587    UnaryUfuncInfo('abs',
11588                   aliases=('absolute', ),
11589                   ref=np.abs,
11590                   dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
11591                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
11592                   skips=(
11593                       DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients',
11594                                    'test_inplace_grad', dtypes=(torch.cdouble,)),
11595                       DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients',
11596                                    'test_inplace_gradgrad', dtypes=(torch.cdouble,)),
11597                       DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestFwdGradients',
11598                                    'test_inplace_forward_mode_AD', dtypes=(torch.cdouble,)),
11599                       DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestSparseUnaryUfuncs",
11600                                    "test_inplace", dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
11601                       # Reference: https://github.com/pytorch/pytorch/issues/49224
11602                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
11603                                    dtypes=[torch.int8], active_if=TEST_WITH_ASAN),
11604                       # TODO: Fix test_out_arg_all_dtypes as torch.empty_like(expected_output) where expected_output=op(input)
11605                       # We can break the logic of the loop over all possible types but it is OK.
11606                       # https://github.com/pytorch/pytorch/blob/master/test/test_unary_ufuncs.py#L440-L449
11607                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes',
11608                                    dtypes=[torch.cfloat, torch.cdouble]),
11609                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace',
11610                                    dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
11611                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace',
11612                                    dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
11613                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace',
11614                                    dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
11615                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides',
11616                                    dtypes=(torch.cdouble, torch.cfloat, torch.chalf)),
11617                   ),
11618                   supports_fwgrad_bwgrad=True,
11619                   assert_autodiffed=True,
11620                   supports_sparse=True,
11621                   supports_sparse_csr=True,
11622                   supports_sparse_csc=True,
11623                   supports_sparse_bsr=True,
11624                   supports_sparse_bsc=True,
11625                   supports_forward_ad=True),
11626    # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952)
11627    UnaryUfuncInfo('acos',
11628                   aliases=('arccos', ),
11629                   ref=np.arccos,
11630                   domain=(-1, 1),
11631                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
11632                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
11633                   assert_autodiffed=True,
11634                   supports_forward_ad=True,
11635                   supports_fwgrad_bwgrad=True,
11636                   promotes_int_to_float=True,
11637                   decorators=(precisionOverride({torch.float16: 1e-2,
11638                                                  torch.bfloat16: 1e-1,
11639                                                  torch.complex64: 1e-2}),),
11640                   skips=(
11641                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
11642                                    device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
11643                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
11644                                    device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
11645                       # Failing with wrong imaginary sign on at least some Windows jobs
11646                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
11647                                    device_type='cuda', dtypes=[torch.cdouble],
11648                                    active_if=IS_WINDOWS),
11649                       # Failing with wrong imaginary sign on at least some Windows jobs
11650                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
11651                                    device_type='cuda', dtypes=[torch.cdouble],
11652                                    active_if=IS_WINDOWS),
11653                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
11654                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
11655                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
11656                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
11657                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad',
11658                                    dtypes=[torch.cdouble], active_if=IS_WINDOWS),
11659                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_method_grad',
11660                                    dtypes=[torch.cdouble], active_if=IS_WINDOWS),
11661                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_inplace_grad',
11662                                    dtypes=[torch.cdouble], active_if=IS_WINDOWS),
11663                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
11664                                    dtypes=[torch.cdouble], active_if=IS_WINDOWS),
11665                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_inplace_forward_mode_AD',
11666                                    dtypes=[torch.cdouble], active_if=IS_WINDOWS),)),
11667    # NOTE: the derivative for inplace acosh is not implemented
11668    UnaryUfuncInfo('acosh',
11669                   aliases=('arccosh', ),
11670                   ref=np.arccosh,
11671                   domain=(1, None),
11672                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
11673                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
11674                   decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
11675                   supports_inplace_autograd=False,
11676                   supports_forward_ad=True,
11677                   supports_fwgrad_bwgrad=True,
11678                   promotes_int_to_float=True,
11679                   skips=(
11680                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
11681                                    device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
11682                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
11683                                    device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS),
11684                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
11685                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
11686                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
11687                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
11688                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
11689                                    device_type='cuda', dtypes=[torch.cdouble],
11690                                    active_if=IS_WINDOWS),
11691                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
11692                                    device_type='cuda', dtypes=[torch.cdouble],
11693                                    active_if=IS_WINDOWS),
11694                       # Failing with wrong imaginary sign on at least some Windows jobs
11695                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
11696                                    device_type='cuda', dtypes=[torch.cdouble],
11697                                    active_if=IS_WINDOWS),
11698                   ),
11699                   # acosh is not defined at x < 1 (real)
11700                   reference_numerics_filter=NumericsFilter(
11701                       condition=lambda x: (x < 1 if not x.is_complex() else torch.zeros_like(x, dtype=torch.bool)),
11702                       safe_val=2)),
11703    BinaryUfuncInfo('add',
11704                    # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
11705                    ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
11706                    else np.add(input, np.multiply(alpha, other)),
11707                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
11708                                                     torch.float16, torch.chalf),
11709                    assert_autodiffed=True,
11710                    sample_inputs_func=sample_inputs_add_sub,
11711                    supports_fwgrad_bwgrad=True,
11712                    supports_forward_ad=True,
11713                    supports_two_python_scalars=True,
11714                    decorators=(
11715                        DecorateInfo(
11716                            toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
11717                            'TestBinaryUfuncs', 'test_reference_numerics'),
11718                    ),
11719                    skips=(
11720                        # boolean alpha not handled properly
11721                        DecorateInfo(unittest.expectedFailure,
11722                                     'TestNNCOpInfo',
11723                                     'test_nnc_correctness',
11724                                     dtypes=(torch.bool,)),
11725                        DecorateInfo(unittest.skip("Skipped!"),
11726                                     'TestCommon',
11727                                     'test_numpy_refs',
11728                                     dtypes=(torch.complex128,)),
11729                        DecorateInfo(unittest.skip("Skipped!"),
11730                                     'TestBinaryUfuncs',
11731                                     'test_reference_numerics_extremal_values',
11732                                     dtypes=(torch.complex64, torch.complex128)),
11733                    )),
11734    OpInfo('item',
11735           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.item, inp, *args, **kwargs),
11736           ref=np.ndarray.item,
11737           method_variant=None,
11738           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf, torch.bool),
11739           supports_out=False,
11740           supports_autograd=False,
11741           error_inputs_func=error_inputs_item,
11742           sample_inputs_func=sample_inputs_item,
11743           skips=(
11744               # Error testing item function variant
11745               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
11746                            dtypes=(torch.float32, torch.complex64)),
11747               # FX failed to normalize op - add the op to the op_skip list.
11748               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
11749               # RuntimeError: Composite compliance check failed with the above error.
11750               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
11751               # Booleans mismatch: AssertionError: False is not true
11752               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'),
11753               # Booleans mismatch: AssertionError: False is not true
11754               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'),
11755           )),
11756    OpInfo('arange',
11757           dtypes=all_types_and(torch.bfloat16, torch.float16),
11758           supports_out=True,
11759           supports_autograd=False,
11760           is_factory_function=True,
11761           error_inputs_func=error_inputs_arange,
11762           sample_inputs_func=sample_inputs_arange,
11763           skips=(
11764               # https://github.com/pytorch/pytorch/issues/81774
11765               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
11766
11767               # Tests that assume input is a tensor or sequence of tensors
11768               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
11769               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
11770               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
11771               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
11772
11773               # Lazy tensor failures
11774               DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
11775               DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'),
11776               DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
11777
11778               # Exception raised from analyzeImpl at ../torch/csrc/jit/ir/alias_analysis.cpp:608
11779               # We don't have an op for aten::arange but it isn't a special case.
11780               # Argument types: bool, bool, bool, int, int, Device, boo
11781               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
11782
11783               # Captured graph does not contain aten::arange (succeeds on complex!)
11784               # g: graph():
11785               #   %25 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={1}]()
11786               #   return (%25)
11787               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
11788
11789               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
11790               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
11791           )),
11792    OpInfo('cauchy',
11793           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.cauchy_, inp, *args, **kwargs),
11794           inplace_variant=torch.Tensor.cauchy_,
11795           dtypes=floating_types_and(torch.float16, torch.bfloat16),
11796           supports_out=False,
11797           supports_autograd=False,
11798           allow_cow_input_materialize_forward=[0],
11799           sample_inputs_func=sample_inputs_cauchy,
11800           error_inputs_func=error_inputs_cauchy,
11801           skips=(
11802               # Tests that assume input tensor has a meaningful effect on output tensor
11803               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
11804               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
11805
11806               # AssertionError: JIT Test does not execute any logic
11807               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
11808
11809               # AssertionError: Tensor-likes are not close!
11810               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
11811
11812               # FX failed to normalize op - add the op to the op_skip list.
11813               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
11814
11815               # vmap: calling random operator not supported
11816               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
11817               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
11818
11819               DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'),
11820
11821               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
11822           )),
11823    OpInfo('exponential',
11824           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.exponential_, inp, *args, **kwargs),
11825           inplace_variant=torch.Tensor.exponential_,
11826           dtypes=floating_types_and(torch.float16, torch.bfloat16),
11827           supports_out=False,
11828           supports_autograd=False,
11829           allow_cow_input_materialize_forward=[0],
11830           sample_inputs_func=sample_inputs_exponential,
11831           error_inputs_func=error_inputs_exponential,
11832           skips=(
11833               # Tests that assume input tensor has a meaningful effect on output tensor
11834               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
11835               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
11836
11837               # AssertionError: JIT Test does not execute any logic
11838               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
11839
11840               # AssertionError: Tensor-likes are not close!
11841               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
11842
11843               # FX failed to normalize op - add the op to the op_skip list.
11844               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
11845
11846               # vmap: calling random operator not supported
11847               DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
11848               DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
11849
11850               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
11851               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
11852           )),
11853    OpInfo('geometric',
11854           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.geometric_, inp, *args, **kwargs),
11855           inplace_variant=torch.Tensor.geometric_,
11856           dtypes=floating_types_and(torch.float16, torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8),
11857           supports_out=False,
11858           supports_autograd=False,
11859           allow_cow_input_materialize_forward=[0],
11860           sample_inputs_func=sample_inputs_geometric,
11861           error_inputs_func=error_inputs_geometric,
11862           skips=(
11863               # Tests that assume input tensor has a meaningful effect on output tensor
11864               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
11865               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
11866
11867               # AssertionError: JIT Test does not execute any logic
11868               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
11869
11870               # AssertionError: Tensor-likes are not close!
11871               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
11872
11873               # FX failed to normalize op - add the op to the op_skip list.
11874               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
11875
11876               # vmap: calling random operator not supported
11877               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
11878               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
11879
11880               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
11881           )),
11882    OpInfo('log_normal',
11883           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.log_normal_, inp, *args, **kwargs),
11884           inplace_variant=torch.Tensor.log_normal_,
11885           dtypes=floating_types_and(torch.float16, torch.bfloat16),
11886           supports_out=False,
11887           supports_autograd=False,
11888           allow_cow_input_materialize_forward=[0],
11889           sample_inputs_func=sample_inputs_log_normal,
11890           error_inputs_func=error_inputs_log_normal,
11891           skips=(
11892               # Tests that assume input tensor has a meaningful effect on output tensor
11893               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
11894               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
11895
11896               # AssertionError: JIT Test does not execute any logic
11897               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
11898
11899               # AssertionError: Tensor-likes are not close!
11900               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
11901               # FX failed to normalize op - add the op to the op_skip list.
11902               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
11903
11904               # vmap: calling random operator not supported
11905               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
11906               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
11907               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
11908           )),
11909    OpInfo('normal',
11910           variant_test_name='in_place',
11911           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.normal_, inp, *args, **kwargs),
11912           inplace_variant=torch.Tensor.normal_,
11913           dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
11914           supports_out=False,
11915           supports_autograd=False,
11916           allow_cow_input_materialize_forward=[0],
11917           sample_inputs_func=sample_inputs_normal,
11918           error_inputs_func=error_inputs_normal,
11919           skips=(
11920               # Tests that assume input is a tensor or sequence of tensors
11921               DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
11922
11923               # Tests that assume input tensor has a meaningful effect on output tensor
11924               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
11925               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
11926               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
11927               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
11928               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
11929               # AssertionError: JIT Test does not execute any logic
11930               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
11931               # AssertionError: Tensor-likes are not close!
11932               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
11933               # FX failed to normalize op - add the op to the op_skip list.
11934               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
11935               # vmap: calling random operator not supported
11936               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
11937               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
11938           )),
11939    OpInfo('uniform',
11940           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.uniform_, inp, *args, **kwargs),
11941           method_variant=None,
11942           inplace_variant=torch.Tensor.uniform_,
11943           dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
11944           supports_out=False,
11945           supports_autograd=False,
11946           is_factory_function=False,
11947           allow_cow_input_materialize_forward=[0],
11948           sample_inputs_func=sample_inputs_uniform,
11949           error_inputs_func=error_inputs_uniform,
11950           skips=(
11951               # FX failed to normalize op - add the op to the op_skip list.
11952               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
11953               # Tests that assume input tensor has a meaningful effect on output tensor
11954               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
11955               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
11956               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
11957               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
11958               # AssertionError: JIT Test does not execute any logic
11959               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
11960               # aten.uniform was not decomposed
11961               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
11962               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
11963           )),
11964    BinaryUfuncInfo('clamp_max',
11965                    ref=_clamp_max_numpy,
11966                    dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
11967                    supports_forward_ad=True,
11968                    supports_rhs_python_scalar=False,
11969                    supports_fwgrad_bwgrad=True,
11970                    rhs_make_tensor_kwargs=dict(exclude_zero=False),
11971                    skips=(
11972                        # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
11973                        DecorateInfo(unittest.expectedFailure,
11974                                     'TestBinaryUfuncs',
11975                                     'test_type_promotion',
11976                                     device_type='cuda'),
11977                        # dispatch to lazy test failed
11978                        DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
11979                        # test error disabled since rhs non-tensor python scalar is supported
11980                        DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'),
11981                    )),
11982    BinaryUfuncInfo('clamp_min',
11983                    ref=_clamp_min_numpy,
11984                    dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
11985                    supports_forward_ad=True,
11986                    supports_rhs_python_scalar=False,
11987                    supports_fwgrad_bwgrad=True,
11988                    rhs_make_tensor_kwargs=dict(exclude_zero=False),
11989                    skips=(
11990                        # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
11991                        DecorateInfo(unittest.expectedFailure,
11992                                     'TestBinaryUfuncs',
11993                                     'test_type_promotion',
11994                                     device_type='cuda'),
11995                        # dispatch to lazy test failed
11996                        DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
11997                        # test error disabled since rhs non-tensor python scalar is supported
11998                        DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'),
11999                    )),
12000    BinaryUfuncInfo('mul',
12001                    aliases=('multiply',),
12002                    dtypes=all_types_and_complex_and(torch.chalf, torch.float16, torch.bfloat16, torch.bool),
12003                    assert_autodiffed=True,
12004                    supports_forward_ad=True,
12005                    supports_fwgrad_bwgrad=True,
12006                    supports_two_python_scalars=True,
12007                    error_inputs_sparse_func=error_inputs_sparse_mul,
12008                    sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_coo),
12009                    sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csr),
12010                    sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csc),
12011                    sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsr),
12012                    sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsc)),
12013    BinaryUfuncInfo('sub',
12014                    # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
12015                    ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)),
12016                    aliases=('subtract',),
12017                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf),
12018                    assert_autodiffed=True,
12019                    supports_forward_ad=True,
12020                    supports_fwgrad_bwgrad=True,
12021                    sample_inputs_func=sample_inputs_add_sub,
12022                    supports_two_python_scalars=True,
12023                    decorators=(
12024                        DecorateInfo(
12025                            toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0),
12026                                               torch.bfloat16: tol(atol=1e-5, rtol=5e-3),
12027                                               torch.complex32: tol(atol=1e-5, rtol=1e-3)}),
12028                            'TestBinaryUfuncs', 'test_reference_numerics'),
12029                        DecorateInfo(
12030                            toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
12031                            'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'),
12032                        DecorateInfo(
12033                            toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
12034                            'TestDecomp', 'test_comprehensive', device_type='cpu'),
12035                        DecorateInfo(
12036                            toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
12037                            'TestDecomp', 'test_quick', device_type='cpu'),
12038                    ),
12039                    skips=(
12040                        DecorateInfo(unittest.skip("Skipped!"),
12041                                     'TestBinaryUfuncs',
12042                                     'test_reference_numerics',
12043                                     dtypes=(torch.uint8,)),
12044                        DecorateInfo(unittest.skip("Skipped!"),
12045                                     'TestBinaryUfuncs',
12046                                     'test_reference_numerics_small_values',
12047                                     dtypes=(torch.uint8,)),
12048                    )),
12049    OpInfo('addmm',
12050           # This addmm OpInfo is for when alpha and beta are not both equal to 1.
12051           # alpha=beta=1 is tested in the following opinfo, because that special case will
12052           # trigger addmm being decomposed by a jit pass.
12053           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
12054           dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
12055           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
12056           assert_autodiffed=True,
12057           supports_forward_ad=True,
12058           supports_fwgrad_bwgrad=True,
12059           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
12060           sample_inputs_func=sample_inputs_addmm,
12061           skips=(
12062               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
12063               DecorateInfo(
12064                   unittest.skip("Skipped!"),
12065                   'TestSchemaCheckModeOpInfo',
12066                   'test_schema_correctness',
12067                   dtypes=(torch.complex64, torch.complex128)),
12068           )),
12069    OpInfo('addmm',
12070           # When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add.
12071           variant_test_name='decomposed',
12072           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
12073           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
12074           assert_autodiffed=True,
12075           supports_forward_ad=True,
12076           supports_fwgrad_bwgrad=True,
12077           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
12078           autodiff_nonfusible_nodes=['aten::add', 'aten::mm'],
12079           sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1),
12080           skips=(
12081               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
12082               DecorateInfo(
12083                   unittest.skip("Skipped!"),
12084                   'TestSchemaCheckModeOpInfo',
12085                   'test_schema_correctness',
12086                   dtypes=(torch.complex64, torch.complex128)),
12087               # https://github.com/pytorch/pytorch/issues/71784
12088               DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
12089                            device_type='cpu', dtypes=(torch.float16,)),
12090           )),
12091    OpInfo('addmv',
12092           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
12093           dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
12094                                           torch.bfloat16),
12095           supports_forward_ad=True,
12096           supports_fwgrad_bwgrad=True,
12097           decorators=[
12098               DecorateInfo(
12099                   toleranceOverride({torch.half: tol(atol=1e-5, rtol=3e-3)}),
12100                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
12101           ],
12102           sample_inputs_func=sample_inputs_addmv),
12103    OpInfo('addbmm',
12104           ref=lambda M, batch1, batch2, beta=1, alpha=1: np.add(np.multiply(np.asarray(beta, dtype=M.dtype), M),
12105                                                                 np.multiply(np.asarray(alpha, dtype=batch1.dtype),
12106                                                                             np.sum(np.matmul(batch1, batch2), axis=0))),
12107           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
12108           dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
12109                                                       *[torch.bfloat16]
12110                                                       if SM53OrLater or TEST_WITH_ROCM else []),
12111           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
12112           gradcheck_fast_mode=True,
12113           supports_forward_ad=True,
12114           supports_fwgrad_bwgrad=True,
12115           decorators=[
12116               DecorateInfo(
12117                   toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=1.3e-05),
12118                                      torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
12119                   'TestCommon', 'test_numpy_refs'),
12120               # MPS has slightly worse precision. Is this acceptable?
12121               DecorateInfo(
12122                   toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-04),
12123                                      torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
12124                   'TestCommon', 'test_numpy_ref_mps'),
12125               DecorateInfo(
12126                   toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
12127                   'TestConsistency',
12128                   'test_output_match',
12129               ),
12130               DecorateInfo(
12131                   toleranceOverride({torch.float32: tol(atol=1.5e-05, rtol=1e-05)}),
12132                   'TestCommon', 'test_out'),
12133               DecorateInfo(
12134                   toleranceOverride({torch.half: tol(atol=6e-3, rtol=1e-2)}),
12135                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
12136           ],
12137           skips=(
12138               # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
12139               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
12140               # addbmm does not correctly warn when resizing out= inputs
12141               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
12142               # https://github.com/pytorch/pytorch/issues/55907
12143               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
12144           ),
12145           sample_inputs_func=sample_inputs_addbmm),
12146    OpInfo('baddbmm',
12147           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
12148           dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128,
12149                                           torch.bfloat16),
12150           backward_dtypesIfCUDA=floating_types_and(torch.float16,
12151                                                    *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else [],
12152                                                    torch.complex64, torch.complex128),
12153           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
12154           gradcheck_fast_mode=True,
12155           supports_forward_ad=True,
12156           supports_fwgrad_bwgrad=True,
12157           decorators=[
12158               DecorateInfo(
12159                   toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
12160                   'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
12161               DecorateInfo(
12162                   toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
12163                   'TestMathBits', 'test_conj_view', device_type='cuda'),
12164           ],
12165           sample_inputs_func=sample_inputs_baddbmm,
12166           skips=(
12167               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
12168               DecorateInfo(
12169                   unittest.skip("Skipped!"),
12170                   'TestSchemaCheckModeOpInfo',
12171                   'test_schema_correctness',
12172                   dtypes=(torch.complex64, torch.complex128)),
12173           )),
12174    OpInfo('dot',
12175           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
12176           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
12177           assert_autodiffed=True,
12178           sample_inputs_func=sample_inputs_dot_vdot,
12179           error_inputs_func=error_inputs_dot_vdot,
12180           supports_forward_ad=True,
12181           supports_fwgrad_bwgrad=True,
12182           skips=(
12183               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
12184               DecorateInfo(
12185                   unittest.skip("Skipped!"),
12186                   'TestSchemaCheckModeOpInfo',
12187                   'test_schema_correctness',
12188                   dtypes=(torch.complex64, torch.complex128)),
12189           )),
12190    OpInfo('vdot',
12191           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
12192           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
12193           sample_inputs_func=sample_inputs_dot_vdot,
12194           error_inputs_func=error_inputs_dot_vdot,
12195           supports_forward_ad=True,
12196           supports_fwgrad_bwgrad=True,
12197           skips=(
12198               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
12199               DecorateInfo(
12200                   unittest.skip("Skipped!"),
12201                   'TestSchemaCheckModeOpInfo',
12202                   'test_schema_correctness',
12203                   dtypes=(torch.complex64, torch.complex128)),
12204           )),
12205    OpInfo('bmm',
12206           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
12207           dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
12208                                                       *[torch.bfloat16]
12209                                                       if SM53OrLater or TEST_WITH_ROCM else []),
12210           assert_autodiffed=True,
12211           assert_jit_shape_analysis=True,
12212           supports_forward_ad=True,
12213           supports_fwgrad_bwgrad=True,
12214           skips=(
12215               # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
12216               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
12217               DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}),
12218                            "TestCommon", "test_out")
12219           ),
12220           sample_inputs_func=sample_inputs_bmm),
12221    OpInfo('mv',
12222           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
12223           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
12224           assert_autodiffed=True,
12225           supports_forward_ad=True,
12226           supports_fwgrad_bwgrad=True,
12227           sample_inputs_func=sample_inputs_mv),
12228    OpInfo('addr',
12229           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
12230           # Reference: https://github.com/pytorch/pytorch/issues/50747
12231           supports_forward_ad=True,
12232           supports_fwgrad_bwgrad=True,
12233           skips=(
12234               # Reference: https://github.com/pytorch/pytorch/issues/50747
12235               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
12236                            dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)),
12237           ),
12238           sample_inputs_func=sample_inputs_addr,
12239           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
12240    OpInfo('addcmul',
12241           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
12242           assert_autodiffed=True,
12243           supports_forward_ad=True,
12244           supports_fwgrad_bwgrad=True,
12245           skips=(
12246               # TODO: update sample inputs with for_inplace_variant kwarg to support this test
12247               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
12248           ),
12249           sample_inputs_func=sample_inputs_addcmul_addcdiv,
12250           reference_inputs_func=partial(
12251               reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)),
12252    OpInfo('addcdiv',
12253           dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
12254           supports_forward_ad=True,
12255           supports_fwgrad_bwgrad=True,
12256           skips=(
12257               # TODO: update sample inputs with for_inplace_variant kwarg to support this test
12258               DecorateInfo(unittest.expectedFailure,
12259                            'TestCommon',
12260                            'test_variant_consistency_eager'),
12261           ),
12262           sample_inputs_func=sample_inputs_addcmul_addcdiv,
12263           reference_inputs_func=partial(
12264               reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)),
12265    UnaryUfuncInfo('asin',
12266                   aliases=('arcsin', ),
12267                   ref=np.arcsin,
12268                   domain=(-1, 1),
12269                   supports_sparse=True,
12270                   supports_sparse_csr=True,
12271                   supports_sparse_csc=True,
12272                   supports_sparse_bsr=True,
12273                   supports_sparse_bsc=True,
12274                   supports_forward_ad=True,
12275                   supports_fwgrad_bwgrad=True,
12276                   promotes_int_to_float=True,
12277                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
12278                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
12279                   assert_autodiffed=True,
12280                   decorators=[
12281                       DecorateInfo(
12282                           toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}),
12283                           'TestUnaryUfuncs', device_type='cuda'
12284                       ),
12285                       DecorateInfo(
12286                           toleranceOverride({torch.float32: tol(atol=8e-5, rtol=4e-5)}),
12287                           'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'
12288                       ),
12289                       precisionOverride({torch.bfloat16: 1e-2}),
12290                   ],
12291                   skips=(
12292                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12293                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12294                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12295                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12296                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12297                                    device_type='cuda', dtypes=[torch.cdouble],
12298                                    active_if=IS_WINDOWS),
12299                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12300                                    device_type='cuda', dtypes=[torch.cdouble],
12301                                    active_if=IS_WINDOWS),
12302                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
12303                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
12304                   )),
12305    # NOTE: derivative for inplace asinh is not implemented
12306    UnaryUfuncInfo('asinh',
12307                   aliases=('arcsinh', ),
12308                   ref=np.arcsinh,
12309                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
12310                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
12311                   decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
12312                   supports_inplace_autograd=False,
12313                   supports_forward_ad=True,
12314                   supports_fwgrad_bwgrad=True,
12315                   supports_sparse=True,
12316                   supports_sparse_csr=True,
12317                   supports_sparse_csc=True,
12318                   supports_sparse_bsr=True,
12319                   supports_sparse_bsc=True,
12320                   promotes_int_to_float=True,
12321                   skips=(
12322                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12323                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12324                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12325                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12326                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
12327                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12328                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
12329                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12330                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12331                                    device_type='cuda', dtypes=[torch.cdouble],
12332                                    active_if=IS_WINDOWS),
12333                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12334                                    device_type='cuda', dtypes=[torch.cdouble],
12335                                    active_if=IS_WINDOWS),
12336                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
12337                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
12338                   )),
12339    UnaryUfuncInfo('atan',
12340                   aliases=('arctan', ),
12341                   ref=np.arctan,
12342                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
12343                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
12344                   assert_autodiffed=True,
12345                   supports_forward_ad=True,
12346                   supports_fwgrad_bwgrad=True,
12347                   supports_sparse=True,
12348                   supports_sparse_csr=True,
12349                   supports_sparse_csc=True,
12350                   supports_sparse_bsr=True,
12351                   supports_sparse_bsc=True,
12352                   promotes_int_to_float=True,
12353                   decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
12354                   skips=(
12355                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12356                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12357                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12358                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12359                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
12360                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12361                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12362                                    device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
12363                                    active_if=IS_WINDOWS),
12364                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12365                                    device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
12366                                    active_if=IS_WINDOWS),
12367                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
12368                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
12369                   )),
12370    BinaryUfuncInfo('atan2',
12371                    aliases=('arctan2',),
12372                    dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
12373                    supports_forward_ad=True,
12374                    supports_fwgrad_bwgrad=True,
12375                    promotes_int_to_float=True,
12376                    supports_rhs_python_scalar=False,
12377                    skips=(
12378                        # Incorrectly attempts to use a scalar for the second argument
12379                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
12380                    )),
12381    UnaryUfuncInfo('atanh',
12382                   aliases=('arctanh', ),
12383                   ref=np.arctanh,
12384                   domain=(-1, 1),
12385                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
12386                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
12387                   decorators=[
12388                       precisionOverride({torch.bfloat16: 1e-2}),
12389                       DecorateInfo(
12390                           toleranceOverride({torch.float32: tol(atol=9e-3, rtol=8e-5)}),
12391                           "TestInductorOpInfo",
12392                           "test_comprehensive",
12393                           device_type="cuda"
12394                       ),
12395                   ],
12396                   supports_inplace_autograd=False,
12397                   supports_forward_ad=True,
12398                   supports_fwgrad_bwgrad=True,
12399                   supports_sparse=True,
12400                   supports_sparse_csr=True,
12401                   supports_sparse_csc=True,
12402                   supports_sparse_bsr=True,
12403                   supports_sparse_bsc=True,
12404                   promotes_int_to_float=True,
12405                   skips=(
12406                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
12407                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12408                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12409                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12410                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12411                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
12412                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12413                                    device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
12414                                    active_if=IS_WINDOWS),
12415                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12416                                    device_type='cuda', dtypes=[torch.cfloat],
12417                                    active_if=IS_WINDOWS),
12418                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
12419                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
12420                   )),
12421    OpInfo('allclose',
12422           dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
12423           ref=np.allclose,
12424           supports_autograd=False,
12425           supports_forward_ad=False,
12426           sample_inputs_func=sample_inputs_allclose,
12427           skips=(
12428               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
12429               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
12430               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
12431           ),
12432           supports_out=False),
12433    OpInfo('broadcast_to',
12434           ref=np.broadcast_to,
12435           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
12436           supports_out=False,
12437           supports_forward_ad=True,
12438           supports_fwgrad_bwgrad=True,
12439           # See https://github.com/pytorch/pytorch/pull/78358
12440           check_batched_forward_grad=False,
12441           sample_inputs_func=sample_inputs_broadcast_to),
12442    OpInfo('broadcast_shapes',
12443           op=torch.broadcast_shapes,
12444           ref=np.broadcast_shapes if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None,
12445           dtypes=_dispatch_dtypes((torch.float32,)),
12446           supports_out=False,
12447           supports_gradgrad=False,
12448           assert_autodiffed=False,
12449           supports_autograd=False,
12450           supports_scripting=False,
12451           sample_inputs_func=sample_inputs_broadcast_shapes,
12452           skips=(
12453               # https://github.com/pytorch/pytorch/issues/64997
12454               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
12455               # skip dtype tests since broadcast_shape is not device dependent.
12456               # having dtypes limited to torch.float32 would cause test_dtypes to report unexpected success
12457               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
12458               # skip these tests since we have non tensor input
12459               DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
12460               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
12461               DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
12462           )),
12463    OpInfo('broadcast_tensors',
12464           ref=np.broadcast_arrays,
12465           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
12466           sample_inputs_func=sample_inputs_broadcast_tensors,
12467           reference_inputs_func=reference_inputs_broadcast_tensors,
12468           supports_out=False,
12469           supports_forward_ad=True,
12470           supports_fwgrad_bwgrad=True,
12471           # See https://github.com/pytorch/pytorch/pull/78358
12472           check_batched_forward_grad=False,
12473           skips=(
12474               # https://github.com/pytorch/pytorch/issues/64997
12475               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
12476               # JIT does not support variadic tensors.
12477               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
12478               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
12479               # please report a bug to PyTorch.
12480               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
12481           )),
12482    OpInfo('block_diag',
12483           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
12484           supports_out=False,
12485           supports_forward_ad=True,
12486           supports_fwgrad_bwgrad=True,
12487           # Default batching rule in core doesn't work for ops with TensorList args
12488           check_batched_forward_grad=False,
12489           skips=(
12490               # https://github.com/pytorch/pytorch/issues/64997
12491               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
12492               # JIT does not support variadic tensors.
12493               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
12494               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
12495               # please report a bug to PyTorch.
12496               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
12497           ),
12498           sample_inputs_func=sample_inputs_block_diag),
12499    UnaryUfuncInfo('bitwise_not',
12500                   ref=np.bitwise_not,
12501                   dtypes=integral_types_and(torch.bool),
12502                   operator_variant=operator.invert,
12503                   supports_autograd=False),
12504    BinaryUfuncInfo('bitwise_left_shift',
12505                    op=torch.bitwise_left_shift,
12506                    dtypes=integral_types(),
12507                    dtypesIfCUDA=integral_types(),
12508                    operator_variant=operator.lshift,
12509                    inplace_operator_variant=operator.ilshift,
12510                    supports_autograd=False,
12511                    supports_one_python_scalar=True,
12512                    rhs_make_tensor_kwargs=dict(low=0),
12513                    skips=(
12514                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
12515                        # https://github.com/pytorch/pytorch/issues/70904
12516                        DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
12517                    )),
12518    BinaryUfuncInfo('bitwise_right_shift',
12519                    op=torch.bitwise_right_shift,
12520                    dtypes=integral_types(),
12521                    dtypesIfCUDA=integral_types(),
12522                    operator_variant=operator.rshift,
12523                    inplace_operator_variant=operator.irshift,
12524                    supports_autograd=False,
12525                    supports_one_python_scalar=True,
12526                    rhs_make_tensor_kwargs=dict(low=0),
12527                    skips=(
12528                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
12529                        # https://github.com/pytorch/pytorch/issues/70904
12530                        DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
12531                    )),
12532    OpInfo('combinations',
12533           op=torch.combinations,
12534           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
12535           supports_forward_ad=True,
12536           supports_fwgrad_bwgrad=True,
12537           # See https://github.com/pytorch/pytorch/pull/78358
12538           check_batched_forward_grad=False,
12539           supports_out=False,
12540           sample_inputs_func=sample_inputs_combinations),
12541    OpInfo('cartesian_prod',
12542           op=torch.cartesian_prod,
12543           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
12544           supports_out=False,
12545           supports_forward_ad=True,
12546           supports_fwgrad_bwgrad=True,
12547           # See https://github.com/pytorch/pytorch/pull/78358
12548           check_batched_forward_grad=False,
12549           sample_inputs_func=sample_inputs_cartesian_prod,
12550           skips=(
12551               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
12552               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
12553               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
12554               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
12555               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
12556               DecorateInfo(unittest.expectedFailure,
12557                            'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
12558           )),
12559    OpInfo('cdist',
12560           dtypes=floating_types(),
12561           supports_out=False,
12562           supports_gradgrad=False,
12563           assert_autodiffed=False,
12564           sample_inputs_func=sample_inputs_cdist),
12565    UnaryUfuncInfo('ceil',
12566                   ref=np.ceil,
12567                   dtypes=all_types_and(torch.half, torch.bfloat16),
12568                   supports_forward_ad=True,
12569                   supports_fwgrad_bwgrad=True,
12570                   skips=(
12571                       DecorateInfo(unittest.expectedFailure,
12572                                    'TestNNCOpInfo',
12573                                    'test_nnc_correctness',
12574                                    dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
12575                   ),
12576                   supports_sparse=True,
12577                   supports_sparse_csr=True,
12578                   supports_sparse_csc=True,
12579                   supports_sparse_bsr=True,
12580                   supports_sparse_bsc=True,
12581                   assert_autodiffed=True),
12582    OpInfo('cholesky',
12583           dtypes=floating_and_complex_types(),
12584           sample_inputs_func=sample_inputs_linalg_cholesky,
12585           gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
12586           decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],),
12587    OpInfo('cholesky_inverse',
12588           dtypes=floating_and_complex_types(),
12589           backward_dtypes=floating_and_complex_types(),
12590           # https://github.com/pytorch/pytorch/issues/80411
12591           gradcheck_fast_mode=True,
12592           supports_fwgrad_bwgrad=True,
12593           supports_forward_ad=True,
12594           check_batched_gradgrad=True,
12595           sample_inputs_func=sample_inputs_linalg_cholesky_inverse,
12596           gradcheck_wrapper=gradcheck_wrapper_triangular_input_real_positive_diagonal,
12597           decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
12598           skips=(
12599               # Strides are not the same! Original strides were ((4, 2, 1),) and strides are now ((4, 1, 2),)
12600               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),)),
12601    OpInfo('cholesky_solve',
12602           op=torch.cholesky_solve,
12603           dtypes=floating_and_complex_types(),
12604           sample_inputs_func=sample_inputs_cholesky_solve,
12605           check_batched_gradgrad=False,
12606           supports_forward_ad=True,
12607           supports_fwgrad_bwgrad=True,
12608           gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs),
12609           decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]),
12610    OpInfo('chunk',
12611           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
12612           sample_inputs_func=sample_inputs_chunk,
12613           reference_inputs_func=reference_inputs_chunk,
12614           supports_forward_ad=True,
12615           supports_fwgrad_bwgrad=True,
12616           supports_out=False),
12617    OpInfo('unsafe_chunk',
12618           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
12619           sample_inputs_func=sample_inputs_chunk,
12620           check_batched_forward_grad=False,
12621           reference_inputs_func=reference_inputs_chunk,
12622           supports_forward_ad=True,
12623           supports_fwgrad_bwgrad=True,
12624           supports_out=False),
12625    OpInfo('clone',
12626           ref=np.copy,
12627           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
12628           sample_inputs_func=sample_inputs_clone_contiguous,
12629           reference_inputs_func=reference_inputs_clone_contiguous,
12630           supports_forward_ad=True,
12631           supports_fwgrad_bwgrad=True,
12632           supports_out=False,
12633           skips=(
12634               # TypeError: _copy_dispatcher() got an unexpected keyword argument 'memory_format'
12635               # (NumPy reference needs to be extended with memory_format)
12636               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref'),
12637               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'),
12638           ),),
12639    OpInfo('contiguous',
12640           op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs),
12641           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
12642           sample_inputs_func=sample_inputs_clone_contiguous,
12643           reference_inputs_func=reference_inputs_clone_contiguous,
12644           supports_forward_ad=True,
12645           supports_fwgrad_bwgrad=True,
12646           autodiff_fusible_nodes=['aten::contiguous'],
12647           assert_jit_shape_analysis=True,
12648           supports_out=False,
12649           skips=(
12650               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
12651           )),
12652    OpInfo('sum_to_size',
12653           op=lambda x, *args, **kwargs: x.sum_to_size(*args, **kwargs),
12654           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
12655           sample_inputs_func=sample_inputs_sum_to_size,
12656           error_inputs_func=error_inputs_sum_to_size,
12657           supports_forward_ad=True,
12658           supports_fwgrad_bwgrad=True,
12659           supports_out=False,
12660           skips=(
12661               # lambda impl
12662               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
12663               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float,)),
12664           )),
12665    OpInfo('clamp',
12666           aliases=('clip',),
12667           ref=_clamp_numpy,
12668           dtypes=all_types_and(torch.bfloat16, torch.half),
12669           sample_inputs_func=sample_inputs_clamp,
12670           reference_inputs_func=partial(reference_inputs_elementwise_ternary, sample_inputs_func=sample_inputs_clamp),
12671           assert_autodiffed=True,
12672           supports_forward_ad=True,
12673           supports_fwgrad_bwgrad=True,
12674           skips=(
12675               # NNC appear to not handle boolean clamp
12676               DecorateInfo(unittest.expectedFailure,
12677                            'TestNNCOpInfo',
12678                            'test_nnc_correctness',
12679                            dtypes=(torch.bool,)),
12680               # MPS does not support float64, while numpy does internal computations in float64.
12681               # See https://github.com/pytorch/pytorch/blob/3c1cf03fde145bdbe1f5ffb81765d076c10b4c04/test/test_ops.py#L260-L264
12682               DecorateInfo(unittest.expectedFailure,
12683                            'TestCommon',
12684                            'test_numpy_ref_mps'),
12685           )),
12686    UnaryUfuncInfo('positive',
12687                   ref=np.positive,
12688                   dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
12689                   supports_out=False,
12690                   supports_forward_ad=True,
12691                   supports_fwgrad_bwgrad=True,
12692                   supports_sparse=True,
12693                   supports_sparse_csr=True,
12694                   supports_sparse_csc=True,
12695                   supports_sparse_bsr=True,
12696                   supports_sparse_bsc=True,
12697                   ),
12698    UnaryUfuncInfo('conj',
12699                   ref=np.conj,
12700                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
12701                                                    torch.half, torch.chalf),
12702                   supports_sparse=True,
12703                   supports_forward_ad=True,
12704                   supports_fwgrad_bwgrad=True,
12705                   # See https://github.com/pytorch/pytorch/pull/78358
12706                   check_batched_forward_grad=False,
12707                   supports_out=False),
12708    UnaryUfuncInfo('conj_physical',
12709                   decomp_aten_name='_conj_physical',
12710                   ref=np.conj,
12711                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
12712                                                    torch.half, torch.chalf),
12713                   supports_forward_ad=True,
12714                   supports_fwgrad_bwgrad=True,
12715                   supports_sparse=True,
12716                   supports_sparse_csr=True,
12717                   supports_sparse_csc=True,
12718                   supports_sparse_bsr=True,
12719                   supports_sparse_bsc=True,
12720                   skips=(
12721                       # RuntimeError: inputSet && outputSet
12722                       # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":118,
12723                       # please report a bug to PyTorch.
12724                       DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, )),
12725                       DecorateInfo(unittest.skip("Skipped! conj_physical_ not implemented for sparse"),
12726                                    'TestSparseUnaryUfuncs', 'test_inplace'),
12727                   )),
12728    OpInfo('resolve_conj',
12729           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
12730           sample_inputs_func=sample_inputs_view_as_real,
12731           supports_forward_ad=True,
12732           supports_fwgrad_bwgrad=True,
12733           supports_out=False,
12734           ),
12735    OpInfo('resolve_neg',
12736           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
12737           sample_inputs_func=sample_inputs_view_as_real,
12738           supports_forward_ad=True,
12739           supports_fwgrad_bwgrad=True,
12740           supports_out=False,
12741           ),
12742    OpInfo('view_as_real',
12743           dtypes=complex_types(),
12744           supports_forward_ad=True,
12745           supports_out=False,
12746           supports_fwgrad_bwgrad=True,
12747           sample_inputs_func=sample_inputs_view_as_real,
12748           test_conjugated_samples=False,
12749           ),
12750    OpInfo('view_as_complex',
12751           dtypes=floating_types_and(torch.half),
12752           supports_out=False,
12753           supports_forward_ad=True,
12754           supports_fwgrad_bwgrad=True,
12755           test_neg_view=False,
12756           sample_inputs_func=sample_inputs_view_as_complex,
12757           skips=(
12758               # RuntimeError: Tensor must have a last dimension with stride 1
12759               DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"),
12760               # RuntimeError: "eq_cpu" not implemented for 'ComplexHalf'
12761               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.half,)),
12762               # RuntimeError: view size is not compatible with input tensor's size and stride
12763               DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
12764           )),
12765    BinaryUfuncInfo('complex',
12766                    dtypes=floating_types_and(torch.half),
12767                    supports_forward_ad=True,
12768                    supports_fwgrad_bwgrad=True,
12769                    supports_rhs_python_scalar=False,
12770                    error_inputs_func=error_inputs_complex,
12771                    skips=(
12772                        # Tests don't account for complex's type promotion semantics
12773                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
12774                        DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps'),
12775                        DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),)),
12776    BinaryUfuncInfo('copysign',
12777                    dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
12778                    promotes_int_to_float=True,
12779                    # https://github.com/pytorch/pytorch/issues/80411
12780                    gradcheck_fast_mode=True,
12781                    supports_forward_ad=True,
12782                    supports_fwgrad_bwgrad=True),
12783    OpInfo('corrcoef',
12784           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
12785           sample_inputs_func=sample_inputs_corrcoef,
12786           supports_forward_ad=True,
12787           supports_fwgrad_bwgrad=True,
12788           # See https://github.com/pytorch/pytorch/pull/78358
12789           check_batched_forward_grad=False,
12790           skips=(
12791               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
12792               DecorateInfo(
12793                   unittest.skip("Skipped!"),
12794                   'TestSchemaCheckModeOpInfo',
12795                   'test_schema_correctness',
12796                   dtypes=(torch.complex64, torch.complex128)),
12797           ),
12798           supports_out=False),
12799    UnaryUfuncInfo('cos',
12800                   ref=np.cos,
12801                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
12802                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
12803                   assert_autodiffed=True,
12804                   handles_large_floats=False,
12805                   supports_forward_ad=True,
12806                   supports_fwgrad_bwgrad=True,
12807                   promotes_int_to_float=True,
12808                   decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
12809                   skips=(
12810                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12811                                    dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
12812                       # This fails on CUDA but passes on ROCm
12813                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12814                                    dtypes=(torch.cdouble,), device_type='cuda'),
12815                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12816                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
12817                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12818                                    device_type='cpu',
12819                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
12820                       # AssertionError: Tensor-likes are not close!
12821                       # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed)
12822                       # Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
12823                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
12824                                    device_type='cuda',
12825                                    dtypes=(torch.chalf,), active_if=IS_WINDOWS),
12826                   )),
12827    UnaryUfuncInfo('cosh',
12828                   ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh),
12829                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
12830                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
12831                   assert_autodiffed=True,
12832                   supports_forward_ad=True,
12833                   supports_fwgrad_bwgrad=True,
12834                   promotes_int_to_float=True,
12835                   skips=(
12836                       # Reference: https://github.com/pytorch/pytorch/issues/48641
12837                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12838                                    device_type='cpu', dtypes=[torch.int8]),
12839                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12840                                    dtypes=[torch.cdouble]),
12841                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12842                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
12843                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12844                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
12845                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
12846                                    device_type='cpu',
12847                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
12848                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
12849                                    device_type='cpu',
12850                                    dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
12851                       # AssertionError: Tensor-likes are not close!
12852                       # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed)
12853                       # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed)
12854                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
12855                                    device_type='cuda',
12856                                    dtypes=(torch.chalf,), active_if=IS_WINDOWS),
12857                   )),
12858    OpInfo('cov',
12859           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
12860           sample_inputs_func=sample_inputs_cov,
12861           error_inputs_func=error_inputs_cov,
12862           supports_out=False,
12863           supports_forward_ad=True,
12864           supports_fwgrad_bwgrad=True,
12865           skips=(
12866               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
12867               DecorateInfo(
12868                   unittest.skip("Skipped!"),
12869                   'TestSchemaCheckModeOpInfo',
12870                   'test_schema_correctness',
12871                   dtypes=(torch.complex64, torch.complex128)),
12872               # Float did not match double
12873               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'),
12874               # Jacobian mismatch
12875               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'),
12876               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
12877               DecorateInfo(unittest.skip("Barely fails"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
12878               # JIT test not working for tensor kwargs (https://github.com/pytorch/pytorch/issues/58507)
12879               # RuntimeError:
12880               # undefined value tensor:
12881               #   File "<string>", line 3
12882               # def the_method(i0):
12883               #     return torch.cov(i0, correction=0, fweights=None, aweights=tensor([0.0518, 0.4681], dtype=torch.float32, requires_grad=True)) # noqa: B950
12884               #                                                                ~~~~~~ <--- HERE
12885               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
12886               DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=1.4e-3)}),
12887                            "TestInductorOpInfo", "test_comprehensive", device_type="cpu"),
12888           )),
12889    OpInfo('cross',
12890           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
12891           sample_inputs_func=sample_inputs_cross,
12892           supports_fwgrad_bwgrad=True,
12893           supports_out=True,
12894           supports_forward_ad=True),
12895    OpInfo('cumsum',
12896           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
12897           supports_forward_ad=True,
12898           supports_fwgrad_bwgrad=True,
12899           skips=(
12900               # cumsum does not handle correctly out= dtypes
12901               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
12902           ),
12903           sample_inputs_func=sample_inputs_cumulative_ops),
12904    OpInfo('cumprod',
12905           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
12906           supports_forward_ad=True,
12907           supports_fwgrad_bwgrad=True,
12908           skips=(
12909               # cumprod does not handle correctly out= dtypes
12910               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
12911           ),
12912           # gradgradcheck fails in fast_mode=True: #56275
12913           sample_inputs_func=sample_inputs_cumprod,
12914           gradcheck_fast_mode=False),
12915    OpInfo('cummax',
12916           dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
12917           sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
12918           supports_forward_ad=True,
12919           supports_fwgrad_bwgrad=True,
12920           skips=(
12921           ),
12922           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
12923    OpInfo('cummin',
12924           dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
12925           sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
12926           supports_forward_ad=True,
12927           supports_fwgrad_bwgrad=True,
12928           skips=(
12929           ),
12930           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
12931    UnaryUfuncInfo('deg2rad',
12932                   ref=np.radians,
12933                   decorators=(precisionOverride({torch.bfloat16: 7e-1,
12934                                                  torch.float16: 7e-1}),),
12935                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
12936                   supports_forward_ad=True,
12937                   supports_fwgrad_bwgrad=True,
12938                   supports_sparse=True,
12939                   supports_sparse_csr=True,
12940                   supports_sparse_csc=True,
12941                   supports_sparse_bsr=True,
12942                   supports_sparse_bsc=True,
12943                   promotes_int_to_float=True),
12944    OpInfo('diff',
12945           op=torch.diff,
12946           # np.diff has np._NoValue as default values for prepend and append, compare_with_reference breaks if prepend/append
12947           # are set as None when converting to numpy
12948           ref=lambda input, n=1, dim=-1, prepend=np._NoValue, append=np._NoValue: (
12949               np.diff(input, n, dim, np._NoValue if prepend is None else prepend, np._NoValue if append is None else append)
12950           ),
12951           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
12952           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
12953           gradcheck_fast_mode=True,
12954           supports_forward_ad=True,
12955           supports_fwgrad_bwgrad=True,
12956           sample_inputs_func=sample_inputs_diff,
12957           error_inputs_func=error_inputs_diff,
12958           # See https://github.com/pytorch/pytorch/pull/78358
12959           check_batched_forward_grad=False,
12960           skips=(
12961           )),
12962    BinaryUfuncInfo('div',
12963                    aliases=('divide',),
12964                    variant_test_name='no_rounding_mode',
12965                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
12966                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
12967                    # Runs very slowly on slow gradcheck - alternatively reduce input sizes
12968                    gradcheck_fast_mode=True,
12969                    supports_forward_ad=True,
12970                    promotes_int_to_float=True,
12971                    supports_fwgrad_bwgrad=True,
12972                    supports_two_python_scalars=True,
12973                    assert_autodiffed=True,
12974                    rhs_make_tensor_kwargs=dict(exclude_zero=True),),
12975    BinaryUfuncInfo('div',
12976                    aliases=('divide',),
12977                    variant_test_name='trunc_rounding',
12978                    dtypes=all_types_and(torch.half, torch.bfloat16),
12979                    sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="trunc")),
12980                    # https://github.com/pytorch/pytorch/issues/80411
12981                    gradcheck_fast_mode=True,
12982                    supports_forward_ad=True,
12983                    supports_fwgrad_bwgrad=True,
12984                    supports_two_python_scalars=True,
12985                    assert_autodiffed=True,
12986                    rhs_make_tensor_kwargs=dict(exclude_zero=True),
12987                    decorators=(
12988                        # See https://github.com/pytorch/pytorch/issues/111126
12989                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
12990                    ),
12991                    skips=(
12992                        # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div
12993                        DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'),
12994                        # FIXME:
12995                        # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
12996                        # output 0 with respect to input 1,
12997                        # numerical:tensor(-17746.9307, dtype=torch.float64)
12998                        # analytical:tensor(0., dtype=torch.float64)
12999                        DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
13000                                     'test_fn_grad', device_type='cpu',
13001                                     dtypes=(torch.float64,)),
13002                    )),
13003    BinaryUfuncInfo('div',
13004                    aliases=('divide',),
13005                    variant_test_name='floor_rounding',
13006                    dtypes=all_types_and(torch.half, torch.bfloat16),
13007                    sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="floor")),
13008                    # https://github.com/pytorch/pytorch/issues/80411
13009                    gradcheck_fast_mode=True,
13010                    supports_forward_ad=True,
13011                    supports_fwgrad_bwgrad=True,
13012                    supports_two_python_scalars=True,
13013                    assert_autodiffed=True,
13014                    rhs_make_tensor_kwargs=dict(exclude_zero=True),
13015                    decorators=(
13016                        # See https://github.com/pytorch/pytorch/issues/111126
13017                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
13018                    ),
13019                    skips=(
13020                        # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div
13021                        DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'),
13022                        # FIXME:
13023                        # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
13024                        # output 0 with respect to input 1,
13025                        # numerical:tensor(-17746.9307, dtype=torch.float64)
13026                        # analytical:tensor(0., dtype=torch.float64)
13027                        DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
13028                                     'test_fn_grad',
13029                                     dtypes=(torch.float64,),
13030                                     device_type='cpu'),
13031                    )),
13032    BinaryUfuncInfo('true_divide',
13033                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13034                    dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
13035                    supports_forward_ad=True,
13036                    promotes_int_to_float=True,
13037                    supports_fwgrad_bwgrad=True,
13038                    supports_two_python_scalars=True,
13039                    rhs_make_tensor_kwargs=dict(exclude_zero=True)),
13040    OpInfo('equal',
13041           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
13042           ref=lambda input, other: (input == other).all(),
13043           sample_inputs_func=sample_inputs_equal,
13044           supports_autograd=False,
13045           supports_tracing=False,
13046           skips=(
13047           )),
13048    UnaryUfuncInfo('exp',
13049                   ref=np_unary_ufunc_integer_promotion_wrapper(np.exp),
13050                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13051                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
13052                   skips=(
13053                       # Reference: https://github.com/pytorch/pytorch/issues/48010
13054                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
13055                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
13056                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
13057                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
13058                   ),
13059                   assert_autodiffed=True,
13060                   supports_forward_ad=True,
13061                   supports_fwgrad_bwgrad=True,
13062                   promotes_int_to_float=True),
13063    OpInfo('expand',
13064           op=lambda self, shape: self.expand(shape),
13065           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13066           sample_inputs_func=sample_inputs_expand,
13067           supports_forward_ad=True,
13068           supports_fwgrad_bwgrad=True,
13069           assert_jit_shape_analysis=True,
13070           supports_out=False,
13071           skips=(
13072               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
13073           )),
13074    OpInfo('expand_as',
13075           op=lambda self, other: self.expand_as(other),
13076           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13077           supports_forward_ad=True,
13078           supports_fwgrad_bwgrad=True,
13079           sample_inputs_func=sample_inputs_expand_as,
13080           supports_out=False,
13081           skips=(
13082               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),),
13083           ),
13084    OpInfo('expand_copy',
13085           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13086           sample_inputs_func=sample_inputs_expand,
13087           supports_forward_ad=True,
13088           supports_fwgrad_bwgrad=True,
13089           assert_jit_shape_analysis=True,
13090           supports_out=True,
13091           skips=(
13092               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
13093           )),
13094    OpInfo('diag',
13095           ref=np.diag,
13096           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
13097           dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
13098           supports_forward_ad=True,
13099           supports_fwgrad_bwgrad=True,
13100           check_batched_forward_grad=False,
13101           sample_inputs_func=sample_inputs_diag,
13102           error_inputs_func=error_inputs_diag),
13103    OpInfo('diag_embed',
13104           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
13105           supports_out=False,
13106           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
13107           gradcheck_fast_mode=True,
13108           supports_forward_ad=True,
13109           supports_fwgrad_bwgrad=True,
13110           sample_inputs_func=sample_inputs_diagonal_diag_embed,
13111           reference_inputs_func=reference_inputs_diagonal_diag_embed,
13112           error_inputs_func=error_inputs_diagonal_diag_embed),
13113    OpInfo('diagonal',
13114           aten_backward_name='diagonal_backward',
13115           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
13116           supports_out=False,
13117           supports_forward_ad=True,
13118           supports_fwgrad_bwgrad=True,
13119           sample_inputs_func=sample_inputs_diagonal_diag_embed,
13120           reference_inputs_func=reference_inputs_diagonal_diag_embed,
13121           error_inputs_func=error_inputs_diagonal_diag_embed),
13122    OpInfo('diagonal_copy',
13123           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
13124           supports_forward_ad=True,
13125           supports_fwgrad_bwgrad=True,
13126           sample_inputs_func=sample_inputs_diagonal_diag_embed,
13127           reference_inputs_func=reference_inputs_diagonal_diag_embed,
13128           error_inputs_func=error_inputs_diagonal_diag_embed),
13129    OpInfo('diagonal_scatter',
13130           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
13131           supports_out=False,
13132           supports_forward_ad=True,
13133           supports_fwgrad_bwgrad=True,
13134           sample_inputs_func=sample_inputs_diagonal_scatter),
13135    OpInfo('alias_copy',
13136           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
13137           sample_inputs_func=sample_inputs_alias_copy,
13138           supports_forward_ad=True,
13139           supports_fwgrad_bwgrad=True,
13140           supports_out=True),
13141    BinaryUfuncInfo('eq',
13142                    ref=np.equal,
13143                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
13144                    always_returns_bool=True,
13145                    supports_autograd=False,
13146                    sample_inputs_func=sample_inputs_comparison_ops,
13147                    skips=(
13148                    )),
13149    BinaryUfuncInfo('fmax',
13150                    op=torch.fmax,
13151                    dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
13152                    supports_forward_ad=True,
13153                    supports_fwgrad_bwgrad=True,
13154                    supports_rhs_python_scalar=False,
13155                    skips=(
13156                        # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
13157                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
13158                    )),
13159    BinaryUfuncInfo('fmin',
13160                    op=torch.fmin,
13161                    dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
13162                    supports_forward_ad=True,
13163                    supports_fwgrad_bwgrad=True,
13164                    supports_rhs_python_scalar=False,
13165                    skips=(
13166                        # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
13167                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
13168                    )),
13169    BinaryUfuncInfo('fmod',
13170                    ref=np.fmod,
13171                    dtypes=all_types_and(torch.float16, torch.bfloat16),
13172                    dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
13173                    # https://github.com/pytorch/pytorch/issues/80411
13174                    gradcheck_fast_mode=True,
13175                    supports_forward_ad=True,
13176                    supports_fwgrad_bwgrad=True,
13177                    assert_autodiffed=None,
13178                    rhs_make_tensor_kwargs={'exclude_zero': True},
13179                    decorators=(
13180                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
13181                                     'test_contig_vs_every_other',
13182                                     dtypes=(torch.bfloat16,)),
13183                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
13184                                     'test_non_contig',
13185                                     dtypes=(torch.bfloat16,)),
13186                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
13187                                     'test_reference_numerics',
13188                                     dtypes=(torch.bfloat16,)),
13189                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
13190                                     'test_reference_numerics_small_values',
13191                                     dtypes=(torch.uint8,)),
13192                        # FIXME:
13193                        # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
13194                        # output 0 with respect to input 1,
13195                        # numerical:tensor(101.6283, dtype=torch.float64)
13196                        # analytical:tensor(-18.3575, dtype=torch.float64)
13197                        DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
13198                                     'test_fn_grad',
13199                                     dtypes=(torch.float64,),
13200                                     device_type='cpu'),
13201                    )),
13202    BinaryUfuncInfo('remainder',
13203                    ref=np.remainder,
13204                    dtypes=all_types_and(torch.float16, torch.bfloat16),
13205                    dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
13206                    # https://github.com/pytorch/pytorch/issues/80411
13207                    gradcheck_fast_mode=True,
13208                    supports_forward_ad=True,
13209                    supports_fwgrad_bwgrad=True,
13210                    assert_autodiffed=None,
13211                    operator_variant=operator.mod,
13212                    inplace_operator_variant=operator.imod,
13213                    supports_one_python_scalar=True,
13214                    rhs_make_tensor_kwargs={'exclude_zero': True},
13215                    decorators=(
13216                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
13217                                     'test_contig_vs_every_other',
13218                                     dtypes=(torch.bfloat16,)),
13219                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
13220                                     'test_non_contig',
13221                                     dtypes=(torch.bfloat16,)),
13222                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
13223                                     'test_reference_numerics',
13224                                     dtypes=(torch.bfloat16,)),
13225                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
13226                                     'test_reference_numerics_small_values',
13227                                     dtypes=(torch.uint8,)),
13228                        DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo',
13229                                     'test_nnc_correctness',
13230                                     dtypes=(torch.bfloat16,)),
13231                        # Fails on XLA
13232                        # False is not true : Tensors failed to compare as equal!
13233                        # Attempted to compare equality of tensors with different dtypes
13234                        DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)),
13235                        # FIXME:
13236                        # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
13237                        # output 0 with respect to input 1,
13238                        # numerical:tensor(102.4676, dtype=torch.float64)
13239                        # analytical:tensor(-17.5182, dtype=torch.float64)
13240                        DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
13241                                     'test_fn_grad', device_type='cpu',
13242                                     dtypes=(torch.float64,)),
13243                        DecorateInfo(
13244                            toleranceOverride({
13245                                torch.float16: tol(atol=5e-4, rtol=3e-3),
13246                            }),
13247                            "TestInductorOpInfo",
13248                            "test_comprehensive",
13249                            device_type="cuda"
13250                        ),
13251                    )),
13252    UnaryUfuncInfo('frac',
13253                   ref=lambda x: np.modf(x)[0],
13254                   dtypes=floating_types_and(torch.bfloat16, torch.float16),
13255                   dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
13256                   assert_autodiffed=True,
13257                   supports_forward_ad=True,
13258                   supports_fwgrad_bwgrad=True,
13259                   supports_sparse=True,
13260                   supports_sparse_csr=True,
13261                   supports_sparse_csc=True,
13262                   supports_sparse_bsr=True,
13263                   supports_sparse_bsc=True,
13264                   skips=(
13265                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
13266                                    dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)),
13267                       # 76047
13268                       DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
13269                                    dtypes=(torch.bfloat16, torch.float32, torch.float64)),
13270                   )),
13271    OpInfo('stft',
13272           decorators=[
13273               skipCPUIfNoFFT,
13274               DecorateInfo(unittest.skip("Skipped! stft does not match the native function"),
13275                            'TestJit', 'test_variant_consistency_jit'),
13276           ],
13277           dtypes=floating_and_complex_types(),
13278           sample_inputs_func=sample_inputs_stft,
13279           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
13280           gradcheck_fast_mode=True,
13281           supports_forward_ad=True,
13282           supports_fwgrad_bwgrad=True,
13283           check_batched_forward_grad=False,
13284           check_batched_grad=False,
13285           check_batched_gradgrad=False,
13286           supports_out=False,
13287           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
13288           ),
13289    OpInfo('istft',
13290           dtypes=complex_types(),
13291           sample_inputs_func=sample_inputs_istft,
13292           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
13293           gradcheck_fast_mode=True,
13294           supports_forward_ad=True,
13295           supports_fwgrad_bwgrad=True,
13296           check_batched_forward_grad=False,
13297           check_batched_grad=False,
13298           check_batched_gradgrad=False,
13299           supports_out=False,
13300           decorators=(
13301               DecorateInfo(unittest.skip("Skipped! istft does not match the native function"),
13302                            'TestJit', 'test_variant_consistency_jit'),
13303           ),
13304           skips=(
13305               skipCPUIfNoFFT,
13306               # gradcheck fails on ROCm (gh-68429)
13307               # grad is computed improperly (probably for weights tensor)
13308               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'),
13309               # Pre-existing condition (calls .item); needs to be fixed
13310               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
13311           )),
13312    UnaryUfuncInfo('floor',
13313                   ref=np.floor,
13314                   dtypes=all_types_and(torch.half, torch.bfloat16),
13315                   supports_forward_ad=True,
13316                   supports_fwgrad_bwgrad=True,
13317                   skips=(
13318                       DecorateInfo(unittest.expectedFailure,
13319                                    'TestNNCOpInfo',
13320                                    'test_nnc_correctness',
13321                                    dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
13322                   ),
13323                   supports_sparse=True,
13324                   supports_sparse_csr=True,
13325                   supports_sparse_csc=True,
13326                   supports_sparse_bsr=True,
13327                   supports_sparse_bsc=True,
13328                   assert_autodiffed=True),
13329    OpInfo('flip',
13330           op=torch.flip,
13331           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13332           sample_inputs_func=sample_inputs_flip,
13333           supports_forward_ad=True,
13334           supports_fwgrad_bwgrad=True,
13335           supports_out=False),
13336    OpInfo('fliplr',
13337           op=torch.fliplr,
13338           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13339           sample_inputs_func=sample_inputs_fliplr_flipud,
13340           error_inputs_func=error_inputs_fliplr,
13341           supports_forward_ad=True,
13342           supports_fwgrad_bwgrad=True,
13343           supports_out=False),
13344    OpInfo('flipud',
13345           op=torch.flipud,
13346           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13347           sample_inputs_func=sample_inputs_fliplr_flipud,
13348           error_inputs_func=error_inputs_flipud,
13349           supports_forward_ad=True,
13350           supports_fwgrad_bwgrad=True,
13351           supports_out=False),
13352    OpInfo('sparse.sampled_addmm',
13353           dtypes=floating_and_complex_types(),
13354           supports_autograd=True,
13355           sample_inputs_func=sample_inputs_sparse_sampled_addmm,
13356           decorators=[
13357               skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3))
13358                               or (_get_torch_rocm_version() >= (5, 2))),
13359                          "cusparseSDDMM was added in 11.2.1"),
13360               skipCPUIfNoMklSparse, ],
13361           skips=(
13362               # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
13363               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
13364               # RuntimeError: Sparse CSR tensors do not have strides.
13365               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
13366               DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
13367               # RuntimeError: sampled_addmm: Expected result to have sparse csr layout, but got Strided
13368               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),
13369               # RuntimeError: Sparse CSR tensors do not have strides
13370               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
13371               # RuntimeError: Sparse CSR tensors do not have strides
13372               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'),
13373               # RuntimeError: Sparse CSR tensors do not have strides
13374               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'),
13375               # RuntimeError: Sparse CSR tensors do not have strides
13376               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
13377               # RuntimeError: Sparse CSR tensors do not have strides
13378               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
13379               # RuntimeError: Sparse CSR tensors do not have strides
13380               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
13381               # RuntimeError: Sparse CSR tensors do not have strides
13382               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13383               # RuntimeError: unsupported memory format option Preserve
13384               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13385               # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype
13386               # RuntimeError: Sparse CSR tensors do not have strides
13387               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
13388               # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
13389               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
13390               # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype.
13391               # RuntimeError: Sparse CSR tensors do not have is_contiguous
13392               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
13393               # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
13394               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
13395               # NotImplementedError: Could not run 'aten::sparse_sampled_addmm' with arguments from the 'SparseCsrMeta' backend.
13396               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
13397               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
13398               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
13399               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
13400               DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
13401           )),
13402    OpInfo('sparse.mm',
13403           dtypes=floating_types_and(torch.bfloat16, torch.float16),
13404           variant_test_name='reduce',
13405           supports_autograd=True,
13406           supports_out=False,
13407           supports_gradgrad=False,
13408           supports_forward_ad=False,
13409           sample_inputs_func=sample_inputs_sparse_mm_reduce,
13410           decorators=[onlyCPU],
13411           skips=(
13412               # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
13413               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
13414               # RuntimeError: Sparse CSR tensors do not have strides.
13415               DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
13416               # RuntimeError: Sparse CSR tensors do not have strides
13417               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
13418               # RuntimeError: Sparse CSR tensors do not have strides
13419               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'),
13420               # RuntimeError: Sparse CSR tensors do not have strides
13421               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'),
13422               # RuntimeError: Sparse CSR tensors do not have strides
13423               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
13424               # RuntimeError: Sparse CSR tensors do not have strides
13425               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
13426               # RuntimeError: Sparse CSR tensors do not have strides
13427               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
13428               # RuntimeError: Sparse CSR tensors do not have strides
13429               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13430               # RuntimeError: unsupported memory format option Preserve
13431               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13432               # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
13433               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
13434               # RuntimeError: Sparse CSR tensors do not have is_contiguou
13435               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
13436               # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
13437               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
13438               # RuntimeError: Sparse CSR tensors do not have strides
13439               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
13440               # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
13441               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'),
13442               # NotImplementedError: Could not run 'aten::_sparse_mm_reduce_impl' with arguments from the 'SparseCsrMeta' backend
13443               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
13444               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
13445               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
13446           )),
13447    UnaryUfuncInfo('i0',
13448                   ref=np_unary_ufunc_integer_promotion_wrapper(
13449                       scipy.special.i0) if TEST_SCIPY else None,
13450                   aliases=('special.i0',),
13451                   decorators=(precisionOverride({torch.bfloat16: 3e-1,
13452                                                  torch.float16: 5e-1}),),
13453                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
13454                   backward_dtypes=floating_types(),
13455                   supports_forward_ad=True,
13456                   supports_fwgrad_bwgrad=True,
13457                   promotes_int_to_float=True,
13458                   sample_inputs_func=sample_inputs_i0_i1,
13459                   skips=(
13460                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
13461                                    dtypes=(torch.int8,)),
13462                   )),
13463    BinaryUfuncInfo('floor_divide',
13464                    ref=_floor_divide_np,
13465                    dtypes=all_types_and(torch.half, torch.bfloat16),
13466                    supports_autograd=False,
13467                    rhs_make_tensor_kwargs=dict(exclude_zero=True),
13468                    supports_two_python_scalars=True,
13469                    skips=(
13470                        # AssertionError: Results of original model and exported/imported version of model differed
13471                        DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
13472                        # bfloat16 floor_divide compared with a float32 reference works inconsistently
13473                        DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs',
13474                                     dtypes=(torch.bfloat16,)),
13475                        # int8 floor divide has different results for -128 // -1 vs. NumPy
13476                        DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
13477                                     dtypes=(torch.int8,)),
13478                        # The following tests fails on some jobs
13479                        DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
13480                                     dtypes=(torch.float16,)),
13481                        DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}),
13482                                     'TestBinaryUfuncs', 'test_reference_numerics'),
13483                    )),
13484    UnaryUfuncInfo('frexp',
13485                   op=torch.frexp,
13486                   ref=np.frexp,
13487                   dtypes=floating_types_and(torch.half, torch.bfloat16),
13488                   decorators=[],
13489                   supports_forward_ad=True,
13490                   supports_fwgrad_bwgrad=True,
13491                   skips=(
13492                       # skips below tests as torch.frexp returns tuple-like (mantissa, exponent) as outputs,
13493                       # while theses tests currently requires output to a single tensor.
13494                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'),
13495                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'),
13496                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'),
13497                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_non_contig_expand'),
13498                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency'),
13499                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
13500
13501                       # skips test_reference_numerics due to error in Windows CI.
13502                       # The np.frexp returns exponent as np.intc dtype on Windows platform,
13503                       # and np.intc does not have the correspond torch dtype
13504                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
13505                                    active_if=IS_WINDOWS),
13506                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
13507                                    active_if=IS_WINDOWS),
13508                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
13509                                    active_if=IS_WINDOWS),
13510                   )),
13511    UnaryUfuncInfo('log1p',
13512                   ref=np.log1p,
13513                   aliases=('special.log1p',),
13514                   domain=(-1, None),
13515                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13516                   decorators=(precisionOverride({torch.bfloat16: 1e-1}),),
13517                   supports_forward_ad=True,
13518                   supports_fwgrad_bwgrad=True,
13519                   supports_sparse=True,
13520                   supports_sparse_csr=True,
13521                   supports_sparse_csc=True,
13522                   supports_sparse_bsr=True,
13523                   supports_sparse_bsc=True,
13524                   assert_autodiffed=True,
13525                   promotes_int_to_float=True),
13526    BinaryUfuncInfo('ge',
13527                    ref=np.greater_equal,
13528                    aliases=('greater_equal',),
13529                    dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
13530                    always_returns_bool=True,
13531                    supports_autograd=False,
13532                    skips=(
13533                    )),
13534    OpInfo('geqrf',
13535           dtypes=floating_and_complex_types(),
13536           sample_inputs_func=sample_inputs_linalg_qr_geqrf,
13537           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
13538           supports_autograd=False,
13539           skips=(
13540               # FIXME: geqrf can't forward with complex inputs that require grad
13541               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
13542               # Strides are not the same!
13543               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
13544           )),
13545    BinaryUfuncInfo('gt',
13546                    ref=np.greater,
13547                    aliases=('greater',),
13548                    dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
13549                    always_returns_bool=True,
13550                    supports_autograd=False,
13551                    skips=(
13552                    )),
13553    UnaryUfuncInfo('imag',
13554                   ref=np.imag,
13555                   dtypes=complex_types_and(torch.chalf),
13556                   supports_out=False,
13557                   supports_forward_ad=True,
13558                   supports_fwgrad_bwgrad=True,
13559                   # See https://github.com/pytorch/pytorch/issues/66357
13560                   # RuntimeError: view_as_real doesn't work on unresolved conjugated tensors.
13561                   check_batched_forward_grad=False,
13562                   skips=(
13563                       # Skip since real and imag don't have out variants.
13564                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
13565                   )),
13566    OpInfo('gradient',
13567           dtypes=floating_and_complex_types_and(torch.int8, torch.int16,
13568                                                 torch.int32, torch.int64,
13569                                                 torch.bfloat16, torch.half),
13570           supports_out=False,
13571           supports_forward_ad=True,
13572           supports_fwgrad_bwgrad=True,
13573           # See https://github.com/pytorch/pytorch/pull/78358
13574           check_batched_forward_grad=False,
13575           skips=(
13576               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
13577               # following tests give a runtime error with undefined value tensor
13578               # see discussion : https://github.com/pytorch/pytorch/issues/56660
13579               # RuntimeError:
13580               # Arguments for call are not valid.
13581               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)),  # noqa: B950
13582               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
13583               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
13584           ),
13585           supports_inplace_autograd=False,
13586           sample_inputs_func=sample_inputs_gradient,
13587           error_inputs_func=error_inputs_gradient),
13588    OpInfo('isin',
13589           dtypes=all_types(),
13590           dtypesIfCUDA=all_types_and(torch.half),
13591           supports_autograd=False,
13592           sample_inputs_func=sample_inputs_isin),
13593    OpInfo('kthvalue',
13594           dtypes=all_types_and(torch.bfloat16, torch.float16),
13595           supports_forward_ad=True,
13596           supports_fwgrad_bwgrad=True,
13597           sample_inputs_func=sample_inputs_kthvalue,
13598           error_inputs_func=error_inputs_kthvalue),
13599    BinaryUfuncInfo('le',
13600                    ref=np.less_equal,
13601                    aliases=('less_equal',),
13602                    dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
13603                    always_returns_bool=True,
13604                    supports_autograd=False,
13605                    skips=(
13606                    )),
13607    OpInfo('linspace',
13608           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
13609           is_factory_function=True,
13610           supports_out=True,
13611           supports_autograd=False,
13612           error_inputs_func=error_inputs_linspace,
13613           sample_inputs_func=sample_inputs_linspace,
13614           skips=(
13615               # FX failed to normalize op - add the op to the op_skip list.
13616               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
13617               # Tests that assume input is a tensor or sequence of tensors
13618               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
13619               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
13620               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
13621               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
13622
13623               # Same failure as arange: cannot find linspace in captured graph
13624               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
13625
13626               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13627               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
13628               # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
13629               # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
13630               # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
13631               # CUDA driver allocated memory was 1254555648 and is now 1242955776.
13632               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
13633                            dtypes=(torch.cfloat,), device_type="cuda"),
13634           )),
13635    OpInfo('linspace',
13636           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
13637           is_factory_function=True,
13638           supports_out=True,
13639           supports_autograd=False,
13640           error_inputs_func=error_inputs_linspace,
13641           sample_inputs_func=sample_inputs_linspace_tensor_overload,
13642           variant_test_name="tensor_overload",
13643           skips=(
13644               # FX failed to normalize op - add the op to the op_skip list.
13645               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
13646               # TypeError: 'int' object is not subscriptable
13647               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
13648               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
13649               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
13650
13651               # Same failure as arange: cannot find linspace in captured graph
13652               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
13653
13654               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13655               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
13656               # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
13657               # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
13658               # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
13659               # CUDA driver allocated memory was 1254555648 and is now 1242955776.
13660               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
13661                            dtypes=(torch.cfloat,), device_type="cuda"),
13662           )),
13663    OpInfo('logspace',
13664           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
13665           is_factory_function=True,
13666           supports_out=True,
13667           supports_autograd=False,
13668           error_inputs_func=error_inputs_linspace,
13669           sample_inputs_func=sample_inputs_logspace,
13670           skips=(
13671               # FX failed to normalize op - add the op to the op_skip list.
13672               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
13673               # Tests that assume input is a tensor or sequence of tensors
13674               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
13675               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
13676               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
13677               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
13678               # Same failure as arange: cannot find linspace in captured graph
13679               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
13680
13681               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13682               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
13683
13684               # Off-by-one issue when casting floats to ints
13685               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick',
13686                            dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
13687               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive',
13688                            dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
13689               # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
13690               # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
13691               # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
13692               # CUDA driver allocated memory was 1254555648 and is now 1242955776.
13693               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
13694                            dtypes=(torch.cfloat,), device_type="cuda"),
13695           )),
13696    OpInfo('logspace',
13697           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
13698           is_factory_function=True,
13699           supports_out=True,
13700           supports_autograd=False,
13701           error_inputs_func=error_inputs_linspace,
13702           sample_inputs_func=sample_inputs_logspace_tensor_overload,
13703           variant_test_name="tensor_overload",
13704           skips=(
13705               # FX failed to normalize op - add the op to the op_skip list.
13706               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
13707               # TypeError: 'int' object is not subscriptable
13708               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
13709               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
13710               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
13711               # Same failure as arange: cannot find linspace in captured graph
13712               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
13713
13714               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13715               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
13716
13717               # Off-by-one issue when casting floats to ints
13718               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick',
13719                            dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
13720               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive',
13721                            dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"),
13722               # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API
13723               # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64!
13724               # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0.
13725               # CUDA driver allocated memory was 1254555648 and is now 1242955776.
13726               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
13727                            dtypes=(torch.cfloat,), device_type="cuda"),
13728           )),
13729    UnaryUfuncInfo('log',
13730                   ref=np.log,
13731                   domain=(0, None),
13732                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13733                   dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
13734                   backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.chalf),
13735                   assert_autodiffed=True,
13736                   supports_forward_ad=True,
13737                   supports_fwgrad_bwgrad=True,
13738                   promotes_int_to_float=True,
13739                   decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
13740                   skips=(
13741                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
13742                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
13743                                    active_if=IS_WINDOWS),
13744                   ),
13745                   # log(z)->-inf for |z|->0
13746                   reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)),
13747    UnaryUfuncInfo('log10',
13748                   ref=np.log10,
13749                   domain=(0, None),
13750                   decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
13751                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13752                   assert_autodiffed=True,
13753                   supports_forward_ad=True,
13754                   supports_fwgrad_bwgrad=True,
13755                   promotes_int_to_float=True,
13756                   skips=(
13757                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
13758                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
13759                                    active_if=IS_WINDOWS),
13760                   ),
13761                   # log10(z)->-inf for |z|->0
13762                   reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)),
13763    UnaryUfuncInfo('log2',
13764                   ref=np.log2,
13765                   domain=(0, None),
13766                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13767                   assert_autodiffed=True,
13768                   supports_forward_ad=True,
13769                   supports_fwgrad_bwgrad=True,
13770                   promotes_int_to_float=True,
13771                   decorators=(precisionOverride({torch.bfloat16: 1e-1}),),
13772                   skips=(
13773                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
13774                                    dtypes=[torch.cfloat, torch.cdouble]),
13775                   ),
13776                   # log2(z)->-inf for |z|->0
13777                   reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)),
13778    BinaryUfuncInfo('ldexp',
13779                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13780                    # Runs very slowly on slow gradcheck - alternatively reduce input sizes
13781                    gradcheck_fast_mode=True,
13782                    supports_forward_ad=True,
13783                    supports_fwgrad_bwgrad=True,
13784                    supports_inplace_autograd=False,
13785                    promotes_int_to_float=True,
13786                    supports_out=True,
13787                    supports_rhs_python_scalar=False,
13788                    skips=(
13789                        # RuntimeError: mul(): functions with out=... arguments don't support
13790                        # automatic differentiation, but one of the arguments requires grad
13791                        # https://github.com/pytorch/pytorch/issues/68966
13792                        DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
13793                        DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
13794                        DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
13795                        DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
13796                    ),
13797                    decorators=[
13798                        DecorateInfo(
13799                            toleranceOverride({
13800                                torch.complex64: tol(atol=1e-05, rtol=1e-05)
13801                            }),
13802                            'TestCommon', device_type='cpu',
13803                        ),
13804                    ], ),
13805    BinaryUfuncInfo('logaddexp',
13806                    dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
13807                    dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
13808                    supports_forward_ad=True,
13809                    supports_fwgrad_bwgrad=True,
13810                    supports_rhs_python_scalar=False,
13811                    skips=(
13812                        # TODO: FIXME: RuntimeError: not implemented for 'ComplexFloat'
13813                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'),
13814                    )),
13815    OpInfo('logaddexp2',
13816           dtypes=floating_types_and(torch.bfloat16, torch.half),
13817           supports_forward_ad=True,
13818           supports_fwgrad_bwgrad=True,
13819           sample_inputs_func=sample_inputs_logaddexp),
13820    UnaryUfuncInfo('logical_not',
13821                   ref=np.logical_not,
13822                   decorators=(precisionOverride({torch.bfloat16: 7e-1,
13823                                                  torch.float16: 5e-1}),),
13824                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13825                   supports_autograd=False,
13826                   skips=(
13827                       # The function variant always returns BoolTensor
13828                       # while the inplace variant preserves the input dtype.
13829                       # >>> t = torch.randn(3)
13830                       # >>> torch.logical_not(t)
13831                       # tensor([False, False, False])
13832                       # >>> torch.logical_not(t).dtype
13833                       # torch.bool
13834                       # >>> t.logical_not_().dtype
13835                       # torch.float32
13836                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency',
13837                                    dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)),
13838                       DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
13839                                    dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)),
13840                   )),
13841    BinaryUfuncInfo('lt',
13842                    ref=np.less,
13843                    aliases=('less',),
13844                    dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
13845                    always_returns_bool=True,
13846                    supports_autograd=False,
13847                    skips=(
13848                    )),
13849    OpInfo('lu_unpack',
13850           op=torch.lu_unpack,
13851           dtypes=floating_and_complex_types(),
13852           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
13853           gradcheck_fast_mode=True,
13854           supports_forward_ad=True,
13855           supports_fwgrad_bwgrad=True,
13856           skips=(skipCPUIfNoLapack,),
13857           sample_inputs_func=sample_inputs_lu_unpack),
13858    OpInfo('lu',
13859           op=torch.lu,
13860           dtypes=floating_and_complex_types(),
13861           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
13862           gradcheck_fast_mode=True,
13863           supports_forward_ad=True,
13864           supports_fwgrad_bwgrad=True,
13865           # https://github.com/pytorch/pytorch/issues/66357
13866           check_batched_forward_grad=False,
13867           sample_inputs_func=sample_inputs_lu,
13868           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
13869           skips=(
13870               # we skip jit tests because `lu` is a torch function
13871               # RuntimeError:
13872               # 'Tensor (inferred)' object has no attribute or method 'lu'.:
13873               # File "<string>", line 3
13874               # def the_method(i0):
13875               #     return i0.lu(True, True)
13876               #            ~~~~~ <--- HERE
13877               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
13878               # RuntimeError not raised: Expected RuntimeError when calling with input.device=cpu and out.device=cuda
13879               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
13880               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13881               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
13882           )),
13883    OpInfo('lu_solve',
13884           op=torch.lu_solve,
13885           dtypes=floating_and_complex_types(),
13886           supports_forward_ad=True,
13887           # See https://github.com/pytorch/pytorch/issues/66357
13888           check_batched_forward_grad=False,
13889           supports_fwgrad_bwgrad=True,
13890           sample_inputs_func=sample_inputs_lu_solve,
13891           skips=(
13892               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
13893                            device_type='mps', dtypes=[torch.float32]),
13894               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
13895                            device_type='mps', dtypes=[torch.float32]),
13896               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
13897                            device_type='mps', dtypes=[torch.float32]),
13898               DecorateInfo(unittest.skip("Tests different backward paths"),
13899                            "TestCommon", "test_floating_inputs_are_differentiable"),),
13900           decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver]),
13901    OpInfo('masked_fill',
13902           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
13903           sample_inputs_func=sample_inputs_masked_fill,
13904           error_inputs_func=error_inputs_masked_fill,
13905           supports_forward_ad=True,
13906           supports_fwgrad_bwgrad=True,
13907           check_batched_forward_grad=False,
13908           supports_out=False),
13909    OpInfo('masked_scatter',
13910           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13911           sample_inputs_func=sample_inputs_masked_scatter,
13912           error_inputs_func=error_inputs_masked_scatter,
13913           supports_forward_ad=True,
13914           supports_fwgrad_bwgrad=True,
13915           # https://github.com/pytorch/pytorch/issues/66357
13916           check_batched_forward_grad=False,
13917           supports_out=False,
13918           skips=(
13919           )),
13920    OpInfo('masked_select',
13921           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13922           supports_forward_ad=True,
13923           supports_fwgrad_bwgrad=True,
13924           sample_inputs_func=sample_inputs_masked_select,
13925           error_inputs_func=error_inputs_masked_select,
13926           skips=(
13927               # Compiler issue on ROCm. Might need to skip until ROCm5.5
13928               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
13929                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
13930           )),
13931    OpInfo('matrix_exp',
13932           dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
13933           aliases=('linalg.matrix_exp',),
13934           sample_inputs_func=sample_inputs_matrix_exp,
13935           # Needs to construct a 2nx2n matrix by copy_ ing into it
13936           check_batched_grad=False,
13937           check_batched_gradgrad=False,
13938           supports_forward_ad=True,
13939           supports_fwgrad_bwgrad=True,
13940           # https://github.com/pytorch/pytorch/issues/66357
13941           check_batched_forward_grad=False,
13942           skips=(
13943               # mexp does not support bf16 and fp16
13944               DecorateInfo(unittest.skip('Skipped!'), 'TestInductorOpInfo', 'test_comprehensive',
13945                            dtypes=[torch.half], device_type="cpu"),
13946           ),
13947           supports_out=False,
13948           ),
13949    OpInfo('matmul',
13950           aliases=('linalg.matmul',),
13951           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
13952           dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
13953                                                       *[torch.bfloat16]
13954                                                       if SM53OrLater or TEST_WITH_ROCM else []),
13955           assert_autodiffed=True,
13956           assert_jit_shape_analysis=True,
13957           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
13958           gradcheck_fast_mode=True,
13959           supports_forward_ad=True,
13960           supports_fwgrad_bwgrad=True,
13961           check_batched_forward_grad=False,
13962           sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=False),
13963           decorators=[
13964               # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
13965               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
13966               # ROCm intermittently fails the test with standard atol/rtol
13967               DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}),
13968                            'TestCommon', 'test_noncontiguous_samples', device_type='cuda',
13969                            active_if=TEST_WITH_ROCM),
13970               DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}),
13971                            'TestCommon', 'test_out', device_type='cuda',
13972                            active_if=TEST_WITH_ROCM),
13973               # mv for the sample with shapes (S, S, M, M), (M,) has some variance in the
13974               # backward on CPU
13975               DecorateInfo(toleranceOverride({torch.float32: tol(atol=0, rtol=1e-5)}),
13976                            'TestCommon', 'test_noncontiguous_samples',
13977                            device_type='cpu'),
13978               DecorateInfo(
13979                   toleranceOverride({
13980                       torch.float32: tol(atol=1e-5, rtol=1e-5),
13981                       torch.complex64: tol(atol=1e-5, rtol=1e-5),
13982                   }),
13983                   "TestDecomp", "test_comprehensive", device_type="cuda",
13984               ),
13985           ],
13986           skips=(
13987               # Strides are not the same!
13988               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
13989               # https://github.com/pytorch/pytorch/issues/67470
13990               DecorateInfo(unittest.skip("67470!"),
13991                            'TestCommon', 'test_noncontiguous_samples',
13992                            device_type='cpu', dtypes=(torch.long,)),
13993               # AssertionError: False is not true : Tensors failed to compare as equal!
13994               DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo',
13995                            device_type='xla', dtypes=(torch.long,)),
13996               # https://github.com/pytorch/pytorch/issues/71774
13997               DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
13998                            device_type='cpu', dtypes=(torch.long,)),
13999           )),
14000    OpInfo('max',
14001           variant_test_name='reduction_with_dim',
14002           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
14003           sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
14004           supports_fwgrad_bwgrad=True,
14005           skips=(
14006           ),
14007           supports_forward_ad=True),
14008    OpInfo('max',
14009           variant_test_name='reduction_no_dim',
14010           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
14011           supports_out=True,
14012           supports_forward_ad=True,
14013           supports_fwgrad_bwgrad=True,
14014           sample_inputs_func=sample_inputs_max_min_reduction_no_dim,
14015           skips=(
14016           )),
14017    OpInfo('median',
14018           dtypes=all_types_and(torch.bfloat16, torch.float16),
14019           # TODO: some signatures of median do support out
14020           supports_out=False,
14021           supports_forward_ad=True,
14022           supports_fwgrad_bwgrad=True,
14023           error_inputs_func=error_inputs_median,
14024           sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
14025    OpInfo('nanmedian',
14026           dtypes=all_types_and(torch.bfloat16, torch.float16),
14027           # TODO: some signatures of nanmedian do support out
14028           supports_out=False,
14029           supports_forward_ad=True,
14030           supports_fwgrad_bwgrad=True,
14031           sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)),
14032    OpInfo('var_mean',
14033           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
14034           sample_inputs_func=sample_inputs_std_var,
14035           # TODO: some signatures of var_mean do support out
14036           supports_out=False,
14037           supports_forward_ad=True,
14038           check_batched_forward_grad=False,
14039           supports_fwgrad_bwgrad=True,
14040           decorators=(
14041               DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
14042                            "TestDecomp", "test_comprehensive", device_type="cuda"),
14043               DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
14044                            "TestInductorOpInfo", "test_comprehensive", device_type="cuda"),
14045           )),
14046    OpInfo('var_mean',
14047           variant_test_name='unbiased',
14048           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
14049           sample_inputs_func=sample_inputs_std_var_unbiased,
14050           # TODO: some signatures of var_mean do support out
14051           supports_out=False,
14052           supports_forward_ad=True,
14053           check_batched_forward_grad=False,
14054           supports_fwgrad_bwgrad=True,
14055           decorators=(
14056               DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
14057                            "TestDecomp", "test_comprehensive", device_type="cuda"),
14058               DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
14059                            "TestInductorOpInfo", "test_comprehensive", device_type="cuda"),
14060           )),
14061    OpInfo('std_mean',
14062           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
14063           sample_inputs_func=sample_inputs_std_var,
14064           # TODO: some signatures of std_mean do support out
14065           supports_out=False,
14066           supports_forward_ad=True,
14067           check_batched_forward_grad=False,
14068           supports_fwgrad_bwgrad=True,
14069           decorators=(
14070               DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
14071                            "TestDecomp", "test_comprehensive", device_type="cuda"),
14072           )),
14073    OpInfo('std_mean',
14074           variant_test_name='unbiased',
14075           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
14076           sample_inputs_func=sample_inputs_std_var_unbiased,
14077           # TODO: some signatures of var_mean do support out
14078           supports_out=False,
14079           supports_forward_ad=True,
14080           check_batched_forward_grad=False,
14081           supports_fwgrad_bwgrad=True,
14082           decorators=(
14083               DecorateInfo(
14084                   toleranceOverride({
14085                       torch.float16: tol(atol=4e-5, rtol=9e-3),
14086                       torch.float64: tol(atol=2e-7, rtol=2e-7),
14087                   }),
14088                   "TestDecomp",
14089                   "test_comprehensive",
14090                   device_type="cuda"
14091               ),
14092               DecorateInfo(
14093                   toleranceOverride({
14094                       torch.float16: tol(atol=4e-5, rtol=9e-3),
14095                       torch.float64: tol(atol=2e-7, rtol=2e-7),
14096                   }),
14097                   "TestInductorOpInfo",
14098                   "test_comprehensive",
14099                   device_type="cuda"
14100               ),
14101           )),
14102    OpInfo('meshgrid',
14103           variant_test_name='variadic_tensors',
14104           ref=np.meshgrid,
14105           dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16),
14106           sample_inputs_func=partial(sample_inputs_meshgrid, variant='variadic'),
14107           skips=[
14108               # JIT does not support variadic tensors.
14109               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
14110               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
14111               # please report a bug to PyTorch.
14112               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
14113               # meshgrid is defined in torch.functional to take a
14114               # variadic list of tensors. Variadic parameters are not
14115               # compatible with the normalize operator tests.
14116               DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
14117               # Skip operator schema test because this is a functional and not an operator
14118               DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
14119           ],
14120           supports_out=False,
14121           supports_fwgrad_bwgrad=True,
14122           supports_forward_ad=True,
14123           # See https://github.com/pytorch/pytorch/pull/78358
14124           check_batched_forward_grad=False,),
14125    OpInfo('meshgrid',
14126           variant_test_name='list_of_tensors',
14127           # Unlike the variant above, we do not use np.meshgrid as a
14128           # ref since it does not officially support list of numpy
14129           # arrays.
14130           dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16),
14131           sample_inputs_func=partial(sample_inputs_meshgrid, variant='list'),
14132           skips=[
14133               # meshgrid is defined in torch.functional to take a
14134               # variadic list of tensors. Variadic parameters are not
14135               # compatible with the normalize operator tests.
14136               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
14137           ],
14138           assert_autodiffed=True,
14139           supports_out=False,
14140           autodiff_nonfusible_nodes=[],
14141           supports_fwgrad_bwgrad=True,
14142           supports_forward_ad=True,
14143           # See https://github.com/pytorch/pytorch/pull/78358
14144           check_batched_forward_grad=False,),
14145    OpInfo('min',
14146           variant_test_name='reduction_with_dim',
14147           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
14148           sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
14149           supports_fwgrad_bwgrad=True,
14150           supports_forward_ad=True,
14151           skips=(
14152           )),
14153    OpInfo('min',
14154           variant_test_name='reduction_no_dim',
14155           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
14156           supports_out=True,
14157           supports_forward_ad=True,
14158           supports_fwgrad_bwgrad=True,
14159           sample_inputs_func=sample_inputs_max_min_reduction_no_dim,
14160           skips=(
14161           )),
14162    OpInfo('quantile',
14163           dtypes=floating_types(),
14164           sample_inputs_func=sample_inputs_reduction_quantile,
14165           supports_forward_ad=True,
14166           supports_fwgrad_bwgrad=True,
14167           # See https://github.com/pytorch/pytorch/issues/66357
14168           # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which
14169           # does not have a batching rule in core
14170           check_batched_forward_grad=False),
14171    OpInfo('nanquantile',
14172           dtypes=floating_types(),
14173           sample_inputs_func=sample_inputs_reduction_quantile,
14174           supports_forward_ad=True,
14175           supports_fwgrad_bwgrad=True,
14176           # See https://github.com/pytorch/pytorch/issues/66357
14177           # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which
14178           # does not have a batching rule in core
14179           check_batched_forward_grad=False),
14180    BinaryUfuncInfo(
14181        'max',
14182        aliases=('maximum',),
14183        variant_test_name='binary',
14184        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
14185        supports_forward_ad=True,
14186        supports_fwgrad_bwgrad=True,
14187        assert_autodiffed=True,
14188        ref=np.maximum,
14189        supports_rhs_python_scalar=False,
14190        skips=(
14191            # Incorrectly attempts to use a scalar for the second argument
14192            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
14193            # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
14194            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'),
14195        )),
14196    BinaryUfuncInfo(
14197        'maximum',
14198        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
14199        supports_forward_ad=True,
14200        supports_fwgrad_bwgrad=True,
14201        ref=np.maximum,
14202        supports_rhs_python_scalar=False,
14203        skips=(
14204            # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat'
14205            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'),
14206        )),
14207    BinaryUfuncInfo(
14208        'min',
14209        aliases=('minimum',),
14210        variant_test_name='binary',
14211        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
14212        supports_forward_ad=True,
14213        supports_fwgrad_bwgrad=True,
14214        assert_autodiffed=True,
14215        ref=np.minimum,
14216        supports_rhs_python_scalar=False,
14217        skips=(
14218            # Incorrectly attempts to use a scalar for the second argument
14219            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
14220            # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
14221            DecorateInfo(unittest.expectedFailure,
14222                         'TestBinaryUfuncs',
14223                         'test_type_promotion',
14224                         device_type='cuda'),
14225        )),
14226    BinaryUfuncInfo(
14227        'minimum',
14228        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
14229        supports_forward_ad=True,
14230        supports_fwgrad_bwgrad=True,
14231        ref=np.minimum,
14232        supports_rhs_python_scalar=False,
14233        skips=(
14234            # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat'
14235            DecorateInfo(unittest.expectedFailure,
14236                         'TestBinaryUfuncs',
14237                         'test_type_promotion',
14238                         device_type='cuda'),
14239        ),
14240    ),
14241    BinaryUfuncInfo('logical_and',
14242                    ref=np.logical_and,
14243                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
14244                    supports_autograd=False,
14245                    always_returns_bool=True,
14246                    supports_rhs_python_scalar=False),
14247    BinaryUfuncInfo('logical_or',
14248                    ref=np.logical_or,
14249                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
14250                    supports_autograd=False,
14251                    always_returns_bool=True,
14252                    supports_rhs_python_scalar=False),
14253    BinaryUfuncInfo('logical_xor',
14254                    ref=np.logical_xor,
14255                    dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
14256                    supports_autograd=False,
14257                    always_returns_bool=True,
14258                    supports_rhs_python_scalar=False,
14259                    skips=(
14260                    )),
14261    BinaryUfuncInfo('bitwise_and',
14262                    ref=np.bitwise_and,
14263                    dtypes=integral_types_and(torch.bool),
14264                    operator_variant=operator.and_,
14265                    inplace_operator_variant=operator.iand,
14266                    supports_autograd=False,
14267                    supports_one_python_scalar=True,
14268                    skips=(
14269                        # RuntimeError: "bitwise_and_cuda" not implemented for 'Half'
14270                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs',
14271                                     'test_type_promotion', device_type='cuda'),
14272                    )),
14273    BinaryUfuncInfo('bitwise_or',
14274                    ref=np.bitwise_or,
14275                    dtypes=integral_types_and(torch.bool),
14276                    operator_variant=operator.or_,
14277                    inplace_operator_variant=operator.ior,
14278                    supports_autograd=False,
14279                    supports_one_python_scalar=True,
14280                    skips=(
14281                        # TODO: FIXME: RuntimeError: "bitwise_or_cuda" not implemented for 'Half'
14282                        DecorateInfo(unittest.expectedFailure,
14283                                     'TestBinaryUfuncs',
14284                                     'test_type_promotion',
14285                                     device_type='cuda'),
14286                    )),
14287    BinaryUfuncInfo('bitwise_xor',
14288                    ref=np.bitwise_xor,
14289                    dtypes=integral_types_and(torch.bool),
14290                    operator_variant=operator.xor,
14291                    inplace_operator_variant=operator.ixor,
14292                    supports_autograd=False,
14293                    supports_one_python_scalar=True,
14294                    skips=(
14295                        # TODO: FIXME: RuntimeError: "bitwise_xor_cuda" not implemented for 'Half'
14296                        DecorateInfo(unittest.expectedFailure,
14297                                     'TestBinaryUfuncs',
14298                                     'test_type_promotion',
14299                                     device_type='cuda'),
14300                    )),
14301    BinaryUfuncInfo('heaviside',
14302                    ref=lambda a, b: (
14303                        # necessary because np.heaviside incorrectly returns float64 when passed args of dtype int64
14304                        np.int64(np.heaviside(a, b)) if a.dtype == np.int64 and b.dtype == np.int64 else np.heaviside(a, b)
14305                    ),
14306                    dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
14307                    supports_autograd=False,
14308                    supports_rhs_python_scalar=False,
14309                    skips=(
14310                        # RuntimeError: heaviside is not yet implemented for tensors with different dtypes.
14311                        DecorateInfo(unittest.expectedFailure,
14312                                     'TestBinaryUfuncs',
14313                                     'test_type_promotion'),
14314                        DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
14315                        # PyTorch's heaviside does not appear to propagate NaNs
14316                        DecorateInfo(unittest.skip("Skipped!"),
14317                                     'TestBinaryUfuncs',
14318                                     'test_reference_numerics_extremal_values'),
14319                    )),
14320    BinaryUfuncInfo('lcm',
14321                    ref=np.lcm,
14322                    dtypes=integral_types_and(),
14323                    supports_autograd=False,
14324                    supports_rhs_python_scalar=False),
14325    BinaryUfuncInfo('gcd',
14326                    ref=np.gcd,
14327                    dtypes=integral_types_and(),
14328                    supports_autograd=False,
14329                    supports_rhs_python_scalar=False,
14330                    skips=(
14331                        DecorateInfo(unittest.expectedFailure,
14332                                     'TestBinaryUfuncs',
14333                                     'test_reference_numerics_small_values',
14334                                     dtypes=(torch.int8,)),)),
14335    BinaryUfuncInfo('isclose',
14336                    ref=np.isclose,
14337                    dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
14338                    sample_inputs_func=sample_inputs_isclose,
14339                    error_inputs_func=error_inputs_isclose,
14340                    supports_autograd=False,
14341                    supports_out=False,
14342                    supports_rhs_python_scalar=False,
14343                    skips=(
14344                        DecorateInfo(unittest.expectedFailure,
14345                                     'TestCommon',
14346                                     'test_numpy_refs', dtypes=(torch.complex128,)),
14347                        # RuntimeError: Short did not match Int
14348                        DecorateInfo(unittest.expectedFailure,
14349                                     'TestBinaryUfuncs',
14350                                     'test_type_promotion'),
14351                        DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
14352                        DecorateInfo(unittest.skip("Skipped!"),
14353                                     'TestBinaryUfuncs',
14354                                     'test_reference_numerics_extremal_values'),
14355                    )),
14356    # `softmax` supports different dtypes based on whether `dtype` argument,
14357    # is passed or not. Hence two OpInfo entries, one with dtype and other without.
14358    # https://github.com/pytorch/pytorch/issues/68752
14359    OpInfo('softmax',
14360           aliases=('special.softmax', 'nn.functional.softmax',),
14361           aten_name='softmax',
14362           aten_backward_name='_softmax_backward_data',
14363           dtypes=floating_types_and(torch.half, torch.bfloat16),
14364           sample_inputs_func=sample_inputs_softmax_variant,
14365           assert_jit_shape_analysis=True,
14366           assert_autodiffed=True,
14367           supports_forward_ad=True,
14368           supports_fwgrad_bwgrad=True,
14369           supports_out=True),
14370    OpInfo('softmax',
14371           aliases=('special.softmax', 'nn.functional.softmax',),
14372           variant_test_name="with_dtype",
14373           aten_name='softmax',
14374           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
14375           sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
14376           assert_autodiffed=True,
14377           supports_forward_ad=True,
14378           supports_fwgrad_bwgrad=True,
14379           supports_out=True),
14380    OpInfo(
14381        '_softmax_backward_data',
14382        op=torch.ops.aten._softmax_backward_data,
14383        aten_name='_softmax_backward_data',
14384        dtypes=floating_types_and(torch.bfloat16, torch.float16),
14385        sample_inputs_func=sample_inputs_softmax_backward_data,
14386        assert_autodiffed=True,
14387        supports_forward_ad=True,
14388        supports_fwgrad_bwgrad=True,
14389        supports_out=False,
14390        skips=(
14391            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cpu'),
14392            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
14393        ),
14394    ),
14395    # `softmin` supports different dtypes based on whether `dtype` argument,
14396    # is passed or not. Hence two OpInfo entries, one with dtype and other without.
14397    # https://github.com/pytorch/pytorch/issues/68752
14398    OpInfo('nn.functional.softmin',
14399           aten_name='softmin',
14400           dtypes=floating_types_and(torch.half, torch.bfloat16),
14401           sample_inputs_func=sample_inputs_softmax_variant,
14402           assert_jit_shape_analysis=False,
14403           assert_autodiffed=False,
14404           supports_forward_ad=True,
14405           supports_fwgrad_bwgrad=True,
14406           supports_out=False),
14407    OpInfo('nn.functional.softmin',
14408           variant_test_name="with_dtype",
14409           aten_name='softmin',
14410           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
14411           sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
14412           assert_autodiffed=False,
14413           supports_forward_ad=True,
14414           supports_fwgrad_bwgrad=True,
14415           supports_out=False),
14416    OpInfo(
14417        "nn.functional.cross_entropy",
14418        dtypes=floating_types_and(torch.float16, torch.bfloat16),
14419        sample_inputs_func=sample_inputs_cross_entropy,
14420        supports_out=False,
14421        supports_forward_ad=True,
14422        supports_fwgrad_bwgrad=True,
14423        decorators=(
14424            DecorateInfo(
14425                toleranceOverride({torch.float32: tol(atol=3e-3, rtol=1e-3)}),
14426                "TestJit",
14427                "test_variant_consistency_jit",
14428                device_type="cpu",
14429            ),
14430        ),
14431        skips=(
14432            # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 1536
14433            # test_ops.TestJitCUDA.test_variant_consistency_jit_nn_functional_cross_entropy_cuda_float32 leaked
14434            # 1536 bytes CUDA memory on device 0
14435            DecorateInfo(
14436                unittest.expectedFailure,
14437                "TestJit",
14438                "test_variant_consistency_jit",
14439                device_type="cuda",
14440            ),
14441            DecorateInfo(unittest.skip("FP16 corss_entropy cases have not been enabled on MPS yet"),
14442                         dtypes=(torch.half,), device_type="mps"),
14443
14444        )
14445    ),
14446    OpInfo('nn.functional.normalize',
14447           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
14448           sample_inputs_func=sample_inputs_normalize,
14449           supports_forward_ad=True,
14450           supports_fwgrad_bwgrad=True),
14451    OpInfo('aminmax',
14452           ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)),
14453           dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
14454           decorators=(onlyNativeDeviceTypes,),
14455           supports_autograd=False,
14456           sample_inputs_func=sample_inputs_aminmax,
14457           error_inputs_func=error_inputs_aminmax_amax_amin),
14458    OpInfo('as_strided',
14459           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
14460           supports_out=False,
14461           supports_forward_ad=True,
14462           supports_fwgrad_bwgrad=True,
14463           # vmap does not support inplace views
14464           check_inplace_batched_forward_grad=False,
14465           sample_inputs_func=sample_inputs_as_strided,
14466           skips=(
14467               # Note: This xfail is fine -- it's inherent to how as_strided works
14468               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
14469               # AssertionError: False is not true : Scalars failed to compare as equal!
14470               DecorateInfo(unittest.skip("Errors when storage_offset is included"),
14471                            'TestCommon', 'test_variant_consistency_eager'),
14472               # Not close
14473               DecorateInfo(unittest.skip("Errors when storage_offset is included"),
14474                            'TestCommon', 'test_complex_half_reference_testing'),
14475               # Not close
14476               DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
14477               DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
14478               DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'),
14479               DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'),
14480           )),
14481    OpInfo('as_strided',
14482           variant_test_name='partial_views',
14483           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
14484           supports_out=False,
14485           supports_forward_ad=True,
14486           supports_fwgrad_bwgrad=True,
14487           # vmap does not support inplace views
14488           check_inplace_batched_forward_grad=False,
14489           sample_inputs_func=sample_inputs_as_strided_partial_views,
14490           skips=(
14491               # Note: This xfail is fine -- it's inherent to how as_strided works
14492               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
14493               # RuntimeError: This operator is not Composite Compliant: the
14494               # storage_offset of the tensor was modified directly without
14495               # going through the PyTorch dispatcher.
14496               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
14497               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
14498
14499               # These fail because the test changes the input's in-memory layout
14500               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'),
14501               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
14502               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
14503               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
14504               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad',
14505                            dtypes=(torch.complex64, torch.complex128)),
14506               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
14507               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'),
14508               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_grad'),
14509               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_gradgrad'),
14510               DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo',
14511                            'test_make_fx_symbolic_exhaustive_inplace'),
14512               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
14513               # Fail but are also flaky
14514               DecorateInfo(unittest.skip("Test changes in memory layout"), 'TestMathBits'),
14515               DecorateInfo(unittest.skip("Modifies input strides and storage_offset"), 'TestCommon',
14516                            'test_non_standard_bool_values'),
14517               # RuntimeError: setStorage: sizes [2, 2], strides [1, 2], storage offset 10, and itemsize 2 requiring a
14518               # storage size of 28 are out of bounds for storage of size 20
14519               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace'),
14520               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace'),
14521               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace'),
14522               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides'),
14523           )),
14524    OpInfo('as_strided_copy',
14525           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
14526           supports_out=True,
14527           supports_forward_ad=True,
14528           supports_fwgrad_bwgrad=True,
14529           # vmap does not support inplace views
14530           check_inplace_batched_forward_grad=False,
14531           sample_inputs_func=sample_inputs_as_strided,
14532           skips=(
14533               # Note: This xfail is fine -- it's inherent to how as_strided works
14534               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
14535               # AssertionError: False is not true : Scalars failed to compare as equal!
14536               DecorateInfo(unittest.skip("Errors when storage_offset is included"),
14537                            'TestCommon', 'test_variant_consistency_eager'),
14538               # Not close
14539               DecorateInfo(unittest.skip("Errors when storage_offset is included"),
14540                            'TestCommon', 'test_complex_half_reference_testing'),
14541               # Not close
14542               DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
14543               DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
14544               DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'),
14545               DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'),
14546               DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
14547           )),
14548    OpInfo('as_strided_scatter',
14549           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
14550           supports_out=False,
14551           supports_forward_ad=True,
14552           supports_fwgrad_bwgrad=True,
14553           # vmap does not support inplace views
14554           check_inplace_batched_forward_grad=False,
14555           sample_inputs_func=sample_inputs_as_strided_scatter,
14556           error_inputs_func=error_inputs_as_strided_scatter,
14557           skips=(
14558               DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'),  # noqa: B950
14559               DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'),  # noqa: B950
14560               DecorateInfo(unittest.skip('Fails on cuda + rocm'), 'TestCommon', 'test_complex_half_reference_testing'),
14561               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'),
14562               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
14563               DecorateInfo(unittest.skip('Passes on complex128 and float64 only'), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
14564               # AssertionError: Tensor-likes are not close! (new_empty_strided.default)
14565               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),)),
14566    OpInfo('native_layer_norm',
14567           aten_name='native_layer_norm',
14568           ref=reference_native_layer_norm,
14569           dtypes=floating_types_and(torch.half, torch.bfloat16),
14570           supports_out=False,
14571           assert_jit_shape_analysis=True,
14572           supports_fwgrad_bwgrad=True,
14573           sample_inputs_func=sample_inputs_native_layer_norm,
14574           error_inputs_func=error_inputs_native_layer_norm,
14575           skips=(
14576               # IndexError: tuple index out of range
14577               DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients', 'test_forward_mode_AD'),
14578               # Tests fail when weight=None and bias is defined
14579               # https://github.com/pytorch/pytorch/issues/79705
14580               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'),
14581               # JIT test also tries to compute double backward, which fails
14582               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
14583               DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
14584               DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-03, rtol=5e-03)}),
14585                            "TestDecomp", "test_comprehensive", device_type="cpu"),
14586           )),
14587    OpInfo('native_batch_norm',
14588           aten_name='native_batch_norm',
14589           dtypes=floating_types_and(torch.float16, torch.bfloat16),
14590           supports_forward_ad=True,
14591           supports_fwgrad_bwgrad=True,
14592           assert_jit_shape_analysis=True,
14593           allow_cow_input_materialize_forward=[3, 4],
14594           allow_cow_input_materialize_backward=[3, 4],
14595           sample_inputs_func=sample_inputs_native_batch_norm,
14596           skips=(
14597               # NotImplementedError: Could not run
14598               # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
14599               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
14600               # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
14601               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
14602               # Problem with _get_numerical_jacobian
14603               # IndexError: tuple index out of range
14604               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
14605               # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
14606               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
14607               # https://github.com/pytorch/pytorch/issues/85960
14608               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
14609               # AssertionError: Booleans mismatch: True is not False
14610               DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'),
14611               DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'),
14612               DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
14613                            "TestCompositeCompliance", "test_forward_ad"),
14614           )
14615           ),
14616    OpInfo('_native_batch_norm_legit',
14617           aten_name='_native_batch_norm_legit',
14618           dtypes=floating_types_and(torch.float16, torch.bfloat16),
14619           supports_forward_ad=True,
14620           supports_fwgrad_bwgrad=True,
14621           assert_jit_shape_analysis=True,
14622           allow_cow_input_materialize_forward=[3, 4],
14623           allow_cow_input_materialize_backward=[3, 4],
14624           sample_inputs_func=sample_inputs__native_batch_norm_legit,
14625           skips=(
14626               # NotImplementedError: Could not run
14627               # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
14628               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
14629               # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
14630               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
14631               # Problem with _get_numerical_jacobian
14632               # IndexError: tuple index out of range
14633               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
14634               # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
14635               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
14636               # https://github.com/pytorch/pytorch/issues/85960
14637               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
14638               DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
14639                            "TestCompositeCompliance", "test_forward_ad"),
14640           )
14641           ),
14642    OpInfo('_batch_norm_with_update',
14643           op=torch.ops.aten._batch_norm_with_update,
14644           aten_name='_batch_norm_with_update',
14645           dtypes=floating_types_and(torch.float16, torch.bfloat16),
14646           supports_forward_ad=True,
14647           supports_fwgrad_bwgrad=True,
14648           assert_jit_shape_analysis=True,
14649           allow_cow_input_materialize_forward=[3, 4],
14650           allow_cow_input_materialize_backward=[3, 4],
14651           sample_inputs_func=sample_inputs__batch_norm_with_update,
14652           skips=(
14653               # NotImplementedError: Could not run
14654               # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
14655               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
14656               # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
14657               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
14658               # Problem with _get_numerical_jacobian
14659               # IndexError: tuple index out of range
14660               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
14661               # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
14662               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
14663               DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}),
14664                            "TestCompositeCompliance", "test_forward_ad"),
14665               # _batch_norm_with_update expects contiguous inputs for cudnn and miopen
14666               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type="cuda"),
14667               DecorateInfo(unittest.expectedFailure,
14668                            'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides', device_type="cuda"),
14669               # _batch_norm_with_update does not have python bindings
14670               DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
14671               # aten out variants do not accept out= kwarg, only python out variants
14672               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
14673               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
14674           )
14675           ),
14676    OpInfo('nn.functional.cosine_similarity',
14677           aten_name="cosine_similarity",
14678           dtypes=floating_types_and(torch.half, torch.bfloat16),
14679           supports_out=False,
14680           supports_forward_ad=True,
14681           supports_fwgrad_bwgrad=True,
14682           decorators=[
14683               DecorateInfo(
14684                   toleranceOverride({torch.float16: tol(atol=1.3e-5, rtol=2e-2)}),
14685                   "TestInductorOpInfo",
14686                   "test_comprehensive",
14687                   device_type="cuda"
14688               ),
14689           ],
14690           sample_inputs_func=sample_inputs_cosine_similarity),
14691    OpInfo('nn.functional.adaptive_avg_pool1d',
14692           dtypes=floating_types_and(torch.half, torch.bfloat16),
14693           supports_out=False,
14694           supports_forward_ad=True,
14695           supports_fwgrad_bwgrad=True,
14696           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14697           error_inputs_func=error_inputs_adaptive_avg_pool1d,
14698           sample_inputs_func=sample_inputs_adaptive_avg_pool1d),
14699    OpInfo('nn.functional.adaptive_avg_pool2d',
14700           dtypes=floating_types_and(torch.half, torch.bfloat16),
14701           decorators=(
14702               # RuntimeError:
14703               # adaptive_avg_pool2d(Tensor input, int[2] output_size) -> (Tensor):
14704               # Expected a value of type 'List[int]' for argument 'output_size' but
14705               # instead found type 'Tuple[NoneType, int]'. :
14706               #   File "<string>", line 3
14707               # def the_method(i0):
14708               #     return torch.nn.functional.adaptive_avg_pool2d(i0, (None, 7))
14709               #            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
14710               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
14711           ),
14712           supports_out=False,
14713           supports_forward_ad=True,
14714           supports_fwgrad_bwgrad=True,
14715           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14716           error_inputs_func=error_inputs_adaptive_avg_pool2d,
14717           sample_inputs_func=sample_inputs_adaptive_avg_pool2d),
14718    OpInfo('nn.functional.adaptive_avg_pool3d',
14719           dtypes=floating_types_and(torch.half, torch.bfloat16),
14720           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
14721           decorators=(
14722               # RuntimeError:
14723               # adaptive_avg_pool3d(Tensor input, int[3] output_size) -> (Tensor):
14724               # Expected a value of type 'List[int]' for argument 'output_size' but
14725               # instead found type 'Tuple[NoneType, NoneType, NoneType]'. :
14726               #   File "<string>", line 3
14727               #
14728               # def the_method(i0):
14729               #     return torch.nn.functional.adaptive_avg_pool3d(i0, (None, None, None))
14730               #            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
14731               #
14732               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
14733           ),
14734           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
14735           gradcheck_fast_mode=True,
14736           supports_out=False,
14737           supports_forward_ad=True,
14738           supports_fwgrad_bwgrad=True,
14739           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14740           error_inputs_func=error_inputs_adaptive_avg_pool3d,
14741           sample_inputs_func=sample_inputs_adaptive_avg_pool3d),
14742    OpInfo('nn.functional.adaptive_max_pool1d',
14743           dtypes=floating_types_and(torch.half, torch.bfloat16),
14744           supports_out=False,
14745           supports_forward_ad=True,
14746           supports_fwgrad_bwgrad=True,
14747           # got: Batching rule not implemented for aten::flatten.using_ints
14748           check_batched_forward_grad=False,
14749           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14750           error_inputs_func=error_inputs_adaptive_max_pool1d,
14751           sample_inputs_func=sample_inputs_adaptive_max_pool1d),
14752    OpInfo('nn.functional.adaptive_max_pool2d',
14753           dtypes=floating_types_and(torch.half, torch.bfloat16),
14754           decorators=(
14755               # RuntimeError:
14756               # adaptive_max_pool2d(Tensor input, int[2] output_size) -> (Tensor):
14757               # Expected a value of type 'List[int]' for argument 'output_size' but
14758               # instead found type 'Tuple[NoneType, int]'. :
14759               #   File "<string>", line 3
14760               # def the_method(i0):
14761               #     return torch.nn.functional.adaptive_max_pool2d(i0, (None, 7))
14762               #            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
14763               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
14764           ),
14765           supports_out=False,
14766           supports_forward_ad=True,
14767           supports_fwgrad_bwgrad=True,
14768           # got: Batching rule not implemented for aten::flatten.using_ints
14769           check_batched_forward_grad=False,
14770           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14771           error_inputs_func=error_inputs_adaptive_max_pool2d,
14772           sample_inputs_func=sample_inputs_adaptive_max_pool2d),
14773    OpInfo('nn.functional.adaptive_max_pool3d',
14774           dtypes=floating_types_and(torch.bfloat16),
14775           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
14776           decorators=(
14777               # RuntimeError:
14778               # adaptive_max_pool3d(Tensor input, int[3] output_size) -> (Tensor):
14779               # Expected a value of type 'List[int]' for argument 'output_size' but
14780               # instead found type 'Tuple[NoneType, NoneType, NoneType]'. :
14781               #   File "<string>", line 3
14782               #
14783               # def the_method(i0):
14784               #     return torch.nn.functional.adaptive_max_pool3d(i0, (None, None, None))
14785               #            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
14786               #
14787               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
14788           ),
14789           supports_out=False,
14790           supports_forward_ad=True,
14791           supports_fwgrad_bwgrad=True,
14792           # got: Batching rule not implemented for aten::flatten.using_ints
14793           check_batched_forward_grad=False,
14794           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14795           error_inputs_func=error_inputs_adaptive_max_pool3d,
14796           sample_inputs_func=sample_inputs_adaptive_max_pool3d),
14797    OpInfo('nn.functional.avg_pool1d',
14798           aten_name='avg_pool1d',
14799           supports_autograd=True,
14800           supports_out=False,
14801           supports_forward_ad=True,
14802           supports_fwgrad_bwgrad=True,
14803           dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16),
14804           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
14805           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14806           error_inputs_func=error_inputs_avg_pool1d,
14807           sample_inputs_func=sample_inputs_avgpool1d),
14808    OpInfo('nn.functional.avg_pool3d',
14809           aten_name='avg_pool3d',
14810           supports_autograd=True,
14811           supports_forward_ad=True,
14812           supports_fwgrad_bwgrad=True,
14813           dtypes=floating_types_and(torch.int64),
14814           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
14815           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14816           error_inputs_func=error_inputs_avg_pool3d,
14817           sample_inputs_func=sample_inputs_avgpool3d,
14818           skips=(
14819               # AssertionError: Tensor-likes are not close!
14820               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
14821           )),
14822    OpInfo(
14823        "nn.functional.binary_cross_entropy_with_logits",
14824        aten_name="binary_cross_entropy_with_logits",
14825        supports_autograd=True,
14826        supports_forward_ad=True,
14827        supports_fwgrad_bwgrad=True,
14828        supports_out=False,
14829        dtypes=floating_types_and(torch.half, torch.bfloat16),
14830        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14831        sample_inputs_func=sample_inputs_binary_cross_entropy_with_logits,
14832        skips=(
14833            DecorateInfo(
14834                unittest.skip("Skipped!"),
14835                'TestJit',
14836                'test_variant_consistency_jit',
14837                dtypes=(torch.float32,)
14838            ),
14839        ),
14840    ),
14841    UnaryUfuncInfo(
14842        'nn.functional.relu',
14843        aten_name="relu",
14844        ref=lambda a: np.where(a <= 0, 0, a),
14845        supports_autograd=True,
14846        supports_sparse=True,
14847        supports_sparse_csr=True,
14848        supports_sparse_csc=True,
14849        supports_sparse_bsr=True,
14850        supports_sparse_bsc=True,
14851        dtypes=all_types_and(torch.half, torch.bfloat16),
14852        sample_inputs_func=sample_inputs_nn_activation_relu,
14853        supports_out=False,
14854        supports_fwgrad_bwgrad=True,
14855        supports_forward_ad=True),
14856    OpInfo('nn.functional.conv_transpose1d',
14857           # `ref` for this function is backward of
14858           # corresponding `conv*d`
14859           ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose1d),
14860           aten_name='conv_transpose1d',
14861           aliases=('conv_transpose1d',),
14862           dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16),
14863           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
14864                                                       torch.bfloat16),
14865           sample_inputs_func=sample_inputs_conv_transpose1d,
14866           supports_forward_ad=True,
14867           supports_fwgrad_bwgrad=True,
14868           assert_jit_shape_analysis=True,
14869           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14870           decorators=(
14871               DecorateInfo(
14872                   toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }),
14873                   'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
14874               DecorateInfo(
14875                   toleranceOverride({torch.chalf: tol(atol=5e-2, rtol=5e-2), }),
14876                   'TestCommon', 'test_complex_half_reference_testing'),
14877               DecorateInfo(
14878                   toleranceOverride({torch.float: tol(atol=1.5e-5, rtol=1.5e-5), }),
14879                   'TestCommon', 'test_numpy_ref_mps'),
14880               DecorateInfo(
14881                   toleranceOverride({torch.half: tol(atol=1e-3, rtol=5e-3), }),
14882                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
14883           ),
14884           skips=(
14885               # Reason for Skip: https://github.com/pytorch/pytorch/pull/79694#issuecomment-1186949486
14886               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
14887                            dtypes=(torch.complex64,)),
14888               # RuntimeError: UNSUPPORTED DTYPE: complex
14889               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
14890                            dtypes=(torch.complex64, torch.complex128)),
14891               # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
14892               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
14893               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
14894                            dtypes=(torch.float,)),
14895               # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long'
14896               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
14897                            dtypes=(torch.int64,)),
14898           ),
14899           supports_out=False,),
14900    OpInfo('nn.functional.conv_transpose2d',
14901           aten_name='conv_transpose2d',
14902           aliases=('conv_transpose2d',),
14903           # `ref` for this function is backward of
14904           # corresponding `conv*d`
14905           ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose2d),
14906           dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16),
14907           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
14908                                                       torch.bfloat16),
14909           sample_inputs_func=sample_inputs_conv_transpose2d,
14910           # Runs very slowly on slow-gradcheck for complex.
14911           gradcheck_fast_mode=True,
14912           supports_forward_ad=True,
14913           supports_fwgrad_bwgrad=True,
14914           assert_jit_shape_analysis=True,
14915           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14916           decorators=[
14917               DecorateInfo(
14918                   toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }),
14919                   'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
14920               DecorateInfo(
14921                   toleranceOverride({torch.float32: tol(atol=2e-05, rtol=5e-05), }),
14922                   'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
14923               DecorateInfo(
14924                   toleranceOverride({torch.chalf: tol(atol=8e-2, rtol=8e-2), }),
14925                   'TestCommon', 'test_complex_half_reference_testing'),
14926               DecorateInfo(
14927                   toleranceOverride({torch.half: tol(atol=1e-3, rtol=4e-3), }),
14928                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')],
14929           skips=(
14930               # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
14931               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
14932               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
14933               # RuntimeError: UNSUPPORTED DTYPE: complex
14934               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
14935                            dtypes=(torch.complex64, torch.complex128)),
14936               # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long'
14937               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
14938                            dtypes=(torch.int64,)),
14939               # Reference: https://github.com/pytorch/pytorch/issues/86356
14940               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
14941                            dtypes=(torch.double, torch.cdouble)),
14942               DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
14943               # AssertionError: None mismatch: torch.complex64 is not None
14944               DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', 'test_custom_rules',
14945                            dtypes=(torch.complex64, torch.complex128)),
14946           ),
14947           supports_out=False,),
14948    OpInfo('nn.functional.conv_transpose3d',
14949           aten_name='conv_transpose3d',
14950           aliases=('conv_transpose3d',),
14951           # `ref` for this function is backward of
14952           # corresponding `conv*d`
14953           ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose3d),
14954           dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16),
14955           dtypesIfCUDA=floating_and_complex_types_and(
14956               torch.float16, torch.chalf, torch.bfloat16),
14957           sample_inputs_func=sample_inputs_conv_transpose3d,
14958           supports_forward_ad=True,
14959           supports_fwgrad_bwgrad=True,
14960           assert_jit_shape_analysis=True,
14961           # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
14962           gradcheck_fast_mode=True,
14963           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
14964           decorators=[
14965               DecorateInfo(
14966                   toleranceOverride({torch.float16: tol(atol=5e-2, rtol=5e-2), }),
14967                   'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'),
14968               DecorateInfo(
14969                   toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06),
14970                                     torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}),
14971                   'TestCommon', 'test_variant_consistency_eager', device_type='cuda'),
14972               DecorateInfo(
14973                   toleranceOverride({torch.float32: tol(atol=2e-04, rtol=2e-04), }),
14974                   'TestCompositeCompliance', 'test_operator', device_type='cuda'),
14975               DecorateInfo(
14976                   toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06),
14977                                     torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}),
14978                   'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
14979               DecorateInfo(
14980                   toleranceOverride({torch.float32: tol(atol=1e-04, rtol=2e-05), }),
14981                   'TestCompositeCompliance', 'test_forward_ad', device_type='cuda',
14982                   active_if=TEST_CUDNN),
14983               DecorateInfo(
14984                   toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1e-4)}),
14985                   "TestMathBits", "test_conj_view", device_type='cuda'),
14986               DecorateInfo(
14987                   toleranceOverride({torch.chalf: tol(atol=9e-2, rtol=9e-2), }),
14988                   'TestCommon', 'test_complex_half_reference_testing'),
14989               DecorateInfo(
14990                   toleranceOverride({torch.half: tol(atol=9e-3, rtol=2e-1), }),
14991                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')],
14992           skips=(
14993               # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
14994               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch.
14995               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
14996               # RuntimeError: "slow_conv3d_cpu_grad_input" not implemented for 'Long'
14997               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
14998                            dtypes=(torch.int64,)),
14999               # Reference: https://github.com/pytorch/pytorch/issues/86356
15000               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref',
15001                            dtypes=(torch.double, torch.cdouble)),
15002               DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
15003               # RuntimeError: UNSUPPORTED DTYPE: complex
15004               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness',
15005                            dtypes=(torch.complex64, torch.complex128)),
15006               DecorateInfo(unittest.skip('Skipped for ROCm!'), 'TestCommon', 'test_complex_half_reference_testing',
15007                            dtypes=[torch.complex32], active_if=TEST_WITH_ROCM),
15008           ),
15009           supports_out=False,),
15010    OpInfo('nn.functional.conv1d',
15011           aliases=('conv1d',),
15012           aten_name='conv1d',
15013           dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16),
15014           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
15015                                                       torch.bfloat16),
15016           sample_inputs_func=sample_inputs_conv1d,
15017           error_inputs_func=error_inputs_conv1d,
15018           supports_forward_ad=True,
15019           supports_fwgrad_bwgrad=True,
15020           assert_jit_shape_analysis=True,
15021           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15022           decorators=(
15023               DecorateInfo(
15024                   toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=5e-2)}),
15025                   'TestCommon', 'test_complex_half_reference_testing'
15026               ),
15027               DecorateInfo(
15028                   toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}),
15029                   'TestInductorOpInfo', 'test_comprehensive', device_type='cuda',
15030               ),
15031           ),
15032           skips=(
15033               # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
15034               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
15035               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
15036               # Ref: https://github.com/pytorch/pytorch/issues/75309
15037               # AssertionError: None mismatch: torch.complex128 is not None
15038               DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules',
15039                            'test_custom_rules', dtypes=(torch.complex64, torch.complex128)),
15040               # Ref: https://github.com/pytorch/pytorch/issues/75309
15041               # RuntimeError: UNSUPPORTED DTYPE: complex
15042               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo',
15043                            'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)),
15044           ),
15045           supports_expanded_weight=True,
15046           supports_out=False,),
15047    OpInfo('nn.functional.conv2d',
15048           aliases=('conv2d',),
15049           aten_name='conv2d',
15050           dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16),
15051           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf,
15052                                                       torch.bfloat16),
15053           sample_inputs_func=partial(sample_inputs_conv2d),
15054           error_inputs_func=error_inputs_conv2d,
15055           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15056           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
15057           gradcheck_fast_mode=True,
15058           supports_forward_ad=True,
15059           supports_fwgrad_bwgrad=True,
15060           assert_jit_shape_analysis=True,
15061           decorators=(
15062               DecorateInfo(
15063                   toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}),
15064                   'TestCommon', 'test_complex_half_reference_testing',
15065               ),
15066           ),
15067           skips=(
15068               # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
15069               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
15070               DecorateInfo(unittest.skip("Works on some configs!"), 'TestJit', 'test_variant_consistency_jit'),
15071               # Ref: https://github.com/pytorch/pytorch/issues/75309
15072               # AssertionError: None mismatch: torch.complex128 is not None
15073               DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules',
15074                            'test_custom_rules', dtypes=(torch.complex64, torch.complex128)),
15075               # RuntimeError: UNSUPPORTED DTYPE: complex
15076               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo',
15077                            'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)),
15078           ),
15079           supports_expanded_weight=True,
15080           supports_out=False,),
15081    OpInfo('nn.functional.conv3d',
15082           aliases=('conv3d',),
15083           aten_name='conv3d',
15084           dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16, torch.float16),
15085           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, torch.bfloat16),
15086           sample_inputs_func=sample_inputs_conv3d,
15087           error_inputs_func=error_inputs_conv3d,
15088           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15089           gradcheck_fast_mode=True,
15090           supports_forward_ad=True,
15091           supports_fwgrad_bwgrad=True,
15092           decorators=(
15093               DecorateInfo(
15094                   toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}),
15095                   'TestCommon', 'test_complex_half_reference_testing',
15096               ),
15097               # TF32
15098               DecorateInfo(
15099                   toleranceOverride({torch.float32: tol(atol=5e-3, rtol=1e-3),
15100                                     torch.complex64: tol(atol=5e-3, rtol=1e-3)}),
15101                   'TestCommon', 'test_noncontiguous_samples',
15102               ),
15103               DecorateInfo(
15104                   toleranceOverride({torch.complex64: tol(atol=5e-5, rtol=5e-6)}),
15105                   'TestMathBits', 'test_conj_view',
15106               ),
15107               DecorateInfo(
15108                   toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-6)}),
15109                   'TestOperators', 'test_vjpvmap',
15110               ),
15111           ),
15112           skips=(
15113               # RuntimeError: !lhs.isAliasOf(rhs) INTERNAL ASSERT FAILED at
15114               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch.
15115               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
15116               # RuntimeError: UNSUPPORTED DTYPE: complex
15117               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo',
15118                            'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)),
15119               # AssertionError: Tensor-likes are not close!
15120               # break slow tests
15121               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'),
15122           ),
15123           supports_expanded_weight=True,
15124           supports_out=False,),
15125    OpInfo('nn.functional.group_norm',
15126           aten_name='group_norm',
15127           aliases=('group_norm',),
15128           ref=reference_group_norm,
15129           dtypes=floating_types_and(torch.float16, torch.bfloat16),
15130           supports_out=False,
15131           supports_forward_ad=True,
15132           supports_fwgrad_bwgrad=True,
15133           error_inputs_func=error_inputs_group_norm,
15134           decorators=[
15135               # RuntimeError: Cannot insert a Tensor that requires grad as a constant.
15136               # Consider making it a parameter or input, or detaching the gradient
15137               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
15138               DecorateInfo(
15139                   toleranceOverride({torch.float32: tol(atol=5e-05, rtol=3e-03)}),
15140                   "TestDecomp",
15141                   "test_comprehensive",
15142                   device_type="cpu"
15143               ),
15144           ],
15145           sample_inputs_func=sample_inputs_group_norm,
15146           reference_inputs_func=reference_inputs_group_norm,
15147           supports_expanded_weight=True,),
15148    OpInfo('nn.functional.instance_norm',
15149           # no ref because instance_norm will often have numerical instability (large numbers or nan)
15150           dtypes=floating_types_and(torch.float16, torch.bfloat16),
15151           supports_out=False,
15152           supports_forward_ad=True,
15153           supports_fwgrad_bwgrad=True,
15154           allow_cow_input_materialize_forward=['running_mean', 'running_var'],
15155           decorators=[
15156               # RuntimeError: Cannot insert a Tensor that requires grad as a constant.
15157               # Consider making it a parameter or input, or detaching the gradient
15158               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
15159           ],
15160           sample_inputs_func=sample_inputs_instance_norm,
15161           supports_expanded_weight=True,),
15162    OpInfo('nn.functional.layer_norm',
15163           aten_name='layer_norm',
15164           aten_backward_name='layer_norm_backward',
15165           aliases=('layer_norm',),
15166           ref=reference_layer_norm,
15167           dtypes=floating_types_and(torch.half, torch.bfloat16),
15168           supports_out=False,
15169           supports_forward_ad=True,
15170           supports_fwgrad_bwgrad=True,
15171           assert_jit_shape_analysis=True,
15172           decorators=[
15173               DecorateInfo(
15174                   toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
15175                   'TestCommon', 'test_numpy_refs'
15176               ),
15177               DecorateInfo(unittest.skip("Bug in MPS backend!"), 'TestCommon', 'test_numpy_ref_mps'),
15178           ],
15179           sample_inputs_func=sample_inputs_layer_norm,
15180           supports_expanded_weight=True,),
15181    OpInfo('nn.functional.rms_norm',
15182           aten_name='rms_norm',
15183           aliases=('rms_norm',),
15184           ref=reference_rms_norm,
15185           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
15186           supports_out=False,
15187           supports_forward_ad=True,
15188           supports_fwgrad_bwgrad=True,
15189           sample_inputs_func=sample_inputs_rms_norm,
15190           error_inputs_func=error_inputs_rms_norm,),
15191    OpInfo('nn.functional.local_response_norm',
15192           dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16),
15193           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
15194           supports_out=False,
15195           supports_forward_ad=True,
15196           supports_fwgrad_bwgrad=True,
15197           decorators=[
15198               # RuntimeError: falseINTERNAL ASSERT FAILED at
15199               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
15200               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
15201           ],
15202           sample_inputs_func=sample_inputs_local_response_norm,),
15203    OpInfo('constant_pad_nd',
15204           supports_forward_ad=True,
15205           supports_fwgrad_bwgrad=True,
15206           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
15207           sample_inputs_func=sample_inputs_constant_pad_nd,
15208           supports_out=False,
15209           skips=(
15210               # bool can't be passed to Scalar arguments in JIT tracer because
15211               # BoolType is not a subtype of ScalarType.
15212               DecorateInfo(
15213                   unittest.expectedFailure, 'TestNNCOpInfo',
15214                   'test_nnc_correctness', dtypes=(torch.bool,)),
15215           )),
15216    OpInfo('nn.functional.pad',
15217           variant_test_name='constant',
15218           aten_name='constant_pad_nd',
15219           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
15220           gradcheck_fast_mode=True,
15221           supports_forward_ad=True,
15222           supports_fwgrad_bwgrad=True,
15223           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
15224           sample_inputs_func=partial(sample_inputs_nn_pad, mode='constant'),
15225           supports_out=False),
15226    OpInfo('nn.functional.pad',
15227           variant_test_name='reflect',
15228           supports_forward_ad=True,
15229           supports_fwgrad_bwgrad=True,
15230           dtypes=all_types_and_complex_and(torch.bfloat16),
15231           dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
15232           sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'),
15233           skips=(
15234               # Doesn't have a corresponding aten operator.
15235               # RuntimeError: falseINTERNAL ASSERT FAILED at
15236               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
15237               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
15238           ),
15239           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15240           supports_out=False),
15241    OpInfo('nn.functional.pad',
15242           variant_test_name='replicate',
15243           supports_forward_ad=True,
15244           supports_fwgrad_bwgrad=True,
15245           dtypes=all_types_and_complex_and(torch.bfloat16),
15246           dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
15247           sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'),
15248           skips=(
15249               # Doesn't have a corresponding aten operator.
15250               # RuntimeError: falseINTERNAL ASSERT FAILED at
15251               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
15252               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
15253           ),
15254           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15255           supports_out=False),
15256    OpInfo('nn.functional.pad',
15257           variant_test_name='replicate_negative',
15258           supports_forward_ad=True,
15259           supports_fwgrad_bwgrad=True,
15260           dtypes=all_types_and_complex_and(torch.bfloat16),
15261           dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
15262           sample_inputs_func=sample_inputs_nn_pad_replicate_negative,
15263           skips=(
15264               # Doesn't have a corresponding aten operator.
15265               # RuntimeError: falseINTERNAL ASSERT FAILED at
15266               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
15267               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
15268               # Some negative padding cases cause a segfault on MPS
15269               DecorateInfo(unittest.skip("Not fully supported on MPS"), 'TestConsistency'),
15270           ),
15271           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15272           supports_out=False),
15273    OpInfo('nn.functional.pad',
15274           variant_test_name='circular',
15275           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
15276           sample_inputs_func=partial(sample_inputs_nn_pad, mode='circular'),
15277           supports_forward_ad=True,
15278           supports_fwgrad_bwgrad=True,
15279           check_batched_grad=False,
15280           # https://github.com/pytorch/pytorch/issues/66357
15281           check_batched_forward_grad=False,
15282           skips=(
15283               # Doesn't have a corresponding aten operator.
15284               # RuntimeError: falseINTERNAL ASSERT FAILED at
15285               # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch.
15286               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
15287               # Difference from <type> is larger with decomposition new_empty_strided.default than original on output 0
15288               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),
15289           ),
15290           supports_out=False),
15291    OpInfo('nn.functional.hardswish',
15292           aten_name="hardswish",
15293           aten_backward_name='hardswish_backward',
15294           supports_autograd=True,
15295           assert_autodiffed=True,
15296           sample_inputs_func=sample_inputs_hardswish,
15297           dtypes=floating_types_and(torch.bfloat16, torch.half),
15298           supports_gradgrad=True,
15299           supports_forward_ad=True,
15300           supports_fwgrad_bwgrad=True,
15301           supports_out=False,
15302           autodiff_nonfusible_nodes=["aten::hardswish"]),
15303    OpInfo('nn.functional.unfold',
15304           aten_name='im2col',
15305           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool),
15306           dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool),
15307           sample_inputs_func=sample_inputs_nn_unfold,
15308           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
15309           gradcheck_fast_mode=True,
15310           supports_forward_ad=True,
15311           supports_fwgrad_bwgrad=True,
15312           supports_out=False,
15313           skips=(
15314               # NOTE: this failure may not reproduce consistently on different systems
15315               # false INTERNAL ASSERT FAILED at "...torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185
15316               DecorateInfo(unittest.skip("Internal assert failed!"), 'TestJit', 'test_variant_consistency_jit'),
15317           )),
15318    OpInfo('nn.functional.interpolate',
15319           aten_name="interpolate",
15320           variant_test_name='nearest',
15321           supports_autograd=True,
15322           supports_fwgrad_bwgrad=True,
15323           supports_forward_ad=True,
15324           dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16),
15325           sample_inputs_func=partial(sample_inputs_interpolate, 'nearest'),
15326           skips=(
15327               # RuntimeError: false
15328               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
15329               # please report a bug to PyTorch.
15330               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
15331           ),
15332           supports_out=False),
15333    OpInfo('nn.functional.interpolate',
15334           aten_name="interpolate",
15335           variant_test_name='nearest-exact',
15336           supports_autograd=True,
15337           supports_fwgrad_bwgrad=True,
15338           supports_forward_ad=True,
15339           dtypes=floating_types_and(torch.half, torch.bfloat16, torch.uint8),
15340           sample_inputs_func=partial(sample_inputs_interpolate, 'nearest-exact'),
15341           skips=(
15342               # RuntimeError: false
15343               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
15344               # please report a bug to PyTorch.
15345               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
15346               # RuntimeError: aten::_upsample_nearest_exact*d hit the vmap fallback which is currently disabled
15347               DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'),
15348               DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'),
15349               DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'),
15350               # NotImplementedError: The operator 'aten::_upsample_nearest_exact3d.out' is not currently implemented
15351               # for the MPS device.
15352               DecorateInfo(unittest.expectedFailure, 'TestConsistency'),
15353           ),
15354           supports_out=False),
15355    OpInfo('nn.functional.interpolate',
15356           aten_name="interpolate",
15357           variant_test_name='linear',
15358           supports_autograd=True,
15359           supports_fwgrad_bwgrad=True,
15360           supports_forward_ad=True,
15361           dtypes=floating_types_and(torch.half, torch.bfloat16),
15362           sample_inputs_func=partial(sample_inputs_interpolate, 'linear'),
15363           skips=(
15364               # RuntimeError: false
15365               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
15366               # please report a bug to PyTorch.
15367               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
15368           ),
15369           supports_out=False),
15370    OpInfo('nn.functional.interpolate',
15371           aten_name="interpolate",
15372           variant_test_name='bilinear',
15373           supports_fwgrad_bwgrad=True,
15374           supports_autograd=True,
15375           supports_forward_ad=True,
15376           dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16),
15377           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
15378           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15379           sample_inputs_func=partial(sample_inputs_interpolate, 'bilinear'),
15380           reference_inputs_func=partial(reference_inputs_interpolate, 'bilinear'),
15381           skips=(
15382               # RuntimeError: false
15383               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
15384               # please report a bug to PyTorch.
15385               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
15386           ),
15387           supports_out=False),
15388    OpInfo('nn.functional.interpolate',
15389           aten_name="interpolate",
15390           variant_test_name='bicubic',
15391           supports_autograd=True,
15392           supports_forward_ad=True,
15393           supports_fwgrad_bwgrad=True,
15394           dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16),
15395           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
15396           sample_inputs_func=partial(sample_inputs_interpolate, 'bicubic'),
15397           reference_inputs_func=partial(reference_inputs_interpolate, 'bicubic'),
15398           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15399           skips=(
15400               # RuntimeError: false
15401               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
15402               # please report a bug to PyTorch.
15403               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
15404           ),
15405           supports_out=False),
15406    OpInfo('nn.functional.interpolate',
15407           aten_name="interpolate",
15408           variant_test_name='trilinear',
15409           supports_autograd=True,
15410           supports_forward_ad=True,
15411           supports_fwgrad_bwgrad=True,
15412           dtypes=floating_types_and(torch.half, torch.bfloat16),
15413           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15414           sample_inputs_func=partial(sample_inputs_interpolate, 'trilinear'),
15415           skips=(
15416               # RuntimeError: false
15417               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
15418               # please report a bug to PyTorch.
15419               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
15420           ),
15421           supports_out=False),
15422    OpInfo('nn.functional.interpolate',
15423           aten_name="interpolate",
15424           variant_test_name='area',
15425           supports_autograd=True,
15426           supports_forward_ad=True,
15427           supports_fwgrad_bwgrad=True,
15428           dtypes=floating_types_and(torch.half, torch.bfloat16),
15429           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
15430           sample_inputs_func=partial(sample_inputs_interpolate, 'area'),
15431           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15432           skips=(
15433               # RuntimeError: false
15434               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
15435               # please report a bug to PyTorch.
15436               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
15437           ),
15438           supports_out=False),
15439    OpInfo('nn.functional.upsample_bilinear',
15440           supports_autograd=True,
15441           supports_forward_ad=True,
15442           supports_fwgrad_bwgrad=True,
15443           dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16),
15444           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
15445           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15446           sample_inputs_func=partial(sample_inputs_upsample, 'bilinear'),
15447           reference_inputs_func=partial(reference_inputs_upsample, 'bilinear'),
15448           skips=(
15449               # RuntimeError: false
15450               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
15451               # please report a bug to PyTorch.
15452               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
15453           ),
15454           supports_out=False),
15455    OpInfo('_upsample_bilinear2d_aa',
15456           op=torch.ops.aten._upsample_bilinear2d_aa,
15457           aten_name='_upsample_bilinear2d_aa',
15458           supports_autograd=True,
15459           supports_forward_ad=True,
15460           supports_fwgrad_bwgrad=True,
15461           dtypes=floating_types_and(torch.uint8),
15462           dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
15463           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15464           sample_inputs_func=partial(sample_inputs_upsample_aa, 'bilinear'),
15465           supports_out=False,
15466           skips=(
15467               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
15468               DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
15469               DecorateInfo(unittest.expectedFailure, 'TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive'),
15470               DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
15471               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
15472           )),
15473    OpInfo(
15474        "nn.functional.soft_margin_loss",
15475        dtypes=floating_types_and(torch.half, torch.bfloat16),
15476        supports_out=False,
15477        supports_forward_ad=True,
15478        # doesn't support grad on target
15479        sample_inputs_func=partial(sample_inputs_loss, rhs_requires_grad=False),
15480        error_inputs_func=error_inputs_soft_margin_loss,
15481    ),
15482    OpInfo('nn.functional.upsample_nearest',
15483           supports_autograd=True,
15484           supports_forward_ad=True,
15485           supports_fwgrad_bwgrad=True,
15486           dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16),
15487           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15488           sample_inputs_func=partial(sample_inputs_upsample, 'nearest'),
15489           skips=(
15490               # RuntimeError: false
15491               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185,
15492               # please report a bug to PyTorch.
15493               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
15494           ),
15495           supports_out=False),
15496    OpInfo(
15497        "nn.functional.margin_ranking_loss",
15498        dtypes=all_types_and(torch.half, torch.bfloat16),
15499        supports_out=False,
15500        sample_inputs_func=sample_inputs_margin_ranking_loss,
15501        error_inputs_func=error_inputs_margin_ranking_loss,
15502        reference_inputs_func=reference_inputs_margin_ranking_loss,
15503        supports_forward_ad=True,
15504        supports_fwgrad_bwgrad=True),
15505    OpInfo(
15506        "nn.functional.multi_margin_loss",
15507        dtypes=floating_types(),
15508        dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
15509        supports_out=False,
15510        supports_gradgrad=False,
15511        sample_inputs_func=sample_inputs_multi_margin_loss,
15512        reference_inputs_func=reference_inputs_multi_margin_loss,
15513        error_inputs_func=error_inputs_multi_margin_loss,
15514        decorators=(
15515            DecorateInfo(
15516                toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
15517                "TestJit",
15518                "test_variant_consistency_jit",
15519            ),
15520        ),
15521    ),
15522    OpInfo(
15523        "nn.functional.multilabel_margin_loss",
15524        dtypes=floating_types(),
15525        dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
15526        supports_out=False,
15527        supports_gradgrad=False,
15528        sample_inputs_func=sample_inputs_multilabel_margin_loss,
15529        reference_inputs_func=reference_inputs_multilabel_margin_loss,
15530        error_inputs_func=error_inputs_multilabel_margin_loss,
15531    ),
15532    OpInfo('nn.functional.leaky_relu',
15533           aliases=None,
15534           aten_name="leaky_relu",
15535           aten_backward_name='leaky_relu_backward',
15536           sample_inputs_func=sample_inputs_leaky_relu,
15537           dtypes=floating_types_and(torch.bfloat16, torch.float16),
15538           inplace_variant=lambda x, negative_slope=0.01:
15539               torch.nn.functional.leaky_relu(x, negative_slope, inplace=True),
15540           supports_autograd=True,
15541           assert_autodiffed=True,
15542           supports_gradgrad=True,
15543           supports_out=False,
15544           supports_forward_ad=True,
15545           supports_fwgrad_bwgrad=True,
15546           autodiff_nonfusible_nodes=["aten::leaky_relu"]),
15547    OpInfo(
15548        "nn.functional.multilabel_soft_margin_loss",
15549        supports_out=False,
15550        dtypes=floating_types_and(torch.half, torch.bfloat16),
15551        sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
15552        supports_forward_ad=True,
15553        supports_fwgrad_bwgrad=True,
15554        decorators=(
15555            DecorateInfo(
15556                toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
15557                "TestJit",
15558                "test_variant_consistency_jit",
15559            ),
15560            DecorateInfo(
15561                toleranceOverride({torch.float16: tol(atol=4e-3, rtol=1.3e-3)}),
15562                "TestInductorOpInfo",
15563                "test_comprehensive",
15564                device_type="cuda"
15565            ),
15566        ),
15567        skips=(
15568            # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 4096
15569            # __main__.TestJitCUDA.test_variant_consistency_jit_nn_functional_multilabel_soft_margin_loss_cuda_float32
15570            # leaked 4096 bytes CUDA memory on device 0
15571            DecorateInfo(
15572                # Skip instead of expectedFailure because this fails
15573                # locally for me but passes in CI.
15574                unittest.skip("Skipped!"),
15575                "TestJit",
15576                "test_variant_consistency_jit",
15577                device_type="cuda",
15578            ),
15579        ),
15580    ),
15581    OpInfo('nn.functional.avg_pool2d',
15582           aten_name='avg_pool2d',
15583           supports_autograd=True,
15584           supports_forward_ad=True,
15585           supports_fwgrad_bwgrad=True,
15586           dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16),
15587           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
15588           error_inputs_func=error_inputs_avg_pool2d,
15589           sample_inputs_func=sample_inputs_avgpool2d,
15590           skips=(
15591               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'),
15592           )),
15593    OpInfo('nn.functional.fractional_max_pool2d',
15594           supports_autograd=True,
15595           supports_out=False,
15596           supports_forward_ad=True,
15597           supports_fwgrad_bwgrad=True,
15598           op=lambda input, *args, **kwargs:
15599               wrapper_set_seed(torch.nn.functional.fractional_max_pool2d, input, *args, **kwargs),
15600           # vmap does not support random operations
15601           check_batched_forward_grad=False,
15602           dtypes=floating_types_and(torch.bfloat16, torch.float16),
15603           test_neg_view=False,
15604           sample_inputs_func=sample_inputs_fractional_max_pool2d,
15605           decorators=(
15606               # FIXME: AssertionError: False is not true : Tensors failed to compare as equal!
15607               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
15608               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
15609               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
15610               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')),
15611           skips=(
15612               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)),
15613    OpInfo('nn.functional.fractional_max_pool3d',
15614           supports_autograd=True,
15615           supports_out=False,
15616           supports_forward_ad=True,
15617           supports_fwgrad_bwgrad=True,
15618           op=lambda input, *args, **kwargs:
15619               wrapper_set_seed(torch.nn.functional.fractional_max_pool3d, input, *args, **kwargs),
15620           # vmap does not support random operations
15621           check_batched_forward_grad=False,
15622           dtypes=floating_types_and(torch.bfloat16, torch.float16),
15623           test_neg_view=False,
15624           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15625           sample_inputs_func=sample_inputs_fractional_max_pool3d,
15626           decorators=(
15627               # FIXME: both derivatives are implemented incorrectly
15628               # https://github.com/pytorch/pytorch/issues/69322
15629               # FIXME: AssertionError: False is not true : Tensors failed to compare as equal!
15630               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
15631               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
15632               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270
15633               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')),
15634           skips=(
15635               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)),
15636    OpInfo('nn.functional.max_pool1d',
15637           aten_name='max_pool1d',
15638           supports_autograd=True,
15639           supports_out=False,
15640           supports_forward_ad=True,
15641           supports_fwgrad_bwgrad=True,
15642           # got: Batching rule not implemented for aten::flatten.using_ints
15643           check_batched_forward_grad=False,
15644           # TODO: add shape checks
15645           assert_jit_shape_analysis=False,
15646           dtypes=floating_types_and(torch.bfloat16, torch.float16),
15647           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
15648           skips=(
15649               # Pre-existing condition; Needs to be fixed
15650               DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo',
15651                            'test_nnc_correctness', dtypes=(torch.bfloat16,)),
15652               # RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet.
15653               # Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data()
15654               # to actually allocate memory
15655               DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
15656           ),
15657           error_inputs_func=error_inputs_max_pool1d,
15658           sample_inputs_func=sample_inputs_max_pool),
15659    OpInfo('nn.functional.max_pool2d',
15660           aten_name='max_pool2d',
15661           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
15662           gradcheck_fast_mode=True,
15663           # Vmap is not happy with non-contiguous (channels_last) inputs
15664           check_batched_gradgrad=False,
15665           supports_out=False,
15666           supports_forward_ad=True,
15667           supports_fwgrad_bwgrad=True,
15668           # got: Batching rule not implemented for aten::flatten.using_ints
15669           check_batched_forward_grad=False,
15670           assert_jit_shape_analysis=True,
15671           dtypes=all_types_and(torch.float16, torch.bfloat16),
15672           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
15673           error_inputs_func=error_inputs_max_pool2d,
15674           sample_inputs_func=sample_inputs_max_pool),
15675    OpInfo('max_pool2d_with_indices_backward',
15676           op=max_pool2d_backward,
15677           # We've defined a custom op, so there's no corresponding aten op
15678           aten_name=None,
15679           method_variant=None,
15680           inplace_variant=None,
15681           operator_variant=None,
15682           inplace_operator_variant=None,
15683           check_batched_gradgrad=False,
15684           supports_out=False,
15685           supports_forward_ad=True,
15686           supports_fwgrad_bwgrad=True,
15687           check_batched_forward_grad=False,
15688           assert_jit_shape_analysis=False,
15689           dtypes=floating_types_and(torch.bfloat16, torch.float16),
15690           sample_inputs_func=sample_inputs_max_pool,
15691           skips=(
15692               # We've defined a custom op here, and we don't handle the case where we receive an out kwarg
15693               DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"),
15694               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
15695               # FX failed to normalize op - add the op to the op_skip list.
15696               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
15697               # object has no attribute max_pool2d_with_indices_backward (It's not available on torch -- so expected)
15698               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')
15699           )),
15700    OpInfo('nn.functional.max_pool3d',
15701           aten_name='max_pool3d',
15702           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
15703           gradcheck_fast_mode=True,
15704           supports_out=False,
15705           supports_forward_ad=True,
15706           supports_fwgrad_bwgrad=True,
15707           # got: Batching rule not implemented for aten::flatten.using_ints
15708           check_batched_forward_grad=False,
15709           # TODO: add shape checks
15710           assert_jit_shape_analysis=False,
15711           dtypes=all_types_and(torch.bfloat16, torch.float16),
15712           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
15713           # TODO: investigate nondeterminism
15714           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15715           error_inputs_func=error_inputs_max_pool3d,
15716           sample_inputs_func=sample_inputs_max_pool),
15717    OpInfo('nn.functional.max_unpool1d',
15718           aten_name='max_unpool1d',
15719           supports_autograd=True,
15720           supports_forward_ad=True,
15721           supports_fwgrad_bwgrad=True,
15722           supports_out=False,
15723           assert_jit_shape_analysis=False,
15724           dtypes=floating_types_and(torch.float16, torch.bfloat16),
15725           sample_inputs_func=sample_inputs_max_unpool,
15726           skips=(
15727               # Gradients are tested in `variant_test_name=grad` below.
15728               # We skip tests here because there is non-determinism in backward
15729               # with gather, when there are writes into the same memory location,
15730               # and if there are several indices pointing to the same memory,
15731               # gradcheck is oblivious about that and cannot perturb them all at once
15732               # (see sample_inputs_max_unpool_grad to find out more).
15733               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
15734               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
15735               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
15736                            active_if=(not IS_MACOS)),
15737               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad',
15738                            device_type='cpu'),
15739           )),
15740    OpInfo('nn.functional.max_unpool1d',
15741           variant_test_name='grad',
15742           aten_name='max_unpool1d',
15743           supports_autograd=True,
15744           supports_forward_ad=True,
15745           supports_fwgrad_bwgrad=True,
15746           supports_out=False,
15747           assert_jit_shape_analysis=False,
15748           dtypes=floating_types_and(torch.float16, torch.bfloat16),
15749           sample_inputs_func=sample_inputs_max_unpool_grad),
15750    OpInfo('nn.functional.max_unpool2d',
15751           aten_name='max_unpool2d',
15752           supports_autograd=True,
15753           supports_forward_ad=True,
15754           supports_fwgrad_bwgrad=True,
15755           supports_out=False,
15756           assert_jit_shape_analysis=False,
15757           dtypes=floating_types_and(torch.float16, torch.bfloat16),
15758           sample_inputs_func=sample_inputs_max_unpool,
15759           skips=(
15760               # Gradients are tested in `variant_test_name=grad` below.
15761               # We skip tests here because there is non-determinism in backward
15762               # with gather, when there are writes into the same memory location,
15763               # and if there are several indices pointing to the same memory,
15764               # gradcheck is oblivious about that and cannot perturb them all at once
15765               # (see sample_inputs_max_unpool_grad to find out more).
15766               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
15767                            active_if=(not IS_MACOS)),
15768               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
15769               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
15770               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'),
15771           )),
15772    OpInfo('nn.functional.max_unpool2d',
15773           variant_test_name='grad',
15774           aten_name='max_unpool2d',
15775           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
15776           gradcheck_fast_mode=True,
15777           supports_forward_ad=True,
15778           supports_fwgrad_bwgrad=True,
15779           # Vmap is not happy with non-contiguous (channels_last) inputs
15780           check_batched_grad=False,
15781           supports_out=False,
15782           assert_jit_shape_analysis=False,
15783           dtypes=floating_types_and(torch.float16, torch.bfloat16),
15784           sample_inputs_func=sample_inputs_max_unpool_grad),
15785    OpInfo('nn.functional.max_unpool3d',
15786           aten_name='max_unpool3d',
15787           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
15788           gradcheck_fast_mode=True,
15789           supports_forward_ad=True,
15790           supports_fwgrad_bwgrad=True,
15791           supports_out=False,
15792           assert_jit_shape_analysis=False,
15793           dtypes=floating_types_and(torch.float16, torch.bfloat16),
15794           sample_inputs_func=sample_inputs_max_unpool,
15795           skips=(
15796               # Gradients are tested in `variant_test_name=grad` below.
15797               # We skip tests here because there is non-determinism in backward
15798               # with gather, when there are writes into the same memory location,
15799               # and if there are several indices pointing to the same memory,
15800               # gradcheck is oblivious about that and cannot perturb them all at once
15801               # (see sample_inputs_max_unpool_grad to find out more).
15802               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD',
15803                            active_if=(not IS_MACOS)),
15804               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
15805               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
15806               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'),
15807           )),
15808    OpInfo('nn.functional.max_unpool3d',
15809           variant_test_name='grad',
15810           aten_name='max_unpool3d',
15811           supports_autograd=True,
15812           supports_forward_ad=True,
15813           supports_fwgrad_bwgrad=True,
15814           supports_out=False,
15815           assert_jit_shape_analysis=False,
15816           dtypes=floating_types_and(torch.float16, torch.bfloat16),
15817           sample_inputs_func=sample_inputs_max_unpool_grad),
15818    OpInfo('nn.functional.linear',
15819           aten_name='linear',
15820           supports_autograd=True,
15821           supports_gradgrad=True,
15822           sample_inputs_func=sample_inputs_linear,
15823           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
15824           dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16),
15825           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
15826           backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
15827           # linear calls mm under the hood which is nondeterministic on CUDA
15828           # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
15829           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
15830           supports_forward_ad=True,
15831           supports_fwgrad_bwgrad=True,
15832           # See https://github.com/pytorch/pytorch/issues/66357
15833           check_batched_forward_grad=False,
15834           supports_expanded_weight=True,
15835           decorators=(
15836               # Strides are not the same!
15837               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
15838           )),
15839    OpInfo('nn.functional.bilinear',
15840           aten_name='bilinear',
15841           supports_autograd=True,
15842           sample_inputs_func=sample_inputs_bilinear,
15843           dtypes=all_types_and(torch.float16, torch.bfloat16),
15844           dtypesIfCUDA=floating_types_and(torch.float16,
15845                                           *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else []),
15846           decorators=(
15847               DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-03, rtol=1.3e-03)}),
15848                            'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
15849           ),
15850           skips=(
15851               # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
15852               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
15853               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)),
15854           ),
15855           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
15856           gradcheck_fast_mode=True,
15857           supports_forward_ad=True,
15858           supports_fwgrad_bwgrad=True,
15859           supports_out=False),
15860    OpInfo('nn.functional.glu',
15861           aten_name='glu',
15862           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
15863           gradcheck_fast_mode=True,
15864           sample_inputs_func=sample_inputs_glu,
15865           dtypes=floating_types_and(torch.bfloat16, torch.float16),
15866           supports_forward_ad=True,
15867           supports_fwgrad_bwgrad=True,
15868           supports_out=False),
15869    UnaryUfuncInfo(
15870        'nn.functional.elu',
15871        aten_backward_name='elu_backward',
15872        ref=lambda x, alpha=1.0, inplace=False:
15873            np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x) - 1)),
15874        dtypes=floating_types_and(torch.bfloat16, torch.float16),
15875        supports_forward_ad=True,
15876        supports_fwgrad_bwgrad=True,
15877        supports_autograd=True,
15878        assert_autodiffed=False,
15879        supports_gradgrad=True,
15880        supports_out=False,
15881        sample_kwargs=lambda device, dtype, input:
15882            ({'alpha': 0.8}, {'alpha': 0.8}),
15883        inplace_variant=lambda x, alpha=1.0:
15884            torch.nn.functional.elu(x, alpha, inplace=True),
15885        decorators=[
15886            DecorateInfo(
15887                toleranceOverride({
15888                    torch.float16: tol(atol=1e-03, rtol=1.2e-03),
15889                    torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03)
15890                }),
15891                'TestUnaryUfuncs', device_type='cuda',
15892            ), ],
15893    ),
15894    # Marked as a Unary function because it has some rather odd broadcasting semantics in its
15895    # second argument
15896    UnaryUfuncInfo(
15897        'nn.functional.prelu',
15898        aten_backward_name='_prelu_kernel_backward',
15899        ref=lambda x, weight:
15900            np.maximum(0., x) + np.minimum(0., x) *
15901            (weight if x.ndim == 1 else weight.reshape([weight.size if i == 1 else 1 for i in range(0, x.ndim)])),
15902        dtypes=floating_types_and(torch.bfloat16, torch.float16),
15903        supports_forward_ad=True,
15904        supports_fwgrad_bwgrad=True,
15905        supports_autograd=True,
15906        assert_autodiffed=False,
15907        supports_gradgrad=True,
15908        supports_out=False,
15909        # test_reference_numerics only tests the case when the weight tensor is a scalar
15910        sample_kwargs=sample_kwargs_prelu_scalar_weight,
15911        error_inputs_func=error_inputs_prelu,
15912        sample_inputs_func=sample_inputs_prelu,
15913        reference_inputs_func=reference_inputs_prelu,
15914        decorators=[
15915            # RuntimeError: Cannot insert a Tensor that requires grad as a constant.
15916            # Consider making it a parameter or input, or detaching the gradient
15917            # https://github.com/pytorch/pytorch/issues/68752
15918            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), ],
15919    ),
15920    UnaryUfuncInfo(
15921        'nn.functional.celu',
15922        ref=lambda x, alpha=1.0, inplace=False:
15923            np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x / alpha) - 1)),
15924        dtypes=floating_types_and(torch.bfloat16, torch.float16),
15925        supports_forward_ad=True,
15926        supports_fwgrad_bwgrad=True,
15927        supports_autograd=True,
15928        assert_autodiffed=False,
15929        supports_gradgrad=True,
15930        supports_out=False,
15931        sample_kwargs=lambda device, dtype, input:
15932            ({'alpha': 0.8}, {'alpha': 0.8}),
15933        inplace_variant=lambda x, alpha=1.0:
15934            torch.nn.functional.celu(x, alpha, inplace=True),
15935        decorators=[
15936            DecorateInfo(
15937                toleranceOverride({
15938                    torch.float16: tol(atol=1e-03, rtol=1.2e-03),
15939                    torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03)
15940                }),
15941                'TestUnaryUfuncs', device_type='cuda',
15942            ), ],
15943    ),
15944    UnaryUfuncInfo(
15945        'nn.functional.rrelu',
15946        aten_backward_name='rrelu_with_noise_backward',
15947        op=lambda input, *args, **kwargs:
15948            wrapper_set_seed(torch.nn.functional.rrelu, input, *args, **kwargs),
15949        inplace_variant=lambda input, *args, **kwargs:
15950            wrapper_set_seed(torch.nn.functional.rrelu, input, *args, inplace=True, **kwargs),
15951        dtypes=floating_types_and(torch.bfloat16),
15952        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
15953        gradcheck_wrapper=wrapper_set_seed,
15954        supports_forward_ad=True,
15955        supports_fwgrad_bwgrad=True,
15956        supports_out=False,
15957        sample_kwargs=lambda device, dtype, input:
15958            (dict(lower=0., upper=1., training=True), dict(lower=0., upper=1., training=True)),
15959        sample_inputs_func=sample_inputs_rrelu,
15960        error_inputs_func=error_inputs_rrelu,
15961        decorators=(
15962            DecorateInfo(
15963                toleranceOverride({
15964                    torch.float16: tol(atol=1e-03, rtol=1.2e-03),
15965                    torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03)
15966                }),
15967                'TestUnaryUfuncs', device_type='cuda',
15968            ),),
15969        skips=(
15970            # lambda impl
15971            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
15972            # lambda impl
15973            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
15974            # In-place operations do not play well with forward AD
15975            # https://github.com/pytorch/pytorch/issues/77447
15976            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients',
15977                         'test_inplace_forward_mode_AD'),
15978            # The noise vector that's generated in these tests is not the same elementwise
15979            DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'),
15980            DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'),
15981            DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_non_contig_expand'),
15982            DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'),
15983            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))),
15984    UnaryUfuncInfo(
15985        'nn.functional.selu',
15986        ref=lambda x, inplace=False:
15987            1.0507009873554804934193349852946 * (
15988                np.maximum(0., x) + np.minimum(0., 1.6732632423543772848170429916717 * (np.exp(x) - 1))
15989            ),
15990        dtypes=floating_types_and(torch.bfloat16, torch.float16),
15991        supports_forward_ad=True,  # depends on 'elu'
15992        supports_fwgrad_bwgrad=True,
15993        supports_autograd=True,
15994        assert_autodiffed=False,
15995        supports_gradgrad=True,
15996        supports_out=False,
15997        inplace_variant=lambda x: torch.nn.functional.selu(x, inplace=True),
15998        decorators=[
15999            DecorateInfo(
16000                toleranceOverride({
16001                    torch.float16: tol(atol=1e-2, rtol=1.8e-2),
16002                    torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2)
16003                }),
16004                'TestUnaryUfuncs', device_type='cuda',
16005            ), ],
16006    ),
16007    OpInfo(
16008        'torch._scaled_mm',
16009        sample_inputs_func=sample_inputs_scaled_mm,
16010        dtypes=empty_types(),
16011        dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
16012        supports_out=True,
16013        supports_forward_ad=False,
16014        supports_autograd=False,
16015        decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0')],
16016        skips=(
16017            # Sample inputs isn't really parametrized on dtype
16018            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
16019                         device_type='cuda'),
16020            # "mul_cuda" not implemented for float8_e4m3fn
16021            # https://github.com/pytorch/pytorch/issues/107256
16022            DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
16023                         dtypes=(torch.float8_e4m3fn,)),
16024        )
16025    ),
16026    OpInfo(
16027        'torch.ops.aten._safe_softmax.default',
16028        dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool),
16029        sample_inputs_func=sample_inputs_safe_softmax,
16030        assert_jit_shape_analysis=True,
16031        assert_autodiffed=True,
16032        supports_forward_ad=True,
16033        supports_fwgrad_bwgrad=True,
16034        supports_out=False,
16035        supports_cow_input_no_materialize_backward=False,
16036        decorators=[],
16037        skips=(
16038            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
16039        ),
16040    ),
16041    OpInfo(
16042        'nn.functional.scaled_dot_product_attention',
16043        op=lambda *args, **kwargs:
16044               wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
16045        sample_inputs_func=sample_inputs_scaled_dot_product_attention,
16046        dtypes=floating_types_and(torch.float16, torch.bfloat16),
16047        supports_out=False,
16048        supports_forward_ad=False,
16049        supports_fwgrad_bwgrad=True,
16050        check_batched_forward_grad=False,
16051        decorators=[DecorateInfo(toleranceOverride(
16052            {torch.float32: tol(atol=5e-05, rtol=5e-6)}), 'TestCommon',), ],
16053        skips=(
16054            # When attn mask is a composite tensor this fails backward by returning a none
16055            DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
16056            # This is only failing on Linux Bionic 3.10 Cuda 11.6
16057            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
16058                         device_type='cuda', active_if=_get_torch_cuda_version() >= (11, 6)),
16059            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples',
16060                         dtypes=(torch.float32,)),
16061            # AssertionError: JIT Test does not execute any logic
16062            DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
16063            # Forward works for dtype=float64 which is the math path
16064            DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
16065            # Not implemented for Forward AD
16066            DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad',
16067                         device_type='cpu'),
16068            # Not implemented for backward derivative
16069            DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad',
16070                         device_type='cpu'),
16071            # CPU and CUDA have inconsistencies for intermediate outputs
16072            DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace',
16073                         device_type='cpu'),
16074            DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
16075                         device_type='cpu'),
16076            # When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false
16077            DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward',
16078                         device_type='cpu'),
16079            # OpInfo was implemented with a lambda
16080            DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
16081            # TODO Need to understand what this is testing and why it doesn't work
16082            DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'),
16083            DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'),
16084            # TODO skip this for now since we can't skip on runtime arch support
16085            DecorateInfo(unittest.skip('This is '), 'TestInductorOpInfo', 'test_comprehensive'),
16086            # skip for sm < 80
16087            DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
16088                         device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater),
16089            # FIXME
16090            DecorateInfo(unittest.skip('test_cow_input does not work with efficient attention on ROCM'),
16091                         'TestCompositeCompliance', 'test_cow_input',
16092                         device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32),
16093                         active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION),
16094            DecorateInfo(unittest.skip('test_fake_crossref_backward_amp does not work with efficient attention on ROCM'),
16095                         'TestFakeTensor', 'test_fake_crossref_backward_amp',
16096                         device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32),
16097                         active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION),
16098            DecorateInfo(unittest.skip('test_fake_crossref_backward_no_amp does not work with efficient attention on ROCM'),
16099                         'TestFakeTensor', 'test_fake_crossref_backward_no_amp',
16100                         device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32),
16101                         active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION),
16102            # for element 1, was torch.Size([4, 4, 0]) but real shape was torch.Size([16, 3, 0])
16103            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda",
16104                         dtypes=[torch.float16, torch.bfloat16, torch.float32],
16105                         active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
16106            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda",
16107                         dtypes=[torch.float16, torch.bfloat16, torch.float32],
16108                         active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
16109            # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11])
16110            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides",
16111                         device_type="cuda", dtypes=[torch.float32],
16112                         active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),),
16113    ),
16114    OpInfo(
16115        'torch.ops.aten._flash_attention_forward',
16116        sample_inputs_func=sample_inputs_flash_attention_forward,
16117        dtypes=empty_types(),
16118        dtypesIfCUDA=custom_types(torch.float16)
16119        if not SM80OrLater
16120        else custom_types(torch.float16, torch.bfloat16),
16121        supports_out=False,
16122        supports_autograd=True,
16123        supports_fwgrad_bwgrad=False,
16124        supports_forward_ad=False,
16125        check_batched_forward_grad=False,
16126        decorators=[skipCUDAIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "This platform doesn't support Flash Attention")],
16127        skips=(
16128            # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11])
16129            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", device_type="cuda",
16130                         dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
16131            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda",
16132                         dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
16133            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda",
16134                         dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
16135            # Checking the scalar value of the philox seed and offset
16136            # Checking the scalar value of the philox seed and offset
16137            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
16138            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
16139            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),
16140            # None Mismatch Tensor
16141            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
16142        )
16143    ),
16144    OpInfo(
16145        'torch.ops.aten._efficient_attention_forward',
16146        sample_inputs_func=sample_inputs_efficient_attention_forward,
16147        dtypes=empty_types(),
16148        dtypesIfCUDA=custom_types(torch.float16, torch.float32)
16149        if not SM80OrLater
16150        else custom_types(torch.float16, torch.float32, torch.bfloat16),
16151        supports_out=False,
16152        supports_autograd=True,
16153        supports_fwgrad_bwgrad=False,
16154        supports_forward_ad=False,
16155        check_batched_forward_grad=False,
16156        # TODO: Skip because it produces a CUDA illegal memory access for some reason
16157        skip_cow_input_backward=True,
16158        # FIXME: mask_type == 2 (LowerRight)
16159        decorators=[
16160            skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"),
16161            skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")],
16162        skips=(
16163            # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11])
16164            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", device_type="cuda",
16165                         dtypes=[torch.float16, torch.bfloat16, torch.float32],
16166                         active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
16167            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda",
16168                         dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
16169            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda",
16170                         dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
16171            # Checking the scaler value of the philox seed and offset
16172            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'),
16173            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'),
16174            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),
16175            # None Mismatch Tensor
16176            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'),
16177        )
16178    ),
16179    UnaryUfuncInfo(
16180        'nn.functional.silu',
16181        aten_backward_name='silu_backward',
16182        ref=lambda x, inplace=False: x / (1 + np.exp(-x)),
16183        dtypes=floating_types_and(torch.bfloat16, torch.float16),
16184        supports_forward_ad=True,
16185        supports_autograd=True,
16186        supports_fwgrad_bwgrad=True,
16187        assert_autodiffed=True,
16188        supports_out=False,
16189        inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True),
16190        decorators=[
16191            DecorateInfo(
16192                toleranceOverride({
16193                    torch.float16: tol(atol=1e-3, rtol=1e-3),
16194                    torch.bfloat16: tol(atol=1e-4, rtol=1e-4)
16195                }),
16196                'TestUnaryUfuncs', device_type='cuda',
16197            ), ],
16198        skips=(
16199            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
16200                         dtypes=(torch.cfloat,), device_type='cpu'),
16201        ),
16202        autodiff_nonfusible_nodes=["aten::silu"],
16203    ),
16204    # TODO: combine this with the nn.functional.silu OpInfo when
16205    # complex autodiff for silu is supported or when
16206    # the forward bug is fixed
16207    # Note: silu errors when given inputs that require grad
16208    #   but it doesn't support grad in their dtype
16209    #   This is why the dtypes list above passes test_dtypes,
16210    #   because it's getting lucky and failing in forward
16211    #   because test_dtypes sets requires_grad to True
16212    #   THIS IS A BUG
16213    UnaryUfuncInfo(
16214        'nn.functional.silu',
16215        variant_test_name='complex',
16216        ref=lambda x, inplace=False:
16217            x / (1 + np.exp(-x)),
16218        dtypes=complex_types(),
16219        dtypesIfCUDA=complex_types(),
16220        supports_forward_ad=False,
16221        supports_autograd=False,
16222        assert_autodiffed=False,
16223        supports_out=False,
16224        inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True),
16225        decorators=[
16226            DecorateInfo(
16227                toleranceOverride({
16228                    torch.float16: tol(atol=1e-3, rtol=1e-3),
16229                    torch.bfloat16: tol(atol=1e-4, rtol=1e-4)
16230                }),
16231                'TestUnaryUfuncs', device_type='cuda',
16232            ), ],
16233        skips=(
16234            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
16235                         dtypes=(torch.cfloat,)),
16236            # FIXME: intentionally misreports dtypes
16237            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
16238            # FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j)
16239            DecorateInfo(unittest.skip("Skipped!"),
16240                         'TestUnaryUfuncs', 'test_reference_numerics_large',
16241                         dtypes=(torch.complex64, torch.cdouble)),
16242            DecorateInfo(unittest.skip("Skipped!"),
16243                         'TestUnaryUfuncs', 'test_reference_numerics_small',
16244                         dtypes=(torch.complex64,)),
16245            DecorateInfo(unittest.skip("Skipped!"),
16246                         'TestUnaryUfuncs', 'test_reference_numerics_extremal',
16247                         dtypes=(torch.complex64,)))),
16248    UnaryUfuncInfo(
16249        'nn.functional.hardsigmoid',
16250        aten_backward_name='hardsigmoid_backward',
16251        ref=reference_hardsigmoid,
16252        dtypes=floating_types_and(torch.bfloat16, torch.float16),
16253        supports_autograd=True,
16254        assert_autodiffed=False,
16255        supports_gradgrad=False,
16256        supports_forward_ad=True,
16257        supports_out=False,
16258        inplace_variant=partial(torch.nn.functional.hardsigmoid, inplace=True),
16259        decorators=[
16260            DecorateInfo(
16261                toleranceOverride({torch.float16: tol(atol=1e-04, rtol=0.001)}), 'TestUnaryUfuncs', device_type='cuda',), ],
16262        skips=[
16263            # still want to test that first derivative works though second derivative isn't supported
16264            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', "test_inplace_gradgrad"),
16265            # produces 0 instead of nan on ROCM
16266            DecorateInfo(unittest.expectedFailure,
16267                         'TestUnaryUfuncs', "test_reference_numerics_extremal",
16268                         device_type='cuda',
16269                         active_if=(TEST_WITH_ROCM)), ]
16270    ),
16271    UnaryUfuncInfo(
16272        'nn.functional.logsigmoid',
16273        aten_name="log_sigmoid",
16274        aten_backward_name='log_sigmoid_backward',
16275        ref=reference_logsigmoid,
16276        dtypes=floating_types_and(torch.half, torch.bfloat16),
16277        supports_autograd=True,
16278        assert_autodiffed=False,
16279        supports_forward_ad=True,
16280        supports_fwgrad_bwgrad=True,
16281        supports_gradgrad=True,
16282        # autodiff_nonfusible_nodes=["aten::log_sigmoid"],
16283        decorators=[
16284            DecorateInfo(
16285                precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}),
16286                'TestUnaryUfuncs', 'test_reference_numerics_small'),
16287            DecorateInfo(
16288                precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}),
16289                'TestUnaryUfuncs', 'test_reference_numerics_large'),
16290            DecorateInfo(
16291                precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}),
16292                'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
16293        ],
16294        skips=(
16295            # Resized a non-empty tensor but did not warn about it.
16296            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cpu'),
16297        ),
16298    ),
16299    UnaryUfuncInfo(
16300        'nn.functional.mish',
16301        aten_backward_name='mish_backward',
16302        ref=lambda x: x * np.tanh(reference_softplus(x)),
16303        dtypes=floating_types_and(torch.bfloat16, torch.float16),
16304        supports_forward_ad=True,
16305        supports_fwgrad_bwgrad=True,
16306        supports_autograd=True,
16307        assert_autodiffed=False,
16308        supports_gradgrad=True,
16309        supports_out=False,
16310        inplace_variant=partial(torch.nn.functional.mish, inplace=True),
16311        decorators=[
16312            DecorateInfo(
16313                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), 'TestUnaryUfuncs',), ],
16314    ),
16315    UnaryUfuncInfo(
16316        'nn.functional.softsign',
16317        ref=lambda x: x / (np.abs(x) + 1),
16318        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
16319        dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
16320        supports_forward_ad=True,
16321        supports_fwgrad_bwgrad=True,
16322        supports_autograd=True,
16323        assert_autodiffed=False,
16324        supports_gradgrad=True,
16325        supports_out=False,
16326        decorators=[
16327            DecorateInfo(
16328                toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1.3e-04)}), 'TestUnaryUfuncs',), ],
16329        skips=(
16330            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
16331                         dtypes=(torch.int, torch.int8)),),
16332    ),
16333    UnaryUfuncInfo(
16334        'nn.functional.tanhshrink',
16335        ref=lambda x: x - np.tanh(x),
16336        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
16337        supports_forward_ad=True,
16338        supports_fwgrad_bwgrad=True,
16339        supports_autograd=True,
16340        assert_autodiffed=False,
16341        supports_gradgrad=True,
16342        supports_out=False,
16343        decorators=[
16344            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
16345                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
16346            DecorateInfo(
16347                toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), 'TestUnaryUfuncs',),
16348            DecorateInfo(toleranceOverride({torch.complex64: tol(atol=6e-04, rtol=1e-05),
16349                                            torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}),
16350                         'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'),
16351        ],
16352        skips=(
16353            # in each case, pytorch will produce a nan while numpy will not
16354            DecorateInfo(unittest.skip("Fails on some jobs works on others!"),
16355                         'TestUnaryUfuncs', "test_reference_numerics_large",
16356                         dtypes=(torch.complex64, torch.complex128), active_if=(IS_MACOS)),
16357            DecorateInfo(unittest.skip("Fails on some jobs works on others!"),
16358                         'TestUnaryUfuncs', "test_reference_numerics_extremal",
16359                         dtypes=(torch.complex64, torch.complex128), device_type='cpu',
16360                         active_if=(IS_MACOS or IS_WINDOWS)),
16361        ),
16362        # tan(j * pi/2 * odd_number) is nan which also make tanhshrink nan.
16363        reference_numerics_filter=NumericsFilter(
16364            condition=lambda x: (close_to_int(x / (math.pi * 0.5j))
16365                                 if x.is_complex() else x.new_tensor(False, dtype=torch.bool)),
16366            safe_val=0)
16367    ),
16368    UnaryUfuncInfo(
16369        'nn.functional.threshold',
16370        ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype),
16371        dtypes=all_types_and(torch.half, torch.bfloat16),
16372        inplace_variant=lambda x, threshold, value:
16373            torch.nn.functional.threshold(x, threshold, value, inplace=True),
16374        supports_forward_ad=True,
16375        supports_fwgrad_bwgrad=True,
16376        assert_autodiffed=False,
16377        supports_gradgrad=True,
16378        supports_out=False,
16379        sample_kwargs=lambda device, dtype, input: ({'threshold': float.fromhex('0x1.3ap-3'),
16380                                                    'value': -9},
16381                                                    {'threshold': float.fromhex('0x1.3ap-3'),
16382                                                    'value': -9}),
16383        # TODO(whc) should not need sample_inputs_func, but without it
16384        # kwargs aren't being hooked up properly
16385        sample_inputs_func=sample_inputs_threshold,
16386    ),
16387    OpInfo(
16388        "nn.functional.triplet_margin_loss",
16389        sample_inputs_func=sample_inputs_triplet_margin_loss,
16390        error_inputs_func=error_inputs_triplet_margin_loss,
16391        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
16392        supports_out=False,
16393        supports_forward_ad=True,
16394        supports_fwgrad_bwgrad=True,
16395    ),
16396    OpInfo(
16397        "nn.functional.triplet_margin_with_distance_loss",
16398        sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
16399        error_inputs_func=error_inputs_triplet_margin_loss,
16400        dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
16401        supports_out=False,
16402        supports_forward_ad=True,
16403        supports_fwgrad_bwgrad=True,
16404        skips=(
16405            # This test cannot handle a callable passed to `distance_function`. If we would use
16406            # `distance_function=None`, the test would pass fine.
16407            DecorateInfo(
16408                unittest.expectedFailure,
16409                "TestJit",
16410                "test_variant_consistency_jit",
16411            ),
16412            DecorateInfo(
16413                unittest.expectedFailure,
16414                "TestNormalizeOperators",
16415                "test_normalize_operator_exhaustive",
16416            ),
16417        ),
16418    ),
16419    BinaryUfuncInfo('nextafter',
16420                    dtypes=floating_types_and(torch.bfloat16, torch.half),
16421                    dtypesIfCUDA=floating_types_and(torch.bfloat16),
16422                    supports_autograd=False,
16423                    supports_rhs_python_scalar=False),
16424    OpInfo(
16425        "to",
16426        op=lambda x, *args, **kwargs: x.to(*args, **kwargs),
16427        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
16428        supports_forward_ad=True,
16429        supports_fwgrad_bwgrad=True,
16430        supports_out=False,
16431        sample_inputs_func=sample_inputs_to,
16432        skips=(
16433            # RuntimeError: undefined value cpu
16434            DecorateInfo(
16435                unittest.skip("Skipped!"),
16436                "TestJit",
16437                "test_variant_consistency_jit",
16438                device_type="cpu",
16439            ),
16440            # NotImplementedError: Cannot copy out of meta tensor; no data!
16441            DecorateInfo(
16442                unittest.skip("Skipped!"),
16443                "TestMeta",
16444                "test_meta_outplace",
16445            ),
16446            # https://github.com/pytorch/pytorch/issues/84335
16447            DecorateInfo(
16448                unittest.skip("Skipped!"),
16449                "TestProxyTensorOpInfo",
16450                "test_make_fx_symbolic_exhaustive",
16451            ),
16452            DecorateInfo(
16453                unittest.skip("Skipped!"),
16454                "TestNormalizeOperators",
16455                "test_normalize_operator_exhaustive",
16456            ),
16457        ),
16458    ),
16459    OpInfo('topk',
16460           dtypes=all_types_and(torch.bfloat16, torch.float16),
16461           supports_forward_ad=True,
16462           supports_fwgrad_bwgrad=True,
16463           assert_jit_shape_analysis=True,
16464           sample_inputs_func=sample_inputs_topk),
16465    # Multiple variants for batch_norm to test with and without cuDNN disabled
16466    # See https://github.com/pytorch/pytorch/pull/63218#discussion_r688549391 for more details
16467    OpInfo('nn.functional.batch_norm',
16468           aten_name='batch_norm',
16469           dtypes=floating_types_and(torch.float16, torch.bfloat16),
16470           supports_out=False,
16471           supports_forward_ad=True,
16472           supports_fwgrad_bwgrad=True,
16473           assert_jit_shape_analysis=True,
16474           allow_cow_input_materialize_forward=[1, 2],
16475           allow_cow_input_materialize_backward=[1, 2],
16476           sample_inputs_func=sample_inputs_batch_norm,
16477           skips=(
16478               # see https://github.com/pytorch/pytorch/issues/71286
16479               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
16480               DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
16481                            device_type='cpu', dtypes=(torch.bfloat16, torch.float16)),
16482               DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-05, rtol=1e-05)}),
16483                            'TestCompositeCompliance', 'test_forward_ad', device_type="cpu"),
16484           )),
16485    # This variant tests batch_norm with cuDNN disabled only on CUDA devices
16486    OpInfo('nn.functional.batch_norm',
16487           variant_test_name='without_cudnn',
16488           aten_name='batch_norm',
16489           dtypes=empty_types(),
16490           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
16491           supports_out=False,
16492           supports_forward_ad=True,
16493           supports_fwgrad_bwgrad=True,
16494           allow_cow_input_materialize_forward=[1, 2],
16495           allow_cow_input_materialize_backward=[1, 2],
16496           decorators=[onlyCUDA, disablecuDNN],
16497           skips=(
16498               DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-04)}),
16499                            'TestJit', 'test_variant_consistency_jit'),
16500           ),
16501           sample_inputs_func=sample_inputs_batch_norm),
16502    OpInfo(
16503        "nn.functional.binary_cross_entropy",
16504        aten_backward_name='binary_cross_entropy_backward',
16505        sample_inputs_func=sample_inputs_binary_cross_entropy,
16506        dtypes=floating_types_and(torch.float16, torch.bfloat16),
16507        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
16508        supports_out=False,
16509        gradcheck_fast_mode=False,
16510        supports_autograd=True,
16511        supports_forward_ad=True,
16512        supports_fwgrad_bwgrad=True,
16513        decorators=(
16514            # RuntimeError: expected int at position 0, but got: Tensor
16515            DecorateInfo(
16516                unittest.skip("Skipped!"),
16517                "TestCudaFuserOpInfo",
16518            ),
16519            # RuntimeError: expected int at position 0, but got: Tensor
16520            DecorateInfo(
16521                unittest.skip("Skipped!"),
16522                "TestNNCOpInfo",
16523                "test_nnc_correctness",
16524            ),
16525            # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120783
16526            DecorateInfo(
16527                unittest.skip("Skipped!"),
16528                "TestCompositeCompliance",
16529                "test_cow_input",
16530                device_type='cuda',
16531            ),
16532            DecorateInfo(
16533                toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}),
16534                "TestJit",
16535                "test_variant_consistency_jit",
16536            ),
16537            # RuntimeError: output with shape [] doesn't match the broadcast shape [5, 5]
16538            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'),
16539            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
16540            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
16541        ),
16542        skips=(
16543            # RuntimeError: expected int at position 0, but got: Tensor
16544            DecorateInfo(
16545                unittest.expectedFailure,
16546                "TestJit",
16547                "test_variant_consistency_jit",
16548            ),
16549        ),
16550    ),
16551    # We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the
16552    # standard entry, second is to run gradcheck tests on the second argument.
16553    BinaryUfuncInfo('igamma',
16554                    dtypes=floating_types_and(torch.bfloat16, torch.float16),
16555                    aliases=('torch.special.gammainc',),
16556                    dtypesIfCUDA=floating_types(),
16557                    # TODO: FIXME
16558                    supports_rhs_python_scalar=False,
16559                    supports_autograd=False,
16560                    skips=(
16561                        # FIXME: incorrectly tries to pass a rhs scalar
16562                        DecorateInfo(unittest.expectedFailure, 'TestJit',
16563                                     'test_jit_alias_remapping'),
16564                    )),
16565    # TODO: FIXME, ideally by implemented grad for both inputs
16566    # BinaryUfuncInfo('igamma',
16567    #                 variant_test_name='grad_other',
16568    #                 # Since autograd formula is implemented only for other and
16569    #                 # gradcheck test verifies the formula for input in SampleInput,
16570    #                 # we permute the arguments.
16571    #                 op=lambda self, other, **kwargs: torch.igamma(other, self, **kwargs),
16572    #                 inplace_variant=None,
16573    #                 method_variant=None,
16574    #                 supports_rhs_python_scalar=False,
16575    #                 rhs_make_tensor_kwargs=dict(requires_grad=False),
16576    #                 dtypes=floating_types_and(torch.bfloat16, torch.float16),
16577    #                 backward_dtypesIfCPU=floating_types_and(torch.bfloat16),
16578    #                 dtypesIfCUDA=floating_types(),
16579    #                 backward_dtypesIfCUDA=floating_types(),
16580    #                 supports_inplace_autograd=False,
16581    #                 skips=(
16582    #                     # Derivative wrt first tensor not implemented
16583    #                     DecorateInfo(unittest.expectedFailure, "TestCommon",
16584    #                                  "test_floating_inputs_are_differentiable"),"),
16585    #                     # test does not work with passing lambda for op
16586    #                     # AssertionError: False is not true : Tensors failed to compare as equal!
16587    #                     DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
16588    #                     # test fails are we permute the arguments function variant
16589    #                     # but not for inplace or method.
16590    #                     DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
16591    #                     # TypeError: igamma(): argument 'input' (position 1) must be Tensor, not float
16592    #                     DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'),
16593    #                 )),
16594    BinaryUfuncInfo('igammac',
16595                    dtypes=floating_types_and(torch.bfloat16, torch.float16),
16596                    aliases=('torch.special.gammaincc',),
16597                    dtypesIfCUDA=floating_types(),
16598                    supports_autograd=False,
16599                    supports_rhs_python_scalar=False,
16600                    skips=(
16601                        # FIXME: incorrectly tries to pass a rhs scalar
16602                        DecorateInfo(unittest.expectedFailure, 'TestJit',
16603                                     'test_jit_alias_remapping'),
16604                    )),
16605    # TODO: FIXME, ideally by implementing grad for both inputs
16606    # BinaryUfuncInfo('igammac',
16607    #                 variant_test_name='grad_other',
16608    #                 # Since autograd formula is implemented only for other and
16609    #                 # gradcheck test verifies the formula for input in SampleInput,
16610    #                 # we permute the arguments
16611    #                 op=lambda self, other, **kwargs: torch.igammac(other, self, **kwargs),
16612    #                 inplace_variant=None,
16613    #                 method_variant=None,
16614    #                 supports_rhs_python_scalar=False,
16615    #                 rhs_make_tensor_kwargs=dict(requires_grad=False),
16616    #                 dtypes=floating_types_and(torch.bfloat16, torch.float16),
16617    #                 backward_dtypesIfCPU=floating_types_and(torch.bfloat16),
16618    #                 dtypesIfCUDA=floating_types(),
16619    #                 backward_dtypesIfCUDA=floating_types(),
16620    #                 supports_inplace_autograd=False,
16621    #                 decorators=[
16622    #                     # Derivative wrt first tensor not implemented
16623    #                     DecorateInfo(unittest.expectedFailure, "TestCommon",
16624    #                                  "test_floating_inputs_are_differentiable"),
16625    #                 ],
16626    #                 skips=(
16627    #                     # test does not work with passing lambda for op
16628    #                     # AssertionError: False is not true : Tensors failed to compare as equal!
16629    #                     DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
16630    #                     # test fails are we permute the arguments function variant
16631    #                     # but not for inplace or method.
16632    #                     DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
16633    #                     # TypeError: igammac(): argument 'input' (position 1) must be Tensor, not float
16634    #                     DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'),
16635    #                 )),
16636    UnaryUfuncInfo('nn.functional.softshrink',
16637                   aten_name="softshrink",
16638                   aten_backward_name='softshrink_backward',
16639                   dtypes=floating_types_and(torch.bfloat16, torch.float16),
16640                   supports_forward_ad=True,
16641                   supports_fwgrad_bwgrad=True,
16642                   assert_autodiffed=False,
16643                   sample_inputs_func=sample_inputs_softshrink,
16644                   error_inputs_func=error_inputs_softshrink),
16645    UnaryUfuncInfo('nn.functional.hardshrink',
16646                   aten_name="hardshrink",
16647                   aten_backward_name='hardshrink_backward',
16648                   dtypes=floating_types_and(torch.bfloat16, torch.float16),
16649                   assert_autodiffed=True,
16650                   sample_inputs_func=sample_inputs_hardshrink,
16651                   supports_forward_ad=True,
16652                   supports_fwgrad_bwgrad=True,
16653                   autodiff_nonfusible_nodes=["aten::hardshrink"]),
16654    UnaryUfuncInfo('nn.functional.hardtanh',
16655                   aten_name="hardtanh",
16656                   aten_backward_name='hardtanh_backward',
16657                   dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.half, torch.bfloat16),
16658                   backward_dtypes=all_types_and(torch.half, torch.bfloat16),
16659                   backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
16660                   assert_autodiffed=True,
16661                   sample_inputs_func=sample_inputs_hardtanh,
16662                   error_inputs_func=error_inputs_hardtanh,
16663                   supports_out=False,
16664                   supports_forward_ad=True,
16665                   supports_fwgrad_bwgrad=True,
16666                   autodiff_nonfusible_nodes=["aten::hardtanh"]),
16667    OpInfo('nn.functional.gelu',
16668           aten_name="gelu",
16669           aten_backward_name='gelu_backward',
16670           ref=reference_gelu if TEST_SCIPY else None,
16671           error_inputs_func=error_inputs_gelu,
16672           supports_autograd=True,
16673           assert_autodiffed=True,
16674           sample_inputs_func=sample_inputs_gelu,
16675           dtypes=floating_types_and(torch.bfloat16, torch.half),
16676           supports_gradgrad=True,
16677           supports_forward_ad=True,
16678           supports_fwgrad_bwgrad=True,
16679           autodiff_nonfusible_nodes=["aten::gelu"],
16680           skips=(
16681               # AssertionError: Tensor-likes are not close!
16682               # May not replicate in CI
16683               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
16684               DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
16685           )),
16686    UnaryUfuncInfo('nn.functional.relu6',
16687                   aten_name="relu6",
16688                   dtypes=all_types_and(torch.half, torch.bfloat16),
16689                   backward_dtypes=floating_types_and(torch.half, torch.bfloat16),
16690                   assert_autodiffed=True,
16691                   supports_out=False,
16692                   supports_forward_ad=True,
16693                   supports_fwgrad_bwgrad=True,
16694                   autodiff_nonfusible_nodes=["aten::relu6"]),
16695    OpInfo('mm',
16696           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
16697           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
16698           assert_autodiffed=True,
16699           supports_forward_ad=True,
16700           supports_fwgrad_bwgrad=True,
16701           sample_inputs_func=sample_inputs_mm,
16702           skips=(
16703               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
16704               DecorateInfo(
16705                   unittest.skip("Skipped!"),
16706                   'TestSchemaCheckModeOpInfo',
16707                   'test_schema_correctness',
16708                   dtypes=(torch.complex64, torch.complex128)),
16709           )),
16710    OpInfo('mode',
16711           op=torch.mode,
16712           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
16713           supports_forward_ad=True,
16714           supports_fwgrad_bwgrad=True,
16715           skips=(
16716               # Resized a non-empty tensor but did not warn about it
16717               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
16718               # FIXME:
16719               # Expected 2114 but got 1123.
16720               # Absolute difference: 991 (up to 0.001 allowed)
16721               # Relative difference: 0.46877956480605487 (up to 0.001 allowed)
16722               DecorateInfo(
16723                   unittest.skip("Skipped!"),
16724                   "TestCommon",
16725                   "test_compare_cpu",
16726                   dtypes=(torch.float32,),
16727                   device_type="cuda",
16728               ),
16729           ),
16730           sample_inputs_func=sample_inputs_mode,),
16731    make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_1',
16732                         domain=(1, None),
16733                         skips=skips_mvlgamma(),
16734                         sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})),
16735    make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_3',
16736                         domain=(2, None),
16737                         skips=skips_mvlgamma(),
16738                         sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})),
16739    make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_5',
16740                         domain=(3, None),
16741                         skips=skips_mvlgamma(),
16742                         sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})),
16743    BinaryUfuncInfo('ne',
16744                    ref=np.not_equal,
16745                    aliases=('not_equal',),
16746                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
16747                    always_returns_bool=True,
16748                    supports_autograd=False,
16749                    skips=(
16750                    )),
16751    OpInfo('narrow',
16752           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
16753           supports_out=False,
16754           supports_forward_ad=True,
16755           supports_fwgrad_bwgrad=True,
16756           sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=True),
16757           reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=True),
16758           error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=False),
16759           skips=(
16760               # Use of .item()
16761               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
16762               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
16763               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
16764               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
16765               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
16766           )),
16767    OpInfo('narrow_copy',
16768           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
16769           supports_out=True,
16770           supports_forward_ad=False,
16771           supports_fwgrad_bwgrad=False,
16772           supports_autograd=False,
16773           # https://github.com/pytorch/pytorch/issues/86931
16774           sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=False),
16775           reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=False),
16776           error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=False),
16777           skips=(
16778               # https://github.com/pytorch/pytorch/issues/84577
16779               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
16780               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
16781               # Could not run 'aten::narrow_copy.out' with arguments from the 'CUDA' backend
16782               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace',
16783                            device_type='cuda'),
16784               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace',
16785                            device_type='cuda'),
16786               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
16787                            device_type='cuda'),
16788               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
16789           )),
16790    OpInfo('view_copy',
16791           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
16792           ref=lambda x, newshape: np.reshape(x, newshape).copy(),
16793           supports_out=True,
16794           supports_forward_ad=True,
16795           supports_fwgrad_bwgrad=True,
16796           supports_autograd=True,
16797           sample_inputs_func=sample_inputs_view_reshape,
16798           error_inputs_func=error_inputs_view_reshape,
16799           skips=(
16800               # RuntimeError: view size is not compatible with input tensor's size and stride
16801               # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
16802               DecorateInfo(
16803                   unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"
16804               ),
16805           )),
16806    UnaryUfuncInfo('neg',
16807                   aliases=('negative', ),
16808                   ref=np.negative,
16809                   dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
16810                   error_inputs_func=error_inputs_neg,
16811                   supports_forward_ad=True,
16812                   supports_fwgrad_bwgrad=True,
16813                   supports_sparse=True,
16814                   supports_sparse_csr=True,
16815                   supports_sparse_csc=True,
16816                   supports_sparse_bsr=True,
16817                   supports_sparse_bsc=True,
16818                   assert_autodiffed=True),
16819    OpInfo('dist',
16820           op=torch.dist,
16821           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
16822           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
16823           gradcheck_fast_mode=True,
16824           supports_out=False,
16825           supports_forward_ad=True,
16826           # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got:
16827           # Could not allocate memory to change Tensor SizesAndStrides!
16828           check_batched_forward_grad=False,
16829           supports_fwgrad_bwgrad=True,
16830           sample_inputs_func=sample_inputs_dist),
16831    OpInfo('outer',
16832           op=torch.outer,
16833           aliases=('ger', ),
16834           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
16835           supports_forward_ad=True,
16836           supports_fwgrad_bwgrad=True,
16837           # See https://github.com/pytorch/pytorch/pull/78358
16838           check_batched_forward_grad=False,
16839           sample_inputs_func=sample_inputs_outer,),
16840    OpInfo('ormqr',
16841           op=torch.ormqr,
16842           dtypes=floating_and_complex_types(),
16843           # https://github.com/pytorch/pytorch/issues/80411
16844           gradcheck_fast_mode=True,
16845           supports_forward_ad=False,
16846           supports_fwgrad_bwgrad=False,
16847           sample_inputs_func=sample_inputs_ormqr,
16848           error_inputs_func=error_inputs_ormqr,
16849           decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack],
16850           skips=(
16851               # Strides are not the same!
16852               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
16853           )),
16854    OpInfo('permute',
16855           ref=np.transpose,
16856           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
16857           supports_out=False,
16858           assert_autodiffed=True,
16859           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
16860           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
16861           assert_jit_shape_analysis=True,
16862           supports_forward_ad=True,
16863           supports_fwgrad_bwgrad=True,
16864           supports_varargs=True,
16865           sample_inputs_func=sample_inputs_permute,
16866           reference_inputs_func=reference_inputs_permute),
16867    BinaryUfuncInfo('pow',
16868                    dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
16869                    dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf),
16870                    ref=np.power,
16871                    # Due to AVX2 currently not being fully supported for Float16, log_vml_cpu can't be enabled
16872                    # for Float16, causing this test to fail. pow's autograd for Float16 is thus currently
16873                    # unsupported on CPU.
16874                    backward_dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
16875                    backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
16876                    # https://github.com/pytorch/pytorch/issues/80411
16877                    gradcheck_fast_mode=True,
16878                    supports_inplace_autograd=False,
16879                    supports_forward_ad=True,
16880                    supports_fwgrad_bwgrad=True,
16881                    assert_autodiffed=True,
16882                    supports_one_python_scalar=True,
16883                    # Integer types do not support negative exponentes
16884                    rhs_make_tensor_kwargs=dict(low=0),
16885                    # Raising negative real numbers to fractional powers is not supported
16886                    lhs_make_tensor_kwargs=dict(low=0),
16887                    decorators=(
16888                        DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}),
16889                                     'TestBinaryUfuncs', 'test_reference_numerics'),
16890                        DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05),
16891                                                        torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}),
16892                                     'TestBinaryUfuncs', 'test_scalar_support'),
16893                    ),
16894                    skips=(
16895                        # Skipping integers because they are being raised to negative powers causing an error
16896                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
16897                                     dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]),
16898                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
16899                                     dtypes=[torch.int16, torch.int32, torch.int64]),
16900                        # FIXME Complex values error with: Greatest absolute difference: nan at index
16901                        # Ref: https://github.com/pytorch/pytorch/issues/76853
16902                        # For `chalf`, reference computation in `numpy` is computed in `cfloat`.
16903                        # Output of `chalf` saturates to `inf` quicker than reference due to its small range
16904                        # which leads to failure of this test.
16905                        DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick',
16906                                     dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM),
16907                        # FIXME:
16908                        # Mismatched elements: 1 / 500 (0.2%)
16909                        # Greatest absolute difference: nan at index (7, 9, 0) (up to 1e-05 allowed)
16910                        # Greatest relative difference: nan at index (7, 9, 0) (up to 0.001 allowed)
16911                        DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive',
16912                                     dtypes=(torch.complex32,)),
16913                        DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing',
16914                                     dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM),
16915                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_batch_vs_slicing',
16916                                     dtypes=(torch.complex32,)),
16917                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_non_contig',
16918                                     dtypes=(torch.complex32,)),
16919                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics',
16920                                     dtypes=(torch.complex32,)),
16921                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
16922                                     dtypes=(torch.complex32, torch.complex64, torch.complex128)),
16923                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
16924                                     dtypes=(torch.complex32, torch.complex64, torch.complex128)),
16925                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
16926                                     dtypes=(torch.complex32, torch.complex64, torch.complex128)),
16927                    )),
16928    BinaryUfuncInfo('float_power',
16929                    ref=np.float_power,
16930                    dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
16931                    promotes_int_to_float=True,
16932                    # https://github.com/pytorch/pytorch/issues/80411
16933                    gradcheck_fast_mode=True,
16934                    supports_forward_ad=True,
16935                    supports_fwgrad_bwgrad=True,
16936                    supports_one_python_scalar=True,
16937                    # Integer types do not support negative exponentes
16938                    rhs_make_tensor_kwargs=dict(low=0),
16939                    # Raising negative real numbers to fractional powers is not supported
16940                    lhs_make_tensor_kwargs=dict(low=0),
16941                    decorators=(
16942                        DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05),
16943                                                        torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}),
16944                                     'TestBinaryUfuncs', 'test_scalar_support'),
16945                    ),
16946                    skips=(
16947                        # FIXME
16948                        # AssertionError: Object comparison failed: torch.float64 != torch.float32
16949                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
16950                        # -3.43399e+38 is outside the range of representable values of type 'float'
16951                        DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
16952                        # Complex values error with: Greatest absolute difference: nan at index
16953                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values',
16954                                     dtypes=[torch.complex64, torch.complex128]),
16955                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values',
16956                                     dtypes=[torch.complex64, torch.complex128]),
16957                        DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values',
16958                                     dtypes=[torch.complex64, torch.complex128]),
16959                        # Inplace always promotes to double and thus other floating dtypes are not supported
16960                        DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace',
16961                                     dtypes=[torch.bfloat16, torch.float16, torch.float32]),
16962                    )),
16963    OpInfo('qr',
16964           op=torch.qr,
16965           dtypes=floating_and_complex_types(),
16966           sample_inputs_func=sample_inputs_linalg_qr_geqrf,
16967           supports_forward_ad=True,
16968           supports_fwgrad_bwgrad=True,
16969           # In-place ops
16970           check_batched_gradgrad=False,
16971           decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack]),
16972    UnaryUfuncInfo('rad2deg',
16973                   ref=np.degrees,
16974                   decorators=(precisionOverride({torch.bfloat16: 7e-1,
16975                                                  torch.float16: 7e-1}),),
16976                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
16977                   supports_forward_ad=True,
16978                   supports_fwgrad_bwgrad=True,
16979                   supports_sparse=True,
16980                   supports_sparse_csr=True,
16981                   supports_sparse_csc=True,
16982                   supports_sparse_bsr=True,
16983                   supports_sparse_bsc=True,
16984                   promotes_int_to_float=True),
16985    UnaryUfuncInfo('real',
16986                   ref=np.real,
16987                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
16988                   supports_out=False,
16989                   supports_forward_ad=True,
16990                   supports_fwgrad_bwgrad=True,
16991                   # See https://github.com/pytorch/pytorch/issues/66357
16992                   check_batched_forward_grad=False,
16993                   skips=(
16994                       # Skip since real and imag don't have out variants.
16995                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
16996                   )),
16997    OpInfo(
16998        "roll",
16999        ref=np.roll,
17000        dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
17001        error_inputs_func=error_inputs_roll,
17002        supports_out=False,
17003        supports_forward_ad=True,
17004        supports_fwgrad_bwgrad=True,
17005        sample_inputs_func=sample_inputs_roll,
17006        decorators=(onlyNativeDeviceTypes,),
17007    ),
17008    OpInfo(
17009        "rot90",
17010        dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
17011        error_inputs_func=error_inputs_rot90,
17012        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
17013        gradcheck_fast_mode=True,
17014        supports_out=False,
17015        supports_forward_ad=True,
17016        supports_fwgrad_bwgrad=True,
17017        sample_inputs_func=sample_inputs_rot90,
17018    ),
17019    # To test reference numerics against multiple values of argument `decimals`,
17020    # we make multiple OpInfo entries with each entry corresponding to different value of decimals.
17021    UnaryUfuncInfo('round',
17022                   ref=np.round,
17023                   aliases=('special.round',),
17024                   dtypes=all_types_and(torch.half, torch.bfloat16),
17025                   supports_forward_ad=True,
17026                   supports_fwgrad_bwgrad=True,
17027                   skips=(
17028                       DecorateInfo(unittest.expectedFailure,
17029                                    'TestNNCOpInfo',
17030                                    'test_nnc_correctness',
17031                                    dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
17032                       DecorateInfo(unittest.skip("Skipped!"),
17033                                    'TestNNCOpInfo',
17034                                    'test_nnc_correctness',
17035                                    dtypes=(torch.bfloat16,)),
17036                   ),
17037                   supports_sparse=True,
17038                   supports_sparse_csr=True,
17039                   supports_sparse_csc=True,
17040                   supports_sparse_bsr=True,
17041                   supports_sparse_bsc=True,
17042                   assert_autodiffed=True,
17043                   ),
17044    UnaryUfuncInfo('round',
17045                   ref=np.round,
17046                   variant_test_name='decimals_0',
17047                   aliases=('special.round',),
17048                   dtypes=floating_types_and(torch.half, torch.bfloat16),
17049                   sample_kwargs=lambda device, dtype, input: ({'decimals': 0}, {'decimals': 0}),
17050                   sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 0}),
17051                   supports_forward_ad=True,
17052                   supports_fwgrad_bwgrad=True,
17053                   assert_autodiffed=False,
17054                   supports_sparse_csr=False),
17055    UnaryUfuncInfo('round',
17056                   ref=np.round,
17057                   variant_test_name='decimals_3',
17058                   aliases=('special.round',),
17059                   dtypes=floating_types_and(torch.bfloat16),
17060                   dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
17061                   sample_kwargs=lambda device, dtype, input: ({'decimals': 3}, {'decimals': 3}),
17062                   sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 3}),
17063                   skips=(
17064                       # test_ops already tested for this overload with `decimals_0` opinfo entry
17065                       DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
17066                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
17067                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
17068                       DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
17069                       DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'),
17070                       DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
17071                                    "TestUnaryUfuncs", "test_reference_numerics_extremal",
17072                                    device_type="cuda"),
17073                       DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
17074                                    "TestUnaryUfuncs", "test_reference_numerics_normal",
17075                                    device_type="cuda"),
17076                   ),
17077                   supports_forward_ad=True,
17078                   supports_fwgrad_bwgrad=True,
17079                   assert_autodiffed=False,
17080                   supports_sparse_csr=False),
17081    UnaryUfuncInfo('round',
17082                   ref=np.round,
17083                   variant_test_name='decimals_neg_3',
17084                   aliases=('special.round',),
17085                   dtypes=floating_types_and(torch.bfloat16),
17086                   dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
17087                   sample_kwargs=lambda device, dtype, input: ({'decimals': -3}, {'decimals': -3}),
17088                   sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': -3}),
17089                   skips=(
17090                       # test_ops already tested for this overload with `decimals_0` opinfo entry
17091                       DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
17092                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
17093                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
17094                       DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
17095                       DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'),
17096                   ),
17097                   supports_forward_ad=True,
17098                   supports_fwgrad_bwgrad=True,
17099                   assert_autodiffed=False,
17100                   supports_sparse_csr=False),
17101    UnaryUfuncInfo('sin',
17102                   ref=np.sin,
17103                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
17104                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
17105                   assert_autodiffed=True,
17106                   handles_large_floats=False,
17107                   supports_sparse=True,
17108                   supports_sparse_csr=True,
17109                   supports_sparse_csc=True,
17110                   supports_sparse_bsr=True,
17111                   supports_sparse_bsc=True,
17112                   supports_forward_ad=True,
17113                   supports_fwgrad_bwgrad=True,
17114                   promotes_int_to_float=True,
17115                   skips=(
17116                       # Fails on CUDA but passes on ROCm
17117                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17118                                    dtypes=(torch.cdouble,), device_type='cuda'),
17119                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17120                                    dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
17121                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17122                                    dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS),
17123                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
17124                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
17125                   ),
17126                   decorators=(precisionOverride({torch.bfloat16: 1e-2}),)),
17127    UnaryUfuncInfo('sinc',
17128                   ref=np_sinc_with_fp16_as_fp32,
17129                   aliases=('special.sinc',),
17130                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
17131                   handles_large_floats=False,
17132                   supports_forward_ad=True,
17133                   supports_fwgrad_bwgrad=True,
17134                   promotes_int_to_float=True),
17135    UnaryUfuncInfo('sinh',
17136                   ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh),
17137                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
17138                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
17139                   assert_autodiffed=True,
17140                   supports_forward_ad=True,
17141                   supports_fwgrad_bwgrad=True,
17142                   supports_sparse=True,
17143                   supports_sparse_csr=True,
17144                   supports_sparse_csc=True,
17145                   supports_sparse_bsr=True,
17146                   supports_sparse_bsc=True,
17147                   promotes_int_to_float=True,
17148                   decorators=(precisionOverride({torch.float16: 1e-2}),),
17149                   skips=(
17150                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17151                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
17152                                    active_if=(IS_MACOS or IS_WINDOWS)),
17153                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17154                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
17155                                    active_if=(IS_MACOS or IS_WINDOWS)),
17156                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17157                                    dtypes=(torch.cdouble,)),
17158                       # Reference: https://github.com/pytorch/pytorch/issues/48641
17159                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17160                                    device_type='cpu', dtypes=[torch.int8]),
17161                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
17162                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
17163                   )),
17164    UnaryUfuncInfo('sign',
17165                   ref=reference_sign,
17166                   dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
17167                   dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.half),
17168                   supports_forward_ad=True,
17169                   supports_fwgrad_bwgrad=True,
17170                   supports_sparse=True,
17171                   supports_sparse_csr=True,
17172                   supports_sparse_csc=True,
17173                   supports_sparse_bsr=True,
17174                   supports_sparse_bsc=True,
17175                   skips=(
17176                       # Reference: https://github.com/pytorch/pytorch/issues/41245
17177                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17178                                    dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]),
17179                   )),
17180    UnaryUfuncInfo('sgn',
17181                   ref=reference_sgn,
17182                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
17183                   backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
17184                   backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf),
17185                   supports_forward_ad=True,
17186                   supports_fwgrad_bwgrad=True,
17187                   supports_sparse=True,
17188                   supports_sparse_csr=True,
17189                   supports_sparse_csc=True,
17190                   supports_sparse_bsr=True,
17191                   supports_sparse_bsc=True,
17192                   skips=(
17193                       # Reference: https://github.com/pytorch/pytorch/issues/41245
17194                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17195                                    dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]),
17196                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
17197                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
17198                   )),
17199    OpInfo('split',
17200           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
17201           sample_inputs_func=partial(sample_inputs_split, list_args=False),
17202           supports_forward_ad=True,
17203           supports_fwgrad_bwgrad=True,
17204           supports_out=False,
17205           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
17206           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
17207           assert_autodiffed=True),
17208    OpInfo('split',
17209           # Cannot declare this aten_name because of
17210           # test_variant_consistency_jit_split_list_args_cpu_float32
17211           decomp_aten_name='split_with_sizes',
17212           variant_test_name='list_args',
17213           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
17214           sample_inputs_func=partial(sample_inputs_split, list_args=True),
17215           supports_forward_ad=True,
17216           supports_fwgrad_bwgrad=True,
17217           supports_out=False),
17218    # `unsafe_split` supports only `int` for split_size argument
17219    OpInfo('unsafe_split',
17220           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
17221           sample_inputs_func=partial(sample_inputs_split, list_args=False),
17222           supports_forward_ad=True,
17223           supports_fwgrad_bwgrad=True,
17224           supports_out=False,
17225           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
17226           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
17227           assert_autodiffed=True,
17228           check_batched_forward_grad=False),
17229    OpInfo('split_with_sizes',
17230           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
17231           sample_inputs_func=sample_inputs_split_with_sizes,
17232           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
17233           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
17234           supports_out=False,
17235           supports_forward_ad=True,
17236           supports_fwgrad_bwgrad=True,
17237           assert_autodiffed=True),
17238    OpInfo('split_with_sizes_copy',
17239           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
17240           sample_inputs_func=sample_inputs_split_with_sizes,
17241           supports_out=True,
17242           supports_forward_ad=True,
17243           supports_fwgrad_bwgrad=True,
17244           skips=(
17245               # No error raised
17246               DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_requires_grad_error"),
17247           )),
17248    BinaryUfuncInfo('__radd__',
17249                    op=torch.Tensor.__radd__,
17250                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
17251                    supports_out=False,
17252                    skips=(
17253                        DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
17254                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
17255
17256                    ),
17257                    assert_autodiffed=True,
17258                    supports_forward_ad=True,
17259                    supports_fwgrad_bwgrad=True,
17260                    autodiff_nonfusible_nodes=['aten::add'],),
17261    BinaryUfuncInfo('__rdiv__',
17262                    op=torch.Tensor.__rdiv__,
17263                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
17264                    promotes_int_to_float=True,
17265                    lhs_make_tensor_kwargs={'exclude_zero': True},
17266                    # Runs very slowly on slow gradcheck - alternatively reduce input sizes
17267                    gradcheck_fast_mode=True,
17268                    supports_out=False,
17269                    skips=(
17270                        # https://github.com/pytorch/pytorch/issues/76806
17271                        DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
17272                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
17273                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
17274                    ),
17275                    supports_forward_ad=True,
17276                    supports_fwgrad_bwgrad=True,
17277                    assert_autodiffed=True,
17278                    autodiff_nonfusible_nodes=['aten::mul', 'aten::reciprocal'],),
17279    BinaryUfuncInfo('__rmul__',
17280                    op=torch.Tensor.__rmul__,
17281                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool),
17282                    supports_out=False,
17283                    skips=(
17284                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
17285                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
17286                    ),
17287                    assert_autodiffed=True,
17288                    supports_forward_ad=True,
17289                    supports_fwgrad_bwgrad=True,
17290                    autodiff_nonfusible_nodes=['aten::mul'],),
17291    BinaryUfuncInfo('__rand__',
17292                    op=torch.Tensor.__rand__,
17293                    dtypes=integral_types_and(torch.bool),
17294                    supports_out=False,
17295                    supports_autograd=False,
17296                    supports_forward_ad=True,
17297                    skips=(
17298                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
17299                    )),
17300    BinaryUfuncInfo('__ror__',
17301                    op=torch.Tensor.__ror__,
17302                    dtypes=integral_types_and(torch.bool),
17303                    supports_out=False,
17304                    supports_autograd=False,
17305                    supports_forward_ad=True,
17306                    skips=(
17307                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
17308                    )),
17309    BinaryUfuncInfo('__rxor__',
17310                    op=torch.Tensor.__rxor__,
17311                    dtypes=integral_types_and(torch.bool),
17312                    supports_out=False,
17313                    supports_autograd=False,
17314                    supports_forward_ad=True,
17315                    skips=(
17316                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
17317                    )),
17318    OpInfo('__rmatmul__',
17319           op=torch.Tensor.__rmatmul__,
17320           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
17321           dtypesIfCUDA=floating_and_complex_types_and(torch.float16,
17322                                                       *[torch.bfloat16]
17323                                                       if SM53OrLater or TEST_WITH_ROCM else []),
17324           assert_autodiffed=True,
17325           sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=True),
17326           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
17327           gradcheck_fast_mode=True,
17328           supports_out=False,
17329           supports_forward_ad=True,
17330           supports_fwgrad_bwgrad=True,
17331           check_batched_forward_grad=False,
17332           decorators=(
17333               # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3
17334               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater),
17335               DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}),
17336                            'TestMathBits', 'test_conj_view'),
17337               DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.2e-03)}),
17338                            'TestCommon', 'test_noncontiguous_samples'),
17339               DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1e-05)}),
17340                            "TestDecomp", "test_comprehensive", device_type="cuda",
17341                            active_if=TEST_WITH_ROCM),
17342           ),
17343           skips=(
17344               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
17345               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
17346               # https://github.com/pytorch/pytorch/issues/67470
17347               DecorateInfo(unittest.skip("67470!"),
17348                            'TestCommon', 'test_noncontiguous_samples',
17349                            device_type='cpu', dtypes=(torch.long,)),
17350               # Fails on XLA.
17351               # AssertionError: False is not true : Tensors failed to compare as equal
17352               DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)),
17353               # https://github.com/pytorch/pytorch/issues/71774
17354               DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness',
17355                            device_type='cpu', dtypes=(torch.long,)),
17356           )),
17357    BinaryUfuncInfo('__rmod__',
17358                    op=torch.Tensor.__rmod__,
17359                    dtypes=floating_types_and(torch.bfloat16, torch.half,),
17360                    dtypesIfCUDA=all_types_and(torch.bfloat16, torch.half),
17361                    # https://github.com/pytorch/pytorch/issues/80411
17362                    gradcheck_fast_mode=True,
17363                    supports_out=False,
17364                    supports_forward_ad=True,
17365                    supports_fwgrad_bwgrad=True,
17366                    supports_one_python_scalar=True,
17367                    skips=(
17368                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
17369                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
17370                    ),
17371                    # Support autograd after torch.remainder(Tensor, Tensor) supports
17372                    # autograd of the second argument.
17373                    # https://github.com/pytorch/pytorch/pull/58476/files#r637167630
17374                    # supports_autograd=False,
17375                    assert_autodiffed=True,
17376                    autodiff_nonfusible_nodes=['aten::remainder'],),
17377    BinaryUfuncInfo('__rpow__',
17378                    op=torch.Tensor.__rpow__,
17379                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
17380                    # Reference: https://github.com/pytorch/pytorch/issues/54774
17381                    # "log2" "_vml_cpu" not implemented for Half
17382                    backward_dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
17383                    supports_out=False,
17384                    supports_forward_ad=True,
17385                    supports_fwgrad_bwgrad=True,
17386                    supports_one_python_scalar=True,
17387                    skips=(
17388                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
17389                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
17390                        # TODO: FIXME tolerance is too high
17391                        DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients'),
17392                        DecorateInfo(unittest.skip('Skipped!'), 'TestBwdGradients'),
17393                    ),
17394                    assert_autodiffed=True,
17395                    autodiff_nonfusible_nodes=['aten::pow'],),
17396    BinaryUfuncInfo('__rsub__',
17397                    op=torch.Tensor.__rsub__,
17398                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
17399                    supports_forward_ad=True,
17400                    supports_fwgrad_bwgrad=True,
17401                    supports_out=False,
17402                    supports_one_python_scalar=True,
17403                    skips=(
17404                        DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
17405                        DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',),
17406                    ),
17407                    assert_autodiffed=True,
17408                    autodiff_nonfusible_nodes=['aten::rsub'],),
17409    BinaryUfuncInfo('rsub',
17410                    dtypes=all_types_and_complex_and(torch.bfloat16, torch.half),
17411                    supports_forward_ad=True,
17412                    supports_fwgrad_bwgrad=True,
17413                    supports_out=False,
17414                    supports_inplace_autograd=False,
17415                    assert_autodiffed=None,
17416                    sample_inputs_func=sample_inputs_add_sub),
17417    OpInfo('select',
17418           aten_backward_name='select_backward',
17419           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
17420           sample_inputs_func=sample_inputs_select,
17421           assert_jit_shape_analysis=True,
17422           supports_forward_ad=True,
17423           supports_fwgrad_bwgrad=True,
17424           supports_out=False),
17425    OpInfo('select_scatter',
17426           dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool),
17427           sample_inputs_func=sample_inputs_select_scatter,
17428           supports_forward_ad=True,
17429           supports_fwgrad_bwgrad=True,
17430           supports_out=False),
17431    OpInfo('slice',
17432           op=torch.ops.aten.slice.Tensor,
17433           dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf),
17434           sample_inputs_func=sample_inputs_slice,
17435           gradcheck_fast_mode=True,
17436           supports_forward_ad=True,
17437           supports_fwgrad_bwgrad=True,
17438           supports_scripting=False,
17439           supports_inplace_autograd=False,
17440           supports_out=False),
17441    OpInfo('slice_scatter',
17442           dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool),
17443           sample_inputs_func=sample_inputs_slice_scatter,
17444           # https://github.com/pytorch/pytorch/issues/80411
17445           gradcheck_fast_mode=True,
17446           supports_forward_ad=True,
17447           supports_fwgrad_bwgrad=True,
17448           supports_out=True),
17449    UnaryUfuncInfo('signbit',
17450                   ref=np.signbit,
17451                   dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half),
17452                   supports_sparse=True,
17453                   supports_sparse_csr=True,
17454                   supports_sparse_csc=True,
17455                   supports_sparse_bsr=True,
17456                   supports_sparse_bsc=True,
17457                   supports_autograd=False,),
17458    UnaryUfuncInfo('tan',
17459                   ref=np.tan,
17460                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
17461                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
17462                   decorators=(DecorateInfo(
17463                               toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}),
17464                               'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17465                               device_type='cuda'),),
17466                   assert_autodiffed=True,
17467                   supports_forward_ad=True,
17468                   supports_fwgrad_bwgrad=True,
17469                   supports_sparse=True,
17470                   supports_sparse_csr=True,
17471                   supports_sparse_csc=True,
17472                   supports_sparse_bsr=True,
17473                   supports_sparse_bsc=True,
17474                   promotes_int_to_float=True,
17475                   skips=(
17476                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17477                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
17478                                    active_if=(IS_MACOS or IS_WINDOWS)),
17479                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17480                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
17481                                    active_if=(IS_MACOS or IS_WINDOWS)),
17482                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
17483                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
17484                       # FIXME:
17485                       # Mismatched elements: 2 / 400 (0.5%)
17486                       # Greatest absolute difference: inf at index (7, 16) (up to 1e-05 allowed)
17487                       # Greatest relative difference: nan at index (7, 16) (up to 0.001 allowed)
17488                       DecorateInfo(
17489                           unittest.skip("Skipped!"),
17490                           "TestInductorOpInfo",
17491                           "test_comprehensive",
17492                           dtypes=(torch.float16,),
17493                           device_type="cuda",
17494                       ),
17495                   ),
17496                   # tan(pi/2 * odd_number) is nan
17497                   reference_numerics_filter=NumericsFilter(
17498                       condition=lambda x: close_to_int(x / (math.pi * 0.5)), safe_val=math.pi)),
17499    UnaryUfuncInfo('tanh',
17500                   ref=np.tanh,
17501                   aten_backward_name='tanh_backward',
17502                   aliases=('nn.functional.tanh',),
17503                   decorators=(precisionOverride({torch.bfloat16: 1e-2}),
17504                               DecorateInfo(
17505                                   toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}),
17506                                   'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17507                                   device_type='cuda'),),
17508                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
17509                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
17510                   assert_autodiffed=True,
17511                   assert_jit_shape_analysis=True,
17512                   supports_forward_ad=True,
17513                   supports_fwgrad_bwgrad=True,
17514                   supports_sparse=True,
17515                   supports_sparse_csr=True,
17516                   supports_sparse_csc=True,
17517                   supports_sparse_bsr=True,
17518                   supports_sparse_bsc=True,
17519                   promotes_int_to_float=True,
17520                   skips=(
17521                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17522                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
17523                                    active_if=(IS_MACOS or IS_WINDOWS)),
17524                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17525                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
17526                                    active_if=(IS_MACOS or IS_WINDOWS)),
17527                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
17528                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
17529                   ),
17530                   # tan(j * pi/2 * odd_number) is nan
17531                   reference_numerics_filter=NumericsFilter(
17532                       condition=lambda x: (close_to_int(x / (math.pi * 0.5j))
17533                                            if x.is_complex() else x.new_tensor(False, dtype=torch.bool)),
17534                       safe_val=0)),
17535    OpInfo('tensor_split',
17536           ref=np.array_split,
17537           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
17538           dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
17539           supports_out=False,
17540           supports_forward_ad=True,
17541           supports_fwgrad_bwgrad=True,
17542           skips=(
17543               # Pre-existing condition; Needs to be fixed
17544               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
17545               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
17546               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
17547           ),
17548           sample_inputs_func=sample_inputs_tensor_split,),
17549    OpInfo('hsplit',
17550           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16),
17551           supports_out=False,
17552           supports_forward_ad=True,
17553           supports_fwgrad_bwgrad=True,
17554           # See https://github.com/pytorch/pytorch/pull/78358
17555           check_batched_forward_grad=False,
17556           sample_inputs_func=sample_inputs_hsplit,
17557           error_inputs_func=error_inputs_hsplit,),
17558    OpInfo('vsplit',
17559           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16),
17560           supports_out=False,
17561           supports_forward_ad=True,
17562           supports_fwgrad_bwgrad=True,
17563           # See https://github.com/pytorch/pytorch/pull/78358
17564           check_batched_forward_grad=False,
17565           sample_inputs_func=sample_inputs_vsplit,
17566           error_inputs_func=error_inputs_vsplit,),
17567    OpInfo('dsplit',
17568           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16),
17569           supports_out=False,
17570           supports_forward_ad=True,
17571           supports_fwgrad_bwgrad=True,
17572           # See https://github.com/pytorch/pytorch/pull/78358
17573           check_batched_forward_grad=False,
17574           sample_inputs_func=sample_inputs_dsplit,
17575           error_inputs_func=error_inputs_dsplit,),
17576    OpInfo('triangular_solve',
17577           op=torch.triangular_solve,
17578           dtypes=floating_and_complex_types(),
17579           sample_inputs_func=sample_inputs_legacy_solve,
17580           check_batched_gradgrad=False,
17581           supports_forward_ad=True,
17582           supports_fwgrad_bwgrad=True,
17583           gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs),
17584           decorators=[
17585               skipCUDAIfNoMagma,
17586               skipCPUIfNoLapack,
17587               DecorateInfo(
17588                   toleranceOverride({torch.float32: tol(atol=3e-5, rtol=3e-6)}),
17589                   'TestConsistency', 'test_output_match', device_type='cpu',
17590               ),
17591           ],
17592           skips=(
17593               # AssertionError: Scalars are not equal!
17594               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
17595               # Gradcheck fails
17596               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad',
17597                            dtypes=floating_and_complex_types()),
17598               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
17599                            device_type='mps', dtypes=[torch.float32]),
17600               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
17601                            device_type='mps', dtypes=[torch.float32]),
17602               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
17603                            device_type='mps', dtypes=[torch.float32]),
17604           )),
17605    UnaryUfuncInfo('trunc',
17606                   aliases=('fix', ),
17607                   ref=np.trunc,
17608                   dtypes=all_types_and(torch.half, torch.bfloat16),
17609                   supports_forward_ad=True,
17610                   supports_fwgrad_bwgrad=True,
17611                   supports_sparse=True,
17612                   skips=(
17613                       DecorateInfo(unittest.expectedFailure,
17614                                    'TestNNCOpInfo',
17615                                    'test_nnc_correctness',
17616                                    dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
17617                   ),
17618                   supports_sparse_csr=True,
17619                   supports_sparse_csc=True,
17620                   supports_sparse_bsr=True,
17621                   supports_sparse_bsc=True,
17622                   assert_autodiffed=True),
17623    UnaryUfuncInfo('exp2',
17624                   aliases=('special.exp2', ),
17625                   ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2),
17626                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
17627                   supports_forward_ad=True,
17628                   supports_fwgrad_bwgrad=True,
17629                   promotes_int_to_float=True,
17630                   skips=(
17631                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17632                                    dtypes=[torch.cdouble]),
17633                       # Reference: https://github.com/pytorch/pytorch/issues/48010
17634                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17635                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
17636                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17637                                    device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
17638                   )),
17639    UnaryUfuncInfo('expm1',
17640                   aliases=('special.expm1', ),
17641                   ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1),
17642                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
17643                   supports_forward_ad=True,
17644                   supports_fwgrad_bwgrad=True,
17645                   supports_sparse=True,
17646                   supports_sparse_csr=True,
17647                   supports_sparse_csc=True,
17648                   supports_sparse_bsr=True,
17649                   supports_sparse_bsc=True,
17650                   promotes_int_to_float=True,
17651                   assert_autodiffed=True,
17652                   skips=(
17653                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17654                                    device_type='cuda', dtypes=[torch.complex128]),
17655                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
17656                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
17657                   )),
17658    UnaryUfuncInfo('nan_to_num',
17659                   ref=np.nan_to_num,
17660                   dtypes=all_types_and(torch.half, torch.bool, torch.bfloat16),
17661                   dtypesIfCUDA=all_types_and(torch.half, torch.bool, torch.bfloat16),
17662                   supports_forward_ad=True,
17663                   supports_fwgrad_bwgrad=True,
17664                   supports_sparse=True,
17665                   skips=(
17666                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
17667                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
17668                   ),
17669                   # Passing numpy_kwargs via sample_kwargs, as numpy does comparison
17670                   # with BFloat16 in float, since it currently doesn't support BFloat16.
17671                   # Ref: https://github.com/pytorch/pytorch/issues/57982#issuecomment-839150556
17672                   sample_kwargs=lambda device, dtype, input: ({},
17673                                                               {'posinf': torch.finfo(torch.bfloat16).max,
17674                                                                'neginf': torch.finfo(torch.bfloat16).min})
17675                   if dtype is torch.bfloat16 else ({}, {})),
17676    UnaryUfuncInfo('reciprocal',
17677                   ref=np_unary_ufunc_integer_promotion_wrapper(np.reciprocal),
17678                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
17679                   assert_autodiffed=True,
17680                   supports_forward_ad=True,
17681                   supports_fwgrad_bwgrad=True,
17682                   promotes_int_to_float=True,
17683                   skips=(
17684                       # Reference: https://github.com/pytorch/pytorch/issues/45690
17685                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17686                                    dtypes=[torch.cfloat, torch.cdouble]),
17687                   )),
17688    UnaryUfuncInfo('rsqrt',
17689                   ref=lambda x: np.reciprocal(np.sqrt(x)),
17690                   domain=(0, None),
17691                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
17692                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
17693                   decorators=(precisionOverride({torch.half: 5e-2}),),
17694                   assert_autodiffed=True,
17695                   supports_forward_ad=True,
17696                   supports_fwgrad_bwgrad=True,
17697                   promotes_int_to_float=True,
17698                   skips=(
17699                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17700                                    dtypes=(torch.cfloat, torch.cdouble)),
17701                       # AssertionError: Tensor-likes are not close!
17702                       # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed)
17703                       # Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
17704                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large',
17705                                    dtypes=(torch.chalf,)),
17706                   )),
17707    UnaryUfuncInfo('sqrt',
17708                   ref=np.sqrt,
17709                   supports_sparse=True,
17710                   domain=(0, None),
17711                   dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
17712                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
17713                   assert_autodiffed=True,
17714                   supports_forward_ad=True,
17715                   supports_sparse_csr=True,
17716                   supports_sparse_csc=True,
17717                   supports_sparse_bsr=True,
17718                   supports_sparse_bsc=True,
17719                   supports_fwgrad_bwgrad=True,
17720                   promotes_int_to_float=True,
17721                   decorators=(
17722                       precisionOverride({torch.bfloat16: 7e-2}),
17723                       DecorateInfo(
17724                           toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
17725                           'TestUnaryUfuncs', 'test_reference_numerics_large'),
17726                   ),
17727                   skips=(
17728                       # Reference: https://github.com/pytorch/pytorch/issues/47358
17729                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17730                                    device_type='cpu', dtypes=(torch.cfloat, torch.cdouble),
17731                                    active_if=IS_MACOS),
17732                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
17733                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
17734                   )),
17735    UnaryUfuncInfo('square',
17736                   ref=np.square,
17737                   dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
17738                   decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),),
17739                   supports_forward_ad=True,
17740                   supports_fwgrad_bwgrad=True,
17741                   skips=(
17742                       # Reference: https://github.com/pytorch/pytorch/issues/52549
17743                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
17744                                    dtypes=[torch.cfloat, torch.cdouble]),
17745                       # >>> t = torch.tensor(complex(-0.01, float("inf")))
17746                       # >>> np.square(t.numpy())
17747                       # (-inf-infj)
17748                       # >>> t.square()
17749                       # tensor(-inf-infj)
17750                       # >>> t.cuda().square()
17751                       # tensor(inf+nanj, device='cuda:0')
17752                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
17753                                    device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
17754                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace',
17755                                    dtypes=[torch.bool]),
17756                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace',
17757                                    dtypes=[torch.bool]),
17758                       DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace',
17759                                    dtypes=[torch.bool]),
17760                   ),),
17761    OpInfo('lerp',
17762           dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
17763           dtypesIfCUDA=floating_and_complex_types_and(torch.chalf, torch.half, torch.bfloat16),
17764           sample_inputs_func=sample_inputs_lerp,
17765           supports_forward_ad=True,
17766           supports_fwgrad_bwgrad=True,
17767           assert_autodiffed=True),
17768    UnaryUfuncInfo('angle',
17769                   ref=np.angle,
17770                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
17771                   dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool),
17772                   decorators=(precisionOverride({torch.float16: 1e-2,
17773                                                  torch.bfloat16: 1e-2}),),
17774                   backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
17775                   backward_dtypesIfCUDA=floating_and_complex_types_and(torch.chalf),
17776                   supports_forward_ad=True,
17777                   supports_fwgrad_bwgrad=True,
17778                   supports_sparse_csr=True,
17779                   supports_sparse_csc=True,
17780                   supports_sparse_bsr=True,
17781                   supports_sparse_bsc=True,
17782                   supports_complex_to_float=True,
17783                   skips=(
17784                       # Ref: https://github.com/pytorch/pytorch/issues/78413
17785                       DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_small',
17786                                    dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64),),
17787                   )),
17788    UnaryUfuncInfo('isfinite',
17789                   ref=np.isfinite,
17790                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
17791                   supports_out=False,
17792                   supports_autograd=False),
17793    UnaryUfuncInfo('isinf',
17794                   ref=np.isinf,
17795                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
17796                   supports_out=False,
17797                   supports_sparse=True,
17798                   supports_sparse_csr=True,
17799                   supports_sparse_csc=True,
17800                   supports_sparse_bsr=True,
17801                   supports_sparse_bsc=True,
17802                   supports_autograd=False),
17803    UnaryUfuncInfo('isposinf',
17804                   ref=np.isposinf,
17805                   dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
17806                   supports_sparse=True,
17807                   supports_sparse_csr=True,
17808                   supports_sparse_csc=True,
17809                   supports_sparse_bsr=True,
17810                   supports_sparse_bsc=True,
17811                   supports_autograd=False),
17812    UnaryUfuncInfo('isneginf',
17813                   ref=np.isneginf,
17814                   dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
17815                   supports_sparse=True,
17816                   supports_sparse_csr=True,
17817                   supports_sparse_csc=True,
17818                   supports_sparse_bsr=True,
17819                   supports_sparse_bsc=True,
17820                   supports_autograd=False),
17821    UnaryUfuncInfo('isreal',
17822                   ref=np.isreal,
17823                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
17824                   supports_out=False,
17825                   supports_autograd=False),
17826    UnaryUfuncInfo('isnan',
17827                   ref=np.isnan,
17828                   dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
17829                   supports_out=False,
17830                   supports_sparse=True,
17831                   supports_sparse_csr=True,
17832                   supports_sparse_csc=True,
17833                   supports_sparse_bsr=True,
17834                   supports_sparse_bsc=True,
17835                   supports_autograd=False),
17836    OpInfo('einsum',
17837           # we need this lambda because SampleInput expects tensor input as the first argument
17838           # TODO(@heitorschueroff) update SampleInput to handle such cases
17839           op=lambda tensors, equation: torch.einsum(equation, tensors),
17840           dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
17841           dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
17842           backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
17843           supports_out=False,
17844           supports_forward_ad=True,
17845           supports_fwgrad_bwgrad=True,
17846           check_batched_forward_grad=False,
17847           # See https://github.com/pytorch/pytorch/issues/66357
17848           sample_inputs_func=sample_inputs_einsum,
17849           skips=(
17850               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
17851               # test does not work with passing lambda for op
17852               # there's a test `test_einsum` in `test_jit.py` to handle this case
17853               # AssertionError: JIT Test does not execute any logic
17854               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
17855           )),
17856    OpInfo('svd',
17857           op=torch.svd,
17858           dtypes=floating_and_complex_types(),
17859           sample_inputs_func=sample_inputs_svd,
17860           # Runs very slowly on slow-gradcheck - alternatively reduce input sizes
17861           gradcheck_fast_mode=True,
17862           supports_forward_ad=True,
17863           supports_fwgrad_bwgrad=True,
17864           check_batched_forward_grad=False,
17865           # We're using at::allclose, which does not have a batching rule
17866           check_batched_grad=False,
17867           check_batched_gradgrad=False,
17868           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
17869           skips=(
17870               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
17871               DecorateInfo(
17872                   unittest.skip("Skipped!"),
17873                   'TestSchemaCheckModeOpInfo',
17874                   'test_schema_correctness',
17875                   dtypes=(torch.complex64, torch.complex128)),
17876               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
17877                            device_type='mps', dtypes=[torch.float32]),
17878               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
17879                            device_type='mps', dtypes=[torch.float32]),
17880               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
17881                            device_type='mps', dtypes=[torch.float32]),
17882           )),
17883    OpInfo('svd_lowrank',
17884           op=lambda *args, **kwargs: wrapper_set_seed(
17885               lambda a, b, **kwargs: torch.svd_lowrank(a @ b.mT, **kwargs),
17886               *args, **kwargs
17887           ),
17888           dtypes=floating_and_complex_types(),
17889           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
17890           gradcheck_fast_mode=True,
17891           supports_out=False,
17892           # Due to the use of randomness
17893           check_batched_grad=False,
17894           check_batched_gradgrad=False,
17895           check_batched_forward_grad=False,
17896           supports_fwgrad_bwgrad=True,
17897           supports_forward_ad=True,
17898           sample_inputs_func=sample_inputs_svd_lowrank,
17899           decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off,
17900                       DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03),
17901                                                       torch.complex64: tol(atol=1e-02, rtol=1e-02)}),
17902                                    'TestCommon', 'test_noncontiguous_samples'),
17903                       # FIXME This should be the following, but the toleranceOverride does not seem to do anything!
17904                       # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}),
17905                       #              'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
17906                       DecorateInfo(unittest.skip("See comment above"),
17907                                    'TestFwdGradients',
17908                                    'test_fn_fwgrad_bwgrad',
17909                                    dtypes=[torch.complex128]),
17910                       ],
17911           skips=(
17912               # test does not work with passing lambda for op
17913               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
17914               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
17915               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
17916               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
17917               DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
17918                            dtypes=(torch.complex64, torch.complex128)),
17919               DecorateInfo(slowTest, 'TestCompositeCompliance', 'test_forward_ad'),
17920           )),
17921    OpInfo('pca_lowrank',
17922           op=lambda *args, **kwargs: wrapper_set_seed(
17923               lambda a, b, **kwargs: torch.pca_lowrank(a @ b.mT, **kwargs),
17924               *args, **kwargs
17925           ),
17926           dtypes=floating_and_complex_types(),
17927           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
17928           gradcheck_fast_mode=True,
17929           supports_out=False,
17930           check_batched_forward_grad=False,
17931           check_batched_grad=False,
17932           check_batched_gradgrad=False,
17933           supports_forward_ad=True,
17934           supports_fwgrad_bwgrad=True,
17935           sample_inputs_func=sample_inputs_pca_lowrank,
17936           decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off,
17937                       DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03),
17938                                                       torch.complex64: tol(atol=4e-02, rtol=4e-02)}),
17939                                    'TestCommon', 'test_noncontiguous_samples'),
17940                       DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=5e-05)}),
17941                                    'TestOperators', 'test_grad'),
17942                       # FIXME This should be the following, but the toleranceOverride does not seem to do anything!
17943                       # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}),
17944                       #              'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
17945                       DecorateInfo(unittest.skip("See comment above"),
17946                                    'TestFwdGradients',
17947                                    'test_fn_fwgrad_bwgrad',
17948                                    dtypes=[torch.complex128]),
17949                       DecorateInfo(
17950                           toleranceOverride({torch.float32: tol(atol=3e-5, rtol=1e-3)}),
17951                           'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'),
17952                       ],
17953           skips=(
17954               # test does not work with passing lambda for op
17955               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
17956               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
17957               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
17958               DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness',
17959                            dtypes=(torch.complex64, torch.complex128)),
17960               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
17961           )),
17962    BinaryUfuncInfo('polar',
17963                    dtypes=floating_types(),
17964                    # this function is undefined if 'abs' values are <0
17965                    supports_forward_ad=True,
17966                    lhs_make_tensor_kwargs=dict(low=0),
17967                    supports_rhs_python_scalar=False,
17968                    skips=(
17969                        # RuntimeError: Expected object of scalar type Float but got scalar type Double for second argument
17970                        DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_type_promotion'),
17971                        DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
17972                        # GradcheckError: Jacobian computed with forward mode mismatch for output 0 with respect to input 0
17973                        # Numerical:
17974                        #  tensor([[0.]], dtype=torch.float64)
17975                        # Analytical:
17976                        # tensor([[-0.0047]], dtype=torch.float64, grad_fn=<CopySlices>)
17977                        DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
17978                    )),
17979    # TODO(@kshitij12345): Refactor similar to `mvlgamma` entries.
17980    # To test reference numerics against multiple values of argument `n`,
17981    # we make multiple OpInfo entries with each entry corresponding to different value of n (currently 0 to 4).
17982    # We run the op tests from test_ops.py only for `n=0` to avoid redundancy in testing.
17983    UnaryUfuncInfo('polygamma',
17984                   op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
17985                   variant_test_name='polygamma_n_0',
17986                   ref=reference_polygamma if TEST_SCIPY else None,
17987                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
17988                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
17989                   supports_forward_ad=True,
17990                   supports_fwgrad_bwgrad=True,
17991                   promotes_int_to_float=True,
17992                   sample_inputs_func=sample_inputs_polygamma,
17993                   skips=(
17994                       DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
17995                   ),
17996                   sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0}),
17997                   # polygamma functions have multiple singularities at x having non-positive integer value
17998                   reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4),
17999                                                            safe_val=1)),
18000    *(UnaryUfuncInfo('polygamma',
18001                     op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
18002                     variant_test_name=f'polygamma_n_{n_}',
18003                     ref=reference_polygamma if TEST_SCIPY else None,
18004                     dtypes=all_types_and(torch.bool, torch.bfloat16),
18005                     dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
18006                     supports_forward_ad=True,
18007                     supports_fwgrad_bwgrad=True,
18008                     promotes_int_to_float=True,
18009                     sample_inputs_func=sample_inputs_polygamma,
18010                     decorators=(
18011                         DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-3)}), 'TestUnaryUfuncs'),
18012                         DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e1, rtol=1e-1),
18013                                                         torch.float32: tol(atol=1e-4, rtol=1e-2)}),
18014                                      'TestUnaryUfuncs', 'test_reference_numerics_normal',
18015                                      active_if=IS_WINDOWS),
18016                     ),
18017                     skips=(
18018                         # Redundant tests
18019                         DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
18020                         DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
18021                         DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
18022                         DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
18023                         DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
18024                         # Mismatch: https://github.com/pytorch/pytorch/issues/55357
18025                         DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
18026                         DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'),
18027                     ),
18028                     sample_kwargs=lambda device, dtype, input: ({'n': n_}, {'n': n_}),
18029                     # polygamma functions have multiple singularities at x having non-positive integer value
18030                     reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4),
18031                                                              safe_val=1))
18032      for n_ in (1, 2, 3, 4)),
18033    OpInfo('ravel',
18034           ref=np.ravel,
18035           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
18036           supports_out=False,
18037           supports_forward_ad=True,
18038           supports_fwgrad_bwgrad=True,
18039           # See https://github.com/pytorch/pytorch/pull/78358
18040           check_batched_forward_grad=False,
18041           sample_inputs_func=sample_inputs_ravel,
18042           ),
18043    OpInfo('unravel_index',
18044           ref=np.unravel_index,
18045           dtypes=integral_types_and(),
18046           supports_out=False,
18047           supports_autograd=False,
18048           sample_inputs_func=sample_inputs_unravel_index,
18049           ),
18050    OpInfo('reshape',
18051           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
18052           sample_inputs_func=sample_inputs_view_reshape,
18053           reference_inputs_func=reference_inputs_view_reshape,
18054           error_inputs_func=error_inputs_view_reshape,
18055           supports_out=False,
18056           supports_forward_ad=True,
18057           supports_fwgrad_bwgrad=True,
18058           ),
18059    OpInfo('reshape_as',
18060           op=lambda x, other: x.reshape_as(other),
18061           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
18062           sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True),
18063           reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True),
18064           error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True),
18065           supports_out=False,
18066           supports_forward_ad=True,
18067           supports_fwgrad_bwgrad=True,
18068           skips=(
18069               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18070           )),
18071    OpInfo('view',
18072           op=lambda x, shape: x.view(shape),
18073           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
18074           supports_out=False,
18075           supports_forward_ad=True,
18076           supports_fwgrad_bwgrad=True,
18077           assert_jit_shape_analysis=True,
18078           sample_inputs_func=sample_inputs_view_reshape,
18079           reference_inputs_func=reference_inputs_view_reshape,
18080           error_inputs_func=error_inputs_view_reshape,
18081           skips=(
18082               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18083               # RuntimeError: view size is not compatible with input tensor's size and stride
18084               # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
18085               DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
18086           )),
18087    OpInfo('view_as',
18088           op=lambda x, other: x.view_as(other),
18089           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
18090           supports_out=False,
18091           supports_forward_ad=True,
18092           supports_fwgrad_bwgrad=True,
18093           sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True),
18094           reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True),
18095           error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True),
18096           skips=(
18097               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18098               # RuntimeError: view size is not compatible with input tensor's size and stride
18099               DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides")
18100           )),
18101    OpInfo('atleast_1d',
18102           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
18103           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
18104           gradcheck_fast_mode=True,
18105           supports_out=False,
18106           supports_forward_ad=True,
18107           supports_fwgrad_bwgrad=True,
18108           # See https://github.com/pytorch/pytorch/pull/78358
18109           check_batched_forward_grad=False,
18110           sample_inputs_func=sample_inputs_atleast1d2d3d,
18111           skips=(
18112               # JIT does not support variadic tensors.
18113               # RuntimeError: input->type()->kind() == TypeKind::OptionalType
18114               # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
18115               # please report a bug to PyTorch.
18116               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18117               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
18118           ),
18119           ),
18120    OpInfo('atleast_2d',
18121           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
18122           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
18123           gradcheck_fast_mode=True,
18124           supports_out=False,
18125           supports_forward_ad=True,
18126           supports_fwgrad_bwgrad=True,
18127           # See https://github.com/pytorch/pytorch/pull/78358
18128           check_batched_forward_grad=False,
18129           skips=(
18130               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18131               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
18132           ),
18133           sample_inputs_func=sample_inputs_atleast1d2d3d,
18134           ),
18135    OpInfo('atleast_3d',
18136           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
18137           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
18138           gradcheck_fast_mode=True,
18139           supports_out=False,
18140           supports_forward_ad=True,
18141           supports_fwgrad_bwgrad=True,
18142           # See https://github.com/pytorch/pytorch/pull/78358
18143           check_batched_forward_grad=False,
18144           skips=(
18145               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18146               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
18147           ),
18148           sample_inputs_func=sample_inputs_atleast1d2d3d,
18149           ),
18150    OpInfo('flatten',
18151           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
18152           ref=reference_flatten,
18153           supports_out=False,
18154           supports_forward_ad=True,
18155           supports_fwgrad_bwgrad=True,
18156           # See https://github.com/pytorch/pytorch/pull/78358
18157           check_batched_forward_grad=False,
18158           sample_inputs_func=sample_inputs_flatten,
18159           reference_inputs_func=reference_inputs_flatten,
18160           ),
18161    OpInfo('unflatten',
18162           op=torch.unflatten,
18163           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
18164           supports_out=False,
18165           supports_forward_ad=True,
18166           supports_fwgrad_bwgrad=True,
18167           sample_inputs_func=sample_inputs_unflatten,
18168           ),
18169    OpInfo('column_stack',
18170           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
18171           supports_forward_ad=True,
18172           supports_fwgrad_bwgrad=True,
18173           # See https://github.com/pytorch/pytorch/pull/78358
18174           check_batched_forward_grad=False,
18175           sample_inputs_func=sample_inputs_column_stack,),
18176    OpInfo('pinverse',
18177           op=torch.pinverse,
18178           dtypes=floating_and_complex_types(),
18179           check_batched_grad=False,
18180           check_batched_gradgrad=False,
18181           supports_forward_ad=True,
18182           supports_fwgrad_bwgrad=True,
18183           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
18184           supports_out=False,
18185           sample_inputs_func=sample_inputs_linalg_invertible,
18186           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
18187           skips=(
18188               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
18189                            device_type='mps', dtypes=[torch.float32]),
18190               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
18191                            device_type='mps', dtypes=[torch.float32]),
18192           )),
18193    OpInfo('gather',
18194           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
18195           dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
18196           sample_inputs_func=sample_inputs_gather,
18197           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
18198           supports_forward_ad=True,
18199           supports_fwgrad_bwgrad=True,
18200           error_inputs_func=error_inputs_gather,
18201           ),
18202    OpInfo('index_fill',
18203           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32),
18204           supports_out=False,
18205           supports_forward_ad=True,
18206           supports_fwgrad_bwgrad=True,
18207           # https://github.com/pytorch/pytorch/issues/66357
18208           check_batched_forward_grad=False,
18209           skips=(
18210               # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal!
18211               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
18212               # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal!
18213               DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_amp'),
18214           ),
18215           sample_inputs_func=sample_inputs_index,
18216           reference_inputs_func=partial(sample_inputs_index, reference=True)),
18217    OpInfo('index_copy',
18218           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32),
18219           supports_out=True,
18220           supports_forward_ad=True,
18221           supports_fwgrad_bwgrad=True,
18222           # https://github.com/pytorch/pytorch/issues/66357
18223           check_batched_forward_grad=False,
18224           sample_inputs_func=sample_inputs_index,
18225           reference_inputs_func=partial(sample_inputs_index, reference=True),
18226           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
18227    OpInfo('index_select',
18228           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
18229           backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf),
18230           sample_inputs_func=sample_inputs_index,
18231           reference_inputs_func=partial(sample_inputs_index, reference=True),
18232           error_inputs_func=error_inputs_index_select,
18233           supports_forward_ad=True,
18234           supports_fwgrad_bwgrad=True,
18235           assert_jit_shape_analysis=True,
18236           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
18237    OpInfo('index_add',
18238           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
18239           supports_out=True,
18240           supports_forward_ad=True,
18241           supports_fwgrad_bwgrad=True,
18242           # https://github.com/pytorch/pytorch/issues/66357
18243           check_batched_forward_grad=False,
18244           sample_inputs_func=sample_inputs_index,
18245           reference_inputs_func=partial(sample_inputs_index, reference=True),
18246           error_inputs_func=error_inputs_index_add,
18247           skips=(
18248               # boolean alpha not handled properly
18249               DecorateInfo(unittest.expectedFailure,
18250                            'TestNNCOpInfo',
18251                            'test_nnc_correctness',
18252                            dtypes=(torch.bool,)),
18253           ),
18254           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
18255    *(OpInfo('index_reduce',
18256             variant_test_name=reduction_type,
18257             dtypes=all_types_and(torch.float16, torch.bfloat16),
18258             skips=(
18259                 DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-3)}),
18260                              'TestInductorOpInfo', 'test_comprehensive'),
18261             ),
18262             supports_out=True,
18263             sample_inputs_func=sample_inputs_index_reduce,
18264             ) for reduction_type in ('mean', 'prod', 'amin', 'amax')),
18265    OpInfo('_unsafe_masked_index',
18266           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
18267           supports_out=False,
18268           supports_inplace_autograd=False,
18269           supports_scripting=False,
18270           supports_forward_ad=True,
18271           supports_fwgrad_bwgrad=True,
18272           sample_inputs_func=sample_inputs__unsafe_masked_index,
18273           skips=(
18274               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
18275               DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward',
18276                            dtypes=(torch.float64,), active_if=IS_WINDOWS),
18277           ),),
18278    OpInfo('_unsafe_masked_index_put_accumulate',
18279           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
18280           supports_out=False,
18281           supports_inplace_autograd=False,
18282           supports_scripting=False,
18283           supports_forward_ad=True,
18284           supports_fwgrad_bwgrad=True,
18285           decorators=(
18286               DecorateInfo(
18287                   toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-2)}),
18288                   'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'
18289               ),
18290           ),
18291           sample_inputs_func=sample_inputs__unsafe_masked_index_put_accumulate,
18292           skips=(
18293               DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward',
18294                            dtypes=(torch.float64,), active_if=IS_WINDOWS),
18295           ),),
18296    OpInfo('__getitem__',
18297           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
18298           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
18299           gradcheck_fast_mode=True,
18300           supports_out=False,
18301           supports_forward_ad=True,
18302           supports_fwgrad_bwgrad=True,
18303           supports_inplace_autograd=False,
18304           supports_scripting=False,
18305           op=torch.Tensor.__getitem__,
18306           skips=(
18307               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18308               # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 104448
18309               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),),
18310           sample_inputs_func=sample_inputs_getitem),
18311    OpInfo('index_put',
18312           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
18313           supports_out=False,
18314           supports_inplace_autograd=True,
18315           supports_forward_ad=True,
18316           supports_fwgrad_bwgrad=True,
18317           # https://github.com/pytorch/pytorch/issues/66357
18318           check_batched_forward_grad=False,
18319           test_neg_view=False,
18320           sample_inputs_func=sample_inputs_index_put,
18321           skips=(
18322               DecorateInfo(unittest.skip("Skipped"), 'TestBwdGradients', 'test_fn_grad', dtypes=[torch.float64],
18323                            device_type='cuda', active_if=(TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)),
18324           )),
18325    OpInfo('sort',
18326           dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
18327           dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
18328           sample_inputs_func=sample_inputs_sort,
18329           supports_forward_ad=True,
18330           supports_fwgrad_bwgrad=True,
18331           skips=(
18332           )),
18333    OpInfo('unique',
18334           dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64),
18335           dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.uint16, torch.uint32, torch.uint64),
18336           sample_inputs_func=sample_inputs_unique,
18337           supports_out=False,
18338           supports_autograd=False,
18339           skips=(
18340               # lambda impl
18341               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18342               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18343               DecorateInfo(unittest.skip('Output order is undefined when sorted=False'), 'TestCommon', 'test_compare_cpu'),
18344           )),
18345    OpInfo('unique_consecutive',
18346           dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
18347           dtypesIfCUDA=all_types_and(torch.bool, torch.float16),
18348           sample_inputs_func=sample_inputs_unique_consecutive,
18349           supports_out=False,
18350           supports_autograd=False,
18351           skips=(
18352               # lambda impl
18353               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18354               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18355           )),
18356    OpInfo('put',
18357           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
18358           supports_out=False,
18359           supports_forward_ad=True,
18360           supports_fwgrad_bwgrad=True,
18361           check_batched_forward_grad=False,
18362           check_batched_gradgrad=False,  # vmap complains of the sizes
18363           sample_inputs_func=sample_inputs_put),
18364    OpInfo('take',
18365           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
18366           check_batched_grad=False,  # vmap complains of the sizes
18367           supports_forward_ad=True,
18368           supports_fwgrad_bwgrad=True,
18369           sample_inputs_func=sample_inputs_take,
18370           error_inputs_func=error_inputs_take),
18371    OpInfo('scatter',
18372           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
18373           supports_forward_ad=True,
18374           supports_fwgrad_bwgrad=True,
18375           sample_inputs_func=sample_inputs_scatter,
18376           error_inputs_func=error_inputs_scatter_and_scatter_add),
18377    UnaryUfuncInfo(
18378        'bfloat16',
18379        op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs),
18380        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18381        supports_out=False,
18382        sample_inputs_func=sample_inputs_conversion,
18383        skips=(
18384            # autograd tests don't handle operators that change dtype
18385            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
18386            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
18387            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18388            # RuntimeError: attribute lookup is not defined on builtin
18389            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18390            DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
18391        )),
18392    UnaryUfuncInfo(
18393        'bool',
18394        op=lambda x, *args, **kwargs: x.bool(*args, **kwargs),
18395        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18396        supports_out=False,
18397        sample_inputs_func=sample_inputs_conversion,
18398        supports_autograd=False,
18399        skips=(
18400            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18401            # RuntimeError: attributis not defined on builtin
18402            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18403        )),
18404    UnaryUfuncInfo(
18405        'byte',
18406        op=lambda x, *args, **kwargs: x.byte(*args, **kwargs),
18407        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
18408        supports_out=False,
18409        sample_inputs_func=sample_inputs_byte,
18410        # The autograd test runner cannot handle functions that change dtype
18411        supports_autograd=False,
18412        skips=(
18413            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18414            # RuntimeError: attribute lookup is not defined on builtin
18415            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18416            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
18417        )),
18418    UnaryUfuncInfo(
18419        'char',
18420        op=lambda x, *args, **kwargs: x.char(*args, **kwargs),
18421        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18422        supports_out=False,
18423        sample_inputs_func=sample_inputs_conversion,
18424        # The autograd test runner cannot handle functions that change dtype
18425        supports_autograd=False,
18426        skips=(
18427            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18428            # RuntimeError: attribute lookup is not defined on builtin
18429            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18430            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
18431        )),
18432    UnaryUfuncInfo(
18433        'double',
18434        op=lambda x, *args, **kwargs: x.double(*args, **kwargs),
18435        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18436        supports_out=False,
18437        sample_inputs_func=sample_inputs_conversion,
18438        supports_forward_ad=True,
18439        supports_fwgrad_bwgrad=True,
18440        skips=(
18441            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18442            # RuntimeError: attribute lookup is not defined on builtin
18443            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18444        )),
18445    UnaryUfuncInfo(
18446        'float',
18447        op=lambda x, *args, **kwargs: x.float(*args, **kwargs),
18448        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18449        supports_out=False,
18450        sample_inputs_func=sample_inputs_conversion,
18451        skips=(
18452            # autograd tests don't handle operators that change dtype
18453            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
18454            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
18455            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18456            # RuntimeError: attribute lookup is not defined on builtin
18457            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18458        )),
18459    UnaryUfuncInfo(
18460        'half',
18461        op=lambda x, *args, **kwargs: x.half(*args, **kwargs),
18462        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
18463        supports_out=False,
18464        sample_inputs_func=sample_inputs_conversion,
18465        supports_autograd=True,
18466        skips=(
18467            # autograd tests don't handle operators that change dtype
18468            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
18469            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
18470            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18471            # RuntimeError: attribute lookup is not defined on builtin
18472            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18473        )),
18474    UnaryUfuncInfo(
18475        'int',
18476        op=lambda x, *args, **kwargs: x.int(*args, **kwargs),
18477        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
18478        supports_out=False,
18479        sample_inputs_func=sample_inputs_conversion,
18480        supports_autograd=False,
18481        skips=(
18482            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18483            # RuntimeError: attribute lookup is not defined on builtin
18484            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18485            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
18486        )),
18487    UnaryUfuncInfo(
18488        'long',
18489        op=lambda x, *args, **kwargs: x.long(*args, **kwargs),
18490        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18491        supports_out=False,
18492        sample_inputs_func=sample_inputs_conversion,
18493        supports_autograd=False,
18494        skips=(
18495            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18496            # RuntimeError: attribute lookup is not defined on builtin
18497            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18498            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
18499        )),
18500    UnaryUfuncInfo(
18501        'short',
18502        op=lambda x, *args, **kwargs: x.short(*args, **kwargs),
18503        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
18504        supports_out=False,
18505        sample_inputs_func=sample_inputs_conversion,
18506        supports_autograd=False,
18507        skips=(
18508            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18509            # RuntimeError: attribute lookup is not defined on builtin
18510            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18511            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
18512        )),
18513    UnaryUfuncInfo(
18514        'cdouble',
18515        op=torch.Tensor.cdouble,
18516        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18517        supports_out=False,
18518        sample_inputs_func=sample_inputs_conversion,
18519        supports_forward_ad=True,
18520        supports_fwgrad_bwgrad=True,
18521        skips=(
18522            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18523            # RuntimeError: attribute lookup is not defined on builtin
18524            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18525            DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
18526        )),
18527    UnaryUfuncInfo(
18528        'cfloat',
18529        op=torch.Tensor.cfloat,
18530        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18531        supports_out=False,
18532        sample_inputs_func=sample_inputs_conversion,
18533        skips=(
18534            # autograd tests don't handle operators that change dtype
18535            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
18536            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
18537            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18538            # RuntimeError: attribute lookup is not defined on builtin
18539            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18540            DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
18541        )),
18542    UnaryUfuncInfo(
18543        'chalf',
18544        op=lambda x, *args, **kwargs: x.chalf(*args, **kwargs),
18545        dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18546        supports_out=False,
18547        sample_inputs_func=sample_inputs_conversion,
18548        skips=(
18549            # autograd tests don't handle operators that change dtype
18550            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'),
18551            DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'),
18552            # use of lambda doesn't work with test_normalize_operator_exhaustive
18553            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
18554            # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
18555            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager',
18556                         device_type='cpu'),
18557            # TypeError: 'int' object is not iterable
18558            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18559            # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
18560            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view',
18561                         device_type='cpu'),
18562            # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
18563            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view',
18564                         device_type='cpu'),
18565            # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
18566            # RuntimeError: "neg_conj_cuda" not implemented for 'ComplexHalf'
18567            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
18568        )
18569    ),
18570    OpInfo('empty_like',
18571           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18572           supports_out=False,
18573           sample_inputs_func=sample_inputs_like_fns,
18574           reference_inputs_func=reference_inputs_like_fns,
18575           supports_autograd=False,
18576           skips=(
18577               # Empty tensor data is garbage so it's hard to make comparisons with it.
18578               DecorateInfo(unittest.skip("Skipped!"),
18579                            "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18580               # Empty tensor data is garbage so it's hard to make comparisons with it.
18581               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
18582               # Empty tensor data is garbage so it's hard to make comparisons with it.
18583               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
18584               # Empty tensor data is garbage so it's hard to make comparisons with it.
18585               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
18586               # Empty tensor data is garbage so it's hard to make comparisons with it.
18587               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
18588               # Empty tensor data is garbage so it's hard to make comparisons with it.
18589               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
18590               # Empty tensor data is garbage so it's hard to make comparisons with it.
18591               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
18592               # Empty tensor data is garbage so it's hard to make comparisons with it.
18593               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
18594               # Empty tensor data is garbage so it's hard to make comparisons with it.
18595               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing'),
18596               # Empty tensor data is garbage so it's hard to make comparisons with it.
18597               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
18598               DecorateInfo(unittest.skip("Expected: empty_like is not comparable"), 'TestCompositeCompliance',
18599                            'test_operator'),
18600               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
18601           )),
18602    OpInfo('zeros_like',
18603           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18604           supports_out=False,
18605           sample_inputs_func=sample_inputs_like_fns,
18606           supports_autograd=False,
18607           error_inputs_sparse_func=error_inputs_sparse_like_fns,
18608           sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo),
18609           sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr),
18610           sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc),
18611           sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr),
18612           sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc),
18613           skips=(
18614           )),
18615    OpInfo('ones_like',
18616           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18617           supports_out=False,
18618           sample_inputs_func=sample_inputs_like_fns,
18619           supports_autograd=False,
18620           skips=(
18621           )),
18622    OpInfo('randn',
18623           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32),
18624           op=lambda *args, **kwargs: wrapper_set_seed(torch.randn, *args, **kwargs),
18625           supports_out=True,
18626           sample_inputs_func=sample_inputs_randn,
18627           supports_autograd=False,
18628           skips=(
18629               # Tests that assume input is a tensor or sequence of tensors
18630               DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
18631               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
18632               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
18633               # CPU randn generates different values based on the strides of out tensor
18634               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
18635               # randn fails to warn when resizing its out tensor
18636               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
18637               # FX failed to normalize op - add the op to the op_skip list.
18638               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
18639               # Tests that assume input tensor has a meaningful effect on output tensor
18640               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
18641               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
18642               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
18643               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
18644               # AssertionError: JIT Test does not execute any logic
18645               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18646               DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'),
18647           )),
18648    OpInfo('randn_like',
18649           dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32),
18650           op=lambda inp, *args, **kwargs:
18651               wrapper_set_seed(torch.randn_like, inp, *args, **kwargs),
18652           supports_out=False,
18653           sample_inputs_func=sample_inputs_like_fns,
18654           supports_autograd=False,
18655           error_inputs_sparse_func=error_inputs_sparse_like_fns,
18656           sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo),
18657           sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr),
18658           sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc),
18659           sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr),
18660           sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc),
18661           skips=(
18662               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18663               # AssertionError: JIT Test does not execute any logic
18664               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18665               DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"),
18666                            'TestCommon', 'test_complex_half_reference_testing'),
18667               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
18668           )),
18669    OpInfo('rand_like',
18670           dtypes=floating_types_and(torch.half, torch.bfloat16, torch.complex32, torch.complex64, torch.complex128),
18671           op=lambda inp, *args, **kwargs:
18672               wrapper_set_seed(torch.randn_like, inp, *args, **kwargs),
18673           supports_out=False,
18674           sample_inputs_func=sample_inputs_like_fns,
18675           supports_autograd=False,
18676           skips=(
18677               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18678               # AssertionError: JIT Test does not execute any logic
18679               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18680               DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"),
18681                            'TestCommon', 'test_complex_half_reference_testing'),
18682               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
18683           )),
18684    OpInfo('randint',
18685           dtypes=all_types_and(torch.half, torch.bfloat16),
18686           op=lambda *args, **kwargs:
18687               wrapper_set_seed(torch.randint, *args, **kwargs),
18688           supports_out=False,
18689           sample_inputs_func=sample_inputs_randint,
18690           supports_autograd=False,
18691           skips=(
18692               # Tests that assume input is a tensor or sequence of tensors
18693               DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
18694               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"),
18695               DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
18696               # CPU randint generates different values based on the strides of out tensor
18697               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
18698               # randint fails to warn when resizing its out tensor
18699               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
18700               # FX failed to normalize op - add the op to the op_skip list.
18701               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
18702               # Tests that assume input tensor has a meaningful effect on output tensor
18703               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
18704               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
18705               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
18706               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
18707               # AssertionError: JIT Test does not execute any logic
18708               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18709               # Might need to skip until ROCm5.5
18710               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_multiple_devices',
18711                            dtypes=[torch.float32, torch.int64], active_if=TEST_WITH_ROCM),
18712           )),
18713    OpInfo('randint_like',
18714           dtypes=all_types_and(torch.half, torch.bfloat16),
18715           op=lambda inp, *args, **kwargs:
18716               wrapper_set_seed(torch.randint_like, inp, *args, **kwargs),
18717           supports_out=False,
18718           sample_inputs_func=sample_inputs_randint_like,
18719           supports_autograd=False,
18720           skips=(
18721               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18722               # AssertionError: JIT Test does not execute any logic
18723               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18724               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
18725           )),
18726    OpInfo('full_like',
18727           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
18728           supports_out=False,
18729           sample_inputs_func=sample_inputs_full_like,
18730           supports_autograd=False,
18731           skips=(
18732           )),
18733    OpInfo('new_zeros',
18734           op=lambda x, *args, **kwargs: x.new_zeros(*args, **kwargs),
18735           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18736           supports_out=False,
18737           sample_inputs_func=sample_inputs_new_fns,
18738           skips=(
18739               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18740           ),
18741           supports_autograd=False),
18742    OpInfo('new_ones',
18743           op=lambda x, *args, **kwargs: x.new_ones(*args, **kwargs),
18744           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18745           supports_out=False,
18746           sample_inputs_func=sample_inputs_new_fns,
18747           skips=(
18748               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18749           ),
18750           supports_autograd=False),
18751    OpInfo('ones',
18752           op=torch.ones,
18753           supports_autograd=False,
18754           supports_varargs=True,
18755           is_factory_function=True,
18756           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18757           supports_out=True,
18758           sample_inputs_func=sample_inputs_ones_zeros,
18759           skips=(
18760               # Tests that assume input is a tensor or sequence of tensors
18761               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
18762               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
18763               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
18764               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
18765
18766               # Same failure as arange: cannot find linspace in captured graph
18767               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
18768
18769               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
18770               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
18771           )),
18772    OpInfo('zeros',
18773           op=torch.zeros,
18774           supports_autograd=False,
18775           is_factory_function=True,
18776           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18777           supports_out=True,
18778           sample_inputs_func=sample_inputs_ones_zeros,
18779           skips=(
18780               # Tests that assume input is a tensor or sequence of tensors
18781               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
18782               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
18783               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
18784               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
18785
18786               # Same failure as arange: cannot find linspace in captured graph
18787               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
18788
18789               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
18790               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
18791           )),
18792    OpInfo('full',
18793           op=torch.full,
18794           supports_autograd=False,
18795           is_factory_function=True,
18796           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18797           supports_out=True,
18798           sample_inputs_func=sample_inputs_full,
18799           skips=(
18800               # Tests that assume input is a tensor or sequence of tensors
18801               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
18802               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
18803               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
18804               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
18805               # Same failure as arange: cannot find linspace in captured graph
18806               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
18807               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
18808               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
18809               # RuntimeError: UNSUPPORTED DTYPE: bool
18810               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bool,)),
18811           )),
18812    OpInfo('new_empty',
18813           op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs),
18814           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18815           supports_out=False,
18816           sample_inputs_func=sample_inputs_new_fns,
18817           skips=(
18818               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18819               # Empty tensor data is garbage so it's hard to make comparisons with it.
18820               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
18821               # Empty tensor data is garbage so it's hard to make comparisons with it.
18822               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
18823               # Empty tensor data is garbage so it's hard to make comparisons with it.
18824               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
18825               # Empty tensor data is garbage so it's hard to make comparisons with it.
18826               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
18827               # Empty tensor data is garbage so it's hard to make comparisons with it.
18828               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
18829               # Empty tensor data is garbage so it's hard to make comparisons with it.
18830               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
18831               # Empty tensor data is garbage so it's hard to make comparisons with it.
18832               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
18833               # Empty tensor data is garbage so it's hard to make comparisons with it.
18834               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
18835               # Empty tensor data is garbage so it's hard to make comparisons with it.
18836               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
18837               DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), 'TestCompositeCompliance',
18838                            'test_operator'),
18839               DecorateInfo(unittest.skip("Expected: new_empty is not comparable"),
18840                            'TestCommon', 'test_complex_half_reference_testing'),
18841               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
18842           ),
18843           supports_autograd=False),
18844    OpInfo('new_empty_strided',
18845           op=lambda x, *args, **kwargs: x.new_empty_strided(*args, **kwargs),
18846           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18847           supports_out=False,
18848           sample_inputs_func=partial(sample_inputs_new_fns, is_strided=True),
18849           supports_autograd=False,
18850           skips=(
18851               # FX failed to normalize op
18852               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18853               # Lazy tensor failures
18854               DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'),
18855               DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
18856               # Empty tensor data is garbage so it's hard to make comparisons with it.
18857               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18858                            'TestCommon', 'test_variant_consistency_eager'),
18859               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18860                            'TestCommon', 'test_noncontiguous_samples'),
18861               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18862                            'TestMathBits', 'test_conj_view'),
18863               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18864                            'TestMathBits', 'test_neg_view'),
18865               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18866                            'TestMathBits', 'test_neg_conj_view'),
18867               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18868                            'TestCommon', 'test_non_standard_bool_values'),
18869               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18870                            'TestCommon', 'test_complex_half_reference_testing'),
18871               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18872                            'TestCompositeCompliance', 'test_operator'),
18873               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18874                            'TestDecomp', 'test_comprehensive'),
18875               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18876                            'TestDecomp', 'test_quick'),
18877               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18878                            'TestJit', 'test_variant_consistency_jit'),
18879               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18880                            'TestProxyTensorOpInfo', 'test_make_fx_exhaustive'),
18881               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18882                            'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'),
18883               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18884                            'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'),
18885               DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"),
18886                            'TestNNCOpInfo', 'test_nnc_correctness'),
18887               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
18888           )),
18889    OpInfo('empty_strided',
18890           op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.empty_strided, inp, *args, **kwargs),
18891           dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.half),
18892           supports_out=False,
18893           supports_autograd=False,
18894           sample_inputs_func=sample_inputs_empty_strided,
18895           skips=(
18896               # FX failed to normalize op - add the op to the op_skip list.
18897               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
18898               # AssertionError: JIT Test does not execute any logic
18899               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18900               # Empty tensor data is garbage so it's hard to make comparisons with it.
18901               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
18902               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
18903               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
18904               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
18905               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
18906               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
18907               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'),
18908               DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', 'test_operator'),
18909               # Lazy tensor failures
18910               DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestLazyOpInfo'),
18911               # RuntimeError: unsupported operation: more than one element of the written-to tensor refers to a single
18912               # memory location. Please clone() the tensor before performing the operation.
18913               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'),
18914               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
18915               DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
18916           )),
18917    OpInfo('empty',
18918           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18919           sample_inputs_func=sample_inputs_empty,
18920           supports_autograd=False,
18921           skips=(
18922               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18923               # Empty tensor data is garbage so it's hard to make comparisons with it.
18924               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
18925               # Empty tensor data is garbage so it's hard to make comparisons with it.
18926               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
18927               # Empty tensor data is garbage so it's hard to make comparisons with it.
18928               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
18929               # Empty tensor data is garbage so it's hard to make comparisons with it.
18930               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
18931               # Empty tensor data is garbage so it's hard to make comparisons with it.
18932               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
18933               # Empty tensor data is garbage so it's hard to make comparisons with it.
18934               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
18935               # Empty tensor data is garbage so it's hard to make comparisons with it.
18936               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
18937               # Empty tensor data is garbage so it's hard to make comparisons with it.
18938               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
18939               # Empty tensor data is garbage so it's hard to make comparisons with it.
18940               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
18941               DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance',
18942                            'test_operator'),
18943               # requires_grad doesn't exist in the jit schema
18944               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
18945               DecorateInfo(unittest.skip("Expected: empty is not comparable"),
18946                            'TestCommon',
18947                            'test_out'),
18948               DecorateInfo(unittest.skip("Expected: empty is not comparable"),
18949                            'TestCommon',
18950                            'test_out_warning'),
18951               DecorateInfo(unittest.skip("Expected: empty is not comparable"),
18952                            'TestLazyOpInfo'),
18953               DecorateInfo(unittest.skip("Expected: empty is not comparable"),
18954                            'TestCommon', 'test_complex_half_reference_testing'),
18955               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
18956           )),
18957    OpInfo('eye',
18958           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
18959           sample_inputs_func=sample_inputs_eye,
18960           error_inputs_func=error_inputs_eye,
18961           supports_out=True,
18962           supports_autograd=False,
18963           skips=(
18964               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
18965               # TODO: same as this?
18966               # https://github.com/pytorch/pytorch/issues/81774
18967               # also see: arange, new_full
18968               # fails to match any schemas despite working in the interpreter
18969               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
18970               # fails to match any schemas despite working in the interpreter
18971               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
18972               # skip these tests since we have non tensor input
18973               DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
18974               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
18975               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
18976               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
18977               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
18978               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
18979               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
18980           )),
18981    OpInfo('empty_permuted',
18982           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
18983           sample_inputs_func=sample_inputs_empty_permuted,
18984           error_inputs_func=error_inputs_empty_permuted,
18985           supports_out=False,
18986           supports_autograd=False,
18987           skips=(
18988               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
18989               # Empty tensor data is garbage so it's hard to make comparisons with it.
18990               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
18991               # Empty tensor data is garbage so it's hard to make comparisons with it.
18992               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
18993               # Empty tensor data is garbage so it's hard to make comparisons with it.
18994               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
18995               # Empty tensor data is garbage so it's hard to make comparisons with it.
18996               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
18997               # Empty tensor data is garbage so it's hard to make comparisons with it.
18998               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
18999               # Empty tensor data is garbage so it's hard to make comparisons with it.
19000               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
19001               # Empty tensor data is garbage so it's hard to make comparisons with it.
19002               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
19003               # Empty tensor data is garbage so it's hard to make comparisons with it.
19004               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
19005               # Empty tensor data is garbage so it's hard to make comparisons with it.
19006               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
19007               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 'TestCompositeCompliance',
19008                            'test_operator'),
19009               # requires_grad doesn't exist in the jit schema
19010               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
19011               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
19012                            'TestCommon',
19013                            'test_out'),
19014               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
19015                            'TestCommon',
19016                            'test_out_warning'),
19017               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
19018                            'TestLazyOpInfo'),
19019               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
19020                            'TestCommon', 'test_complex_half_reference_testing'),
19021               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
19022           )),
19023    OpInfo('scalar_tensor',
19024           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
19025           sample_inputs_func=sample_inputs_scalar_tensor,
19026           supports_autograd=False,
19027           supports_out=False,
19028           skips=(
19029               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
19030               # fails to match any schemas despite working in the interpreter
19031               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
19032               # fails to match any schemas despite working in the interpreter
19033               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
19034               # skip these tests since we have non tensor input
19035               DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"),
19036               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
19037               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
19038               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
19039               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
19040           )),
19041    OpInfo('new_full',
19042           op=lambda x, *args, **kwargs: x.new_full(*args, **kwargs),
19043           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
19044           supports_out=False,
19045           sample_inputs_func=sample_inputs_new_full,
19046           skips=(
19047               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
19048           ),
19049           supports_autograd=False),
19050    OpInfo('multinomial',
19051           op=lambda inp, *args, **kwargs:
19052               wrapper_set_seed(torch.multinomial, inp, *args, **kwargs),
19053           method_variant=lambda inp, *args, **kwargs:
19054               wrapper_set_seed(torch.Tensor.multinomial, inp, *args, **kwargs),
19055           dtypes=floating_types_and(torch.bfloat16, torch.half),
19056           supports_out=True,
19057           sample_inputs_func=sample_inputs_multinomial,
19058           error_inputs_func=error_inputs_multinomial,
19059           skips=(
19060               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
19061               # Strides are not the same!
19062               # This may not be reproducible in CI
19063               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
19064               # AssertionError: JIT Test does not execute any logic
19065               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
19066               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
19067               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
19068               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
19069           supports_autograd=False),
19070    OpInfo('normal',
19071           op=lambda inp, *args, **kwargs:
19072               wrapper_set_seed(torch.normal, inp, *args, **kwargs),
19073           # The inplace variant (Tensor.normal_) is different from torch.normal
19074           inplace_variant=None,
19075           dtypes=floating_types_and(torch.bfloat16, torch.half),
19076           dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
19077           supports_out=True,
19078           sample_inputs_func=sample_inputs_normal_tensor_first,
19079           skips=(
19080               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
19081               # Tensor-likes are not close!
19082               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
19083               # AssertionError: JIT Test does not execute any logic
19084               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
19085               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
19086               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
19087               # Computed gradient is incorrect -- would be an exfail but gradgrad somehow passes
19088               DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestFwdGradients'),
19089               DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestBwdGradients'),
19090               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
19091               # RuntimeError: Difference from {dtype} is larger with decomposition
19092               DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'),
19093               DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'),
19094               # The inplace variant (Tensor.normal_) is different from torch.normal
19095               # inplace varaint Tensor.normal_ is decomposed using randn_like()
19096               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'))),
19097    OpInfo('normal',
19098           # This has its own variant b/c OpInfos assume the first arg is a Tensor but it is not here
19099           variant_test_name='number_mean',
19100           op=lambda std, mean, *args, **kwargs:
19101               wrapper_set_seed(torch.normal, mean, std, *args, **kwargs),
19102           # The inplace variant (Tensor.normal_) is different from torch.normal
19103           inplace_variant=None,
19104           dtypes=floating_types_and(torch.bfloat16, torch.half),
19105           dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
19106           supports_out=True,
19107           sample_inputs_func=sample_inputs_normal_tensor_second,
19108           skips=(
19109               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
19110               # AssertionError: JIT Test does not execute any logic
19111               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
19112               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
19113               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
19114               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
19115               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),
19116               DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'),
19117               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
19118               DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
19119               DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
19120               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_compare_cpu'),
19121               DecorateInfo(unittest.skip("Skipped!"), 'TestEagerFusionOpInfo'),
19122               DecorateInfo(unittest.skip("Skipped!"), 'TestOperators'),
19123               # AssertionError
19124               DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'),
19125               # AssertionError
19126               DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'),
19127               # AssertionError in CUDA variant
19128               DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', device_type='cuda'),
19129               DecorateInfo(unittest.skip("Skipped!"), 'TestDeviceUtils', 'test_device_mode_ops'))),
19130    OpInfo('bernoulli',
19131           op=lambda inp, *args, **kwargs:
19132               wrapper_set_seed(torch.bernoulli, inp, *args, **kwargs),
19133           # The inplace variant (Tensor.bernoulli_) is different from torch.bernoulli
19134           inplace_variant=None,
19135           method_variant=lambda inp, *args, **kwargs:
19136               wrapper_set_seed(torch.Tensor.bernoulli, inp, *args, **kwargs),
19137           dtypes=floating_types_and(torch.bfloat16, torch.half),
19138           supports_out=True,
19139           supports_forward_ad=True,
19140           supports_fwgrad_bwgrad=True,
19141           sample_inputs_func=sample_inputs_bernoulli,
19142           error_inputs_func=error_inputs_bernoulli,
19143           skips=(
19144               # vmap: We do not yet support calling random operations inside of vmap
19145               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
19146               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
19147               # AssertionError: JIT Test does not execute any logic
19148               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
19149               # Expected RuntimeError when doing an unsafe cast from a result of
19150               # dtype torch.float32 into an out= with dtype torch.lon
19151               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
19152               # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
19153               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
19154               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))),
19155    OpInfo('scatter_add',
19156           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
19157           sample_inputs_func=sample_inputs_scatter_add,
19158           error_inputs_func=error_inputs_scatter_and_scatter_add,
19159           supports_forward_ad=True,
19160           supports_fwgrad_bwgrad=True,
19161           ),
19162    OpInfo('stack',
19163           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
19164           sample_inputs_func=sample_inputs_stack,
19165           assert_autodiffed=True,
19166           supports_forward_ad=True,
19167           supports_fwgrad_bwgrad=True,
19168           skips=(
19169               # https://github.com/pytorch/pytorch/issues/77046
19170               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
19171               DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
19172           ),
19173           ),
19174    OpInfo('_chunk_cat',
19175           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
19176           sample_inputs_func=sample_inputs_chunk_cat,
19177           error_inputs_func=error_inputs_chunk_cat,
19178           supports_autograd=False,
19179           supports_out=True,
19180           ),
19181    OpInfo('hstack',
19182           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
19183           sample_inputs_func=sample_inputs_hstack_dstack_vstack,
19184           error_inputs_func=error_inputs_hstack_dstack_vstack,
19185           supports_forward_ad=True,
19186           supports_fwgrad_bwgrad=True,
19187           ),
19188    BinaryUfuncInfo('hypot',
19189                    dtypes=floating_types_and(torch.bfloat16, torch.half),
19190                    dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
19191                    supports_forward_ad=True,
19192                    supports_fwgrad_bwgrad=True,
19193                    supports_rhs_python_scalar=False),
19194    OpInfo('histogram',
19195           dtypes=floating_types(),
19196           dtypesIfCUDA=_dispatch_dtypes(),  # histogram is only implemented on CPU
19197           sample_inputs_func=sample_inputs_histogram,
19198           supports_autograd=False,
19199           skips=(
19200               # JIT tests don't work with Tensor keyword arguments
19201               # https://github.com/pytorch/pytorch/issues/58507
19202               # RuntimeError:
19203               # undefined value tensor:
19204               #   File "<string>", line 3
19205               # def the_method(i0):
19206               #     return torch.histogram(i0, 1, weight=tensor(-0.5735, dtype=torch.float32), density=False)
19207               #                                          ~~~~~~ <--- HERE
19208               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
19209               # Not Implemented on XLA.
19210               DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla'),
19211           )),
19212    OpInfo('histogramdd',
19213           dtypes=floating_types(),
19214           dtypesIfCUDA=_dispatch_dtypes(),  # histogramdd is only implemented on CPU
19215           sample_inputs_func=sample_inputs_histogramdd,
19216           error_inputs_func=error_inputs_histogramdd,
19217           supports_autograd=False,
19218           skips=(
19219               # Not implemented on CUDA
19220               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='cuda'),
19221               # JIT tests don't work with Tensor keyword arguments
19222               # https://github.com/pytorch/pytorch/issues/58507
19223               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
19224           )),
19225    OpInfo('histc',
19226           dtypes=floating_types_and(torch.bfloat16, torch.float16),
19227           dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64),
19228           sample_inputs_func=sample_inputs_histc,
19229           supports_out=True,
19230           supports_autograd=False,
19231           skips=(
19232               # CUDA histc returns a float tensor but does not correctly warn when passed an integral out tensor
19233               # "AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast
19234               # from a result of dtype torch.float32 into an out= with dtype torch.long"
19235               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'),
19236           )),
19237    OpInfo('bincount',
19238           dtypes=integral_types_and(),
19239           sample_inputs_func=sample_inputs_bincount,
19240           supports_out=False,
19241           supports_autograd=False,
19242           skips=(
19243               # JIT tests don't work with Tensor keyword arguments
19244               # https://github.com/pytorch/pytorch/issues/58507
19245               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
19246           )),
19247    OpInfo('bucketize',
19248           dtypes=all_types_and(torch.float16, torch.bfloat16),
19249           dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
19250           sample_inputs_func=sample_inputs_bucketize,
19251           reference_inputs_func=reference_inputs_bucketize,
19252           error_inputs_func=error_inputs_bucketize,
19253           supports_autograd=False,
19254           skips=(
19255               # JIT tests don't work with Tensor keyword arguments
19256               DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'),
19257           )),
19258    OpInfo('searchsorted',
19259           dtypes=all_types_and(torch.bfloat16, torch.float16),
19260           dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
19261           sample_inputs_func=sample_inputs_searchsorted,
19262           supports_autograd=False,
19263           ref=reference_searchsorted,
19264           skips=(
19265               # JIT tests don't work with Tensor keyword arguments
19266               # https://github.com/pytorch/pytorch/issues/58507
19267               DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'),
19268           )),
19269    OpInfo('cat',
19270           ref=_cat_np,
19271           aliases=('concat', 'concatenate'),
19272           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32),
19273           sample_inputs_func=sample_inputs_cat_concat,
19274           reference_inputs_func=reference_inputs_cat,
19275           error_inputs_func=error_inputs_cat,
19276           # https://github.com/pytorch/pytorch/issues/80411
19277           gradcheck_fast_mode=True,
19278           supports_forward_ad=True,
19279           supports_fwgrad_bwgrad=True,
19280           # See https://github.com/pytorch/pytorch/issues/66357
19281           check_batched_forward_grad=False,
19282           assert_autodiffed=True,
19283           skips=(
19284               # https://github.com/pytorch/pytorch/issues/89353
19285               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'),
19286               # RuntimeError: Arguments for call not valid.
19287               #               Expected a value of type 'List[Tensor]' for argument
19288               #               'tensors' but instead found type 'Tensor (inferred)'.
19289               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),
19290               # see https://github.com/pytorch/pytorch/issues/71286
19291               DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
19292               # see https://github.com/pytorch/pytorch/issues/99806
19293               # RuntimeError: The size of tensor a (25) must match the size of tensor b (0) at non-singleton dimension 0.
19294               DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'),
19295           )),
19296    OpInfo('unbind',
19297           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
19298           ref=reference_unbind,
19299           sample_inputs_func=sample_inputs_unbind,
19300           error_inputs_func=error_inputs_unbind,
19301           supports_forward_ad=True,
19302           supports_fwgrad_bwgrad=True,
19303           supports_gradgrad=True,
19304           supports_out=False,
19305           ),
19306    OpInfo('vstack',
19307           aliases=('row_stack',),
19308           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
19309           sample_inputs_func=sample_inputs_hstack_dstack_vstack,
19310           error_inputs_func=error_inputs_hstack_dstack_vstack,
19311           supports_forward_ad=True,
19312           supports_fwgrad_bwgrad=True,
19313           skips=(
19314               # RuntimeError: _fn() Expected a value of type
19315               #   'Tensor (inferred)' for argument 't0' but instead found type 'tuple'.
19316               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),)),
19317    OpInfo('dstack',
19318           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
19319           sample_inputs_func=sample_inputs_hstack_dstack_vstack,
19320           error_inputs_func=error_inputs_hstack_dstack_vstack,
19321           supports_forward_ad=True,
19322           supports_fwgrad_bwgrad=True,
19323           # See https://github.com/pytorch/pytorch/pull/78358
19324           check_batched_forward_grad=False,
19325           ),
19326    OpInfo('unfold',
19327           op=lambda x, *args: x.unfold(*args),
19328           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
19329           backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
19330           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
19331           gradcheck_fast_mode=True,
19332           supports_out=False,
19333           supports_forward_ad=True,
19334           supports_fwgrad_bwgrad=True,
19335           check_batched_gradgrad=False,
19336           # See https://github.com/pytorch/pytorch/issues/66357
19337           check_batched_forward_grad=False,
19338           skips=(
19339               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
19340               # Skip operator schema test because this is a functional and not an operator
19341               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
19342           ),
19343           sample_inputs_func=sample_inputs_unfold),
19344    OpInfo('unfold_copy',
19345           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
19346           backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
19347           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
19348           gradcheck_fast_mode=True,
19349           supports_out=True,
19350           supports_forward_ad=True,
19351           supports_fwgrad_bwgrad=True,
19352           check_batched_gradgrad=False,
19353           # See https://github.com/pytorch/pytorch/issues/66357
19354           check_batched_forward_grad=False,
19355           sample_inputs_func=sample_inputs_unfold),
19356    OpInfo('msort',
19357           dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
19358           dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
19359           check_batched_gradgrad=False,
19360           supports_forward_ad=True,
19361           supports_fwgrad_bwgrad=True,
19362           sample_inputs_func=sample_inputs_msort,
19363           skips=(
19364           )),
19365    OpInfo('movedim',
19366           aliases=('moveaxis',),
19367           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
19368           supports_out=False,
19369           supports_forward_ad=True,
19370           supports_fwgrad_bwgrad=True,
19371           # See https://github.com/pytorch/pytorch/pull/78358
19372           check_batched_forward_grad=False,
19373           sample_inputs_func=sample_movedim_moveaxis,
19374           reference_inputs_func=reference_movedim_moveaxis,
19375           error_inputs_func=error_movedim_moveaxis),
19376    OpInfo('renorm',
19377           dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
19378           sample_inputs_func=sample_inputs_renorm,
19379           error_inputs_func=error_inputs_renorm,
19380           supports_forward_ad=True,
19381           supports_fwgrad_bwgrad=True,
19382           skips=(
19383               # RuntimeError: Difference from float64 is larger with decomposition
19384               # linalg_vector_norm.default than original on output 0.
19385               # Original max diff: 2.560596747969157e-07,
19386               # Decomp max diff: 1.8187482915266173e-06
19387               DecorateInfo(unittest.skip("Inconsistent accuracy"), 'TestDecomp', 'test_comprehensive',
19388                            device_type='cpu', dtypes=(torch.float16,)),
19389           )),
19390    ShapeFuncInfo('repeat',
19391                  op=lambda x, dims: x.repeat(dims),
19392                  ref=np.tile,
19393                  dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
19394                  # https://github.com/pytorch/pytorch/issues/80411
19395                  gradcheck_fast_mode=True,
19396                  supports_out=False,
19397                  supports_forward_ad=True,
19398                  supports_fwgrad_bwgrad=True,
19399                  sample_inputs_func=sample_repeat_tile,
19400                  skips=(
19401                      DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
19402                  )),
19403    OpInfo('squeeze',
19404           ref=_squeeze_ref,
19405           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
19406           supports_out=False,
19407           assert_autodiffed=True,
19408           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
19409           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
19410           assert_jit_shape_analysis=True,
19411           supports_forward_ad=True,
19412           supports_fwgrad_bwgrad=True,
19413           # vmap does not support inplace views
19414           check_inplace_batched_forward_grad=False,
19415           # https://github.com/pytorch/pytorch/issues/66357
19416           check_batched_forward_grad=False,
19417           sample_inputs_func=sample_inputs_squeeze),
19418    OpInfo('squeeze',
19419           ref=_squeeze_ref,
19420           variant_test_name="multiple",
19421           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
19422           supports_out=False,
19423           assert_autodiffed=True,
19424           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
19425           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
19426           supports_forward_ad=True,
19427           supports_fwgrad_bwgrad=True,
19428           # vmap does not support inplace views
19429           check_inplace_batched_forward_grad=False,
19430           # https://github.com/pytorch/pytorch/issues/66357
19431           check_batched_forward_grad=False,
19432           sample_inputs_func=sample_inputs_squeeze_multiple),
19433    UnaryUfuncInfo(
19434        'fill',
19435        ref=_fill_np,
19436        method_variant=None,
19437        sample_kwargs=_fill_sample_kwargs,
19438        sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'value': True}),
19439        supports_forward_ad=True,
19440        supports_fwgrad_bwgrad=True,
19441        # https://github.com/pytorch/pytorch/issues/66357
19442        check_batched_forward_grad=False,
19443        dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
19444        supports_out=False,
19445        skips=(
19446            # JIT has issue when op is passed as lambda
19447            # AssertionError: JIT Test does not execute any logic
19448            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
19449            DecorateInfo(unittest.skip("No fill_ op"), 'TestCudaFuserOpInfo'),
19450            DecorateInfo(unittest.skip("No fill_ op"), 'TestNNCOpInfo'),
19451        )),
19452    OpInfo('resize_',
19453           op=lambda x, shape: x.clone().resize_(shape),
19454           method_variant=None,
19455           inplace_variant=torch.Tensor.resize_,
19456           # the test fails because resize_ doesn't work with imag views as expected by the test
19457           # https://github.com/pytorch/pytorch/issues/65945
19458           test_neg_view=False,
19459           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
19460           supports_out=False,
19461           supports_autograd=False,
19462           skips=(
19463               # Cannot resize variables that require grad
19464               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
19465               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
19466               DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'),
19467           ),
19468           sample_inputs_func=sample_inputs_resize_ops),
19469    OpInfo('resize_as_',
19470           op=lambda x, other: torch.resize_as_(x.clone(), other),
19471           method_variant=None,
19472           inplace_variant=torch.Tensor.resize_as_,
19473           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
19474           supports_out=False,
19475           supports_autograd=False,
19476           skips=(
19477               # Cannot resize variables that require grad
19478               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'),
19479               DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
19480           ),
19481           sample_inputs_func=sample_inputs_resize_ops),
19482    OpInfo('take_along_dim',
19483           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
19484           dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
19485           supports_inplace_autograd=False,
19486           supports_forward_ad=True,
19487           supports_fwgrad_bwgrad=True,
19488           # See https://github.com/pytorch/pytorch/pull/78358
19489           check_batched_forward_grad=False,
19490           sample_inputs_func=sample_inputs_take_along_dim,
19491           gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
19492           decorators=(
19493               # RuntimeError: view size is not compatible with input tensor's size and stride
19494               DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
19495           )),
19496    ShapeFuncInfo('tile',
19497                  ref=np.tile,
19498                  dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
19499                  # https://github.com/pytorch/pytorch/issues/80411
19500                  gradcheck_fast_mode=True,
19501                  supports_out=False,
19502                  supports_forward_ad=True,
19503                  supports_fwgrad_bwgrad=True,
19504                  sample_inputs_func=sample_repeat_tile),
19505    OpInfo('trapz',  # TODO: in the future, 'trapz' should be made a proper alias of 'trapezoid'
19506           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
19507           supports_out=False,
19508           supports_forward_ad=True,
19509           supports_fwgrad_bwgrad=True,
19510           # See https://github.com/pytorch/pytorch/pull/78358
19511           check_batched_forward_grad=False,
19512           decorators=[
19513               DecorateInfo(
19514                   toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}),
19515                   'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'
19516               ),
19517           ],
19518           sample_inputs_func=sample_trapezoid),
19519    OpInfo('trapezoid',
19520           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
19521           supports_out=False,
19522           supports_forward_ad=True,
19523           supports_fwgrad_bwgrad=True,
19524           # See https://github.com/pytorch/pytorch/pull/78358
19525           check_batched_forward_grad=False,
19526           decorators=[
19527               DecorateInfo(
19528                   toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}),
19529                   'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'
19530               ),
19531           ],
19532           sample_inputs_func=sample_trapezoid),
19533    OpInfo('cumulative_trapezoid',
19534           dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
19535           supports_forward_ad=True,
19536           supports_fwgrad_bwgrad=True,
19537           # See https://github.com/pytorch/pytorch/pull/78358
19538           check_batched_forward_grad=False,
19539           supports_out=False,
19540           decorators=(
19541               DecorateInfo(
19542                   toleranceOverride({torch.float16: tol(atol=4e-3, rtol=4e-3)}),
19543                   'TestInductorOpInfo', 'test_comprehensive',
19544               ),
19545           ),
19546           sample_inputs_func=sample_cumulative_trapezoid,),
19547    OpInfo('unsqueeze',
19548           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
19549           supports_out=False,
19550           supports_forward_ad=True,
19551           supports_fwgrad_bwgrad=True,
19552           # See https://github.com/pytorch/pytorch/pull/78358
19553           check_batched_forward_grad=False,
19554           # vmap does not support inplace views
19555           check_inplace_batched_forward_grad=False,
19556           assert_jit_shape_analysis=True,
19557           assert_autodiffed=True,
19558           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
19559           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
19560           sample_inputs_func=sample_unsqueeze),
19561    OpInfo('unsqueeze_copy',
19562           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
19563           supports_out=True,
19564           supports_forward_ad=True,
19565           supports_fwgrad_bwgrad=True,
19566           # See https://github.com/pytorch/pytorch/pull/78358
19567           check_batched_forward_grad=False,
19568           # vmap does not support inplace views
19569           check_inplace_batched_forward_grad=False,
19570           assert_jit_shape_analysis=True,
19571           assert_autodiffed=True,
19572           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
19573           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
19574           sample_inputs_func=sample_unsqueeze,
19575           skips=(
19576               DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
19577               DecorateInfo(
19578                   unittest.expectedFailure,
19579                   'TestJit',
19580                   'test_variant_consistency_jit',
19581                   dtypes=(torch.float32,),
19582               ),
19583           )),
19584    BinaryUfuncInfo('xlogy',
19585                    aliases=('special.xlogy',),
19586                    dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
19587                    promotes_int_to_float=True,
19588                    supports_forward_ad=True,
19589                    supports_fwgrad_bwgrad=True,
19590                    supports_one_python_scalar=True,
19591                    # We don't test 0 as the gradient will be NaN and it'll break
19592                    rhs_make_tensor_kwargs=dict(low=0.01)),
19593    OpInfo('zero_',
19594           op=lambda x: torch.zero_(x.clone()),
19595           method_variant=None,
19596           inplace_variant=torch.Tensor.zero_,
19597           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
19598           # https://github.com/pytorch/pytorch/issues/80411
19599           gradcheck_fast_mode=True,
19600           supports_out=False,
19601           supports_forward_ad=True,
19602           supports_fwgrad_bwgrad=True,
19603           supports_gradgrad=True,
19604           skips=(
19605               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
19606           ),
19607           sample_inputs_func=sample_inputs_zero_),
19608    OpInfo('logsumexp',
19609           aliases=('special.logsumexp',),
19610           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
19611           assert_autodiffed=True,
19612           supports_forward_ad=True,
19613           supports_fwgrad_bwgrad=True,
19614           gradcheck_fast_mode=False,
19615           sample_inputs_func=sample_inputs_logsumexp,
19616           reference_inputs_func=reference_inputs_logsumexp),
19617    OpInfo('trace',
19618           dtypes=all_types_and_complex(),
19619           dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
19620           error_inputs_func=error_inputs_trace,
19621           supports_inplace_autograd=False,
19622           supports_out=False,
19623           supports_forward_ad=True,
19624           supports_fwgrad_bwgrad=True,
19625           sample_inputs_func=sample_inputs_trace),
19626    OpInfo('transpose',
19627           ref=_numpy_ref_transpose,
19628           aliases=('swapdims', 'swapaxes'),
19629           assert_jit_shape_analysis=True,
19630           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
19631           supports_out=False,
19632           supports_forward_ad=True,
19633           supports_fwgrad_bwgrad=True,
19634           # vmap does not support inplace views
19635           check_inplace_batched_forward_grad=False,
19636           sample_inputs_func=sample_inputs_transpose_swapdims),
19637    OpInfo('T',
19638           op=lambda x: x.T,
19639           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
19640           supports_out=False,
19641           supports_forward_ad=True,
19642           supports_fwgrad_bwgrad=True,
19643           skips=(
19644               # lambda impl
19645               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
19646               DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
19647           sample_inputs_func=sample_inputs_T,
19648           error_inputs_func=error_inputs_T),
19649    OpInfo('H',
19650           op=lambda x: x.H,
19651           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
19652           supports_out=False,
19653           supports_forward_ad=True,
19654           supports_fwgrad_bwgrad=True,
19655           # See https://github.com/pytorch/pytorch/pull/78358
19656           check_batched_forward_grad=False,
19657           skips=(
19658               # lambda impl
19659               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
19660               DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
19661           sample_inputs_func=sample_inputs_T),
19662    OpInfo('mT',
19663           op=lambda x: x.mT,
19664           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
19665           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
19666           gradcheck_fast_mode=True,
19667           supports_out=False,
19668           supports_forward_ad=True,
19669           supports_fwgrad_bwgrad=True,
19670           skips=(
19671               # lambda impl
19672               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
19673               DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
19674           sample_inputs_func=sample_inputs_adjoint),
19675    OpInfo('mH',
19676           op=lambda x: x.mH,
19677           aliases=('adjoint',),
19678           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf),
19679           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
19680           gradcheck_fast_mode=True,
19681           supports_out=False,
19682           supports_forward_ad=True,
19683           supports_fwgrad_bwgrad=True,
19684           # See https://github.com/pytorch/pytorch/pull/78358
19685           check_batched_forward_grad=False,
19686           skips=(
19687               # lambda impl
19688               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
19689               DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),),
19690           sample_inputs_func=sample_inputs_adjoint),
19691    OpInfo('tril',
19692           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
19693           supports_forward_ad=True,
19694           supports_fwgrad_bwgrad=True,
19695           error_inputs_func=error_inputs_tril_triu,
19696           sample_inputs_func=sample_inputs_tril_triu),
19697    OpInfo('triu',
19698           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
19699           supports_forward_ad=True,
19700           supports_fwgrad_bwgrad=True,
19701           error_inputs_func=error_inputs_tril_triu,
19702           sample_inputs_func=sample_inputs_tril_triu),
19703    OpInfo('triu_indices',
19704           dtypes=_dispatch_dtypes((torch.int32, torch.int64)),
19705           sample_inputs_func=sample_inputs_trilu_indices,
19706           ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.triu_indices(h, ofs, w), dtype=dtype),
19707           supports_out=False,
19708           supports_autograd=False,
19709           skips=(
19710               # skip these tests since we have non tensor input
19711               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
19712               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
19713               DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
19714               DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
19715           )),
19716    OpInfo('tril_indices',
19717           dtypes=_dispatch_dtypes((torch.int32, torch.int64)),
19718           sample_inputs_func=sample_inputs_trilu_indices,
19719           ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.tril_indices(h, ofs, w), dtype=dtype),
19720           supports_out=False,
19721           supports_autograd=False,
19722           skips=(
19723               # skip these tests since we have non tensor input
19724               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
19725               DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
19726               DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
19727               DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
19728           )),
19729    OpInfo('kron',
19730           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
19731           dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
19732           # Runs very slowly on slow gradcheck - alternatively reduce input sizes
19733           gradcheck_fast_mode=True,
19734           supports_inplace_autograd=False,
19735           supports_forward_ad=True,
19736           supports_fwgrad_bwgrad=True,
19737           sample_inputs_func=sample_inputs_kron,
19738           decorators=(
19739               # RuntimeError: view size is not compatible with input tensor's size and stride
19740               DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
19741           )),
19742    OpInfo('inner',
19743           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
19744           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
19745           dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
19746           supports_forward_ad=True,
19747           supports_fwgrad_bwgrad=True,
19748           # See https://github.com/pytorch/pytorch/pull/78358
19749           check_batched_forward_grad=False,
19750           sample_inputs_func=sample_inputs_inner,
19751           ),
19752    OpInfo('tensordot',
19753           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
19754           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
19755           dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
19756           supports_forward_ad=True,
19757           supports_fwgrad_bwgrad=True,
19758           # See https://github.com/pytorch/pytorch/pull/78358
19759           check_batched_forward_grad=False,
19760           sample_inputs_func=sample_inputs_tensordot,
19761           skips=(
19762               # Skip operator schema test because this is a functional and not an operator.
19763               # Reference: https://github.com/pytorch/pytorch/issues/54574
19764               DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
19765           )
19766           ),
19767    OpInfo('to_sparse',
19768           op=lambda x, *args: x.to_sparse(*args),
19769           sample_inputs_func=sample_inputs_to_sparse,
19770           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
19771           backward_dtypes=floating_types(),
19772           backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
19773           supports_out=False,
19774           supports_sparse_csr=True,
19775           supports_sparse_csc=True,
19776           check_batched_grad=False,
19777           check_batched_gradgrad=False,
19778           skips=(
19779               # NotImplementedError: Could not run 'aten::normal_' with arguments from the 'SparseCPU' backend
19780               DecorateInfo(unittest.skip(""), 'TestCommon', 'test_noncontiguous_samples'),
19781               # TODO: FIXME: complex inputs requiring grad error in forward
19782               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'),
19783               # lambda impl
19784               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
19785               # Allowed exception: sparse tensors don't have strides
19786               DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'),
19787               DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_backward'),
19788               DecorateInfo(unittest.skip("Allowed exception"), 'TestTags', 'test_tags'),
19789               # TODO: implement csr.to_sparse(sample_dim) where sampled_dim is 1.
19790               DecorateInfo(unittest.skip("csr.to_sparse(1) not implemented. Skipped!"),
19791                            'TestSparseCSR', 'test_sparse_csr_consistency'),
19792               # Compiler issue on ROCm. Might need to skip until ROCm5.5
19793               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values',
19794                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
19795           )
19796           ),
19797    OpInfo('logcumsumexp',
19798           dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
19799           backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
19800           backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16),
19801           supports_forward_ad=True,
19802           supports_fwgrad_bwgrad=True,
19803           skips=(
19804               # AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
19805               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cuda'),
19806               # RuntimeError: "max_values_cpu" not implemented for 'ComplexDouble'
19807               # Falling back to non-numerically stablized exp, causing nan in the results.
19808               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD', dtypes=[torch.complex128]),
19809               DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', dtypes=[torch.complex128]),
19810               DecorateInfo(
19811                   toleranceOverride({
19812                       torch.float16: tol(atol=7e-5, rtol=6e-3),
19813                   }),
19814                   "TestInductorOpInfo",
19815                   "test_comprehensive",
19816                   device_type="cuda"
19817               ),
19818           ),
19819           sample_inputs_func=sample_inputs_logcumsumexp,
19820           error_inputs_func=error_inputs_logcumsumexp),
19821    UnaryUfuncInfo('sigmoid',
19822                   aliases=('special.expit', 'nn.functional.sigmoid'),
19823                   aten_backward_name='sigmoid_backward',
19824                   ref=reference_sigmoid if TEST_SCIPY else None,
19825                   decorators=(precisionOverride({torch.float16: 1e-2,
19826                                                  torch.complex64: 1e-1,
19827                                                  torch.bfloat16: 1e-2}),),
19828                   skips=(
19829                       # Reference: https://github.com/pytorch/pytorch/issues/56012
19830                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
19831                                    dtypes=[torch.complex64, torch.cdouble]),
19832                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
19833                                    dtypes=[torch.chalf, torch.complex64, torch.cdouble])),
19834                   dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
19835                   dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16),
19836                   supports_forward_ad=True,
19837                   supports_fwgrad_bwgrad=True,
19838                   promotes_int_to_float=True,
19839                   assert_autodiffed=True,
19840                   # sigmoid(z) = 1 / (1 + exp(-z)), at z = j * pi * odd_number, the denominator is zero
19841                   reference_numerics_filter=NumericsFilter(
19842                       condition=lambda x: (close_to_int(x / (math.pi * 1j))
19843                                            if x.is_complex() else x.new_tensor(False, dtype=torch.bool)),
19844                       safe_val=0)),
19845    UnaryUfuncInfo('digamma',
19846                   ref=scipy.special.digamma if TEST_SCIPY else None,
19847                   aliases=('special.psi', 'special.digamma',),
19848                   decorators=(precisionOverride({torch.float16: 5e-1}),),
19849                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
19850                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
19851                   supports_forward_ad=True,
19852                   supports_fwgrad_bwgrad=True,
19853                   promotes_int_to_float=True),
19854    UnaryUfuncInfo('erf',
19855                   ref=scipy.special.erf if TEST_SCIPY else None,
19856                   aliases=('special.erf', ),
19857                   decorators=(precisionOverride({torch.float16: 1e-2,
19858                                                  torch.bfloat16: 1e-2}),),
19859                   skips=(
19860                       DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
19861                                    'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
19862
19863                   ),
19864                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
19865                   assert_autodiffed=True,
19866                   assert_jit_shape_analysis=True,
19867                   supports_sparse=True,
19868                   supports_sparse_csr=True,
19869                   supports_sparse_csc=True,
19870                   supports_sparse_bsr=True,
19871                   supports_sparse_bsc=True,
19872                   supports_forward_ad=True,
19873                   supports_fwgrad_bwgrad=True,
19874                   promotes_int_to_float=True),
19875    UnaryUfuncInfo('erfc',
19876                   ref=scipy.special.erfc if TEST_SCIPY else None,
19877                   aliases=('special.erfc', ),
19878                   decorators=(precisionOverride({torch.float16: 1e-2,
19879                                                  torch.bfloat16: 1e-2}),),
19880                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
19881                   assert_autodiffed=True,
19882                   supports_forward_ad=True,
19883                   supports_fwgrad_bwgrad=True,
19884                   promotes_int_to_float=True),
19885    UnaryUfuncInfo('erfinv',
19886                   ref=scipy.special.erfinv if TEST_SCIPY else None,
19887                   aliases=('special.erfinv', ),
19888                   decorators=(precisionOverride({torch.float16: 1e-2,
19889                                                  torch.bfloat16: 1e-2,
19890                                                  torch.float32: 1e-4}),),
19891                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
19892                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
19893                   supports_sparse_csr=True,
19894                   supports_sparse_csc=True,
19895                   supports_sparse_bsr=True,
19896                   supports_sparse_bsc=True,
19897                   supports_forward_ad=True,
19898                   supports_fwgrad_bwgrad=True,
19899                   promotes_int_to_float=True,
19900                   domain=(-1, 1),
19901                   skips=(
19902                       # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611
19903                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
19904                                    active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
19905                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
19906                                    active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
19907                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
19908                                    active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
19909                   )),
19910    OpInfo("nn.functional.smooth_l1_loss",
19911           ref=reference_smooth_l1_loss,
19912           sample_inputs_func=sample_inputs_smooth_l1_loss,
19913           dtypes=floating_types_and(torch.float16, torch.bfloat16),
19914           backward_dtypes=floating_types_and(torch.bfloat16),
19915           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
19916           backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
19917           supports_out=False,
19918           supports_forward_ad=True,
19919           supports_fwgrad_bwgrad=True,
19920           skips=(
19921               # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED
19922               # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch.
19923               DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),)),
19924    OpInfo(
19925        "nn.functional.l1_loss",
19926        ref=loss_reference_reduction_wrapper(lambda input, target: np.abs(input - target)),
19927        sample_inputs_func=sample_inputs_l1_loss,
19928        error_inputs_func=error_inputs_l1_loss,
19929        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
19930        supports_out=False,
19931        supports_forward_ad=True,
19932        supports_fwgrad_bwgrad=True,
19933        skips=(
19934            # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED
19935            # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch.
19936            DecorateInfo(
19937                unittest.expectedFailure,
19938                "TestJit",
19939                "test_variant_consistency_jit",
19940                dtypes=(torch.float32,),
19941            ),
19942        ),
19943    ),
19944    UnaryUfuncInfo('lgamma',
19945                   ref=reference_lgamma if TEST_SCIPY else None,
19946                   aliases=('special.gammaln', ),
19947                   decorators=(precisionOverride({torch.float16: 7e-1}),),
19948                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
19949                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
19950                   supports_forward_ad=True,
19951                   supports_fwgrad_bwgrad=True,
19952                   promotes_int_to_float=True,
19953                   skips=(
19954                       # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214
19955                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
19956                                    dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
19957                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
19958                                    dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
19959                   ),
19960                   # lgamma have multiple singularities at x <= 0
19961                   reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
19962    OpInfo(
19963        'logdet',
19964        dtypes=floating_and_complex_types(),
19965        supports_out=False,
19966        supports_forward_ad=True,
19967        supports_fwgrad_bwgrad=True,
19968        sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet,
19969        decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]),
19970    # `log_softmax` supports different dtypes based on whether `dtype` argument,
19971    # is passed or not. Hence two OpInfo entries, one with dtype and other without.
19972    OpInfo(
19973        'log_softmax',
19974        aliases=('special.log_softmax', 'nn.functional.log_softmax'),
19975        supports_out=True,
19976        aten_backward_name='_log_softmax_backward_data',
19977        dtypes=floating_types_and(torch.float16, torch.bfloat16),
19978        sample_inputs_func=sample_inputs_softmax_variant,
19979        supports_forward_ad=True,
19980        supports_fwgrad_bwgrad=True,
19981        assert_autodiffed=True),
19982    OpInfo(
19983        'log_softmax',
19984        variant_test_name='with_dtype',
19985        aliases=('special.log_softmax', 'nn.functional.log_softmax'),
19986        supports_out=True,
19987        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
19988        sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True),
19989        supports_forward_ad=True,
19990        supports_fwgrad_bwgrad=True,
19991        assert_autodiffed=True),
19992    UnaryUfuncInfo('logit',
19993                   aten_backward_name='logit_backward',
19994                   ref=scipy.special.logit if TEST_SCIPY else None,
19995                   domain=(0, 1),
19996                   aliases=('special.logit', ),
19997                   supports_forward_ad=True,
19998                   supports_fwgrad_bwgrad=True,
19999                   promotes_int_to_float=True,
20000                   decorators=(precisionOverride({torch.bfloat16: 5e-1,
20001                                                  torch.float16: 5e-1}),),
20002                   dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
20003                   sample_inputs_func=sample_inputs_logit),
20004    OpInfo('where',
20005           # Currently only the `input` is tested in gradcheck.
20006           # If we pass `condition` first, none of the input which supports
20007           # autograd will be tested. Hence the following lambda.
20008           op=lambda self, condition, other, **kwargs: torch.where(condition, self, other, **kwargs),
20009           ref=lambda self, condition, other: np.where(condition, self, other),
20010           sample_inputs_func=sample_inputs_where,
20011           reference_inputs_func=reference_inputs_where,
20012           error_inputs_func=error_inputs_where,
20013           supports_forward_ad=True,
20014           supports_fwgrad_bwgrad=True,
20015           decorators=(
20016               DecorateInfo(onlyCUDA, "TestCommon", 'test_errors'),),
20017           skips=(
20018               # lambda impl
20019               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
20020               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
20021           ),
20022           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf)),
20023    OpInfo('nonzero',
20024           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
20025           sample_inputs_func=sample_inputs_nonzero,
20026           supports_autograd=False,
20027           skips=(
20028               DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
20029               # nonzero(): argument 'out' must be Tensor, not tuple
20030               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
20031               # https://github.com/pytorch/pytorch/issues/67458
20032               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20033               # nonzero is not raising a warning when the out is resized
20034               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
20035               # Can't find schemas for this operator for some reason
20036               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
20037               # Compiler issue on ROCm. Might need to skip until ROCm5.5
20038               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values',
20039                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
20040           )),
20041    OpInfo('nonzero_static',
20042           dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
20043           sample_inputs_func=sample_inputs_nonzero_static,
20044           supports_out=False,
20045           supports_autograd=False,
20046           decorators=[onlyCPU],
20047           skips=(
20048               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
20049               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
20050               DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
20051               DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
20052               DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'),
20053               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values',
20054                            dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
20055           )),
20056    # Following tests are for jiterator's python interface
20057    # Jiterator can be used to author elementwise CUDA kernel
20058    # jiterator._create_jit_fn returns a callable that behaves like a regular pytorch op
20059    # See create_jit_fn in jiterator.py for more information
20060    UnaryUfuncInfo(
20061        'jiterator_unary',
20062        op=torch.cuda.jiterator._create_jit_fn("template <typename T> T unary(T x) { return x * x + x; }"),
20063        ref=lambda x: x * x + x,
20064        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
20065        supports_out=False,
20066        supports_autograd=False,  # jiterator ops doesn't have backward defined
20067        decorators=[
20068            onlyCUDA,
20069            DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
20070                         'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
20071            DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
20072                         'TestUnaryUfuncs', 'test_reference_numerics_hard'),
20073            DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
20074                         'TestUnaryUfuncs', 'test_reference_numerics_normal'),
20075            DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
20076                         'TestUnaryUfuncs', 'test_reference_numerics_small'),
20077        ],
20078        skips=(
20079            # Jiterator ops doesn't support neg or conj view
20080            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
20081            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
20082            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
20083            # Jiterator ops doesn't support CompositeCompliantTensor
20084            # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
20085            DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
20086            # Skip reference_numerics tests for bool type, as the defined function doesn't work for bool
20087            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
20088                         dtypes=[torch.bool]),
20089            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard',
20090                         dtypes=[torch.bool]),
20091            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
20092                         dtypes=[torch.bool]),
20093            # ROCm generates -inf+infj instead of nan+infj for complex64 for some of the results
20094            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
20095                         dtypes=[torch.complex64], active_if=TEST_WITH_ROCM),
20096            # Expected failure: torch.jiterator_unary is not a valid op
20097            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20098            # Skip Nvfuser
20099            DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
20100        )
20101    ),
20102    BinaryUfuncInfo(
20103        'jiterator_binary',
20104        op=torch.cuda.jiterator._create_jit_fn(
20105            "template <typename T> T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1),
20106        ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
20107            else np.add(input, np.multiply(alpha, other)),
20108        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
20109        sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14),
20110        supports_out=False,
20111        supports_autograd=False,  # jiterator ops doesn't have backward defined
20112        supports_rhs_python_scalar=False,
20113        decorators=[onlyCUDA],
20114        skips=(
20115            # Jiterator ops doesn't support neg or conj view
20116            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
20117            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
20118            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
20119            # Jiterator ops doesn't support CompositeCompliantTensor
20120            # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
20121            DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
20122            # Expected failure: torch.jiterator_binary is not a valid op
20123            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20124            # Skip Nvfuser
20125            DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
20126        )
20127    ),
20128    OpInfo(
20129        'jiterator_4inputs_with_extra_args',
20130        op=torch.cuda.jiterator._create_jit_fn(
20131            "template <typename T> T binary(T i0, T i1, T i2, T i3, T alpha, T beta) { return alpha * i0 + beta * i1 + i2 + i3; }",
20132            alpha=1, beta=1),
20133        ref=lambda i0, i1, i2, i3, *, alpha=1, beta=1: alpha * i0 + beta * i1 + i2 + i3,
20134        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
20135        sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=4, alpha=3.14, beta=-4.20),
20136        supports_out=False,
20137        supports_autograd=False,  # jiterator ops doesn't have backward defined
20138        decorators=[onlyCUDA],
20139        skips=(
20140            # Jiterator ops doesn't support neg or conj view
20141            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
20142            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
20143            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
20144            # Jiterator ops doesn't support CompositeCompliantTensor
20145            # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
20146            DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
20147            # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op
20148            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20149            # Skip Nvfuser
20150            DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
20151        )
20152    ),
20153    BinaryUfuncInfo(
20154        'jiterator_binary_return_by_ref',
20155        op=torch.cuda.jiterator._create_multi_output_jit_fn(
20156            """
20157            template <typename T>
20158            void binary_return_by_ref(T i0, T i1, T& out0) {
20159                out0 = i0 + i1;
20160            }
20161            """,
20162            num_outputs=1),
20163        ref=operator.add,
20164        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
20165        sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-0.42),
20166        supports_out=False,
20167        supports_autograd=False,  # jiterator ops doesn't have backward defined
20168        supports_rhs_python_scalar=False,
20169        decorators=[onlyCUDA],
20170        skips=(
20171            # Jiterator ops doesn't support neg or conj view
20172            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
20173            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
20174            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
20175            # Jiterator ops doesn't support CompositeCompliantTensor
20176            # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
20177            DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
20178            # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op
20179            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20180            # Skip Nvfuser
20181            DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
20182        )
20183    ),
20184    OpInfo(
20185        'jiterator_2inputs_2outputs',
20186        op=torch.cuda.jiterator._create_multi_output_jit_fn(
20187            """
20188            template <typename T>
20189            void binary_2outputs(T i0, T i1, T& out0, T& out1) {
20190                out0 = i0 + i1;
20191                out1 = i0 - i1;
20192            }
20193            """,
20194            num_outputs=2),
20195        ref=lambda i0, i1, *, alpha=1: (i0 + i1, i0 - i1),
20196        dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
20197        sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2),
20198        supports_out=False,
20199        supports_autograd=False,  # jiterator ops doesn't have backward defined
20200        decorators=[onlyCUDA],
20201        skips=(
20202            # Jiterator ops doesn't support neg or conj view
20203            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
20204            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
20205            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
20206            # Jiterator ops doesn't support CompositeCompliantTensor
20207            # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped
20208            DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'),
20209            # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op
20210            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20211            # Skip Nvfuser
20212            DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'),
20213        )
20214    ),
20215    # `torch.norm` has multiple code paths depending on the value of `p`.
20216    # These paths have different dtype support. Also JIT supports,
20217    # most variants but not all of them. So we split the OpInfo entries,
20218    # for `norm` based on the code-paths and JIT support.
20219    OpInfo(
20220        "norm",
20221        sample_inputs_func=sample_inputs_norm,
20222        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
20223        # TODO Benchmark again with the new implementation
20224        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
20225        gradcheck_fast_mode=True,
20226        check_batched_forward_grad=False,
20227        supports_forward_ad=True,
20228        supports_fwgrad_bwgrad=True,
20229        skips=(
20230            # Dispatches in Python to vector_norm. Not sure how to make this test happy
20231            # Happens to pass on complex64. Also a mystery
20232            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
20233                         dtypes=(torch.float32,)),)
20234    ),
20235    OpInfo('norm',
20236           variant_test_name='nuc',
20237           sample_inputs_func=sample_inputs_norm_nuc,
20238           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
20239           check_batched_gradgrad=False,
20240           # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
20241           # got: Could not allocate memory to change Tensor SizesAndStrides!
20242           check_batched_forward_grad=False,
20243           supports_forward_ad=True,
20244           supports_fwgrad_bwgrad=True,
20245           dtypes=floating_and_complex_types(),
20246           dtypesIfCUDA=floating_and_complex_types(),
20247           skips=(
20248               # Dispatches in Python to matrix_norm. Not sure how to make this test happy
20249               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
20250                            dtypes=(torch.complex64, torch.float32,)),)
20251           ),
20252    OpInfo('norm',
20253           variant_test_name='fro',
20254           sample_inputs_func=sample_inputs_norm_fro,
20255           dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16),
20256           dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
20257           supports_forward_ad=True,
20258           # torch.autograd.gradcheck.GradcheckError: While computing batched gradients
20259           # got: Could not allocate memory to change Tensor SizesAndStrides!
20260           check_batched_forward_grad=False,
20261           supports_fwgrad_bwgrad=True,
20262           skips=(
20263               # MPS has some mild accuracy issues for float16. We divide the tolerances by 10
20264               DecorateInfo(
20265                   toleranceOverride({torch.float16: tol(atol=1e-4, rtol=0.01)}),
20266                   'TestConsistency',
20267                   'test_output_match',
20268
20269               ),
20270               # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
20271               DecorateInfo(
20272                   unittest.skip("Skipped!"),
20273                   'TestSchemaCheckModeOpInfo',
20274                   'test_schema_correctness',
20275                   dtypes=(torch.complex64, torch.complex128)),
20276               # Dispatches in Python to vector_norm. Not sure how to make this test happy
20277               DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
20278                            dtypes=(torch.complex64, torch.float32,)),)
20279           ),
20280    OpInfo(
20281        "norm",
20282        variant_test_name="inf",
20283        sample_inputs_func=sample_inputs_norm_inf,
20284        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
20285        supports_forward_ad=True,
20286        check_batched_forward_grad=False,
20287        supports_fwgrad_bwgrad=True,
20288        # fast gradcheck produces NaNs
20289        gradcheck_fast_mode=False,
20290        skips=(
20291            DecorateInfo(
20292                toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}),
20293                'TestInductorOpInfo', 'test_comprehensive', device_type='cuda',
20294            ),
20295            # Dispatches in Python to vector_norm. Not sure how to make this test happy
20296            # Happens to pass on complex64. Also a mystery
20297            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
20298                         dtypes=(torch.float32,))
20299        ),
20300    ),
20301    OpInfo('t',
20302           sample_inputs_func=sample_inputs_t,
20303           supports_out=False,
20304           supports_forward_ad=True,
20305           supports_fwgrad_bwgrad=True,
20306           # See https://github.com/pytorch/pytorch/pull/78358
20307           check_batched_forward_grad=False,
20308           # vmap does not support inplace views
20309           check_inplace_batched_forward_grad=False,
20310           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
20311           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
20312           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
20313           assert_autodiffed=True,
20314           error_inputs_func=error_inputs_t),
20315    OpInfo('t_copy',
20316           sample_inputs_func=sample_inputs_t,
20317           supports_out=True,
20318           supports_forward_ad=True,
20319           supports_fwgrad_bwgrad=True,
20320           # See https://github.com/pytorch/pytorch/pull/78358
20321           check_batched_forward_grad=False,
20322           # vmap does not support inplace views
20323           check_inplace_batched_forward_grad=False,
20324           autodiff_fusible_nodes=[],  # aliases inputs, shouldn't be fused
20325           autodiff_nonfusible_nodes=[],  # aliases inputs, shouldn't be fused
20326           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
20327           assert_autodiffed=True,
20328           error_inputs_func=error_inputs_t),
20329    OpInfo(
20330        "nn.functional.dropout",
20331        op=lambda input, *args, **kwargs:
20332            wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs),
20333        dtypes=floating_types_and(torch.float16, torch.bfloat16),
20334        skips=(
20335            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
20336            # Probably because we have used lambda for the op here
20337            # AssertionError: JIT Test does not execute any logic
20338            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20339            # inplace variant dispatches to dropout kernel, while on CUDA
20340            # the op dispatches to _fused_dropout (with a few more conditions)
20341            # hence, different values and this skip here
20342            DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'),
20343            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
20344        supports_forward_ad=True,
20345        supports_fwgrad_bwgrad=True,
20346        # https://github.com/pytorch/pytorch/issues/66357
20347        check_batched_forward_grad=False,
20348        supports_out=False,
20349        sample_inputs_func=sample_inputs_dropout,
20350        inplace_variant=lambda input, *args, **kwargs:
20351            wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)),
20352    OpInfo(
20353        "native_dropout_backward",
20354        op=torch.ops.aten.native_dropout_backward.default,
20355        aten_name="native_dropout_backward",
20356        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
20357        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
20358        supports_out=False,
20359        sample_inputs_func=sample_inputs_dropout_backward,
20360        skips=(
20361            DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
20362            # Lazy tensor failures
20363            DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
20364            # These tests fail only when built with ASAN
20365            DecorateInfo(unittest.skip("Fails with ASAN"), 'TestLazyOpInfo', 'test_correctness', active_if=TEST_WITH_ASAN),
20366            DecorateInfo(
20367                unittest.skip("Fails with ASAN"),
20368                'TestLazyOpInfo',
20369                'test_correctness_with_reusing_ir',
20370                active_if=TEST_WITH_ASAN
20371            ),
20372        ),
20373    ),
20374    OpInfo(
20375        "nn.functional.dropout2d",
20376        op=lambda input, *args, **kwargs:
20377            wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs),
20378        dtypes=floating_types_and(torch.float16, torch.bfloat16),
20379        skips=(
20380            # lambda impl
20381            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
20382            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20383            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
20384        supports_forward_ad=True,
20385        supports_fwgrad_bwgrad=True,
20386        supports_out=False,
20387        check_batched_forward_grad=False,
20388        # As per the docs, valid input dims are (3, 4)
20389        sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(3, 4)),
20390        inplace_variant=lambda input, *args, **kwargs:
20391            wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs, inplace=True)),
20392    OpInfo(
20393        "nn.functional.dropout3d",
20394        op=lambda input, *args, **kwargs:
20395            wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs),
20396        dtypes=floating_types_and(torch.float16, torch.bfloat16),
20397        skips=(
20398            # lambda impl
20399            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
20400            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20401            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
20402        supports_forward_ad=True,
20403        supports_fwgrad_bwgrad=True,
20404        supports_out=False,
20405        check_batched_forward_grad=False,
20406        # As per the docs, valid input dims are (4, 5)
20407        sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(4, 5)),
20408        inplace_variant=lambda input, *args, **kwargs:
20409            wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs, inplace=True)),
20410    OpInfo(
20411        "nn.functional.alpha_dropout",
20412        op=lambda input, *args, **kwargs:
20413            wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs),
20414        dtypes=floating_types_and(torch.float16, torch.bfloat16),
20415        gradcheck_wrapper=wrapper_set_seed,
20416        supports_forward_ad=True,
20417        supports_fwgrad_bwgrad=True,
20418        supports_out=False,
20419        sample_inputs_func=sample_inputs_dropout,
20420        check_batched_forward_grad=False,
20421        inplace_variant=lambda input, *args, **kwargs:
20422            wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs, inplace=True),
20423        skips=(
20424            # lambda impl
20425            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
20426            # AssertionError: Tensor-likes are not close!
20427            # Fails in cuda11.7
20428            # Error Log: https://github.com/pytorch/pytorch/actions/runs/3440108478/jobs/5738475757
20429            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'),
20430            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),),
20431    # In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype
20432    # unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases
20433    OpInfo(
20434        "nn.functional.feature_alpha_dropout",
20435        op=lambda input, *args, **kwargs:
20436            wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs),
20437        variant_test_name="with_train",
20438        dtypes=floating_types_and(torch.float16, torch.bfloat16),
20439        skips=(
20440            # lambda impl
20441            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
20442            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20443            # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got:
20444            # vmap: We do not yet support calling random operations inside of vmap.
20445            # Please perform random operations outside of vmap as a workaround
20446            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_forward_mode_AD"),
20447            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_inplace_forward_mode_AD"),
20448            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
20449        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
20450        gradcheck_fast_mode=True,
20451        supports_forward_ad=True,
20452        supports_fwgrad_bwgrad=True,
20453        supports_out=False,
20454        # As per the docs, valid input dims are (4, 5)
20455        sample_inputs_func=partial(sample_inputs_dropout, train=True, valid_input_dim=(4, 5)),
20456        inplace_variant=lambda input, *args, **kwargs:
20457            wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)),
20458    OpInfo(
20459        "nn.functional.feature_alpha_dropout",
20460        op=lambda input, *args, **kwargs:
20461            wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs),
20462        variant_test_name="without_train",
20463        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
20464        skips=(
20465            # lambda impl
20466            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
20467            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),
20468        gradcheck_wrapper=wrapper_set_seed,
20469        supports_forward_ad=True,
20470        supports_fwgrad_bwgrad=True,
20471        supports_out=False,
20472        sample_inputs_func=partial(sample_inputs_dropout, train=False),
20473        inplace_variant=lambda input, *args, **kwargs:
20474            wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)),
20475    OpInfo(
20476        "nn.functional.one_hot",
20477        ref=reference_one_hot,
20478        supports_out=False,
20479        dtypes=_dispatch_dtypes((torch.int64,)),
20480        sample_inputs_func=sample_inputs_one_hot,
20481    ),
20482    OpInfo(
20483        "nn.functional.embedding",
20484        aten_backward_name="embedding_dense_backward",
20485        # We use lambda to reshuffle the positional arguments.
20486        # This is because currently only the `input` field of SampleInput
20487        # is tested in gradient tests.
20488        op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(idx, weight, **kwargs),
20489        dtypes=floating_types_and(torch.bfloat16, torch.float16),
20490        sample_inputs_func=sample_inputs_embedding,
20491        allow_cow_input_materialize_forward=[0],
20492        error_inputs_func=error_inputs_embedding,
20493        supports_forward_ad=True,
20494        supports_fwgrad_bwgrad=True,
20495        skips=(
20496            # lambda impl
20497            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20498            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
20499            # Fails on CI https://github.com/pytorch/pytorch/issues/85377
20500            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'),
20501            # Reference: https://github.com/pytorch/pytorch/issues/67084
20502            DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'),
20503            # Not a problem: embedding does weird stuff to its input (it renormalizes)
20504            DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
20505            # Fails due to non-determinism (see issue #74679)
20506            # TODO: Investigate why more granular skips in the test don't work in CI
20507            DecorateInfo(unittest.skip('Skipped!'),
20508                         'TestExpandedWeightFunctional',
20509                         'test_expanded_weight_forward'),
20510        ),
20511        supports_expanded_weight=True,
20512        supports_out=False,
20513    ),
20514    OpInfo(
20515        "nn.functional.embedding_bag",
20516        # We use lambda to reshuffle the positional arguments.
20517        # This is because currently only the `input` field of SampleInput
20518        # is tested in gradient tests.
20519        op=lambda weight, idx, **kwargs: torch.nn.functional.embedding_bag(idx, weight, **kwargs),
20520        dtypes=floating_types_and(torch.bfloat16, torch.float16),
20521        dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
20522        # backward is not supported for mode `max` and dtype `bfloat16`
20523        backward_dtypesIfCUDA=floating_types_and(torch.float16),
20524        sample_inputs_func=sample_inputs_embedding_bag,
20525        skips=(
20526            # lambda impl
20527            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20528            DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
20529            # Not a problem: embedding_bag does weird stuff to its input (it renormalizes)
20530            DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
20531        ),
20532        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
20533        supports_out=False,
20534        supports_gradgrad=False,
20535        allow_cow_input_materialize_forward=[0],
20536    ),
20537    OpInfo(
20538        "nn.functional.multi_head_attention_forward",
20539        op=lambda input, *args, **kwargs:
20540            wrapper_set_seed(torch.nn.functional.multi_head_attention_forward, input, *args, **kwargs),
20541        dtypes=floating_types_and(torch.bfloat16, torch.float16),
20542        sample_inputs_func=sample_inputs_multi_head_attention_forward,
20543        skips=(
20544            # Tensor-likes are not close
20545            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', dtypes=(torch.float32,)),
20546            DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-3, rtol=0)}), 'TestDecomp', 'test_comprehensive'),
20547
20548            # TODO skip this for now since we can't skip on runtime arch support (taken from scaled_dot_product_attention)
20549            DecorateInfo(unittest.skip("Skipped!"), 'TestInductorOpInfo', 'test_comprehensive'),
20550            # randomness
20551            DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
20552            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
20553            # lambda impl
20554            # AssertionError: JIT Test does not execute any logic
20555            DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
20556            DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
20557            # tests running very slowly break slow tests, so we skip them instead of using `slowTest`.
20558            DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'),
20559            DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'),
20560            DecorateInfo(
20561                unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"),
20562                'TestDecomp',
20563                'test_comprehensive',
20564                dtypes=(torch.bfloat16, torch.float16),
20565            ),
20566            DecorateInfo(
20567                unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"),
20568                'TestDecomp',
20569                'test_quick',
20570                dtypes=(torch.bfloat16, torch.float16))),
20571        supports_out=False,
20572        supports_gradgrad=True,
20573        supports_forward_ad=True,
20574        supports_fwgrad_bwgrad=True,
20575        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
20576        gradcheck_fast_mode=True,
20577    ),
20578    UnaryUfuncInfo(
20579        "nn.functional.softplus",
20580        aten_backward_name='softplus_backward',
20581        ref=reference_softplus,
20582        sample_kwargs=lambda device, dtype, input: ({'beta': 3, 'threshold': .2}, {'beta': 3, 'threshold': .2}),
20583        sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'beta': 3, 'threshold': .2}),
20584        supports_forward_ad=True,
20585        supports_fwgrad_bwgrad=True,
20586        dtypes=floating_types_and(torch.bfloat16, torch.float16),
20587        decorators=(
20588            DecorateInfo(
20589                toleranceOverride
20590                ({
20591                    torch.half: tol(atol=1e-2, rtol=1e-2),
20592                    torch.bfloat16: tol(atol=1e-2, rtol=1e-2),
20593                }),
20594                'TestUnaryUfuncs'),
20595        ),
20596    ),
20597    OpInfo(
20598        "nn.functional.mse_loss",
20599        aten_backward_name='mse_loss_backward',
20600        ref=loss_reference_reduction_wrapper(lambda input, target: (input - target) ** 2),
20601        sample_inputs_func=sample_inputs_loss,
20602        supports_out=False,
20603        supports_forward_ad=True,
20604        supports_fwgrad_bwgrad=True,
20605        dtypes=floating_types_and(torch.float16),
20606        backward_dtypes=floating_types(),
20607        dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
20608        backward_dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
20609        skips=(
20610            # RuntimeError: input->type()->kind() == TypeKind::OptionalType
20611            # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
20612            # please report a bug to PyTorch.
20613            DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
20614        ),
20615    ),
20616    OpInfo(
20617        "nn.functional.grid_sample",
20618        dtypes=floating_types(),
20619        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
20620        supports_out=False,
20621        sample_inputs_func=sample_inputs_grid_sample,
20622        reference_inputs_func=reference_inputs_grid_sample,
20623        supports_gradgrad=False,
20624        gradcheck_nondet_tol=1e-15),
20625    # TODO: delete this OpInfo once we add meta support for grid_sampler_3d
20626    OpInfo(
20627        "grid_sampler_2d",
20628        dtypes=floating_types(),
20629        dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
20630        supports_out=False,
20631        sample_inputs_func=sample_inputs_grid_sampler_2d,
20632        supports_gradgrad=False,
20633        gradcheck_nondet_tol=1e-15,
20634        skips=(
20635            DecorateInfo(slowTest, 'TestDecomp', 'test_comprehensive', dtypes=(torch.float32, torch.float64),
20636                         active_if=IS_WINDOWS),
20637        ),),
20638    OpInfo(
20639        "argwhere",
20640        ref=np.argwhere,
20641        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
20642        supports_out=False,
20643        supports_autograd=False,
20644        sample_inputs_func=sample_inputs_argwhere,
20645        skips=(
20646            # Compiler issue on ROCm. Might need to skip until ROCm5.5
20647            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values',
20648                         dtypes=[torch.bool], active_if=TEST_WITH_ROCM),
20649        ),
20650    ),
20651    ReductionOpInfo(
20652        'all',
20653        identity=True,
20654        supports_autograd=False,
20655        result_dtype=torch.bool,
20656        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
20657        ref=reference_reduction_numpy(np.all),
20658        skips=(
20659            # FIXME: uint8 input returns uint8 instead of bool
20660            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]),
20661        ),
20662    ),
20663    ReductionOpInfo(
20664        'any',
20665        identity=False,
20666        supports_autograd=False,
20667        result_dtype=torch.bool,
20668        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
20669        ref=reference_reduction_numpy(np.any),
20670        skips=(
20671            # FIXME: uint8 input returns uint8 instead of bool
20672            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]),
20673        ),
20674    ),
20675    ReductionOpInfo(
20676        'amax',
20677        nan_policy='propagate',
20678        supports_forward_ad=True,
20679        check_batched_forward_grad=False,
20680        supports_fwgrad_bwgrad=True,
20681        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
20682        ref=reference_reduction_numpy(np.amax),
20683        skips=(
20684            # FIXME: reduces all dimensions when dim=[]
20685            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
20686            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
20687        ),
20688        error_inputs_func=error_inputs_aminmax_amax_amin,
20689    ),
20690    ReductionOpInfo(
20691        'amin',
20692        nan_policy='propagate',
20693        supports_forward_ad=True,
20694        check_batched_forward_grad=False,
20695        supports_fwgrad_bwgrad=True,
20696        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
20697        ref=reference_reduction_numpy(np.amin),
20698        skips=(
20699            # FIXME: reduces all dimensions when dim=[]
20700            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
20701            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
20702        ),
20703        error_inputs_func=error_inputs_aminmax_amax_amin,
20704    ),
20705    ReductionOpInfo(
20706        'argmax',
20707        supports_multiple_dims=False,
20708        supports_autograd=False,
20709        assert_jit_shape_analysis=True,
20710        result_dtype=torch.int64,
20711        dtypes=all_types_and(torch.float16, torch.bfloat16),
20712        ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
20713    ),
20714    ReductionOpInfo(
20715        'argmin',
20716        supports_multiple_dims=False,
20717        supports_autograd=False,
20718        result_dtype=torch.int64,
20719        dtypes=all_types_and(torch.float16, torch.bfloat16),
20720        ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
20721    ),
20722    ReductionOpInfo(
20723        'count_nonzero',
20724        identity=0,
20725        supports_out=False,
20726        supports_autograd=False,
20727        result_dtype=torch.int64,
20728        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
20729        sample_inputs_func=sample_inputs_reduction_count_nonzero,
20730        ref=reference_reduction_numpy(np.count_nonzero),
20731        skips=(
20732            # FIXME: count_nonzero does not accept keepdim kwarg
20733            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
20734            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
20735            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'),
20736            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
20737            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'),
20738            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_unsorted_keepdim'),
20739            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_offbounds_keepdim'),
20740            # FIXME: dim=[] reduces all dimensions
20741            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
20742        ),
20743    ),
20744    ReductionOpInfo(
20745        'mean',
20746        nan_policy='propagate',
20747        supports_forward_ad=True,
20748        supports_fwgrad_bwgrad=True,
20749        # FIXME: mean needs 'dim' parameter when using the 'out' overload.
20750        # Adding it with 'generate_args_kwargs' does not work, since these also get passed
20751        # onto the reference implementations.
20752        supports_out=True,
20753        assert_autodiffed=True,
20754        assert_jit_shape_analysis=True,
20755        promotes_int_to_float=True,
20756        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
20757        ref=reference_reduction_numpy(np.mean),
20758        error_inputs_func=error_inputs_mean,
20759        skips=(
20760            # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result
20761            # of dtype torch.float32 into an out= with dtype torch.long
20762            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='cuda', dtypes=[torch.float32]),
20763            # FIXME: mean does not support passing keepdim without passing dim
20764            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
20765            # FIXME: mean reduces all dimensions when dim=[]
20766            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
20767            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
20768            # FIXME: improve precision
20769            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
20770                         dtypes=[torch.float16]),
20771            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values',
20772                         device_type='cuda', dtypes=[torch.complex64]),
20773        ),
20774    ),
20775    ReductionOpInfo(
20776        'nanmean',
20777        nan_policy='omit',
20778        assert_autodiffed=True,
20779        promotes_int_to_float=True,
20780        supports_forward_ad=True,
20781        check_batched_forward_grad=False,
20782        supports_fwgrad_bwgrad=True,
20783        dtypes=floating_types_and(torch.float16, torch.bfloat16),
20784        dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf),
20785        sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True),
20786        ref=reference_reduction_numpy(np.nanmean),
20787        skips=(
20788            # AssertionError: False is not true :
20789            # Failure in testing nodes' autodifferentiation.
20790            DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
20791            # FIXME: prod reduces all dimensions when dim=[]
20792            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
20793            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
20794            # FIXME: improve precision
20795            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
20796                         dtypes=[torch.float16]),
20797            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
20798                         device_type='cuda', dtypes=[torch.float16]),
20799            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values',
20800                         device_type='cuda', dtypes=[torch.complex64]),
20801        ),
20802    ),
20803    ReductionOpInfo(
20804        'std',
20805        nan_policy='propagate',
20806        supports_out=True,
20807        complex_to_real=True,
20808        supports_forward_ad=True,
20809        supports_fwgrad_bwgrad=True,
20810        assert_autodiffed=True,
20811        promotes_int_to_float=True,
20812        check_batched_forward_grad=False,
20813        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
20814        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
20815        sample_inputs_func=sample_inputs_std_var,
20816        ref=reference_std_var(np.std),
20817        generate_args_kwargs=generate_std_var_kwargs,
20818        skips=(
20819            # FIXME: cannot specify keepdim without dim
20820            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
20821            # FIXME: dim=[] reduces all dimensions
20822            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
20823            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
20824            # FIXME: improve precision
20825            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
20826                         dtypes=(torch.float16,)),
20827            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
20828                         dtypes=(torch.float16,)),
20829        ),
20830    ),
20831    ReductionOpInfo(
20832        'std',
20833        variant_test_name='unbiased',
20834        nan_policy='propagate',
20835        supports_out=False,
20836        complex_to_real=True,
20837        supports_forward_ad=True,
20838        supports_fwgrad_bwgrad=True,
20839        assert_autodiffed=True,
20840        promotes_int_to_float=True,
20841        check_batched_forward_grad=False,
20842        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
20843        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
20844        sample_inputs_func=sample_inputs_std_var_unbiased,
20845        skips=(
20846            # FIXME: dim=[] reduces all dimensions
20847            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
20848            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
20849        ),
20850    ),
20851    ReductionOpInfo(
20852        'var',
20853        nan_policy='propagate',
20854        supports_out=True,
20855        assert_autodiffed=True,
20856        promotes_int_to_float=True,
20857        complex_to_real=True,
20858        supports_forward_ad=True,
20859        supports_fwgrad_bwgrad=True,
20860        check_batched_forward_grad=False,
20861        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
20862        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
20863        sample_inputs_func=sample_inputs_std_var,
20864        ref=reference_std_var(np.var),
20865        generate_args_kwargs=generate_std_var_kwargs,
20866        skips=(
20867            # FIXME: cannot specify keepdim without dim
20868            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
20869            # FIXME: dim=[] reduces all dimensions
20870            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
20871            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
20872            # FIXME: improve precision
20873            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'),
20874            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'),
20875            # NumPy is giving NaN for this
20876            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'),
20877        ),
20878    ),
20879    ReductionOpInfo(
20880        'var',
20881        variant_test_name='unbiased',
20882        nan_policy='propagate',
20883        supports_out=False,
20884        complex_to_real=True,
20885        supports_forward_ad=True,
20886        supports_fwgrad_bwgrad=True,
20887        assert_autodiffed=True,
20888        promotes_int_to_float=True,
20889        check_batched_forward_grad=False,
20890        dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
20891        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
20892        sample_inputs_func=sample_inputs_std_var_unbiased,
20893        skips=(
20894            # FIXME: dim=[] reduces all dimensions
20895            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
20896            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
20897        ),
20898    ),
20899    ReductionOpInfo(
20900        'prod',
20901        identity=1,
20902        nan_policy='propagate',
20903        supports_multiple_dims=False,
20904        # https://github.com/pytorch/pytorch/issues/80411
20905        gradcheck_fast_mode=True,
20906        supports_out=False,
20907        supports_forward_ad=True,
20908        supports_fwgrad_bwgrad=True,
20909        promotes_int_to_int64=True,
20910        gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
20911        dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
20912        dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
20913        sample_inputs_func=sample_inputs_prod,
20914        ref=prod_numpy,
20915        skips=(
20916            # FIXME: prod does not support passing keepdim without passing dim
20917            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
20918            # FIXME: prod reduces all dimensions when dim=[]
20919            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
20920            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
20921            # FIXME: prod does not support passing None to dim
20922            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
20923            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
20924            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
20925                         dtypes=[torch.float16, torch.complex64]),
20926            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
20927                         dtypes=[torch.uint8, torch.float16, torch.complex64]),
20928            # FIXME: ValueError: The data in MaskedTensor a and Tensor b do not match
20929            DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all',
20930                         dtypes=[torch.float16]),
20931        ),
20932    ),
20933    ReductionOpInfo(
20934        'sum',
20935        identity=0,
20936        nan_policy='propagate',
20937        supports_out=False,
20938        supports_forward_ad=True,
20939        supports_fwgrad_bwgrad=True,
20940        promotes_int_to_int64=True,
20941        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
20942        dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
20943        ref=reference_reduction_numpy(np.sum),
20944        error_inputs_sparse_func=error_inputs_sparse_reduction_sum,
20945        sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_coo),
20946        sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csr),
20947        sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csc),
20948        sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsr),
20949        sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsc),
20950        skips=(
20951            # FIXME: sum does not support passing keepdim without passing dim
20952            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
20953            # FIXME: sum reduces all dimensions when dim=[]
20954            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
20955            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
20956            # FIXME: improve precision
20957            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
20958                         dtypes=[torch.float16]),
20959            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
20960                         dtypes=[torch.float16]),
20961            DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all',
20962                         dtypes=[torch.float32]),
20963        ),
20964    ),
20965    ReductionOpInfo(
20966        'nansum',
20967        identity=0,
20968        nan_policy='omit',
20969        supports_out=True,
20970        promotes_int_to_int64=True,
20971        supports_forward_ad=True,
20972        check_batched_forward_grad=False,
20973        supports_fwgrad_bwgrad=True,
20974        dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
20975        dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
20976        sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True),
20977        ref=reference_reduction_numpy(np.nansum),
20978        skips=(
20979            # please report a bug to PyTorch.
20980            DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
20981            # FIXME: nansum reduces all dimensions when dim=[]
20982            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
20983            DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
20984            # FIXME: flaky test so skipped instead of xfailed
20985            # possibly bad low precision reference in numpy
20986            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
20987                         dtypes=[torch.float16]),
20988        ),
20989    ),
20990    OpInfo(
20991        "nn.functional.ctc_loss",
20992        dtypes=floating_types(),
20993        supports_out=False,
20994        sample_inputs_func=sample_inputs_ctc_loss,
20995        skips=(
20996            # https://github.com/pytorch/pytorch/issues/67462
20997            # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0
20998            DecorateInfo(
20999                unittest.expectedFailure,
21000                "TestBwdGradients",
21001                "test_fn_grad",
21002                dtypes=(torch.float64,),
21003            ),
21004            # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented
21005            DecorateInfo(
21006                unittest.expectedFailure,
21007                "TestBwdGradients",
21008                "test_fn_gradgrad",
21009                dtypes=(torch.float64,),
21010            ),
21011            # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented
21012            DecorateInfo(
21013                unittest.skip("Skipped!"),
21014                "TestJit",
21015                "test_variant_consistency_jit",
21016                dtypes=(torch.float32,),
21017            ),
21018            # Ref: https://github.com/pytorch/pytorch/issues/85231
21019            DecorateInfo(unittest.skip("Fails with ASAN"),
21020                         'TestProxyTensorOpInfo',
21021                         'test_make_fx_fake_exhaustive', active_if=TEST_WITH_ASAN),
21022        ),
21023    ),
21024    OpInfo(
21025        "nn.functional.cosine_embedding_loss",
21026        dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool),
21027        supports_out=False,
21028        supports_forward_ad=True,
21029        supports_fwgrad_bwgrad=True,
21030        decorators=[
21031            DecorateInfo(
21032                toleranceOverride({torch.float16: tol(atol=1e-4, rtol=2e-3)}),
21033                'TestInductorOpInfo', 'test_comprehensive', device_type="cuda",
21034            ),
21035        ],
21036        sample_inputs_func=sample_inputs_cosine_embedding_loss,
21037    ),
21038    OpInfo(
21039        "nn.functional.nll_loss",
21040        dtypes=floating_types_and(torch.float16, torch.bfloat16),
21041        supports_out=False,
21042        sample_inputs_func=sample_inputs_nll_loss,
21043        supports_forward_ad=True,
21044        supports_fwgrad_bwgrad=True,
21045        assert_jit_shape_analysis=True,
21046        skips=(
21047            # RuntimeError:
21048            # undefined value tensor:
21049            #   File "<string>", line 3
21050            # def the_method(i0, i1):
21051            #     return torch.nn.functional.nll_loss(i0, i1, weight=tensor([8.4784, 1.7658, 4.3228], dtype=torch.float32))
21052            #                                                        ~~~~~~ <--- HERE
21053            DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
21054            # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120782
21055            DecorateInfo(
21056                unittest.skip("Skipped!"),
21057                "TestCompositeCompliance",
21058                "test_cow_input",
21059                device_type='cuda',
21060            ),
21061            DecorateInfo(unittest.skip("FP16 nll_loss cases have not been enabled on MPS yet"),
21062                         dtypes=(torch.half,), device_type="mps"),
21063
21064        ),
21065    ),
21066    OpInfo(
21067        "nn.functional.gaussian_nll_loss",
21068        dtypes=floating_types_and(torch.half, torch.bfloat16),
21069        # Runs very slowly on slow gradcheck - alternatively reduce input sizes
21070        gradcheck_fast_mode=True,
21071        supports_out=False,
21072        supports_forward_ad=True,
21073        supports_fwgrad_bwgrad=True,
21074        sample_inputs_func=sample_inputs_gaussian_nll_loss,
21075        error_inputs_func=error_inputs_gaussian_nll_loss,
21076        skips=(
21077            # Pre-existing condition (calls .item); needs to be fixed
21078            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),
21079            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
21080            # Pre-existing condition (calls .item); needs to be fixed
21081            DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
21082            # JIT does not support variadic tensors.
21083            # RuntimeError: input->type()->kind() == TypeKind::OptionalType
21084            # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270,
21085            # please report a bug to PyTorch.
21086            DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
21087        ),
21088    ),
21089    OpInfo(
21090        "nn.functional.hinge_embedding_loss",
21091        dtypes=floating_types_and(torch.half, torch.bfloat16),
21092        supports_out=False,
21093        supports_forward_ad=True,
21094        supports_fwgrad_bwgrad=True,
21095        sample_inputs_func=sample_inputs_hinge_embedding_loss,
21096        error_inputs_func=error_inputs_hinge_embedding_loss,
21097        reference_inputs_func=reference_inputs_hinge_embedding_loss,
21098    ),
21099    OpInfo(
21100        "nn.functional.huber_loss",
21101        aten_backward_name='huber_loss_backward',
21102        dtypes=floating_types_and(torch.float16, torch.bfloat16),
21103        supports_out=False,
21104        supports_forward_ad=True,
21105        sample_inputs_func=sample_inputs_huber_loss,
21106        error_inputs_func=error_inputs_huber_loss,
21107        skips=(
21108            # JIT does not support variadic tensors.
21109            # RuntimeError: input->type()->kind() == TypeKind::OptionalType
21110            # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270,
21111            # please report a bug to PyTorch.
21112            DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),),
21113        )
21114    ),
21115    OpInfo(
21116        "nn.functional.pdist",
21117        ref=reference_pdist,
21118        sample_inputs_func=sample_inputs_pdist,
21119        dtypes=floating_types(),
21120        supports_out=False,
21121        supports_gradgrad=False,
21122        skips=(
21123            DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
21124        )
21125    ),
21126    OpInfo(
21127        "nn.functional.poisson_nll_loss",
21128        dtypes=all_types_and(torch.half, torch.bfloat16),
21129        supports_out=False,
21130        supports_forward_ad=True,
21131        supports_fwgrad_bwgrad=True,
21132        sample_inputs_func=sample_inputs_poisson_nll_loss,
21133        error_inputs_func=error_inputs_poisson_nll_loss,
21134    ),
21135    OpInfo(
21136        "argsort",
21137        dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
21138        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
21139        sample_inputs_func=sample_inputs_sort,
21140        supports_out=False,
21141        supports_autograd=False,
21142        skips=(
21143            DecorateInfo(
21144                unittest.skip("Skipped!"),
21145                "TestJit",
21146                "test_variant_consistency_jit",
21147                dtypes=(torch.float32,),
21148            ),
21149        ),
21150    ),
21151    OpInfo(
21152        "repeat_interleave",
21153        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
21154        backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf),
21155        sample_inputs_func=sample_inputs_repeat_interleave,
21156        supports_out=False,
21157        supports_forward_ad=True,
21158        supports_fwgrad_bwgrad=True,
21159        # See https://github.com/pytorch/pytorch/pull/78358
21160        check_batched_forward_grad=False,
21161        skips=(
21162            DecorateInfo(
21163                unittest.skip("Skipped!"),
21164                "TestJit",
21165                "test_variant_consistency_jit",
21166                dtypes=(torch.float32, torch.complex64),
21167            ),
21168        ),
21169    ),
21170    OpInfo(
21171        "nn.functional.pairwise_distance",
21172        ref=lambda a, b, p=2.0, eps=1e-6, keepdim=False: (
21173            np.sum(np.abs(a - b + eps) ** p, axis=-1, keepdims=keepdim) ** (1 / p)
21174        ),
21175        sample_inputs_func=sample_inputs_pairwise_distance,
21176        dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
21177        supports_out=False,
21178        supports_forward_ad=True,
21179        supports_fwgrad_bwgrad=True,
21180        skips=(
21181            DecorateInfo(
21182                unittest.skip("Skipped!"),
21183                "TestJit",
21184                "test_variant_consistency_jit",
21185                dtypes=(torch.float32, torch.complex64),
21186            ),
21187        ),
21188    ),
21189    OpInfo(
21190        "nn.functional.pixel_shuffle",
21191        sample_inputs_func=sample_inputs_pixel_shuffle,
21192        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
21193        supports_out=False,
21194        supports_forward_ad=True,
21195        supports_fwgrad_bwgrad=True,
21196        skips=(
21197            DecorateInfo(
21198                unittest.skip("Skipped!"),
21199                "TestJit",
21200                "test_variant_consistency_jit",
21201                dtypes=(torch.float32, torch.complex64),
21202            ),
21203        ),
21204    ),
21205    OpInfo(
21206        "nn.functional.pixel_unshuffle",
21207        sample_inputs_func=sample_inputs_pixel_unshuffle,
21208        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
21209        supports_out=False,
21210        supports_forward_ad=True,
21211        supports_fwgrad_bwgrad=True,
21212        skips=(
21213            DecorateInfo(
21214                unittest.skip("Skipped!"),
21215                "TestJit",
21216                "test_variant_consistency_jit",
21217                dtypes=(torch.float32, torch.complex64),
21218            ),
21219        ),
21220    ),
21221    OpInfo(
21222        "nn.functional.channel_shuffle",
21223        sample_inputs_func=sample_inputs_channel_shuffle,
21224        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
21225        supports_out=False,
21226        supports_forward_ad=True,
21227        supports_fwgrad_bwgrad=True,
21228        allow_cow_input_materialize_forward=[0],
21229        allow_cow_input_materialize_backward=[0, 'output grad 0'],
21230        skips=(
21231            # Skip due to NotImplementedError for MPS device.
21232            DecorateInfo(unittest.expectedFailure, 'TestConsistency'),
21233            DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
21234            DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"),
21235        ),
21236    ),
21237    OpInfo(
21238        "nn.functional.kl_div",
21239        sample_inputs_func=sample_inputs_kl_div,
21240        dtypes=floating_types_and(torch.float16, torch.bfloat16),
21241        supports_out=False,
21242        supports_forward_ad=True,
21243        supports_fwgrad_bwgrad=True,
21244    ),
21245    OpInfo(
21246        "diagflat",
21247        ref=lambda input, offset=0: np.diagflat(input, k=offset),
21248        sample_inputs_func=sample_inputs_diagflat,
21249        dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
21250        dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
21251        supports_out=False,
21252        supports_forward_ad=True,
21253        supports_fwgrad_bwgrad=True,
21254        # See https://github.com/pytorch/pytorch/pull/78358
21255        check_batched_forward_grad=False,
21256    ),
21257    OpInfo(
21258        'scatter_reduce',
21259        variant_test_name='sum',
21260        # complex not added to dtypes as complex gradients are not properly handled
21261        # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet
21262        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
21263        supports_forward_ad=True,
21264        supports_fwgrad_bwgrad=True,
21265        sample_inputs_func=sample_inputs_scatter_reduce,
21266    ),
21267    OpInfo(
21268        'scatter_reduce',
21269        variant_test_name='prod',
21270        # complex not added to dtypes as complex gradients are not properly handled
21271        # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet
21272        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
21273        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
21274        sample_inputs_func=sample_inputs_scatter_reduce,
21275        skips=(
21276            # Not implemented
21277            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
21278            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'),
21279            DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
21280        ),
21281    ),
21282    OpInfo(
21283        'scatter_reduce',
21284        variant_test_name='mean',
21285        # complex not added to dtypes as complex gradients are not properly handled
21286        # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet
21287        dtypes=all_types_and(torch.float16, torch.bfloat16),
21288        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
21289        supports_forward_ad=True,
21290        supports_fwgrad_bwgrad=True,
21291        sample_inputs_func=sample_inputs_scatter_reduce,
21292    ),
21293    OpInfo(
21294        'scatter_reduce',
21295        variant_test_name='amin',
21296        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
21297        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
21298        supports_forward_ad=True,
21299        check_batched_forward_grad=False,
21300        supports_fwgrad_bwgrad=True,
21301        sample_inputs_func=sample_inputs_scatter_reduce,
21302    ),
21303    OpInfo(
21304        'scatter_reduce',
21305        variant_test_name='amax',
21306        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
21307        dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
21308        supports_forward_ad=True,
21309        check_batched_forward_grad=False,
21310        supports_fwgrad_bwgrad=True,
21311        sample_inputs_func=sample_inputs_scatter_reduce,
21312    ),
21313    OpInfo(
21314        '_segment_reduce',
21315        aten_name='segment_reduce',
21316        variant_test_name='lengths',
21317        dtypes=floating_types_and(torch.float16, torch.bfloat16),
21318        supports_out=False,
21319        # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
21320        supports_gradgrad=False,
21321        sample_inputs_func=sample_inputs_segment_reduce,
21322        skips=(
21323            # FIXME: CUDA driver API confirmed a leak in
21324            # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32
21325            DecorateInfo(
21326                unittest.skip("Skipped!"),
21327                "TestJit",
21328                "test_variant_consistency_jit",
21329                device_type="cuda",
21330            ),
21331        ),
21332    ),
21333    OpInfo(
21334        '_segment_reduce',
21335        aten_name='segment_reduce',
21336        variant_test_name='offsets',
21337        dtypes=floating_types_and(torch.float16, torch.bfloat16),
21338        supports_out=False,
21339        # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
21340        supports_gradgrad=False,
21341        sample_inputs_func=partial(sample_inputs_segment_reduce, mode='offsets'),
21342        skips=(
21343            # FIXME: CUDA driver API confirmed a leak in
21344            # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32
21345            DecorateInfo(
21346                unittest.skip("Skipped!"),
21347                "TestJit",
21348                "test_variant_consistency_jit",
21349                device_type="cuda",
21350            ),
21351        ),
21352    ),
21353]
21354op_db += opinfo.definitions.op_db
21355
21356
21357# Separate registry for experimental Python Reference OpInfos.
21358python_ref_db = [
21359    #
21360    # Elementwise Unary OpInfos
21361    #
21362    ElementwiseUnaryPythonRefInfo(
21363        "_refs.abs",
21364        torch_opinfo_name="abs",
21365        skips=(
21366            # Reference: https://github.com/pytorch/pytorch/issues/49224
21367            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21368                         'test_reference_numerics_small',
21369                         dtypes=[torch.int8], active_if=TEST_WITH_ASAN),
21370        ),
21371    ),
21372    ElementwiseUnaryPythonRefInfo(
21373        "_refs.acos",
21374        torch_opinfo_name="acos",
21375        skips=(
21376            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21377                         'test_reference_numerics_normal',
21378                         device_type='cuda', dtypes=[torch.cdouble],
21379                         active_if=IS_WINDOWS),
21380            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21381                         'test_reference_numerics_extremal',
21382                         device_type='cuda', dtypes=[torch.cdouble],
21383                         active_if=IS_WINDOWS),
21384            # Failing with wrong imaginary sign on at least some Windows jobs
21385            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21386                         'test_reference_numerics_small',
21387                         device_type='cuda', dtypes=[torch.cdouble],
21388                         active_if=IS_WINDOWS),
21389            # Failing with wrong imaginary sign on at least some Windows jobs
21390            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21391                         'test_reference_numerics_large',
21392                         device_type='cuda', dtypes=[torch.cdouble],
21393                         active_if=IS_WINDOWS),
21394            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21395                         'test_reference_numerics_large',
21396                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21397            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21398                         'test_reference_numerics_extremal',
21399                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21400        )
21401    ),
21402    ElementwiseUnaryPythonRefInfo(
21403        "_refs.acosh",
21404        torch_opinfo_name="acosh",
21405        skips=(
21406            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21407                         'test_reference_numerics_normal',
21408                         device_type='cuda', dtypes=[torch.cdouble],
21409                         active_if=IS_WINDOWS),
21410            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21411                         'test_reference_numerics_extremal',
21412                         device_type='cuda', dtypes=[torch.cdouble],
21413                         active_if=IS_WINDOWS),
21414            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21415                         'test_reference_numerics_extremal',
21416                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21417            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21418                         'test_reference_numerics_large',
21419                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21420            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21421                         'test_reference_numerics_extremal',
21422                         device_type='cuda', dtypes=[torch.cdouble],
21423                         active_if=IS_WINDOWS),
21424            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21425                         'test_reference_numerics_large',
21426                         device_type='cuda', dtypes=[torch.cdouble],
21427                         active_if=IS_WINDOWS),
21428            # Failing with wrong imaginary sign on at least some Windows jobs
21429            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21430                         'test_reference_numerics_small',
21431                         device_type='cuda', dtypes=[torch.cdouble],
21432                         active_if=IS_WINDOWS),
21433        ),
21434    ),
21435    ElementwiseUnaryPythonRefInfo(
21436        "_refs.asin",
21437        torch_opinfo_name="asin",
21438        decorators=[
21439            DecorateInfo(
21440                toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}),
21441                'TestUnaryUfuncs', device_type='cuda'),
21442            precisionOverride({torch.bfloat16: 1e-2}),
21443        ],
21444        skips=(
21445            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21446                         'test_reference_numerics_extremal',
21447                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21448            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21449                         'test_reference_numerics_large',
21450                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21451            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21452                         'test_reference_numerics_extremal',
21453                         device_type='cuda', dtypes=[torch.cdouble],
21454                         active_if=IS_WINDOWS),
21455            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21456                         'test_reference_numerics_large',
21457                         device_type='cuda', dtypes=[torch.cdouble],
21458                         active_if=IS_WINDOWS),
21459        ),
21460    ),
21461    ElementwiseUnaryPythonRefInfo(
21462        "_refs.asinh",
21463        torch_opinfo_name="asinh",
21464        decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
21465        skips=(
21466            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21467                         'test_reference_numerics_extremal',
21468                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21469            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21470                         'test_reference_numerics_large',
21471                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21472            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21473                         'test_reference_numerics_small',
21474                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21475            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21476                         'test_reference_numerics_normal',
21477                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21478            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21479                         'test_reference_numerics_extremal',
21480                         device_type='cuda', dtypes=[torch.cdouble],
21481                         active_if=IS_WINDOWS),
21482            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21483                         'test_reference_numerics_large',
21484                         device_type='cuda', dtypes=[torch.cdouble],
21485                         active_if=IS_WINDOWS),
21486        ),
21487    ),
21488    PythonRefInfo(
21489        "_refs.lerp",
21490        torch_opinfo_name="lerp",
21491    ),
21492    PythonRefInfo(
21493        "_refs.ones",
21494        torch_opinfo_name="ones",
21495        skips=(
21496            # Tests that assume input is a tensor or sequence of tensors
21497            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21498            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
21499            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
21500        ),
21501    ),
21502    PythonRefInfo(
21503        "_refs.zeros",
21504        torch_opinfo_name="zeros",
21505        skips=(
21506            # Tests that assume input is a tensor or sequence of tensors
21507            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21508            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
21509            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
21510        ),
21511    ),
21512    PythonRefInfo(
21513        "_refs.cauchy",
21514        torch_opinfo_name="cauchy",
21515        decorators=(
21516            # TODO: RuntimeError: no _refs support for torch.rand_like
21517            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
21518                         'TestCommon',
21519                         'test_python_ref'),
21520            # AssertionError: Tensor-likes are not close!
21521            DecorateInfo(unittest.skip("Expected: cauchy is not comparable"),
21522                         'TestCommon',
21523                         'test_out'),
21524            DecorateInfo(unittest.skip("Expected: cauchy is not comparable"),
21525                         'TestCommon',
21526                         'test_out_warning'),
21527            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
21528            DecorateInfo(unittest.skip("Expected: cauchy is not comparable"),
21529                         'TestCommon',
21530                         'test_python_ref_torch_fallback'),
21531            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
21532            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21533        )
21534    ),
21535    PythonRefInfo(
21536        "_refs.exponential",
21537        torch_opinfo_name="exponential",
21538        supports_out=True,
21539        decorators=(
21540            # dtypes that do not support check_uniform_bounds of rand_like
21541            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',
21542                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)),
21543            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
21544
21545            # TODO: RuntimeError: no _refs support for torch.rand_like
21546            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
21547                         'TestCommon',
21548                         'test_python_ref'),
21549
21550            # AssertionError: Tensor-likes are not close!
21551            DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
21552                         'TestCommon',
21553                         'test_out'),
21554            DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
21555                         'TestCommon',
21556                         'test_out_warning'),
21557            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
21558            DecorateInfo(unittest.skip("Expected: exponential is not comparable"),
21559                         'TestCommon',
21560                         'test_python_ref_torch_fallback'),
21561            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
21562            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21563        )
21564    ),
21565    PythonRefInfo(
21566        "_refs.geometric",
21567        torch_opinfo_name="geometric",
21568        supports_out=True,
21569        decorators=(
21570            # dtypes that do not support check_uniform_bounds of rand_like
21571            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'),
21572            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',
21573                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)),
21574            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
21575                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)),
21576
21577            # TODO: RuntimeError: no _refs support for torch.rand_like
21578            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
21579                         'TestCommon',
21580                         'test_python_ref'),
21581            DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
21582                         'TestCommon',
21583                         'test_python_ref_executor', device_type='cuda'),
21584
21585            # AssertionError: Tensor-likes are not close!
21586            DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
21587                         'TestCommon',
21588                         'test_out'),
21589            DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
21590                         'TestCommon',
21591                         'test_out_warning'),
21592            DecorateInfo(unittest.skip("Expected: geometric is not comparable"),
21593                         'TestCommon',
21594                         'test_python_ref_torch_fallback'),
21595            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
21596            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21597        )
21598    ),
21599    PythonRefInfo(
21600        "_refs.log_normal",
21601        torch_opinfo_name="log_normal",
21602        supports_out=True,
21603        decorators=(
21604            # TODO: RuntimeError: no _refs support for torch.rand_like
21605            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
21606                         'TestCommon',
21607                         'test_python_ref'),
21608            DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
21609                         'TestCommon',
21610                         'test_python_ref_executor', device_type='cuda'),
21611
21612            # AssertionError: Tensor-likes are not close!
21613            DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
21614                         'TestCommon',
21615                         'test_out'),
21616            DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
21617                         'TestCommon',
21618                         'test_out_warning'),
21619            DecorateInfo(unittest.skip("Expected: log_normal is not comparable"),
21620                         'TestCommon',
21621                         'test_python_ref_torch_fallback'),
21622            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
21623            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21624        )
21625    ),
21626    PythonRefInfo(
21627        "_refs.normal",
21628        torch_opinfo_name="normal",
21629        supports_out=True,
21630        decorators=(
21631            # TODO: RuntimeError: no _refs support for torch.rand_like
21632            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
21633                         'TestCommon',
21634                         'test_python_ref'),
21635
21636            # AssertionError: Tensor-likes are not close!
21637            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
21638                         'TestCommon',
21639                         'test_out'),
21640            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
21641                         'TestCommon',
21642                         'test_out_warning'),
21643            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
21644                         'TestCommon',
21645                         'test_python_ref_torch_fallback'),
21646            DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'),
21647            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
21648            DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'),
21649            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21650            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
21651            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
21652        )
21653    ),
21654    PythonRefInfo(
21655        "_refs.normal",
21656        torch_opinfo_name="normal",
21657        torch_opinfo_variant_name="number_mean",
21658        supports_out=True,
21659        decorators=(
21660            # TODO: RuntimeError: no _refs support for torch.rand_like
21661            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
21662                         'TestCommon',
21663                         'test_python_ref'),
21664
21665            # AssertionError: Tensor-likes are not close!
21666            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
21667                         'TestCommon',
21668                         'test_out'),
21669            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
21670                         'TestCommon',
21671                         'test_out_warning'),
21672            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
21673                         'TestCommon',
21674                         'test_python_ref_torch_fallback'),
21675            DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'),
21676            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
21677            DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'),
21678            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21679            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
21680            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
21681        )
21682    ),
21683    PythonRefInfo(
21684        "_refs.normal_",
21685        op=torch.Tensor.normal_,
21686        torch_opinfo_name="normal",
21687        torch_opinfo_variant_name="in_place",
21688        supports_out=False,
21689        decorators=(
21690            # TODO: RuntimeError: no _refs support for torch.rand_like
21691            DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"),
21692                         'TestCommon',
21693                         'test_python_ref'),
21694
21695            # AssertionError: Tensor-likes are not close!
21696            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
21697                         'TestCommon',
21698                         'test_out'),
21699            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
21700                         'TestCommon',
21701                         'test_out_warning'),
21702            DecorateInfo(unittest.skip("Expected: normal is not comparable"),
21703                         'TestCommon',
21704                         'test_python_ref_torch_fallback'),
21705            DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'),
21706            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
21707            DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'),
21708            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21709            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
21710            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
21711        )
21712    ),
21713    PythonRefInfo(
21714        "_refs.arange",
21715        torch_opinfo_name="arange",
21716        skips=(
21717            # Tests that assume input is a tensor or sequence of tensors
21718            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21719            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
21720            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
21721        ),
21722    ),
21723    PythonRefInfo(
21724        "_refs.linspace",
21725        torch_opinfo_name="linspace",
21726        skips=(
21727            # Tests that assume input is a tensor or sequence of tensors
21728            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21729            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
21730            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
21731
21732            # cpu implementation is wrong on some integral types
21733            # https://github.com/pytorch/pytorch/issues/81996
21734            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
21735                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
21736            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
21737                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
21738
21739            # cuda implementation is off-by-one on some inputs due to precision issues
21740            # https://github.com/pytorch/pytorch/issues/82230
21741            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
21742                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
21743                         device_type="cuda"),
21744            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
21745                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
21746                         device_type="cuda"),
21747            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
21748                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
21749                         device_type="cuda"),
21750        ),
21751    ),
21752    PythonRefInfo(
21753        "_refs.linspace",
21754        torch_opinfo_name="linspace",
21755        torch_opinfo_variant_name="tensor_overload",
21756        skips=(
21757            # TypeError: 'int' object is not subscriptable
21758            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21759            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
21760
21761            # cpu implementation is wrong on some integral types
21762            # https://github.com/pytorch/pytorch/issues/81996
21763            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
21764                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
21765            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
21766                         dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"),
21767
21768            # cuda implementation is off-by-one on some inputs due to precision issues
21769            # https://github.com/pytorch/pytorch/issues/82230
21770            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
21771                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
21772                         device_type="cuda"),
21773            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
21774                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
21775                         device_type="cuda"),
21776            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
21777                         dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64),
21778                         device_type="cuda"),
21779        ),
21780    ),
21781    PythonRefInfo(
21782        "_refs.logspace",
21783        torch_opinfo_name="logspace",
21784        skips=(
21785            # Tests that assume input is a tensor or sequence of tensors
21786            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21787            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
21788            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
21789
21790            # Off-by-one issue when casting floats to ints
21791            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
21792                         dtypes=(torch.int16, torch.int32, torch.int64),
21793                         device_type="cuda"),
21794            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
21795                         dtypes=(torch.int16, torch.int32, torch.int64),
21796                         device_type="cuda"),
21797            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
21798                         dtypes=(torch.int16, torch.int32, torch.int64),
21799                         device_type="cuda"),
21800        ),
21801    ),
21802    PythonRefInfo(
21803        "_refs.logspace",
21804        torch_opinfo_name="logspace",
21805        torch_opinfo_variant_name="tensor_overload",
21806        skips=(
21807            # TypeError: 'int' object is not subscriptable
21808            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
21809            DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
21810
21811            # Off-by-one issue when casting floats to ints
21812            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
21813                         dtypes=(torch.int16, torch.int32, torch.int64),
21814                         device_type="cuda"),
21815            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
21816                         dtypes=(torch.int16, torch.int32, torch.int64),
21817                         device_type="cuda"),
21818            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
21819                         dtypes=(torch.int16, torch.int32, torch.int64),
21820                         device_type="cuda"),
21821        ),
21822    ),
21823    PythonRefInfo(
21824        "_refs.meshgrid",
21825        torch_opinfo_name="meshgrid",
21826        torch_opinfo_variant_name="variadic_tensors",
21827    ),
21828    PythonRefInfo(
21829        "_refs.take_along_dim",
21830        torch_opinfo_name="take_along_dim",
21831        skips=(
21832            DecorateInfo(unittest.expectedFailure,
21833                         'TestCommon',
21834                         'test_python_ref'),
21835        ),
21836    ),
21837    PythonRefInfo(
21838        "_refs.to",
21839        torch_opinfo_name="to",
21840    ),
21841    PythonRefInfo(
21842        "_refs.triu",
21843        torch_opinfo_name="triu",
21844    ),
21845    PythonRefInfo(
21846        "_refs.tril",
21847        torch_opinfo_name="tril",
21848    ),
21849    PythonRefInfo(
21850        "_refs.triu_indices",
21851        torch_opinfo_name="triu_indices",
21852        # the implementation uses torch.stack that violates view consistency
21853        validate_view_consistency=False,
21854        skips=(
21855            # skip these tests since we have non tensor input
21856            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
21857            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
21858            DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
21859            DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
21860        )),
21861    PythonRefInfo(
21862        "_refs.tril_indices",
21863        torch_opinfo_name="tril_indices",
21864        # the implementation uses torch.stack that violates view consistency
21865        validate_view_consistency=False,
21866        skips=(
21867            # skip these tests since we have non tensor input
21868            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'),
21869            DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'),
21870            DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
21871            DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'),
21872        )),
21873    PythonRefInfo(
21874        "_refs.meshgrid",
21875        torch_opinfo_name="meshgrid",
21876        torch_opinfo_variant_name="list_of_tensors",
21877    ),
21878    PythonRefInfo(
21879        "_refs.movedim",
21880        aliases=('moveaxis',),
21881        torch_opinfo_name="movedim",
21882    ),
21883    PythonRefInfo(
21884        "_refs.bucketize",
21885        torch_opinfo_name="bucketize",
21886        skips=(
21887            # RuntimeError: It appears that you're trying to get value out of a tracing tensor with
21888            #  aten._local_scalar_dense.default - erroring out! [...]
21889            # triggered by mid_val = boundaries[mid]
21890            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref_executor"),
21891        )
21892    ),
21893    PythonRefInfo(
21894        "_refs.equal",
21895        torch_opinfo_name="equal",
21896        skips=(
21897            # RuntimeError: Cannot cast FakeTensor to number
21898            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',),
21899        )
21900    ),
21901    ElementwiseUnaryPythonRefInfo(
21902        "_refs.atan",
21903        torch_opinfo_name="atan",
21904        decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
21905        skips=(
21906            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21907                         'test_reference_numerics_extremal',
21908                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21909            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21910                         'test_reference_numerics_large',
21911                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21912            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21913                         'test_reference_numerics_small',
21914                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21915            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21916                         'test_reference_numerics_extremal',
21917                         device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
21918                         active_if=IS_WINDOWS),
21919            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21920                         'test_reference_numerics_large',
21921                         device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
21922                         active_if=IS_WINDOWS),
21923        ),
21924    ),
21925    ElementwiseUnaryPythonRefInfo(
21926        "_refs.atanh",
21927        torch_opinfo_name="atanh",
21928        decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
21929        skips=(
21930            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21931                         'test_reference_numerics_small',
21932                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21933            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21934                         'test_reference_numerics_extremal',
21935                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21936            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21937                         'test_reference_numerics_large',
21938                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
21939            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21940                         'test_reference_numerics_extremal',
21941                         device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
21942                         active_if=IS_WINDOWS),
21943            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21944                         'test_reference_numerics_large',
21945                         device_type='cuda', dtypes=[torch.cfloat],
21946                         active_if=IS_WINDOWS),
21947        ),
21948    ),
21949    ElementwiseUnaryPythonRefInfo(
21950        "_refs.bitwise_not",
21951        torch_opinfo_name="bitwise_not",
21952    ),
21953    ElementwiseUnaryPythonRefInfo(
21954        "_refs.ceil",
21955        torch_opinfo_name="ceil",
21956        # Fails on int32
21957        # https://github.com/pytorch/pytorch/issues/85258
21958    ),
21959    PythonRefInfo(
21960        "_refs.item",
21961        torch_opinfo_name="item",
21962        skips=(
21963            # RuntimeError: Cannot cast FakeTensor(FakeTensor(..., device='meta', size=()), cpu) to number
21964            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'),
21965            # ValueError: Can't convert a tensor with 10 elements to a number!
21966            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),),
21967    ),
21968    ElementwiseUnaryPythonRefInfo(
21969        "_refs.conj_physical",
21970        torch_opinfo_name="conj_physical",
21971    ),
21972    ElementwiseUnaryPythonRefInfo(
21973        "_refs.cos",
21974        torch_opinfo_name="cos",
21975        decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
21976        skips=(
21977            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21978                         'test_reference_numerics_large',
21979                         dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu',
21980                         active_if=IS_WINDOWS),
21981            # This fails on CUDA but passes on ROCm
21982            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21983                         'test_reference_numerics_large',
21984                         dtypes=(torch.cdouble,), device_type='cuda'),
21985            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21986                         'test_reference_numerics_extremal',
21987                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
21988            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
21989                         'test_reference_numerics_extremal',
21990                         device_type='cpu',
21991                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
21992            # AssertionError: Tensor-likes are not close!
21993            # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed)
21994            # Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
21995            DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs',
21996                         'test_reference_numerics_large',
21997                         device_type='cuda',
21998                         dtypes=(torch.chalf,), active_if=IS_WINDOWS),
21999        ),
22000    ),
22001    ElementwiseUnaryPythonRefInfo(
22002        "_refs.cosh",
22003        torch_opinfo_name="cosh",
22004        skips=(
22005            # Reference: https://github.com/pytorch/pytorch/issues/48641
22006            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22007                         'test_reference_numerics_large',
22008                         device_type='cpu', dtypes=[torch.int8]),
22009            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22010                         'test_reference_numerics_large',
22011                         dtypes=[torch.cdouble]),
22012            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22013                         'test_reference_numerics_extremal',
22014                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
22015            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22016                         'test_reference_numerics_large',
22017                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
22018            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22019                         'test_reference_numerics_extremal',
22020                         device_type='cpu',
22021                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
22022            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22023                         'test_reference_numerics_large',
22024                         device_type='cpu',
22025                         dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
22026            # AssertionError: Tensor-likes are not close!
22027            # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed)
22028            # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed)
22029            DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs',
22030                         'test_reference_numerics_large',
22031                         device_type='cuda',
22032                         dtypes=(torch.chalf,), active_if=IS_WINDOWS),
22033        ),
22034    ),
22035    ElementwiseUnaryPythonRefInfo(
22036        "_refs.digamma",
22037        torch_opinfo_name="digamma",
22038    ),
22039    ElementwiseUnaryPythonRefInfo(
22040        "_refs.erf",
22041        torch_opinfo_name="erf",
22042    ),
22043    ElementwiseUnaryPythonRefInfo(
22044        "_refs.erfinv",
22045        torch_opinfo_name="erfinv",
22046        decorators=(precisionOverride({torch.float16: 1e-2,
22047                                       torch.bfloat16: 1e-2,
22048                                       torch.float32: 1e-4}),),
22049        skips=(
22050            # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611
22051            DecorateInfo(
22052                unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22053                'test_reference_numerics_extremal',
22054                active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
22055            DecorateInfo(
22056                unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22057                'test_reference_numerics_large',
22058                active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
22059            DecorateInfo(
22060                unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22061                'test_reference_numerics_small',
22062                active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")),
22063        ),
22064    ),
22065    ElementwiseUnaryPythonRefInfo(
22066        "_refs.erfc",
22067        torch_opinfo_name="erfc",
22068    ),
22069    ElementwiseUnaryPythonRefInfo(
22070        "_refs.exp",
22071        torch_opinfo_name="exp",
22072        skips=(
22073            # Reference: https://github.com/pytorch/pytorch/issues/48010
22074            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22075                         'test_reference_numerics_extremal',
22076                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
22077            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22078                         'test_reference_numerics_large',
22079                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
22080        ),
22081    ),
22082    ElementwiseUnaryPythonRefInfo(
22083        "_refs.expm1",
22084        torch_opinfo_name="expm1",
22085    ),
22086    ElementwiseUnaryPythonRefInfo(
22087        "_refs.exp2",
22088        torch_opinfo_name="exp2",
22089        skips=(
22090            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22091                         'test_reference_numerics_large',
22092                         dtypes=[torch.cdouble]),
22093            # Reference: https://github.com/pytorch/pytorch/issues/48010
22094            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22095                         'test_reference_numerics_extremal',
22096                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
22097            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22098                         'test_reference_numerics_large',
22099                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
22100        ),
22101    ),
22102    ElementwiseUnaryPythonRefInfo(
22103        "_refs.fill",
22104        torch_opinfo_name="fill",
22105        supports_out=True,
22106    ),
22107    ElementwiseUnaryPythonRefInfo(
22108        "_refs.floor",
22109        torch_opinfo_name="floor",
22110        # Fails on int32
22111        # https://github.com/pytorch/pytorch/issues/85258
22112    ),
22113    ElementwiseUnaryPythonRefInfo(
22114        "_refs.frexp",
22115        torch_opinfo_name="frexp",
22116        # Skipped due to numerical failures on Windows CI.
22117        # This is also skipped in frexp earlier in the file.
22118        skips=(
22119            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
22120                         active_if=IS_WINDOWS),
22121        ),
22122    ),
22123    ElementwiseUnaryPythonRefInfo(
22124        "_refs.frac",
22125        torch_opinfo_name="frac",
22126        skips=(
22127            DecorateInfo(
22128                unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22129                'test_reference_numerics_extremal',
22130                dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)),
22131        ),
22132    ),
22133    ElementwiseUnaryPythonRefInfo(
22134        "_refs.imag",
22135        torch_opinfo_name="imag",
22136    ),
22137    ElementwiseUnaryPythonRefInfo(
22138        "_refs.isfinite",
22139        torch_opinfo_name="isfinite",
22140        supports_out=True,
22141    ),
22142    ElementwiseUnaryPythonRefInfo(
22143        "_refs.isinf",
22144        torch_opinfo_name="isinf",
22145        supports_out=True,
22146    ),
22147    ElementwiseUnaryPythonRefInfo(
22148        "_refs.isposinf",
22149        torch_opinfo_name="isposinf",
22150        supports_out=True,
22151    ),
22152    ElementwiseUnaryPythonRefInfo(
22153        "_refs.isneginf",
22154        torch_opinfo_name="isneginf",
22155        supports_out=True,
22156    ),
22157    ElementwiseUnaryPythonRefInfo(
22158        "_refs.isnan",
22159        torch_opinfo_name="isnan",
22160        supports_out=True,
22161    ),
22162    ElementwiseUnaryPythonRefInfo(
22163        "_refs.isreal",
22164        torch_opinfo_name="isreal",
22165        supports_out=True,
22166    ),
22167    ElementwiseUnaryPythonRefInfo(
22168        "_refs.i0",
22169        torch_opinfo_name="i0",
22170        decorators=(precisionOverride({torch.bfloat16: 3e-1,
22171                                       torch.float16: 5e-1}),),
22172        skips=(
22173            DecorateInfo(unittest.skip("Skipped!"),
22174                         'TestUnaryUfuncs',
22175                         'test_reference_numerics_large',
22176                         dtypes=(torch.int8,)),
22177        ),
22178    ),
22179    ElementwiseUnaryPythonRefInfo(
22180        "_refs.lgamma",
22181        torch_opinfo_name="lgamma",
22182        decorators=(precisionOverride({torch.float16: 7e-1}),),
22183        skips=(
22184            # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214
22185            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22186                         'test_reference_numerics_extremal',
22187                         dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
22188            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22189                         'test_reference_numerics_large',
22190                         dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
22191        ),
22192    ),
22193    ElementwiseUnaryPythonRefInfo(
22194        "_refs.special.multigammaln",
22195        torch_opinfo_name="mvlgamma",
22196        torch_opinfo_variant_name="mvlgamma_p_1",
22197        skips=skips_mvlgamma(),
22198        decorators=(
22199            DecorateInfo(torch.testing._internal.common_utils.markDynamoStrictTest, 'TestUnaryUfuncs',
22200                         'test_reference_numerics_large'),
22201            DecorateInfo(torch.testing._internal.common_utils.xfailIfTorchDynamo, 'TestUnaryUfuncs',
22202                         'test_reference_numerics_large'),
22203        ),
22204    ),
22205    ElementwiseUnaryPythonRefInfo(
22206        "_refs.special.multigammaln",
22207        torch_opinfo_name="mvlgamma",
22208        torch_opinfo_variant_name="mvlgamma_p_3",
22209        skips=skips_mvlgamma(),
22210    ),
22211    ElementwiseUnaryPythonRefInfo(
22212        "_refs.special.multigammaln",
22213        torch_opinfo_name="mvlgamma",
22214        torch_opinfo_variant_name="mvlgamma_p_5",
22215        skips=skips_mvlgamma(),
22216    ),
22217    ElementwiseUnaryPythonRefInfo(
22218        "_refs.log",
22219        torch_opinfo_name="log",
22220        decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
22221        skips=(
22222            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22223                         'test_reference_numerics_extremal',
22224                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
22225                         active_if=IS_WINDOWS),
22226        ),
22227    ),
22228    ElementwiseUnaryPythonRefInfo(
22229        "_refs.log1p",
22230        torch_opinfo_name="log1p",
22231    ),
22232    ElementwiseUnaryPythonRefInfo(
22233        "_refs.log10",
22234        torch_opinfo_name="log10",
22235        decorators=(precisionOverride({torch.bfloat16: 5e-2}),),
22236        skips=(
22237            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22238                         'test_reference_numerics_extremal',
22239                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
22240                         active_if=IS_WINDOWS),
22241        ),
22242    ),
22243    ElementwiseUnaryPythonRefInfo(
22244        "_refs.log2",
22245        torch_opinfo_name="log2",
22246        decorators=(precisionOverride({torch.bfloat16: 1e-1}),),
22247        skips=(
22248            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22249                         'test_reference_numerics_extremal',
22250                         dtypes=[torch.cfloat, torch.cdouble]),
22251        ),
22252    ),
22253    PythonRefInfo(
22254        "_refs.logsumexp",
22255        torch_opinfo_name="logsumexp",
22256        # When keepdim=False logsumexp function uses squeeze operation
22257        # that is not yet exposed in nvFuser's Python API.
22258    ),
22259    PythonRefInfo(
22260        "_refs.log_softmax",
22261        torch_opinfo_name="log_softmax",
22262        torch_opinfo_variant_name="with_dtype",
22263    ),
22264    ElementwiseUnaryPythonRefInfo(
22265        "_refs.nan_to_num",
22266        torch_opinfo_name="nan_to_num",
22267    ),
22268    ElementwiseUnaryPythonRefInfo(
22269        "_refs.neg",
22270        torch_opinfo_name="neg",
22271    ),
22272    ElementwiseUnaryPythonRefInfo(
22273        "_refs.positive",
22274        torch_opinfo_name="positive",
22275    ),
22276    ElementwiseUnaryPythonRefInfo(
22277        "_refs.real",
22278        torch_opinfo_name="real",
22279    ),
22280    ElementwiseUnaryPythonRefInfo(
22281        "_refs.reciprocal",
22282        torch_opinfo_name="reciprocal",
22283        skips=(
22284            # Reference: https://github.com/pytorch/pytorch/issues/45690
22285            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22286                         'test_reference_numerics_extremal',
22287                         dtypes=[torch.cfloat, torch.cdouble]),
22288        ),
22289    ),
22290    ElementwiseUnaryPythonRefInfo(
22291        "_refs.round",
22292        torch_opinfo_name="round",
22293        # Fails on int32
22294        # https://github.com/pytorch/pytorch/issues/85258
22295        skips=(
22296            DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
22297                         "TestUnaryUfuncs", "test_reference_numerics_extremal",
22298                         device_type="cuda"),
22299            DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}),
22300                         "TestUnaryUfuncs", "test_reference_numerics_normal",
22301                         device_type="cuda"),
22302        ),
22303    ),
22304    ElementwiseUnaryPythonRefInfo(
22305        "_refs.rsqrt",
22306        torch_opinfo_name="rsqrt",
22307        decorators=(precisionOverride({torch.half: 5e-2}),),
22308        skips=(
22309            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22310                         'test_reference_numerics_extremal',
22311                         dtypes=(torch.cfloat, torch.cdouble)),
22312            # AssertionError: Tensor-likes are not close!
22313            # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed)
22314            # Greatest relative difference: nan at index (700,) (up to 0.001 allowed)
22315            DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs',
22316                         'test_reference_numerics_large',
22317                         dtypes=(torch.chalf,)),
22318        ),
22319    ),
22320    ElementwiseUnaryPythonRefInfo(
22321        "_refs.sigmoid",
22322        torch_opinfo_name="sigmoid",
22323        aliases=('_refs.special.expit',),
22324        # Reference: https://github.com/pytorch/pytorch/issues/56012
22325        handles_complex_extremal_values=False,
22326        handles_large_floats=False,
22327        decorators=(precisionOverride({torch.float16: 1e-2,
22328                                       torch.complex64: 1e-1,
22329                                       torch.bfloat16: 1e-2}),),
22330        skips=(
22331            # Reference: https://github.com/pytorch/pytorch/issues/56012
22332            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22333                         'test_reference_numerics_extremal',
22334                         dtypes=[torch.complex64, torch.cdouble]),
22335            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22336                         'test_reference_numerics_large',
22337                         dtypes=[torch.chalf, torch.complex64, torch.cdouble])
22338        ),
22339    ),
22340    ElementwiseUnaryPythonRefInfo(
22341        "_refs.sign",
22342        torch_opinfo_name="sign",
22343        skips=(
22344            # Reference: https://github.com/pytorch/pytorch/issues/41245
22345            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22346                         'test_reference_numerics_extremal',
22347                         dtypes=[torch.bfloat16, torch.float16, torch.float32,
22348                                 torch.float64]),
22349        ),
22350    ),
22351    ElementwiseUnaryPythonRefInfo(
22352        "_refs.sgn",
22353        torch_opinfo_name="sgn",
22354        # This is an issue with the vectorised abs on CPU
22355        handles_complex_extremal_values=False,
22356        handles_large_floats=False,
22357        skips=(
22358            # Reference: https://github.com/pytorch/pytorch/issues/41245
22359            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22360                         'test_reference_numerics_extremal',
22361                         dtypes=[torch.bfloat16, torch.float16, torch.float32,
22362                                 torch.float64]),
22363        ),
22364    ),
22365    ElementwiseUnaryPythonRefInfo(
22366        "_refs.signbit",
22367        torch_opinfo_name="signbit",
22368    ),
22369    ElementwiseUnaryPythonRefInfo(
22370        "_refs.sin",
22371        torch_opinfo_name="sin",
22372        decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
22373        skips=(
22374            # Fails on CUDA but passes on ROCm
22375            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22376                         'test_reference_numerics_large',
22377                         dtypes=(torch.cdouble,), device_type='cuda'),
22378            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22379                         'test_reference_numerics_extremal',
22380                         dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu',
22381                         active_if=IS_WINDOWS),
22382            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22383                         'test_reference_numerics_large',
22384                         dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu',
22385                         active_if=IS_WINDOWS),
22386        ),
22387    ),
22388    ElementwiseUnaryPythonRefInfo(
22389        "_refs.sinc",
22390        torch_opinfo_name="sinc",
22391        decorators=(precisionOverride({torch.bfloat16: 1e-2,
22392                                       torch.float16: 1e-2}),),
22393        skips=(
22394            # Reference: https://github.com/pytorch/pytorch/issues/49133
22395            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22396                         'test_reference_numerics_small',
22397                         dtypes=[torch.cfloat]),
22398        ),
22399    ),
22400    ElementwiseUnaryPythonRefInfo(
22401        "_refs.sinh",
22402        torch_opinfo_name="sinh",
22403        decorators=(precisionOverride({torch.float16: 1e-2}),),
22404        skips=(
22405            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22406                         'test_reference_numerics_extremal',
22407                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
22408                         active_if=(IS_MACOS or IS_WINDOWS)),
22409            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22410                         'test_reference_numerics_large',
22411                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
22412                         active_if=(IS_MACOS or IS_WINDOWS)),
22413            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22414                         'test_reference_numerics_large',
22415                         dtypes=(torch.cdouble,)),
22416            # Reference: https://github.com/pytorch/pytorch/issues/48641
22417            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22418                         'test_reference_numerics_large',
22419                         device_type='cpu', dtypes=[torch.int8]),
22420        ),
22421    ),
22422    PythonRefInfo(
22423        "_refs.softmax",
22424        torch_opinfo_name="softmax",
22425        torch_opinfo_variant_name="with_dtype",
22426    ),
22427    ElementwiseUnaryPythonRefInfo(
22428        "_refs.sqrt",
22429        torch_opinfo_name="sqrt",
22430        decorators=(
22431            precisionOverride({torch.bfloat16: 7e-2}),
22432            DecorateInfo(
22433                toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
22434                'TestUnaryUfuncs', 'test_reference_numerics_large'),
22435        ),
22436        skips=(
22437            # Reference: https://github.com/pytorch/pytorch/issues/47358
22438            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22439                         'test_reference_numerics_large',
22440                         device_type='cpu', dtypes=(torch.cfloat, torch.cdouble),
22441                         active_if=IS_MACOS),
22442            # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436
22443            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22444                         'test_reference_numerics_large',
22445                         dtypes=(torch.bfloat16,)),
22446        ),
22447    ),
22448    ElementwiseUnaryPythonRefInfo(
22449        "_refs.square",
22450        torch_opinfo_name="square",
22451        decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),),
22452        skips=(
22453            # AssertionError: Reference result was farther (2.2417024338305655e-07) from the precise computation
22454            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', dtypes=(torch.complex64,)),
22455            # Reference: https://github.com/pytorch/pytorch/issues/52549
22456            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22457                         'test_reference_numerics_large',
22458                         dtypes=[torch.cfloat, torch.cdouble]),
22459            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22460                         'test_reference_numerics_extremal',
22461                         device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
22462        ),
22463    ),
22464    ElementwiseUnaryPythonRefInfo(
22465        "_refs.tan",
22466        torch_opinfo_name="tan",
22467        decorators=[
22468            DecorateInfo(
22469                toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}),
22470                'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'),
22471        ],
22472        skips=(
22473            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22474                         'test_reference_numerics_extremal',
22475                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
22476                         active_if=(IS_MACOS or IS_WINDOWS)),
22477            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22478                         'test_reference_numerics_large',
22479                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
22480                         active_if=(IS_MACOS or IS_WINDOWS)),
22481        )
22482    ),
22483    ElementwiseUnaryPythonRefInfo(
22484        "_refs.tanh",
22485        torch_opinfo_name="tanh",
22486        decorators=[
22487            DecorateInfo(
22488                toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}),
22489                'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'),
22490        ],
22491        skips=(
22492            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22493                         'test_reference_numerics_extremal',
22494                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
22495                         active_if=(IS_MACOS or IS_WINDOWS)),
22496            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22497                         'test_reference_numerics_large',
22498                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
22499                         active_if=(IS_MACOS or IS_WINDOWS)),
22500        ),
22501    ),
22502    ElementwiseUnaryPythonRefInfo(
22503        "_refs.trunc",
22504        torch_opinfo_name="trunc",
22505        # Fails on int32
22506        # https://github.com/pytorch/pytorch/issues/85258
22507    ),
22508    PythonRefInfo(
22509        "_refs.special.log_softmax",
22510        torch_opinfo_name="log_softmax",  # alias
22511        torch_opinfo_variant_name="with_dtype",
22512        supports_out=False,
22513    ),
22514    PythonRefInfo(
22515        "_refs.special.softmax",
22516        torch_opinfo_name="softmax",  # alias
22517        torch_opinfo_variant_name="with_dtype",
22518        supports_out=False,
22519    ),
22520    #
22521    # Elementwise Unary Special OpInfos
22522    #
22523    ElementwiseUnaryPythonRefInfo(
22524        "_refs.special.logit",
22525        torch_opinfo_name="logit",
22526    ),
22527    #
22528    # Elementwise Unary nn.functional OpInfos
22529    #
22530    PythonRefInfo(
22531        "_refs.nn.functional.alpha_dropout",
22532        torch_opinfo_name="nn.functional.alpha_dropout",
22533        decorators=(
22534            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22535                         'TestCommon',
22536                         'test_python_ref'),
22537            # AssertionError: Tensor-likes are not close!
22538            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22539                         'TestCommon',
22540                         'test_python_ref_torch_fallback'),
22541            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22542                         'TestCommon',
22543                         'test_python_ref_executor', device_type='cuda'),
22544            # AssertionError: Tensor-likes are not close!
22545            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22546                         'TestMathBits',
22547                         'test_neg_view'),
22548            # AssertionError: Tensor-likes are not close!
22549            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22550                         'TestCommon',
22551                         'test_compare_cpu'),
22552        )
22553    ),
22554    ElementwiseUnaryPythonRefInfo(
22555        "_refs.nn.functional.celu",
22556        torch_opinfo_name="nn.functional.celu",
22557        supports_out=True,
22558    ),
22559    PythonRefInfo(
22560        "_refs.nn.functional.channel_shuffle",
22561        torch_opinfo_name="nn.functional.channel_shuffle",
22562        supports_out=True,
22563    ),
22564    ElementwiseUnaryPythonRefInfo(
22565        "_refs.nn.functional.threshold",
22566        torch_opinfo_name="nn.functional.threshold",
22567        supports_out=True,
22568    ),
22569    PythonRefInfo(
22570        "_refs.nn.functional.dropout",
22571        torch_opinfo_name="nn.functional.dropout",
22572        decorators=(
22573            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22574                         'TestCommon',
22575                         'test_python_ref'),
22576            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22577                         'TestCommon',
22578                         'test_python_ref_torch_fallback'),
22579            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22580                         'TestCommon',
22581                         'test_out'),
22582            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22583                         'TestCommon',
22584                         'test_out_warning'),
22585            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22586                         'TestMathBits',
22587                         'test_conj_view'),
22588            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22589                         'TestMathBits',
22590                         'test_neg_conj_view'),
22591            DecorateInfo(unittest.skip("Expected: dropout is not comparable"),
22592                         'TestMathBits',
22593                         'test_neg_view'),
22594            # dropout is not comparable
22595            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
22596            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
22597        )
22598    ),
22599    ElementwiseUnaryPythonRefInfo(
22600        "_refs.nn.functional.elu",
22601        torch_opinfo_name="nn.functional.elu",
22602        supports_out=True,
22603        decorators=[
22604            DecorateInfo(
22605                toleranceOverride({
22606                    torch.float16: tol(atol=1e-03, rtol=1.2e-03),
22607                    torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03)
22608                }),
22609                'TestUnaryUfuncs', device_type='cuda',
22610            ), ],
22611    ),
22612    ElementwiseUnaryPythonRefInfo(
22613        "_refs.nn.functional.hardtanh",
22614        torch_opinfo_name="nn.functional.hardtanh",
22615        supports_out=True,
22616    ),
22617    PythonRefInfo(  # TODO: Port this to an UnaryOpInfo
22618        "_refs.nn.functional.gelu",
22619        torch_opinfo_name="nn.functional.gelu",
22620    ),
22621    PythonRefInfo(
22622        "_refs.nn.functional.layer_norm",
22623        torch_opinfo_name="nn.functional.layer_norm",
22624        skips=(
22625            # Reference result was farther (3.5762786809723224e-07) from the precise computation
22626            # than the torch result was (2.5068410824946596e-07)!
22627            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
22628                         dtypes=(torch.float32,), device_type='cpu'),
22629        ),
22630    ),
22631    PythonRefInfo(
22632        "_refs.nn.functional.glu",
22633        torch_opinfo_name="nn.functional.glu",
22634        supports_out=True,
22635    ),
22636    PythonRefInfo(
22637        "_refs.nn.functional.pairwise_distance",
22638        torch_opinfo_name="nn.functional.pairwise_distance",
22639        supports_out=True,
22640    ),
22641    PythonRefInfo(
22642        "_refs.nn.functional.pdist",
22643        torch_opinfo_name="nn.functional.pdist",
22644        supports_out=True,
22645        skips=(
22646            # RunTimeError: no _refs support for torch.Tensor.index_select
22647            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
22648            # Reference result was farther (1.946091651916504e-05) from the precise
22649            # computation than the torch result was (1.1920928955078125e-06)!
22650            DecorateInfo(
22651                unittest.expectedFailure,
22652                'TestCommon',
22653                'test_python_ref_torch_fallback',
22654                dtypes=(torch.float32,),
22655                device_type='cpu',
22656            ),
22657        )),
22658    PythonRefInfo(
22659        "_refs.nn.functional.leaky_relu",
22660        torch_opinfo_name="nn.functional.leaky_relu",
22661        supports_out=True,
22662    ),
22663    PythonRefInfo(
22664        "_refs.nn.functional.log_softmax",
22665        torch_opinfo_name="log_softmax",  # alias
22666        torch_opinfo_variant_name="with_dtype",
22667        supports_out=False,
22668    ),
22669    PythonRefInfo(
22670        "_refs.nn.functional.pixel_shuffle",
22671        torch_opinfo_name="nn.functional.pixel_shuffle",
22672    ),
22673    PythonRefInfo(
22674        "_refs.nn.functional.pixel_unshuffle",
22675        torch_opinfo_name="nn.functional.pixel_unshuffle",
22676    ),
22677    PythonRefInfo(
22678        "_refs.nn.functional.poisson_nll_loss",
22679        torch_opinfo_name="nn.functional.poisson_nll_loss",
22680    ),
22681    ElementwiseUnaryPythonRefInfo(
22682        "_refs.nn.functional.prelu",
22683        torch_opinfo_name="nn.functional.prelu",
22684    ),
22685    ElementwiseUnaryPythonRefInfo(
22686        "_refs.nn.functional.relu",
22687        torch_opinfo_name="nn.functional.relu",
22688        supports_out=True,
22689    ),
22690    ElementwiseUnaryPythonRefInfo(
22691        "_refs.nn.functional.relu6",
22692        torch_opinfo_name="nn.functional.relu6",
22693        supports_out=True,
22694    ),
22695    ElementwiseUnaryPythonRefInfo(
22696        "_refs.nn.functional.mish",
22697        torch_opinfo_name="nn.functional.mish",
22698        supports_out=True,
22699        decorators=[
22700            DecorateInfo(
22701                toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}),
22702                'TestUnaryUfuncs',), ],
22703    ),
22704    ElementwiseUnaryPythonRefInfo(
22705        "_refs.nn.functional.selu",
22706        torch_opinfo_name="nn.functional.selu",
22707        supports_out=True,
22708        decorators=[
22709            DecorateInfo(
22710                toleranceOverride({
22711                    torch.float16: tol(atol=1e-2, rtol=1.8e-2),
22712                    torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2)
22713                }),
22714                'TestUnaryUfuncs', device_type='cuda',
22715            ), ],
22716    ),
22717    PythonRefInfo(
22718        "_refs.nn.functional.softmax",
22719        torch_opinfo_name="softmax",  # alias
22720        torch_opinfo_variant_name="with_dtype",
22721        supports_out=False,
22722    ),
22723    PythonRefInfo(
22724        "_refs.nn.functional.softmin",
22725        torch_opinfo_name="nn.functional.softmin",
22726        torch_opinfo_variant_name="with_dtype",
22727        supports_out=False,
22728    ),
22729    ElementwiseUnaryPythonRefInfo(
22730        "_refs.nn.functional.softplus",
22731        torch_opinfo_name="nn.functional.softplus",
22732    ),
22733    PythonRefInfo(
22734        "_refs.nn.functional.l1_loss",
22735        torch_opinfo_name="nn.functional.l1_loss",
22736    ),
22737    PythonRefInfo(
22738        "_refs.nn.functional.margin_ranking_loss",
22739        torch_opinfo_name="nn.functional.margin_ranking_loss",
22740    ),
22741    PythonRefInfo(
22742        "_refs.nn.functional.mse_loss",
22743        torch_opinfo_name="nn.functional.mse_loss",
22744    ),
22745    PythonRefInfo(
22746        "_refs.nn.functional.smooth_l1_loss",
22747        torch_opinfo_name="nn.functional.smooth_l1_loss",
22748    ),
22749    PythonRefInfo(
22750        "_refs.nn.functional.hinge_embedding_loss",
22751        torch_opinfo_name="nn.functional.hinge_embedding_loss",
22752        skips=(
22753            # Reference result was farther (0.29562714856322714) from the precise
22754            # computation than the torch result was (0.20437285143677286)!
22755            DecorateInfo(
22756                unittest.expectedFailure, 'TestCommon', 'test_python_ref',
22757                dtypes=(torch.bfloat16,), device_type="cpu"
22758            ),
22759        ),
22760    ),
22761    PythonRefInfo(
22762        "_refs.nn.functional.nll_loss",
22763        torch_opinfo_name="nn.functional.nll_loss",
22764        # The corresponding PyTorch op doesn't support out.  But the ref is
22765        # registered as a decomp and ATen has an out variant.
22766        supports_out=True,
22767        # For simpler indexing, we flatten target indices, then reshape the result tensor.
22768        # This creates inconsistent view state with reference impl.
22769        validate_view_consistency=False,
22770        skips=(
22771            # RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out!
22772            DecorateInfo(
22773                unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', device_type="cuda"
22774            ),
22775        ),
22776    ),
22777    PythonRefInfo(
22778        "_refs.nn.functional.huber_loss",
22779        torch_opinfo_name="nn.functional.huber_loss",
22780        # The corresponding PyTorch op doesn't support out.  But the ref is
22781        # registered as a decomp and ATen has an out variant.
22782        supports_out=True,
22783    ),
22784    ElementwiseUnaryPythonRefInfo(
22785        "_refs.nn.functional.tanhshrink",
22786        torch_opinfo_name="nn.functional.tanhshrink",
22787        decorators=[
22788            DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
22789                         'test_reference_numerics_normal',
22790                         device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
22791            DecorateInfo(
22792                toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02),
22793                                   torch.complex64: tol(atol=6e-04, rtol=1e-05)}),
22794                'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'),
22795        ],
22796        skips=(
22797            # in each case, pytorch will produce a nan while numpy will not
22798            DecorateInfo(unittest.skip("Fails on some jobs works on others!"),
22799                         'TestUnaryUfuncs', "test_reference_numerics_large",
22800                         dtypes=(torch.complex64, torch.complex128),
22801                         active_if=(IS_MACOS)),
22802            DecorateInfo(unittest.skip("Fails on some jobs works on others!"),
22803                         'TestUnaryUfuncs', "test_reference_numerics_extremal",
22804                         dtypes=(torch.complex64, torch.complex128),
22805                         device_type='cpu',
22806                         active_if=(IS_MACOS or IS_WINDOWS)),
22807        ),
22808    ),
22809    ElementwiseUnaryPythonRefInfo(
22810        "_refs.nn.functional.hardshrink",
22811        torch_opinfo_name="nn.functional.hardshrink",
22812    ),
22813    ElementwiseUnaryPythonRefInfo(
22814        "_refs.nn.functional.softshrink",
22815        torch_opinfo_name="nn.functional.softshrink",
22816    ),
22817    #
22818    # Elementwise Binary Reference OpInfos
22819    #
22820    ElementwiseBinaryPythonRefInfo(
22821        "_refs.add",
22822        torch_opinfo_name="add",
22823        # https://github.com/pytorch/pytorch/issues/76944
22824        supports_two_python_scalars=True,
22825        supports_one_python_scalar=True,
22826        decorators=(
22827            DecorateInfo(
22828                toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
22829                'TestBinaryUfuncs', 'test_reference_numerics'),
22830        ),
22831        skips=(
22832            DecorateInfo(unittest.skip("Skipped!"),
22833                         'TestBinaryUfuncs',
22834                         'test_reference_numerics_extremal_values',
22835                         dtypes=(torch.complex64, torch.complex128)),
22836        ),
22837    ),
22838    ElementwiseBinaryPythonRefInfo(
22839        "_refs.atan2",
22840        torch_opinfo_name="atan2",
22841    ),
22842    ElementwiseBinaryPythonRefInfo(
22843        "_refs.bitwise_and",
22844        torch_opinfo_name="bitwise_and",
22845    ),
22846    ElementwiseBinaryPythonRefInfo(
22847        "_refs.bitwise_left_shift",
22848        torch_opinfo_name="bitwise_left_shift",
22849        skips=(
22850            # https://github.com/pytorch/pytorch/issues/70904
22851            DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
22852        ),
22853    ),
22854    ElementwiseBinaryPythonRefInfo(
22855        "_refs.bitwise_right_shift",
22856        torch_opinfo_name="bitwise_right_shift",
22857        skips=(
22858            # # https://github.com/pytorch/pytorch/issues/70904
22859            DecorateInfo(unittest.skip("Skipped some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'),
22860        ),
22861    ),
22862    ElementwiseBinaryPythonRefInfo(
22863        "_refs.bitwise_or",
22864        torch_opinfo_name="bitwise_or",
22865    ),
22866    ElementwiseBinaryPythonRefInfo(
22867        "_refs.bitwise_xor",
22868        torch_opinfo_name="bitwise_xor",
22869    ),
22870    ElementwiseBinaryPythonRefInfo(
22871        "_refs.copysign",
22872        torch_opinfo_name="copysign",
22873        skips=(
22874            # RuntimeError: Expected divisor (b) to be on the same device (cuda:0) as dividend (a), but it is found on cpu!
22875            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'),
22876            # FIXME output 0: meta disagrees with real impl
22877            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
22878        )
22879    ),
22880    ElementwiseBinaryPythonRefInfo(
22881        "_refs.div",
22882        torch_opinfo_name="div",
22883        torch_opinfo_variant_name="no_rounding_mode",
22884        # https://github.com/pytorch/pytorch/issues/76944
22885        supports_two_python_scalars=True,
22886        supports_one_python_scalar=True,
22887        skips=(
22888            # NotImplementedError: argument of type: <class 'complex'>
22889            DecorateInfo(
22890                unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor',
22891                dtypes=(torch.complex32, torch.complex64, torch.complex128,)
22892            ),
22893            # Reference result was farther (0.7433461727239705) from the precise
22894            # computation than the torch result was (nan)!
22895            DecorateInfo(
22896                unittest.expectedFailure, 'TestCommon', 'test_python_ref',
22897                dtypes=(torch.complex32,), device_type="cuda"
22898            ),
22899            # Reference result was farther (0.7433461727239705) from the precise
22900            # computation than the torch result was (nan)!
22901            DecorateInfo(
22902                unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
22903                dtypes=(torch.complex32,), device_type="cuda"
22904            ),
22905        ),
22906    ),
22907    ElementwiseBinaryPythonRefInfo(
22908        "_refs.div",
22909        torch_opinfo_name="div",
22910        torch_opinfo_variant_name="trunc_rounding",
22911        # https://github.com/pytorch/pytorch/issues/76944
22912        supports_two_python_scalars=True,
22913        supports_one_python_scalar=True,
22914        decorators=(
22915            # See https://github.com/pytorch/pytorch/issues/111126
22916            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
22917        ),
22918    ),
22919    ElementwiseBinaryPythonRefInfo(
22920        "_refs.div",
22921        torch_opinfo_name="div",
22922        torch_opinfo_variant_name="floor_rounding",
22923        # https://github.com/pytorch/pytorch/issues/76944
22924        supports_two_python_scalars=True,
22925        supports_one_python_scalar=True,
22926        decorators=(
22927            # See https://github.com/pytorch/pytorch/issues/111126
22928            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
22929            # Reference result was farther (nan) from the precise computation than the
22930            # torch result was (inf)!
22931            DecorateInfo(
22932                unittest.expectedFailure,
22933                "TestCommon",
22934                "test_python_ref",
22935                dtypes=(torch.bfloat16,),
22936                device_type="cpu",
22937            ),
22938        ),
22939    ),
22940    ElementwiseBinaryPythonRefInfo(
22941        "_refs.eq",
22942        torch_opinfo_name="eq",
22943    ),
22944    ElementwiseBinaryPythonRefInfo(
22945        "_refs.float_power",
22946        torch_opinfo_name="float_power",
22947        skips=(
22948            # Test doesn't account for float -> double type promotion
22949            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
22950            # Complex values error with: Greatest absolute difference: nan at index
22951            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
22952                         'test_reference_numerics_small_values',
22953                         dtypes=[torch.complex64, torch.complex128]),
22954            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
22955                         'test_reference_numerics_large_values',
22956                         dtypes=[torch.complex64, torch.complex128]),
22957            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
22958                         'test_reference_numerics_extremal_values',
22959                         dtypes=[torch.complex64, torch.complex128]),
22960        ),
22961    ),
22962    ElementwiseBinaryPythonRefInfo(
22963        "_refs.logaddexp",
22964        torch_opinfo_name="logaddexp",
22965        skips=(
22966            # failure due to mismatch in edge cases, which boils down to what torch.exp(inf + infj) should be
22967            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='cpu',
22968                         dtypes=(torch.complex64, torch.complex128)),
22969            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='cpu',
22970                         dtypes=(torch.complex64, torch.complex128)),
22971        ),
22972    ),
22973    PythonRefInfo(
22974        "_refs.logaddexp2",
22975        torch_opinfo_name="logaddexp2",
22976    ),
22977    ElementwiseBinaryPythonRefInfo(
22978        "_refs.floor_divide",
22979        torch_opinfo_name="floor_divide",
22980        rhs_make_tensor_kwargs=dict(exclude_zero=True),
22981        # https://github.com/pytorch/pytorch/issues/76944
22982        supports_two_python_scalars=True,
22983        supports_one_python_scalar=True,
22984        # bfloat16 floor_divide compared with a float32 reference works inconsistently
22985        skips=(
22986            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
22987                         dtypes=(torch.bfloat16,)),
22988            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback',
22989                         dtypes=(torch.bfloat16,)),
22990            # bfloat16 floor_divide compared with a float32 reference works inconsistently
22991            DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs',
22992                         dtypes=(torch.bfloat16,)),
22993            # int8 floor divide has different results for -128 // -1 vs. NumPy
22994            DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs',
22995                         'test_reference_numerics_small_values',
22996                         dtypes=(torch.int8,)),
22997            # The following tests fails on some jobs
22998            DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs',
22999                         'test_reference_numerics_extremal_values',
23000                         dtypes=(torch.float16,)),
23001            DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}),
23002                         'TestBinaryUfuncs', 'test_reference_numerics'),
23003            # FIXME output 0: meta disagrees with real impl
23004            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
23005        ),
23006    ),
23007    ElementwiseBinaryPythonRefInfo(
23008        "_refs.fmax",
23009        torch_opinfo_name="fmax",
23010        supports_rhs_python_scalar=False,
23011    ),
23012    ElementwiseBinaryPythonRefInfo(
23013        "_refs.fmin",
23014        torch_opinfo_name="fmin",
23015        supports_rhs_python_scalar=False,
23016    ),
23017    ElementwiseBinaryPythonRefInfo(
23018        "_refs.fmod",
23019        torch_opinfo_name="fmod",
23020        rhs_make_tensor_kwargs={'exclude_zero': True},
23021        supports_rhs_python_scalar=True,
23022        skips=(
23023            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
23024                         dtypes=(torch.bfloat16,), device_type='cpu'),
23025            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback',
23026                         dtypes=(torch.bfloat16,), device_type='cpu'),
23027            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
23028                         'test_contig_vs_every_other',
23029                         dtypes=(torch.bfloat16,)),
23030            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
23031                         'test_non_contig',
23032                         dtypes=(torch.bfloat16,)),
23033            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
23034                         'test_reference_numerics',
23035                         dtypes=(torch.bfloat16,)),
23036            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
23037                         'test_reference_numerics_small_values',
23038                         dtypes=(torch.uint8,)),
23039        ),
23040    ),
23041    ElementwiseBinaryPythonRefInfo(
23042        "_refs.gcd",
23043        torch_opinfo_name="gcd",
23044        skips=(
23045            DecorateInfo(unittest.expectedFailure,
23046                         'TestBinaryUfuncs',
23047                         'test_reference_numerics_small_values',
23048                         dtypes=(torch.int8,)),
23049        ),
23050    ),
23051    ElementwiseBinaryPythonRefInfo(
23052        "_refs.ge",
23053        torch_opinfo_name="ge",
23054    ),
23055    ElementwiseBinaryPythonRefInfo(
23056        "_refs.gt",
23057        torch_opinfo_name="gt",
23058    ),
23059    ElementwiseBinaryPythonRefInfo(
23060        "_refs.heaviside",
23061        torch_opinfo_name="heaviside",
23062        supports_rhs_python_scalar=False,
23063        skips=(
23064            # PyTorch's heaviside does not appear to propagate NaNs
23065            DecorateInfo(unittest.skip("Skipped!"),
23066                         'TestBinaryUfuncs',
23067                         'test_reference_numerics_extremal_values'),
23068        ),
23069    ),
23070    ElementwiseBinaryPythonRefInfo(
23071        "_refs.hypot",
23072        torch_opinfo_name="hypot",
23073        supports_rhs_python_scalar=False,
23074    ),
23075    ElementwiseBinaryPythonRefInfo(
23076        "_refs.igamma",
23077        torch_opinfo_name="igamma",
23078    ),
23079    ElementwiseBinaryPythonRefInfo(
23080        "_refs.igammac",
23081        torch_opinfo_name="igammac",
23082    ),
23083    ElementwiseBinaryPythonRefInfo(
23084        "_refs.isclose",
23085        torch_opinfo_name="isclose",
23086        skips=(
23087            # Intentional xfail -- isclose does not type promote
23088            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
23089            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
23090            DecorateInfo(unittest.skip("Skipped!"),
23091                         'TestBinaryUfuncs',
23092                         'test_reference_numerics_extremal_values'),
23093        ),
23094    ),
23095    ElementwiseBinaryPythonRefInfo(
23096        "_refs.lcm",
23097        torch_opinfo_name="lcm",
23098    ),
23099    ElementwiseBinaryPythonRefInfo(
23100        "_refs.le",
23101        torch_opinfo_name="le",
23102    ),
23103    ElementwiseBinaryPythonRefInfo(
23104        "_refs.logical_and",
23105        torch_opinfo_name="logical_and",
23106    ),
23107    ElementwiseUnaryPythonRefInfo(
23108        "_refs.logical_not",
23109        torch_opinfo_name="logical_not",
23110    ),
23111    ElementwiseBinaryPythonRefInfo(
23112        "_refs.logical_or",
23113        torch_opinfo_name="logical_or",
23114    ),
23115    ElementwiseBinaryPythonRefInfo(
23116        "_refs.logical_xor",
23117        torch_opinfo_name="logical_xor",
23118    ),
23119    ElementwiseBinaryPythonRefInfo(
23120        "_refs.lt",
23121        torch_opinfo_name="lt",
23122    ),
23123    ElementwiseBinaryPythonRefInfo(
23124        "_refs.maximum",
23125        torch_opinfo_name="maximum",
23126        skips=(
23127            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
23128        ),
23129    ),
23130    ElementwiseBinaryPythonRefInfo(
23131        "_refs.minimum",
23132        torch_opinfo_name="minimum",
23133        skips=(
23134            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
23135        ),
23136    ),
23137    ElementwiseBinaryPythonRefInfo(
23138        "_refs.mul",
23139        torch_opinfo_name="mul",
23140        # https://github.com/pytorch/pytorch/issues/76944
23141        supports_two_python_scalars=True,
23142        supports_one_python_scalar=True,
23143        skips=(
23144            # Reference result was farther (0.0) from the precise computation
23145            # than the torch result was (nan)!
23146            DecorateInfo(
23147                unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
23148                dtypes=(torch.complex32,),
23149            ),
23150            # Reference result was farther (0.0) from the precise computation
23151            # than the torch result was (nan)!
23152            DecorateInfo(
23153                unittest.expectedFailure, 'TestCommon', 'test_python_ref',
23154                dtypes=(torch.complex32,), device_type='cuda'
23155            ),
23156            # Reference result was farther (0.0) from the precise computation
23157            # than the torch result was (nan)!
23158            DecorateInfo(
23159                unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
23160                dtypes=(torch.complex32,), device_type='cuda'
23161            ),
23162        )
23163    ),
23164    ElementwiseBinaryPythonRefInfo(
23165        "_refs.ne",
23166        torch_opinfo_name="ne",
23167    ),
23168    ElementwiseBinaryPythonRefInfo(
23169        "_refs.nextafter",
23170        torch_opinfo_name="nextafter",
23171    ),
23172    ElementwiseBinaryPythonRefInfo(
23173        "_refs.pow",
23174        torch_opinfo_name="pow",
23175        decorators=(
23176            DecorateInfo(
23177                toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}),
23178                'TestBinaryUfuncs', 'test_reference_numerics'),
23179            DecorateInfo(
23180                toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05),
23181                                   torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}),
23182                'TestBinaryUfuncs', 'test_scalar_support'),
23183        ),
23184        skips=(
23185            # Reference result was farther (inf) from the precise
23186            # computation than the torch result was (nan)!
23187            DecorateInfo(
23188                unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
23189                dtypes=(torch.complex32,),
23190            ),
23191            # Reference result was farther (inf) from the precise
23192            # computation than the torch result was (nan)!
23193            DecorateInfo(
23194                unittest.expectedFailure, 'TestCommon', 'test_python_ref',
23195                dtypes=(torch.complex32,), device_type="cuda"
23196            ),
23197            # Reference result was farther (inf) from the precise
23198            # computation than the torch result was (nan)!
23199            DecorateInfo(
23200                unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
23201                dtypes=(torch.complex32,), device_type="cuda"
23202            ),
23203            # Skipping integers because they are being raised to negative powers causing an error
23204            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs',
23205                         'test_reference_numerics_small_values',
23206                         dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]),
23207            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs',
23208                         'test_reference_numerics_large_values',
23209                         dtypes=[torch.int16, torch.int32, torch.int64]),
23210            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
23211                         'test_reference_numerics',
23212                         dtypes=(torch.complex32,)),
23213            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
23214                         'test_reference_numerics_small_values',
23215                         dtypes=(torch.complex32, torch.complex64, torch.complex128)),
23216            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
23217                         'test_reference_numerics_large_values',
23218                         dtypes=(torch.complex32, torch.complex64, torch.complex128)),
23219            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
23220                         'test_reference_numerics_extremal_values',
23221                         dtypes=(torch.complex32, torch.complex64, torch.complex128)),
23222        ),
23223    ),
23224    ElementwiseBinaryPythonRefInfo(
23225        "_refs.remainder",
23226        torch_opinfo_name="remainder",
23227        skips=(
23228            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
23229                         dtypes=(torch.bfloat16,), device_type='cpu'),
23230            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback',
23231                         dtypes=(torch.bfloat16,), device_type='cpu'),
23232            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
23233                         'test_reference_numerics',
23234                         dtypes=(torch.bfloat16,)),
23235            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
23236                         'test_reference_numerics_small_values',
23237                         dtypes=(torch.uint8,)),
23238        ),
23239    ),
23240    ElementwiseBinaryPythonRefInfo(
23241        "_refs.rsub",
23242        torch_opinfo_name="rsub",
23243        # https://github.com/pytorch/pytorch/issues/76944
23244        skips=(
23245            # Reference result was farther (nan) from the precise computation than
23246            # the torch result was (nan)!
23247            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',
23248                         dtypes=(torch.chalf,), device_type='cpu'),
23249            # Reference result was farther (nan) from the precise computation than
23250            # the torch result was (nan)!
23251            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
23252                         dtypes=(torch.chalf,), device_type='cpu'),
23253        ),
23254    ),
23255    ElementwiseBinaryPythonRefInfo(
23256        "_refs.sub",
23257        torch_opinfo_name="sub",
23258        # https://github.com/pytorch/pytorch/issues/76944
23259        supports_two_python_scalars=True,
23260        supports_one_python_scalar=True,
23261        decorators=(
23262            DecorateInfo(
23263                toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0),
23264                                   torch.bfloat16: tol(atol=1e-5, rtol=5e-3),
23265                                   torch.complex32: tol(atol=1e-5, rtol=1e-3)}),
23266                'TestBinaryUfuncs', 'test_reference_numerics'),
23267            DecorateInfo(
23268                toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}),
23269                'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'),
23270            DecorateInfo(
23271                toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
23272                'TestDecomp', 'test_comprehensive', device_type='cpu'),
23273            DecorateInfo(
23274                toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}),
23275                'TestDecomp', 'test_quick', device_type='cpu'),
23276        ),
23277        skips=(
23278            DecorateInfo(unittest.skip("Skipped!"),
23279                         'TestBinaryUfuncs',
23280                         'test_reference_numerics',
23281                         dtypes=(torch.uint8,)),
23282            DecorateInfo(unittest.skip("Skipped!"),
23283                         'TestBinaryUfuncs',
23284                         'test_reference_numerics_small_values',
23285                         dtypes=(torch.uint8,)),
23286        ),
23287    ),
23288    ElementwiseBinaryPythonRefInfo(
23289        "_refs.true_divide",
23290        torch_opinfo_name="true_divide",
23291        # https://github.com/pytorch/pytorch/issues/76944
23292        supports_two_python_scalars=True,
23293        supports_one_python_scalar=True,
23294        skips=(
23295            # Reference result was farther (0.7433461727239705) from the precise
23296            # computation than the torch result was (nan)!
23297            DecorateInfo(
23298                unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor',
23299                dtypes=(torch.complex32,),
23300            ),
23301            # Reference result was farther (0.7433461727239705) from the precise
23302            # computation than the torch result was (nan)!
23303            DecorateInfo(
23304                unittest.expectedFailure, 'TestCommon', 'test_python_ref',
23305                dtypes=(torch.complex32,), device_type="cuda"
23306            ),
23307            # Reference result was farther (0.7433461727239705) from the precise
23308            # computation than the torch result was (nan)!
23309            DecorateInfo(
23310                unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback',
23311                dtypes=(torch.complex32,), device_type="cuda"
23312            ),
23313        ),
23314    ),
23315    #
23316    # Elementwise Ternary Reference OpInfos
23317    #
23318    PythonRefInfo(
23319        "_refs.addcdiv",
23320        torch_opinfo_name="addcdiv",
23321    ),
23322    PythonRefInfo(
23323        "_refs.addcmul",
23324        torch_opinfo_name="addcmul",
23325        skips=(
23326            # Reference result was farther (1.3343989849090576e-05)
23327            # from the precise computation than the torch result
23328            # was (9.592622518539429e-06)!
23329            # FIXME: enable dtype-based tolerances in test_ops.py:TestCommon._ref_test_helper
23330            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
23331                         dtypes=(torch.float16,), device_type="cpu"),
23332            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback',
23333                         dtypes=(torch.float16,), device_type="cpu"),
23334        ),
23335    ),
23336    ElementwiseBinaryPythonRefInfo(
23337        "_refs.clamp_min",
23338        torch_opinfo_name="clamp_min",
23339        skips=(
23340            # test error disabled since rhs non-tensor python scalar is supported
23341            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
23342        ),
23343    ),
23344    ElementwiseBinaryPythonRefInfo(
23345        "_refs.clamp_max",
23346        torch_opinfo_name="clamp_max",
23347        skips=(
23348            # test error disabled since rhs non-tensor python scalar is supported
23349            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
23350        ),
23351    ),
23352    PythonRefInfo(
23353        "_refs.clamp",
23354        torch_opinfo_name="clamp",
23355    ),
23356    PythonRefInfo(
23357        "_refs.nn.functional.triplet_margin_loss",
23358        torch_opinfo_name="nn.functional.triplet_margin_loss",
23359        supports_out=False,
23360        # TODO: Uses minimum and clamp
23361        skips=(
23362            # AssertionError: Tensor-likes are not close!
23363            # Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed)
23364            # Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed)
23365            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
23366                         dtypes=(torch.uint8,), device_type="cpu"),
23367        )
23368    ),
23369    ElementwiseBinaryPythonRefInfo(
23370        "_refs.xlogy",
23371        torch_opinfo_name="xlogy",
23372        supports_one_python_scalar=True,
23373    ),
23374    #
23375    # Elementwise Binary Special OpInfos
23376    #
23377    ElementwiseBinaryPythonRefInfo(
23378        "_refs.special.xlog1py",
23379        torch_opinfo_name="special.xlog1py",
23380        supports_one_python_scalar=True,
23381    ),
23382    #
23383    # Data Conversion & Data Movement Opinfos
23384    #
23385    ElementwiseUnaryPythonRefInfo(
23386        "_refs._conversions.bfloat16",
23387        torch_opinfo_name="bfloat16",
23388        # TODO: If self already has the correct dtype and device, then self is
23389        # returned ignoring memory_format.
23390        # https://github.com/pytorch/pytorch/issues/86558
23391        validate_view_consistency=False,
23392    ),
23393    ElementwiseUnaryPythonRefInfo(
23394        "_refs._conversions.bool",
23395        torch_opinfo_name="bool",
23396        # TODO: If self already has the correct dtype and device, then self is
23397        # returned ignoring memory_format.
23398        # https://github.com/pytorch/pytorch/issues/86558
23399        validate_view_consistency=False,
23400    ),
23401    ElementwiseUnaryPythonRefInfo(
23402        "_refs._conversions.byte",
23403        torch_opinfo_name="byte",
23404        # TODO: If self already has the correct dtype and device, then self is
23405        # returned ignoring memory_format.
23406        # https://github.com/pytorch/pytorch/issues/86558
23407        validate_view_consistency=False,
23408        skips=(
23409            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
23410        )
23411    ),
23412    ElementwiseUnaryPythonRefInfo(
23413        "_refs._conversions.char",
23414        torch_opinfo_name="char",
23415        # TODO: If self already has the correct dtype and device, then self is
23416        # returned ignoring memory_format.
23417        # https://github.com/pytorch/pytorch/issues/86558
23418        validate_view_consistency=False,
23419        skips=(
23420            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
23421        )
23422    ),
23423    ElementwiseBinaryPythonRefInfo(
23424        "_refs._conversions.complex",
23425        torch_opinfo_name="complex",
23426        error_inputs_func=partial(error_inputs_complex, is_ref=True),
23427        skips=(
23428            # Tests don't account for complex's type promotion semantics
23429            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
23430            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
23431        )
23432    ),
23433    ElementwiseBinaryPythonRefInfo(
23434        "_refs._conversions.polar",
23435        torch_opinfo_name="polar",
23436        skips=(
23437            # Tests don't account for complex's type promotion semantics
23438            DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
23439            DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),
23440        )
23441    ),
23442    ElementwiseUnaryPythonRefInfo(
23443        "_refs._conversions.double",
23444        torch_opinfo_name="double",
23445        # TODO: If self already has the correct dtype and device, then self is
23446        # returned ignoring memory_format.
23447        # https://github.com/pytorch/pytorch/issues/86558
23448        validate_view_consistency=False,
23449    ),
23450    ElementwiseUnaryPythonRefInfo(
23451        "_refs._conversions.float",
23452        torch_opinfo_name="float",
23453        # TODO: If self already has the correct dtype and device, then self is
23454        # returned ignoring memory_format.
23455        # https://github.com/pytorch/pytorch/issues/86558
23456        validate_view_consistency=False,
23457    ),
23458    ElementwiseUnaryPythonRefInfo(
23459        "_refs._conversions.half",
23460        torch_opinfo_name="half",
23461        # TODO: If self already has the correct dtype and device, then self is
23462        # returned ignoring memory_format.
23463        # https://github.com/pytorch/pytorch/issues/86558
23464        validate_view_consistency=False,
23465    ),
23466    ElementwiseUnaryPythonRefInfo(
23467        "_refs._conversions.int",
23468        torch_opinfo_name="int",
23469        # TODO: If self already has the correct dtype and device, then self is
23470        # returned ignoring memory_format.
23471        # https://github.com/pytorch/pytorch/issues/86558
23472        validate_view_consistency=False,
23473        skips=(
23474            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
23475        )
23476    ),
23477    ElementwiseUnaryPythonRefInfo(
23478        "_refs._conversions.long",
23479        torch_opinfo_name="long",
23480        # TODO: If self already has the correct dtype and device, then self is
23481        # returned ignoring memory_format.
23482        # https://github.com/pytorch/pytorch/issues/86558
23483        validate_view_consistency=False,
23484        skips=(
23485            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
23486        )
23487    ),
23488    ElementwiseUnaryPythonRefInfo(
23489        "_refs._conversions.short",
23490        torch_opinfo_name="short",
23491        # TODO: If self already has the correct dtype and device, then self is
23492        # returned ignoring memory_format.
23493        # https://github.com/pytorch/pytorch/issues/86558
23494        validate_view_consistency=False,
23495        skips=(
23496            DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'),
23497        )
23498    ),
23499    ElementwiseUnaryPythonRefInfo(
23500        "_refs._conversions.chalf",
23501        torch_opinfo_name="chalf",
23502        # TODO: If self already has the correct dtype and device, then self is
23503        # returned ignoring memory_format.
23504        # https://github.com/pytorch/pytorch/issues/86558
23505        validate_view_consistency=False,
23506    ),
23507    ElementwiseUnaryPythonRefInfo(
23508        "_refs._conversions.cfloat",
23509        torch_opinfo_name="cfloat",
23510        # TODO: If self already has the correct dtype and device, then self is
23511        # returned ignoring memory_format.
23512        # https://github.com/pytorch/pytorch/issues/86558
23513        validate_view_consistency=False,
23514    ),
23515    ElementwiseUnaryPythonRefInfo(
23516        "_refs._conversions.cdouble",
23517        torch_opinfo_name="cdouble",
23518        # TODO: If self already has the correct dtype and device, then self is
23519        # returned ignoring memory_format.
23520        # https://github.com/pytorch/pytorch/issues/86558
23521        validate_view_consistency=False,
23522    ),
23523    PythonRefInfo(
23524        "_refs.clone",
23525        torch_opinfo_name="clone",
23526    ),
23527    #
23528    # View & Shape OpInfos
23529    #
23530    PythonRefInfo(
23531        "_refs.alias_copy",
23532        torch_opinfo_name="alias_copy",
23533        supports_out=True,
23534    ),
23535    PythonRefInfo(
23536        "_refs.atleast_1d",
23537        torch_opinfo_name="atleast_1d",
23538        validate_view_consistency=False,
23539    ),
23540    PythonRefInfo(
23541        "_refs.atleast_2d",
23542        torch_opinfo_name="atleast_2d",
23543        validate_view_consistency=False,
23544    ),
23545    PythonRefInfo(
23546        "_refs.atleast_3d",
23547        torch_opinfo_name="atleast_3d",
23548        validate_view_consistency=False,
23549    ),
23550    PythonRefInfo(
23551        "_refs.as_strided",
23552        torch_opinfo_name="as_strided",
23553        # FIXME: doesn't support chalf
23554        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
23555        skips=(
23556            # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
23557            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
23558            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
23559            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
23560        ),
23561    ),
23562    PythonRefInfo(
23563        "_refs.as_strided_copy",
23564        torch_opinfo_name="as_strided_copy",
23565        supports_out=True,
23566        # FIXME: doesn't support chalf
23567        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
23568        skips=(
23569            # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
23570            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
23571            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
23572            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
23573            # The view function this decompose into does not have a ref
23574            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
23575        ),
23576    ),
23577    PythonRefInfo(
23578        "_refs.as_strided",
23579        torch_opinfo_name="as_strided",
23580        torch_opinfo_variant_name="partial_views",
23581        # FIXME: doesn't support chalf
23582        dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
23583        skips=(
23584            # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
23585            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
23586            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
23587            DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
23588            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
23589        ),
23590    ),
23591    PythonRefInfo(
23592        "_refs.as_strided_scatter",
23593        torch_opinfo_name="as_strided_scatter",
23594        # returns a view of an intermediate tensor (as_strided)
23595        validate_view_consistency=False,
23596    ),
23597    PythonRefInfo(
23598        "_refs.block_diag",
23599        torch_opinfo_name="block_diag",
23600    ),
23601    PythonRefInfo(
23602        "_refs.broadcast_shapes",
23603        torch_opinfo_name="broadcast_shapes",
23604    ),
23605    PythonRefInfo(
23606        "_refs.broadcast_tensors",
23607        torch_opinfo_name="broadcast_tensors",
23608    ),
23609    PythonRefInfo(
23610        "_refs.broadcast_to",
23611        torch_opinfo_name="broadcast_to",
23612    ),
23613    PythonRefInfo(
23614        "_refs.cat",
23615        torch_opinfo_name="cat",
23616        skips=(
23617            # FIXME: AssertionError: RuntimeError not raised
23618            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
23619        ),
23620    ),
23621    PythonRefInfo(
23622        "_refs.chunk",
23623        torch_opinfo_name="chunk",
23624    ),
23625    PythonRefInfo(
23626        "_refs.column_stack",
23627        torch_opinfo_name="column_stack",
23628    ),
23629    ElementwiseUnaryPythonRefInfo(
23630        "_refs.conj",
23631        torch_opinfo_name="conj",
23632    ),
23633    PythonRefInfo(
23634        "_refs.constant_pad_nd",
23635        torch_opinfo_name="constant_pad_nd",
23636    ),
23637    PythonRefInfo(
23638        "_refs.contiguous",
23639        torch_opinfo_name="contiguous",
23640    ),
23641    ElementwiseUnaryPythonRefInfo(
23642        "_refs.deg2rad",
23643        torch_opinfo_name="deg2rad",
23644        decorators=(precisionOverride({torch.bfloat16: 7e-1,
23645                                       torch.float16: 7e-1}),),
23646    ),
23647    PythonRefInfo(
23648        "_refs.dsplit",
23649        torch_opinfo_name="dsplit",
23650    ),
23651    PythonRefInfo(
23652        "_refs.diag",
23653        torch_opinfo_name="diag",
23654    ),
23655    PythonRefInfo(
23656        "_refs.diagonal",
23657        torch_opinfo_name="diagonal",
23658    ),
23659    PythonRefInfo(
23660        "_refs.diagonal_copy",
23661        torch_opinfo_name="diagonal_copy",
23662        supports_out=True,
23663    ),
23664    PythonRefInfo(
23665        "_refs.diagonal_scatter",
23666        torch_opinfo_name="diagonal_scatter",
23667        supports_out=True,
23668        # returns a view of an intermediate tensor (as_strided)
23669        validate_view_consistency=False,
23670    ),
23671    PythonRefInfo(
23672        "_refs.diag_embed",
23673        torch_opinfo_name="diag_embed",
23674        supports_out=True,
23675    ),
23676    PythonRefInfo(
23677        "_refs.dstack",
23678        torch_opinfo_name="dstack",
23679        skips=(
23680            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
23681        ),
23682    ),
23683    PythonRefInfo(
23684        "_refs.expand",
23685        torch_opinfo_name="expand",
23686    ),
23687    PythonRefInfo(
23688        "_refs.expand_as",
23689        torch_opinfo_name="expand_as",
23690    ),
23691    PythonRefInfo(
23692        "_refs.expand_copy",
23693        torch_opinfo_name="expand_copy",
23694        supports_out=True,
23695    ),
23696    PythonRefInfo(
23697        "_refs.flatten",
23698        torch_opinfo_name="flatten",
23699    ),
23700    PythonRefInfo(
23701        "_refs.flip",
23702        torch_opinfo_name="flip",
23703    ),
23704    PythonRefInfo(
23705        "_refs.fliplr",
23706        torch_opinfo_name="fliplr",
23707    ),
23708    PythonRefInfo(
23709        "_refs.flipud",
23710        torch_opinfo_name="flipud",
23711    ),
23712    PythonRefInfo(
23713        "_refs.hstack",
23714        torch_opinfo_name="hstack",
23715        skips=(
23716            # https://github.com/pytorch/pytorch/issues/78613
23717            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
23718        ),
23719    ),
23720    PythonRefInfo(
23721        "_refs.narrow",
23722        torch_opinfo_name="narrow",
23723        error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=True),
23724    ),
23725    PythonRefInfo(
23726        "_refs.narrow_copy",
23727        torch_opinfo_name="narrow_copy",
23728        supports_out=True,
23729        error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True),
23730        skips=(
23731            # The view function this decompose into does not have a ref
23732            DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"),
23733        ),
23734    ),
23735    PythonRefInfo(
23736        "_refs.nn.functional.group_norm",
23737        torch_opinfo_name="nn.functional.group_norm",
23738        validate_view_consistency=False,
23739    ),
23740    PythonRefInfo(
23741        "_refs.native_layer_norm",
23742        torch_opinfo_name="native_layer_norm",
23743        skips=(
23744            DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref",
23745                         device_type="cpu", dtypes=(torch.float32,)),
23746            DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref_torch_fallback",
23747                         device_type="cpu", dtypes=(torch.float32,)),
23748        ),
23749    ),
23750    PythonRefInfo(
23751        "_refs.permute",
23752        torch_opinfo_name="permute",
23753    ),
23754    ElementwiseUnaryPythonRefInfo(
23755        "_refs.rad2deg",
23756        torch_opinfo_name="rad2deg",
23757        decorators=(precisionOverride({torch.bfloat16: 7e-1,
23758                                       torch.float16: 7e-1}),),
23759    ),
23760    PythonRefInfo(
23761        "_refs.ravel",
23762        torch_opinfo_name="ravel",
23763    ),
23764    PythonRefInfo(
23765        "_refs.renorm",
23766        torch_opinfo_name="renorm",
23767    ),
23768    PythonRefInfo(
23769        "_refs.repeat",
23770        torch_opinfo_name="repeat",
23771        validate_view_consistency=False,
23772    ),
23773    PythonRefInfo(
23774        "_refs.reshape",
23775        torch_opinfo_name="reshape",
23776    ),
23777    PythonRefInfo(
23778        "_refs.reshape_as",
23779        torch_opinfo_name="reshape_as",
23780    ),
23781    PythonRefInfo(
23782        "_refs.roll",
23783        torch_opinfo_name="roll",
23784        validate_view_consistency=False,
23785    ),
23786    PythonRefInfo(
23787        "_refs.rot90",
23788        torch_opinfo_name="rot90",
23789        validate_view_consistency=False,
23790    ),
23791    PythonRefInfo(
23792        "_refs.select_scatter",
23793        torch_opinfo_name="select_scatter",
23794    ),
23795    PythonRefInfo(
23796        "_refs.stack",
23797        torch_opinfo_name="stack",
23798        validate_view_consistency=False,
23799    ),
23800    PythonRefInfo(
23801        "_refs.squeeze",
23802        torch_opinfo_name="squeeze",
23803    ),
23804    PythonRefInfo(
23805        "_refs.squeeze",
23806        torch_opinfo_name="squeeze",
23807        torch_opinfo_variant_name="multiple",
23808    ),
23809    PythonRefInfo(
23810        "_refs.tensor_split",
23811        torch_opinfo_name="tensor_split",
23812        skips=(
23813            # RuntimeError: no _refs support for torch.Tensor.tolist
23814            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
23815        ),
23816    ),
23817    PythonRefInfo(
23818        "_refs.hsplit",
23819        torch_opinfo_name="hsplit",
23820    ),
23821    PythonRefInfo(
23822        "_refs.vsplit",
23823        torch_opinfo_name="vsplit",
23824    ),
23825    PythonRefInfo(
23826        "_refs.dot",
23827        torch_opinfo_name="dot",
23828        error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True),
23829        # .conj() does not set ._is_view() correctly in ATen
23830        validate_view_consistency=False,
23831        skips=(
23832            # RuntimeError: no _refs support for torch.Tensor.is_conj
23833            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]),
23834        ),
23835    ),
23836    PythonRefInfo(
23837        "_refs.vdot",
23838        torch_opinfo_name="vdot",
23839        error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True),
23840        # .conj() does not set ._is_view() correctly in ATen
23841        validate_view_consistency=False,
23842        skips=(
23843            # RuntimeError: no _refs support for torch.Tensor.is_conj
23844            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]),
23845        ),
23846    ),
23847    PythonRefInfo(
23848        "_refs.transpose",
23849        torch_opinfo_name="transpose",
23850    ),
23851    PythonRefInfo(
23852        "_refs.t",
23853        torch_opinfo_name="t",
23854    ),
23855    PythonRefInfo(
23856        "_refs.t_copy",
23857        torch_opinfo_name="t_copy",
23858        supports_out=True,
23859    ),
23860    PythonRefInfo(
23861        "_refs.T",
23862        torch_opinfo_name="T",
23863        error_inputs_func=partial(error_inputs_T, has_ndims_error=True),
23864    ),
23865    PythonRefInfo(
23866        "_refs.unfold",
23867        torch_opinfo_name="unfold",
23868    ),
23869    PythonRefInfo(
23870        "_refs.unfold_copy",
23871        torch_opinfo_name="unfold_copy",
23872        supports_out=True,
23873    ),
23874    PythonRefInfo(
23875        "_refs.unsqueeze",
23876        torch_opinfo_name="unsqueeze",
23877    ),
23878    PythonRefInfo(
23879        "_refs.unsqueeze_copy",
23880        torch_opinfo_name="unsqueeze_copy",
23881        supports_out=True,
23882    ),
23883    PythonRefInfo(
23884        "_refs.view",
23885        torch_opinfo_name="view",
23886    ),
23887    PythonRefInfo(
23888        "_refs.view_as",
23889        torch_opinfo_name="view_as",
23890    ),
23891    PythonRefInfo(
23892        "_refs.view_copy",
23893        torch_opinfo_name="view_copy",
23894        supports_out=True,
23895    ),
23896    PythonRefInfo(
23897        "_refs.vstack",
23898        torch_opinfo_name="vstack",
23899        skips=(
23900            # https://github.com/pytorch/pytorch/issues/78613
23901            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
23902        ),
23903    ),
23904    PythonRefInfo(
23905        "_refs.unflatten",
23906        torch_opinfo_name="unflatten",
23907    ),
23908    PythonRefInfo(
23909        "_refs.unbind",
23910        torch_opinfo_name="unbind",
23911    ),
23912    #
23913    # Reduction Reference OpInfos
23914    #
23915    ReductionPythonRefInfo(
23916        "_refs.all",
23917        torch_opinfo_name="all",
23918        skips=(
23919            # FIXME: uint8 input returns uint8 instead of bool
23920            DecorateInfo(
23921                unittest.expectedFailure, 'TestReductions', 'test_result_dtype',
23922                dtypes=[torch.uint8]),
23923        ),
23924    ),
23925    ReductionPythonRefInfo(
23926        "_refs.amax",
23927        torch_opinfo_name="amax",
23928        error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True),
23929        skips=(
23930            # FIXME: reduces all dimensions when dim=[]
23931            DecorateInfo(
23932                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
23933            DecorateInfo(
23934                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
23935        ),
23936    ),
23937    ReductionPythonRefInfo(
23938        "_refs.amin",
23939        torch_opinfo_name="amin",
23940        error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True),
23941        skips=(
23942            # FIXME: reduces all dimensions when dim=[]
23943            DecorateInfo(
23944                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
23945            DecorateInfo(
23946                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
23947        ),
23948    ),
23949    ReductionPythonRefInfo(
23950        "_refs.any",
23951        torch_opinfo_name="any",
23952        skips=(
23953            # FIXME: uint8 input returns uint8 instead of bool
23954            DecorateInfo(
23955                unittest.expectedFailure, 'TestReductions', 'test_result_dtype',
23956                dtypes=[torch.uint8]),
23957        ),
23958    ),
23959    ReductionPythonRefInfo(
23960        "_refs.count_nonzero",
23961        torch_opinfo_name="count_nonzero",
23962        skips=(
23963            # FIXME: count_nonzero does not accept keepdim kwarg
23964            DecorateInfo(
23965                unittest.skip("Skipped!"), 'TestReductions',
23966                'test_dim_default_keepdim'),
23967            DecorateInfo(
23968                unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
23969            DecorateInfo(
23970                unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'),
23971            DecorateInfo(
23972                unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
23973            DecorateInfo(
23974                unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'),
23975            DecorateInfo(
23976                unittest.skip("Skipped!"), 'TestReductions',
23977                'test_dim_multi_unsorted_keepdim'),
23978            # FIXME: dim=[] reduces all dimensions
23979            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
23980        ),
23981    ),
23982    ReductionPythonRefInfo(
23983        "_refs.mean",
23984        torch_opinfo_name="mean",
23985        supports_out=True,
23986        error_inputs_func=partial(error_inputs_mean, is_ref=True),
23987        skips=(
23988            # FIXME: reduces all dimensions when dim=[]
23989            DecorateInfo(
23990                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
23991            DecorateInfo(
23992                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
23993        ),
23994    ),
23995    ReductionPythonRefInfo(
23996        "_refs.std",
23997        torch_opinfo_name="std",
23998        supports_out=True,
23999        skips=(
24000            # FIXME: reduces all dimensions when dim=[]
24001            DecorateInfo(
24002                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
24003            DecorateInfo(
24004                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
24005            # FIXME: improve precision
24006            DecorateInfo(
24007                unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
24008                dtypes=(torch.float16,)),
24009            DecorateInfo(
24010                unittest.skip("Skipped!"), 'TestReductions',
24011                'test_ref_duplicate_values',
24012                dtypes=(torch.float16,)),
24013        ),
24014    ),
24015    # std_mean and var_mean are not ReductionInfos
24016    PythonRefInfo(
24017        "_refs.std_mean",
24018        torch_opinfo_name="std_mean",
24019    ),
24020    ReductionPythonRefInfo(
24021        "_refs.sum",
24022        torch_opinfo_name="sum",
24023        supports_out=True,
24024        skips=(
24025            # FIXME: doesn't test out behavior properly for this operator
24026            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
24027            # FIXME: mean reduces all dimensions when dim=[]
24028            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
24029            DecorateInfo(
24030                unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
24031            # FIXME: improve precision
24032            DecorateInfo(
24033                unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
24034                dtypes=[torch.float16]),
24035            DecorateInfo(
24036                unittest.skip("Skipped!"), 'TestReductions',
24037                'test_ref_duplicate_values',
24038                dtypes=[torch.float16]),
24039            DecorateInfo(
24040                unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all',
24041                dtypes=[torch.float32]),
24042        ),
24043    ),
24044    PythonRefInfo(
24045        "_refs.cumsum",
24046        torch_opinfo_name="cumsum",
24047        supports_out=True,
24048        skips=(
24049            # doesn't test out behavior properly for this operator
24050            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
24051        ),
24052    ),
24053    PythonRefInfo(
24054        "_refs.cumprod",
24055        torch_opinfo_name="cumprod",
24056        supports_out=True,
24057        skips=(
24058            # doesn't test out behavior properly for this operator
24059            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
24060        ),
24061    ),
24062    PythonRefInfo(
24063        "_refs.sum_to_size",
24064        torch_opinfo_name="sum_to_size",
24065        validate_view_consistency=False,
24066    ),
24067    ReductionPythonRefInfo(
24068        "_refs.prod",
24069        torch_opinfo_name="prod",
24070        supports_out=True,
24071        supports_multiple_dims=True,
24072        skips=(
24073            # FIXME: doesn't test out behavior properly for this operator
24074            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
24075            # FIXME: reduces all dimensions when dim=[]
24076            DecorateInfo(
24077                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
24078            DecorateInfo(
24079                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
24080            # FIXME: improve precision
24081            DecorateInfo(
24082                unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
24083                dtypes=[torch.float16, torch.complex64]),
24084        ),
24085    ),
24086    ReductionPythonRefInfo(
24087        "_refs.var",
24088        torch_opinfo_name="var",
24089        supports_out=True,
24090        skips=(
24091            # FIXME: reduces all dimensions when dim=[]
24092            DecorateInfo(
24093                unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
24094            DecorateInfo(
24095                unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
24096            # FIXME: improve precision
24097            DecorateInfo(
24098                unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'),
24099        ),
24100    ),
24101    PythonRefInfo(
24102        "_refs.var_mean",
24103        torch_opinfo_name="var_mean",
24104        validate_view_consistency=False,
24105    ),
24106    #
24107    # Linear Algebra Operators
24108    #
24109    PythonRefInfo(
24110        "_refs.addr",
24111        torch_opinfo_name="addr",
24112        decorators=(
24113            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
24114        ),
24115    ),
24116    PythonRefInfo(
24117        "_refs.trace",
24118        torch_opinfo_name="trace",
24119    ),
24120    PythonRefInfo(
24121        "_refs.norm",
24122        torch_opinfo_name="norm",
24123        supports_out=True,
24124        # Uses vector_norm inside and vector_norm is affected by
24125        # https://github.com/pytorch/pytorch/issues/77216
24126        validate_view_consistency=False,
24127    ),
24128    #
24129    # Tensor Creation Reference OpInfos
24130    #
24131    PythonRefInfo(
24132        "_refs.empty",
24133        torch_opinfo_name="empty",
24134        skips=(
24135            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24136                         'TestCommon',
24137                         'test_python_ref'),
24138            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24139                         'TestCommon',
24140                         'test_python_ref_torch_fallback'),
24141            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24142                         'TestCommon',
24143                         'test_out'),
24144            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24145                         'TestCommon',
24146                         'test_out_warning'),
24147            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24148                         'TestMathBits',
24149                         'test_conj_view'),
24150            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24151                         'TestMathBits',
24152                         'test_neg_conj_view'),
24153            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24154                         'TestMathBits',
24155                         'test_neg_view'),
24156            # FIXME: shouldn't check empty results
24157            DecorateInfo(unittest.skip("Can't check result for empty"), 'TestCommon', 'test_python_ref_executor'),
24158            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
24159        ),
24160    ),
24161    PythonRefInfo(
24162        "_refs.empty_like",
24163        torch_opinfo_name="empty_like",
24164        skips=(
24165            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24166                         'TestCommon',
24167                         'test_python_ref'),
24168            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24169                         'TestCommon',
24170                         'test_python_ref_torch_fallback'),
24171            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24172                         'TestCommon',
24173                         'test_out'),
24174            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24175                         'TestCommon',
24176                         'test_out_warning'),
24177            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24178                         'TestMathBits',
24179                         'test_conj_view'),
24180            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24181                         'TestMathBits',
24182                         'test_neg_conj_view'),
24183            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24184                         'TestMathBits',
24185                         'test_neg_view'),
24186            # FIXME: should not compare results of empty_like
24187            DecorateInfo(unittest.skip("Can't check result for empty_like"), 'TestCommon', 'test_python_ref_executor'),
24188            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
24189        ),
24190    ),
24191    PythonRefInfo(
24192        "_refs.randn",
24193        torch_opinfo_name="randn",
24194        op=lambda *args, **kwargs: wrapper_set_seed(refs.randn, *args, **kwargs),
24195        skips=(
24196            # see https://github.com/pytorch/pytorch/issues/85121
24197            DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"),
24198                         'TestCommon',
24199                         'test_python_ref_executor'),
24200            # These tests expect the input to be a tensor or a sequence of tensors
24201            DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"),
24202            DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_view'),
24203            DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_conj_view'),
24204            DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_conj_view'),
24205        ),
24206    ),
24207    PythonRefInfo(
24208        "_refs.eye",
24209        torch_opinfo_name="eye",
24210        skips=(
24211            # skip these tests since we have non tensor input
24212            DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
24213            DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
24214            DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
24215        ),
24216    ),
24217    PythonRefInfo(
24218        "_refs.new_empty",
24219        torch_opinfo_name="new_empty",
24220        skips=(
24221            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24222                         'TestCommon',
24223                         'test_python_ref'),
24224            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24225                         'TestCommon',
24226                         'test_python_ref_torch_fallback'),
24227            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24228                         'TestCommon',
24229                         'test_out'),
24230            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24231                         'TestCommon',
24232                         'test_out_warning'),
24233            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24234                         'TestMathBits',
24235                         'test_conj_view'),
24236            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24237                         'TestMathBits',
24238                         'test_neg_conj_view'),
24239            DecorateInfo(unittest.skip("Expected: empty is not comparable"),
24240                         'TestMathBits',
24241                         'test_neg_view'),
24242            # FIXME: should not compare results of empty_like
24243            DecorateInfo(unittest.skip("Can't check result for new_empty"), 'TestCommon', 'test_python_ref_executor'),
24244            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
24245        ),
24246    ),
24247    PythonRefInfo(
24248        "_refs.new_empty_strided",
24249        torch_opinfo_name="new_empty_strided",
24250        skips=(
24251            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24252                         'TestCommon',
24253                         'test_python_ref'),
24254            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24255                         'TestCommon',
24256                         'test_python_ref_torch_fallback'),
24257            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24258                         'TestMathBits',
24259                         'test_conj_view'),
24260            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24261                         'TestMathBits',
24262                         'test_neg_conj_view'),
24263            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24264                         'TestMathBits',
24265                         'test_neg_view'),
24266            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24267                         'TestCommon',
24268                         'test_python_ref_executor'),
24269            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
24270
24271        ),
24272    ),
24273    PythonRefInfo(
24274        "_refs.empty_strided",
24275        torch_opinfo_name="empty_strided",
24276        skips=(
24277            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24278                         'TestCommon',
24279                         'test_python_ref'),
24280            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24281                         'TestCommon',
24282                         'test_python_ref_torch_fallback'),
24283            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24284                         'TestMathBits',
24285                         'test_conj_view'),
24286            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24287                         'TestMathBits',
24288                         'test_neg_conj_view'),
24289            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24290                         'TestMathBits',
24291                         'test_neg_view'),
24292            DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"),
24293                         'TestCommon',
24294                         'test_python_ref_executor'),
24295            DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
24296        ),
24297    ),
24298    PythonRefInfo(
24299        "_refs.new_full",
24300        torch_opinfo_name="new_full",
24301    ),
24302    PythonRefInfo(
24303        "_refs.new_ones",
24304        torch_opinfo_name="new_ones",
24305    ),
24306    PythonRefInfo(
24307        "_refs.new_zeros",
24308        torch_opinfo_name="new_zeros",
24309    ),
24310    #
24311    # Conditional Reference OpInfos
24312    #
24313    PythonRefInfo(
24314        "_refs.masked_fill",
24315        torch_opinfo_name="masked_fill",
24316        skips=(
24317            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
24318        ),
24319    ),
24320    PythonRefInfo(
24321        "_refs.where",
24322        torch_opinfo_name="where",
24323        op=lambda self, condition, other: refs.where(condition, self, other),
24324        supports_out=False,
24325        skips=(
24326            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors', device_type='cuda'),
24327        ),
24328    ),
24329    PythonRefInfo(
24330        "_refs.index_select",
24331        torch_opinfo_name="index_select",
24332        # empty_strided
24333        skips=(
24334            # no _refs support for Tensor.__setitem__
24335            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
24336            # Sample out= with a stride of zero. This _out operation checks that the input has no
24337            # inner overlap
24338            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),)
24339    ),
24340    PythonRefInfo(
24341        "_refs.index_copy",
24342        torch_opinfo_name="index_copy",
24343        # empty_strided
24344        skips=(
24345            # no _refs support for Tensor.__setitem__
24346            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
24347        ),
24348    ),
24349    PythonRefInfo(
24350        "_refs.index_add",
24351        torch_opinfo_name="index_add",
24352        # empty_strided
24353        skips=(
24354            # no _refs support for Tensor.__setitem__
24355            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
24356            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),
24357        ),
24358    ),
24359    PythonRefInfo(
24360        "_refs.index_fill",
24361        torch_opinfo_name="index_fill",
24362        # empty_strided
24363        skips=(
24364            # no _refs support for Tensor.__setitem__
24365            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),)
24366    ),
24367    #
24368    # Test-related functions
24369    #
24370    PythonRefInfo(
24371        "_refs.allclose",
24372        torch_opinfo_name="allclose",
24373    ),
24374    #
24375    # Misc functions
24376    #
24377    PythonRefInfo(
24378        "_refs.stft",
24379        torch_opinfo_name="stft",
24380        skips=[
24381            # RuntimeError: no _refs support for aten.pad
24382            DecorateInfo(
24383                unittest.expectedFailure, 'TestCommon', 'test_python_ref'
24384            ),
24385        ],
24386    ),
24387    PythonRefInfo(
24388        "_refs.istft",
24389        torch_opinfo_name="istft",
24390        skips=[
24391            # RuntimeError: no _refs support for aten.unfold_backward
24392            DecorateInfo(
24393                unittest.expectedFailure, 'TestCommon', 'test_python_ref'
24394            ),
24395            DecorateInfo(
24396                unittest.skip("Expected: unfold_backward() got an unexpected keyword argument 'input_sizes'"),
24397                'TestCommon',
24398                'test_python_ref_executor',
24399                dtypes=(torch.complex64, torch.complex128),
24400            ),
24401        ],
24402    ),
24403    PythonRefInfo(
24404        "_refs.view_as_complex",
24405        torch_opinfo_name="view_as_complex",
24406    ),
24407]
24408python_ref_db += opinfo.definitions.python_ref_db
24409
24410# Common operator groupings
24411ops_and_refs = op_db + python_ref_db
24412unary_ufuncs = [op for op in ops_and_refs if isinstance(op, UnaryUfuncInfo)]
24413binary_ufuncs = [op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)]
24414binary_ufuncs_and_refs = tuple(op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo))
24415spectral_funcs = [op for op in ops_and_refs if isinstance(op, SpectralFuncInfo)]
24416sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse]
24417sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr]
24418sparse_reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo) and op.supports_sparse]
24419shape_funcs = [op for op in ops_and_refs if isinstance(op, ShapeFuncInfo)]
24420reduction_ops = [op for op in ops_and_refs if isinstance(op, ReductionOpInfo)]
24421reference_filtered_ops = [op for op in reduction_ops if op.ref is not None]
24422reference_masked_ops = [op for op in reference_filtered_ops if op.name.startswith('masked.')]
24423sparse_masked_reduction_ops = [op for op in sparse_reduction_ops if op.name.startswith('masked.')]
24424
24425# TODO: review porting these to make_tensor
24426def index_variable(shape, max_indices, device=torch.device('cpu')):
24427    if not isinstance(shape, tuple):
24428        shape = (shape,)
24429    index = torch.rand(*shape, dtype=torch.double, device=device).mul_(max_indices).floor_().long()
24430    return index
24431
24432def gather_variable(shape, index_dim, max_indices, duplicate=False, device=torch.device('cpu')):
24433    assert len(shape) == 2
24434    assert index_dim < 2
24435    batch_dim = 1 - index_dim
24436    index = torch.zeros(*shape, dtype=torch.long, device=device)
24437    for i in range(shape[index_dim]):
24438        index.select(index_dim, i).copy_(
24439            torch.randperm(max_indices, device=device)[:shape[batch_dim]])
24440    if duplicate:
24441        index.select(batch_dim, 0).copy_(index.select(batch_dim, 1))
24442    return index
24443
24444def bernoulli_scalar():
24445    return torch.tensor(0, dtype=torch.bool).bernoulli_()
24446
24447def mask_not_all_zeros(shape):
24448    assert len(shape) > 0
24449    while True:
24450        result = torch.randn(shape).gt(0)
24451        if result.sum() > 0:
24452            return result
24453
24454# Copied from functorch
24455def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
24456    return (op_name, variant_name, device_type, dtypes, True)
24457
24458
24459def skip(op_name, variant_name='', *, device_type=None, dtypes=None):
24460    return (op_name, variant_name, device_type, dtypes, False)
24461
24462
24463def skipOps(test_case_name, base_test_name, to_skip):
24464    all_opinfos = op_db
24465    for xfail in to_skip:
24466        op_name, variant_name, device_type, dtypes, expected_failure = xfail
24467        matching_opinfos = [o for o in all_opinfos
24468                            if o.name == op_name and o.variant_test_name == variant_name]
24469        assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
24470        for op in matching_opinfos:
24471            decorators = list(op.decorators)
24472            if expected_failure:
24473                decorator = DecorateInfo(unittest.expectedFailure,
24474                                         test_case_name, base_test_name,
24475                                         device_type=device_type, dtypes=dtypes)
24476                decorators.append(decorator)
24477            else:
24478                decorator = DecorateInfo(unittest.skip("Skipped!"),
24479                                         test_case_name, base_test_name,
24480                                         device_type=device_type, dtypes=dtypes)
24481                decorators.append(decorator)
24482            op.decorators = tuple(decorators)
24483
24484    # This decorator doesn't modify fn in any way
24485    def wrapped(fn):
24486        return fn
24487    return wrapped
24488