1# mypy: ignore-errors 2 3from functools import wraps, partial 4from itertools import product, chain, islice 5import itertools 6import functools 7import copy 8import operator 9import random 10import unittest 11import math 12import enum 13 14import torch 15import numpy as np 16from torch import inf, nan 17 18from typing import Any, Dict, List, Tuple, Union, Sequence 19from torch.testing import make_tensor 20from torch.testing._internal.common_dtype import ( 21 _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, 22 floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, 23 all_types, empty_types, complex_types_and, integral_types, custom_types, 24) 25from torch.testing._internal.common_device_type import \ 26 (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, 27 skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride, 28 skipCPUIfNoMklSparse, 29 toleranceOverride, tol) 30from torch.testing._internal.common_cuda import ( 31 PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 32 SM53OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, 33 _get_torch_rocm_version, 34) 35from torch.testing._internal.common_utils import ( 36 make_fullrank_matrices_with_distinct_singular_values, 37 TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY, 38 torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN, 39 GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW, 40 TEST_WITH_TORCHINDUCTOR 41) 42from torch.testing._utils import wrapper_set_seed 43 44import torch._refs as refs # noqa: F401 45import torch._refs.nn.functional 46import torch._refs.special 47import torch._refs.linalg 48import torch._prims as prims # noqa: F401 49from torch.utils import _pytree as pytree 50 51 52from packaging import version 53 54from torch.testing._internal.opinfo.core import ( # noqa: F401 55 L, 56 M, 57 S, 58 XS, 59 _NOTHING, 60 _getattr_qual, 61 DecorateInfo, 62 SampleInput, 63 ErrorInput, 64 AliasInfo, 65 NumericsFilter, 66 OpInfo, 67 _generate_reduction_inputs, 68 _generate_reduction_kwargs, 69 sample_inputs_reduction, 70 ReductionOpInfo, 71 reference_inputs_elementwise_binary, 72 make_error_inputs_elementwise_binary, 73 generate_elementwise_binary_tensors, 74 generate_elementwise_binary_arbitrarily_strided_tensors, 75 generate_elementwise_binary_small_value_tensors, 76 generate_elementwise_binary_large_value_tensors, 77 generate_elementwise_binary_extremal_value_tensors, 78 generate_elementwise_binary_broadcasting_tensors, 79 generate_elementwise_binary_with_scalar_samples, 80 generate_elementwise_binary_with_scalar_and_type_promotion_samples, 81 generate_elementwise_binary_noncontiguous_tensors, 82 sample_inputs_elementwise_binary, 83 BinaryUfuncInfo, 84 sample_inputs_elementwise_unary, 85 generate_elementwise_unary_tensors, 86 generate_elementwise_unary_small_value_tensors, 87 generate_elementwise_unary_large_value_tensors, 88 generate_elementwise_unary_extremal_value_tensors, 89 reference_inputs_elementwise_unary, 90 UnaryUfuncInfo, 91 sample_inputs_spectral_ops, 92 SpectralFuncType, 93 SpectralFuncInfo, 94 ShapeFuncInfo, 95 sample_inputs_foreach, 96 ForeachFuncInfo, 97 gradcheck_wrapper_hermitian_input, 98 gradcheck_wrapper_triangular_input, 99 gradcheck_wrapper_triangular_input_real_positive_diagonal, 100 gradcheck_wrapper_masked_operation, 101 gradcheck_wrapper_masked_pointwise_operation, 102 clone_sample, 103) 104from torch.testing._internal.opinfo.refs import ( # NOQA: F401 105 _find_referenced_opinfo, 106 _inherit_constructor_args, 107 PythonRefInfo, 108 ReductionPythonRefInfo, 109 ElementwiseUnaryPythonRefInfo, 110 ElementwiseBinaryPythonRefInfo, 111) 112from torch.testing._internal.opinfo.utils import ( 113 np_unary_ufunc_integer_promotion_wrapper, 114 reference_reduction_numpy, 115 prod_numpy 116) 117from torch.testing._internal import opinfo 118from torch.testing._internal.opinfo.definitions.linalg import ( 119 sample_inputs_linalg_cholesky, 120 sample_inputs_linalg_cholesky_inverse, 121 sample_inputs_cross, 122 sample_inputs_linalg_qr_geqrf, 123 sample_inputs_linalg_invertible, 124 sample_inputs_lu_solve, 125 sample_inputs_legacy_solve, 126 sample_inputs_svd, 127 sample_inputs_linalg_det_logdet_slogdet, 128 sample_inputs_linalg_lu, 129 sample_inputs_diagonal_diag_embed, 130 error_inputs_diagonal_diag_embed, 131) 132from torch.testing._internal.opinfo.definitions.special import ( 133 sample_inputs_i0_i1, 134 sample_inputs_polygamma, 135 reference_polygamma, 136) 137from torch.testing._internal.opinfo.definitions._masked import ( 138 sample_inputs_softmax_variant, 139) 140from torch.testing._internal.opinfo.definitions.sparse import ( 141 error_inputs_sparse_like_fns, 142 sample_inputs_sparse_like_fns, 143 error_inputs_sparse_mul, 144 sample_inputs_sparse_mul, 145 error_inputs_sparse_reduction_sum, 146 sample_inputs_sparse_reduction_sum 147) 148 149if TEST_SCIPY: 150 from scipy import stats 151 import scipy.spatial 152 import scipy.special 153 154 155# test if a tensor is close to an integer 156def close_to_int(x, eps=0.1): 157 if x.is_complex(): 158 y = torch.abs(torch.view_as_complex(torch.frac(torch.view_as_real(x)))) 159 else: 160 y = torch.abs(torch.frac(x)) 161 return (y < eps) | (y > (1 - eps)) 162 163 164def sample_inputs_slice(op_info, device, dtype, requires_grad, **kwargs): 165 166 make_input = partial(make_tensor, device=device, dtype=dtype, 167 low=None, high=None, requires_grad=requires_grad) 168 169 yield SampleInput(make_input(3), 0) 170 171 yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2) 172 173 yield SampleInput(make_input(20, 30, 40), dim=1, start=1, end=-2, step=3) 174 175 yield SampleInput(make_input(20, 30, 40), dim=0, start=-10, end=-2, step=2) 176 177 178def sample_inputs_tensor_split(op_info, device, dtype, requires_grad, **kwargs): 179 make_input = partial(make_tensor, device=device, dtype=dtype, 180 low=None, high=None, requires_grad=requires_grad) 181 182 args_cases = ( 183 # Cases with tensor indices. 184 (torch.tensor([1, 2, 3]),), 185 (torch.tensor(1),), 186 (torch.tensor([1, 2, 3]), 1), 187 (torch.tensor([1, 4, 2, 5, 3, 6])[::2], 1), 188 # Cases with list of indices. 189 ((2, 4),), 190 ((2, 4), 1), 191 ((2, 4), -1), 192 # Cases with integer section. 193 (3,), 194 (3, 1), 195 (3, -1), 196 ) 197 198 for args in args_cases: 199 yield SampleInput(make_input((S, S, S)), args=args) 200 201 202def sample_inputs_hsplit(op_info, device, dtype, requires_grad, **kwargs): 203 make_arg = partial(make_tensor, dtype=dtype, device=device, 204 low=None, high=None, requires_grad=requires_grad) 205 yield SampleInput(make_arg(6), 2) 206 yield SampleInput(make_arg(S, S, S), [1, 2, 3]) 207 208def sample_inputs_vsplit(op_info, device, dtype, requires_grad, **kwargs): 209 make_arg = partial(make_tensor, dtype=dtype, device=device, 210 low=None, high=None, requires_grad=requires_grad) 211 yield SampleInput(make_arg(6, S), 2) 212 yield SampleInput(make_arg(S, S, S), [1, 2, 3]) 213 214def sample_inputs_dsplit(op_info, device, dtype, requires_grad, **kwargs): 215 make_arg = partial(make_tensor, dtype=dtype, device=device, 216 low=None, high=None, requires_grad=requires_grad) 217 yield SampleInput(make_arg(S, S, S), [1, 2, 3]) 218 yield SampleInput(make_arg(S, S, 6), 2) 219 220def error_inputs_hsplit(op_info, device, **kwargs): 221 make_arg = partial(make_tensor, dtype=torch.float32, device=device) 222 err_msg1 = ("torch.hsplit requires a tensor with at least 1 dimension, " 223 "but got a tensor with 0 dimensions!") 224 yield ErrorInput(SampleInput(make_arg(()), 0), error_regex=err_msg1) 225 226 err_msg2 = (f"torch.hsplit attempted to split along dimension 1, " 227 f"but the size of the dimension {S} " 228 f"is not divisible by the split_size 0!") 229 yield ErrorInput(SampleInput(make_arg((S, S, S)), 0), error_regex=err_msg2) 230 231 # Incorrect type for indices_or_section argument 232 err_msg3 = ("received an invalid combination of arguments.") 233 yield ErrorInput( 234 SampleInput(make_arg((S, S, S)), "abc"), 235 error_type=TypeError, error_regex=err_msg3) 236 237def error_inputs_vsplit(op_info, device, **kwargs): 238 make_arg = partial(make_tensor, dtype=torch.float32, device=device) 239 err_msg1 = ("torch.vsplit requires a tensor with at least 2 dimension, " 240 "but got a tensor with 1 dimensions!") 241 yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1) 242 243 err_msg2 = (f"torch.vsplit attempted to split along dimension 0, " 244 f"but the size of the dimension {S} " 245 f"is not divisible by the split_size 0!") 246 yield ErrorInput(SampleInput(make_arg(S, S, S), 0), 247 error_regex=err_msg2) 248 249 # Incorrect type for indices_or_section argument 250 err_msg3 = ("received an invalid combination of arguments.") 251 yield ErrorInput(SampleInput(make_arg(S, S, S), "abc"), 252 error_type=TypeError, error_regex=err_msg3) 253 254def error_inputs_dsplit(op_info, device, **kwargs): 255 make_arg = partial(make_tensor, dtype=torch.float32, device=device) 256 err_msg1 = ("torch.dsplit requires a tensor with at least 3 dimension, " 257 "but got a tensor with 1 dimensions!") 258 yield ErrorInput(SampleInput(make_arg(S), 0), error_regex=err_msg1) 259 260 err_msg2 = (f"torch.dsplit attempted to split along dimension 2, " 261 f"but the size of the dimension {S} " 262 f"is not divisible by the split_size 0!") 263 yield ErrorInput(SampleInput(make_arg(S, S, S), 0), error_regex=err_msg2) 264 265 266def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs): 267 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 268 269 # input shape, output shape, output stride, output storage offset 270 test_cases = ( 271 ((1,), (1,), (1,), 0), 272 ((3, 3), (2, 2), (1, 2), 0), 273 ((3, 3), (2, 2), (1, 2), 1), 274 ((16,), (2, 2, 2, 2), (1, 1, 1, 1), 0), 275 ((16,), (2, 1, 1, 2), (1, 7, 7, 1), 0), 276 ) 277 278 for input_shape, output_shape, stride, storage_offset in test_cases: 279 input_t = make_arg(input_shape) 280 kwargs = dict(storage_offset=storage_offset) 281 yield SampleInput(input_t, args=(output_shape, stride), kwargs=kwargs) 282 283def sample_inputs_as_strided_partial_views(op_info, device, dtype, requires_grad, **kwargs): 284 def make_arg(): 285 base = make_tensor((20,), device=device, dtype=dtype) 286 return base[5:15].requires_grad_(requires_grad) 287 288 # as_strided on offset, partial views 289 yield SampleInput(make_arg(), (2, 2), (1, 2)) 290 yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=0) 291 yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=10) 292 293def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs): 294 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 295 296 # input shape, output shape, output stride, output storage offset 297 test_cases = [ 298 ((1,), (), (), 0), 299 ((1,), (1,), (1,), 0), 300 ((3, 3), (2, 2), (1, 2), 0), 301 ((3, 3), (2, 2), (1, 2), 1), 302 ((3, 3), (2, 2), (2, 1), 0), 303 # Scatter to larger dimensions 304 ((16,), (2, 2, 2, 2), (8, 4, 2, 1), 0), 305 # Scatter to larger dimensions with strides inverted 306 ((16,), (2, 1, 1, 2), (1, 2, 4, 8), 0), 307 ] 308 309 for input_shape, output_shape, stride, storage_offset in test_cases: 310 input_t = make_arg(input_shape) 311 input_src = make_arg(output_shape) 312 yield SampleInput(input_t, input_src, output_shape, stride, storage_offset=storage_offset) 313 314 315def error_inputs_as_strided_scatter(op_info, device, **kwargs): 316 make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) 317 318 # Create a small tensor and try to scatter it out of bounds 319 input_t = make_arg([4, 4]) 320 input_src = make_arg([2, 2]) 321 yield ErrorInput( 322 SampleInput(input_t, input_src, [2, 2], [200, 200], storage_offset=0), 323 error_regex="itemsize 4 requiring a storage size of 1604 are out of bounds for storage of size 64" 324 ) 325 326 327def sample_inputs_combinations(op_info, device, dtype, requires_grad, **kwargs): 328 inputs = ( 329 (0,), 330 (0, 1), 331 (0, 1, 2, 3), 332 ) 333 334 rvals = [1, 2, 4] 335 336 products = product(inputs, rvals, [False, True]) 337 338 for input_data, r, with_replacement in products: 339 input_t = torch.tensor(input_data, device=device, dtype=dtype, requires_grad=requires_grad) 340 yield SampleInput(input_t, r=r, with_replacement=with_replacement) 341 342def sample_inputs_cartesian_prod(op_info, device, dtype, requires_grad, **kwargs): 343 make_arg = partial(torch.tensor, device=device, dtype=dtype, requires_grad=requires_grad) 344 345 # constructs 1-D tensors with varying number of elements 346 a = make_arg((0,)) 347 b = make_arg((0, 1)) 348 c = make_arg((0, 1, 2, 3)) 349 350 # sample with only 1 tensor 351 yield SampleInput(a) 352 353 # sample with 2 tensors 354 yield SampleInput(a, b) 355 356 # sample with 3 tensors 357 yield SampleInput(a, b, c) 358 359def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwargs): 360 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 361 362 # Ordered as input_shape, dict of dim and eps 363 cases: Tuple[tuple, dict] = ( # type: ignore[assignment] 364 ((S, S), {'dim': 1}), 365 ((S, 2), {'dim': -1}), 366 ((S,), {'dim': 0, 'eps': 0.5}), 367 ((), {'dim': 0}), 368 ((S, S, M), {'dim': 2}), 369 ((S, S), {}) 370 ) 371 372 for input_shape, kwargs in cases: 373 yield SampleInput(make_arg(input_shape), args=(make_arg(input_shape),), kwargs=kwargs) 374 # Test for Broadcasting 375 yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) 376 yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2}) 377 yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) 378 379 380def sample_inputs_item(op_info, device, dtype, requires_grad, **kwargs): 381 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) 382 383 cases = ( 384 (), 385 (()), 386 (1), 387 ((1,)), 388 ) 389 390 for shape in cases: 391 yield SampleInput(make_arg(shape)) 392 393def error_inputs_item(op, device, **kwargs): 394 make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False) 395 396 cases = ( 397 (M), 398 ((S,)), 399 (S, S), 400 (S, M, L), 401 ) 402 403 for shape in cases: 404 yield ErrorInput( 405 SampleInput(make_arg(shape)), error_type=RuntimeError, 406 error_regex="elements cannot be converted to Scalar") 407 408 409def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs): 410 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 411 make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 412 413 # Ordered as: input shape, kwargs for training, momentum, eps 414 cases: Tuple[Tuple[int], dict] = ( # type: ignore[assignment] 415 ((S, S, S), {'training': True, 'momentum': 0.5, 'eps': 0.6}), 416 ((3, 2, 4), {'training': False, 'momentum': -1.2}), 417 ((3, 1), {'training': True, 'momentum': 0.0}), 418 ((0,), {'training': True}), 419 ((0,), {'training': False}), 420 ((3, 2, 3, 4), {'training': True, 'momentum': -1.0, 'eps': 0.5}), 421 ((3, 2, 3, 4), {'training': False, 'momentum': -1.0, 'eps': 0.5}), 422 ((2, 1), {}), 423 ) 424 425 for input_shape, kwargs in cases: 426 # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) 427 channels = input_shape[1] if len(input_shape) > 1 else 0 428 weight = make_arg(channels) if channels > 0 else None 429 bias = make_arg(channels) if channels > 0 else None 430 running_mean = make_arg_without_requires_grad(channels, low=0) 431 running_var = make_arg_without_requires_grad(channels, low=0) 432 433 yield SampleInput( 434 make_arg(input_shape), 435 args=( 436 running_mean, 437 running_var, 438 weight, 439 bias 440 ), 441 kwargs=kwargs 442 ) 443 444 # Checking for permutations of weights and biases as `None` 445 weights = [channels, None, None] 446 biases = [None, channels, None] 447 is_training = [True, False, False] 448 449 for weight, bias, training in zip(weights, biases, is_training): 450 yield SampleInput( 451 make_arg(input_shape), 452 args=( 453 running_mean, 454 running_var, 455 make_arg(channels), 456 make_arg(channels) 457 ), 458 kwargs={'training': training} 459 ) 460 461 # Test case for no optional kwargs 462 # running_mean and running_var are required in evaluation mode (training: False) but not in training mode 463 yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True}) 464 465def sample_inputs_softmax_backward_data(op_info, device, dtype, requires_grad, **kwargs): 466 make_arg = partial( 467 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad 468 ) 469 cases = [ 470 ((S,), 0), 471 ((S, S), 0), 472 ((S, M, S), -1), 473 ] 474 input_dtypes = [dtype] 475 if dtype == torch.float and device == 'cuda': 476 input_dtypes += [torch.float16] 477 478 for (shape, dim), input_dtype in product(cases, input_dtypes): 479 input = make_arg(shape) 480 output = torch.nn.functional.softmax(input, dim=dim, dtype=input_dtype) 481 yield SampleInput(make_arg(shape), output, dim, input_dtype) 482 483def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs): 484 samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) 485 for sample in samples: 486 # torch.native_batch_norm does not support 0 numel tensors 487 # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) 488 if sample.input.numel() == 0: 489 continue 490 args = sample.args 491 training = sample.kwargs.get('training', True) 492 momentum = sample.kwargs.get('momentum', 0.5) 493 eps = sample.kwargs.get('eps', 1e-5) 494 yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps)) 495 496 497def sample_inputs__native_batch_norm_legit(op_info, device, dtype, requires_grad, **kwargs): 498 samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) 499 for sample in samples: 500 # torch.native_batch_norm does not support 0 numel tensors 501 # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) 502 if sample.input.numel() == 0: 503 continue 504 args = sample.args 505 training = sample.kwargs.get('training', True) 506 momentum = sample.kwargs.get('momentum', 0.5) 507 eps = sample.kwargs.get('eps', 1e-5) 508 if args[0] is not None and args[1] is not None: 509 yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps)) 510 else: 511 yield SampleInput(sample.input, args=(args[2], args[3], training, momentum, eps)) 512 513def sample_inputs__batch_norm_with_update(op_info, device, dtype, requires_grad, **kwargs): 514 samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) 515 for sample in samples: 516 # torch.native_batch_norm does not support 0 numel tensors 517 # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) 518 if sample.input.numel() == 0: 519 continue 520 args = sample.args 521 momentum = sample.kwargs.get('momentum', 0.5) 522 eps = sample.kwargs.get('eps', 1e-5) 523 if any(args[i] is None for i in range(4)): 524 continue 525 yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], momentum, eps)) 526 527def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs): 528 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 529 530 cases = ( 531 (()), 532 ((S, )), 533 ((S, S)), 534 ((S, M, S)) 535 ) 536 537 for shape in cases: 538 yield SampleInput(make_arg(shape)) 539 540def sample_inputs_prelu(op_info, device, dtype, requires_grad, **kwargs): 541 op_kwargs = op_info.sample_kwargs(device, dtype, None)[0] 542 yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad, 543 op_kwargs=op_kwargs) 544 545 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 546 547 cases = ( 548 (()), 549 ((S, )), 550 ((S, S)), 551 ((S, M, S)) 552 ) 553 554 for shape in cases: 555 for weight in [-1., 0., 0.8, 1.]: 556 weight_tensor = torch.tensor(weight, device=device, dtype=dtype, requires_grad=requires_grad) 557 yield SampleInput(make_arg(shape), args=(weight_tensor,)) 558 559 channel_size = shape[1] if len(shape) >= 2 else 1 560 yield SampleInput(make_arg(shape), args=(make_arg((channel_size,)),)) 561 562 weight_tensor = torch.tensor(1., device=device, dtype=dtype, requires_grad=requires_grad) 563 564 yield SampleInput(make_arg((S, S)), kwargs=dict(weight=weight_tensor,)) 565 yield SampleInput(make_arg((S, S)), kwargs=dict(weight=make_arg((S,)),)) 566 567def reference_inputs_prelu(op, device, dtype, requires_grad, **kwargs): 568 yield from sample_inputs_prelu(op, device, dtype, requires_grad, **kwargs) 569 yield from reference_inputs_elementwise_unary(op, device, dtype, requires_grad, **kwargs) 570 571def sample_kwargs_prelu_scalar_weight(device, dtype, input): 572 weight = torch.rand((), device=device, dtype=dtype) 573 # NumPy does not support bfloat16, so we default to float32 (only for NumPy) in that case 574 if dtype == torch.bfloat16: 575 weight_cpu = weight.to(dtype=torch.float32, device="cpu") 576 else: 577 weight_cpu = weight.cpu() 578 np_weight = weight_cpu.numpy() 579 return ({'weight': weight}, {'weight': np_weight}) 580 581def error_inputs_prelu(op, device): 582 # Weight has numel != 1, but self.ndim is zero-dim tensor 583 inp = make_tensor((), device=device, dtype=torch.float32) 584 weight = make_tensor((2,), device=device, dtype=torch.float32) 585 yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), 586 error_regex="Not allow zero-dim input tensor.") 587 588 # Weight has numel != 1, but numel does not match channel size 589 inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32) 590 weight = make_tensor((9,), device=device, dtype=torch.float32) 591 yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), 592 error_regex="Mismatch of parameter numbers and input channel size.") 593 594 # Weight is neither a scalar nor 1-D tensor 595 inp = make_tensor((2, 8, 3), device=device, dtype=torch.float32) 596 weight = make_tensor((2, 4), device=device, dtype=torch.float32) 597 yield ErrorInput(SampleInput(inp, kwargs={'weight': weight}), 598 error_regex="prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = 2") 599 600 # src and index tensors must have the same # of dimensions 601def sample_inputs_norm(op_info, device, dtype, requires_grad, **kwargs): 602 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 603 604 # ord = inf is tested in inputs_norm_inf as it fails on some tests 605 cases = [ 606 ((S, S), (2,), '2'), 607 ((S, S), (0,), '0'), 608 ((S, S), (0.5,), '0_5'), 609 ((S, S), (1,), '1'), 610 ((S, S), (3,), '3'), 611 ((S, S), (-1,), 'neg_1'), 612 ((S, S), (-2,), 'neg_2'), 613 ((S, S), (-0.5,), 'neg_0_5'), 614 ((S, S), (-1.5,), 'neg_1_5'), 615 ] 616 617 cases_nonzero_input = ( 618 ((S, S, S), (1.5,), '1_5_default'), 619 ((S, S, S), (1.5, 1), '1_5_dim'), 620 ((S, S, S), (1.5, -1), '1_5_neg_dim'), 621 ((S, S, S), (1.5, 1, True), 'keepdim_1_5_dim'), 622 ((S, S, S), (1.5, -1, True), 'keepdim_1_5_neg_dim'), 623 ) 624 625 cases_posdim = ( 626 ((S, S), (-2, 1,), 'neg_2_dim'), 627 ((S, S), (-1, 1,), 'neg_1_dim'), 628 ((S, S), (0, 1,), '0_dim'), 629 ((S, S), (1, 1,), '1_dim'), 630 ((S, S), (2, 1,), '2_dim'), 631 ((S, S), (3, 1,), '3_dim'), 632 ((S, S, S), (2, 1), '2_dim'), 633 ((S, S, S), (3, 1), '3_dim'), 634 ((S, S, S), (2, 1, True), 'keepdim_2_dim'), 635 ((S, S, S), (3, 1, True), 'keepdim_3_dim'), 636 ((), (2, 0), '2_dim_scalar'), 637 ((), (3, 0), '3_dim_scalar'), 638 ((), (2, 0, True), 'keepdim_2_dim_scalar'), 639 ((), (3, 0, True), 'keepdim_3_dim_scalar'), 640 ) 641 642 cases_negdim = ((shape, args[:1] + (-args[1],) + args[2:], name.replace("_dim", "_neg_dim")) 643 for shape, args, name in cases_posdim) 644 645 for shape, args, name in itertools.chain(cases, cases_posdim, cases_negdim): 646 yield SampleInput(make_arg(shape), args=args, name=name) 647 648 for shape, args, name in cases_nonzero_input: 649 yield SampleInput(make_arg(shape, exclude_zero=True), args=args, name=name) 650 651 652def sample_inputs_norm_fro(op_info, device, dtype, requires_grad, **kwargs): 653 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 654 655 cases = ( 656 ((S, S), (), 'default'), 657 ((S, S), ('fro',), 'fro_default'), 658 ((S, S), ('fro', [0, 1],), 'fro'), 659 ) 660 661 for shape, args, name in cases: 662 yield SampleInput(make_arg(shape), args=args, name=name) 663 664 665def sample_inputs_norm_nuc(op_info, device, dtype, requires_grad, **kwargs): 666 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 667 668 cases = ( 669 ((S, S), ('nuc',), 'nuc'), 670 ((S, S, S), ('nuc', [1, 2]), 'nuc_batched'), 671 ) 672 673 for shape, args, name in cases: 674 yield SampleInput(make_arg(shape), args=args, name=name) 675 676 677def sample_inputs_norm_inf(op_info, device, dtype, requires_grad, **kwargs): 678 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 679 680 cases = ( 681 ((S, S), (-inf,), '-inf'), 682 ((S, S), (inf,), 'inf'), 683 ((S, S), (inf, 1,), 'inf_2_dim'), 684 ((S, S), (inf, -1,), 'inf_2_neg_dim'), 685 ) 686 687 for shape, args, name in cases: 688 yield SampleInput(make_arg(shape), args=args, name=name) 689 690 691def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs): 692 make_arg = partial( 693 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 694 695 shapes = ( 696 ((), ()), 697 ((S,), ()), 698 ((), (S,)), 699 ((S, 1), (S,)), 700 ((M, S), ()), 701 ((S, S), (S, S)) 702 ) 703 704 for shape_lhs, shape_rhs in shapes: 705 lhs = make_arg(shape_lhs) 706 rhs = make_arg(shape_rhs) 707 broadcasts_input = shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs) 708 709 yield SampleInput(lhs, args=(rhs,), broadcasts_input=broadcasts_input) 710 if shape_lhs == shape_rhs: 711 yield SampleInput(lhs, args=(lhs.clone().detach_(),)) 712 713 714def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs): 715 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 716 717 shapes = ( 718 ((), ()), 719 ((S,), ()), 720 ((S, 1), (S,)), 721 ((M, S), ()), 722 ((S, M, S), (M, S)), 723 ((S, M, S), (S, M, S)), 724 ((M, 1, S), (M, S)), 725 ((M, 1, S), (1, M, S)), 726 ((0, 1, 3), (0, 10, 3)) 727 ) 728 729 num_inputs = kwargs.get('num_inputs') 730 sample_kwargs = kwargs.get('sample_kwargs', {}) 731 732 for shape_lhs, shape_rhs in shapes: 733 lhs = make_arg(shape_lhs) 734 735 args = [] 736 for i in range(num_inputs - 1): 737 args.append(make_arg(shape_rhs)) 738 broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)) 739 740 yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input) 741 742def sample_inputs_broadcast_shapes(op, device, dtype, requires_grad, **kwargs): 743 shapes = ( 744 ((), ()), 745 ((S,), ()), 746 ((S, 1), (S,)), 747 ((S, 1), S), 748 ((M, S), ()), 749 ((S, M, S), (M, S)), 750 ((S, M, S), (S, M, S)), 751 ((M, 1, S), (M, S)), 752 ((M, 1, S), (1, M, S)), 753 ((0, 1, 3), (0, 10, 3)) 754 ) 755 756 for shape in shapes: 757 inp, *arg0 = shape 758 yield SampleInput(inp, args=tuple(arg0)) 759 760def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs): 761 yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) 762 763 # Adds alpha kwarg cases 764 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 765 lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) 766 rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) 767 if dtype is not torch.bool: 768 yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': 2}) 769 else: 770 yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': True}) 771 neg_alpha = -3.125 if (dtype.is_floating_point or dtype.is_complex) else -3 772 lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) 773 rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) 774 if dtype is not torch.bool: 775 yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': neg_alpha}) 776 else: 777 yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False}) 778 779def error_inputs_arange(op, device, **kwargs): 780 yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzer') 781 yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign') 782 yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign') 783 yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range') 784 yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range') 785 786def sample_inputs_arange(op, device, dtype, requires_grad, **kwargs): 787 int_samples = ( 788 # positive direction 789 (-1, 2, 2), 790 # negative direction 791 (2, -3, -1), 792 # start == end 793 (1, 1, 1), 794 (1, 1, -1), 795 # divides evenly 796 (0, -8, -4), 797 (1, 5, 2), 798 # bool 799 (False, True, True), 800 # default step 801 (0, 1, None), 802 # default start 803 (None, 3, None), 804 ) 805 806 def to_float(start, end, step): 807 start = start + 0.1 if start is not None else None 808 end = end + 0.1 809 step = float(step) if step is not None else None 810 return start, end, step 811 812 float_samples = ( 813 # includes endpoint 814 (0., -8. - 1e-6, -4.), 815 (1., 5. + 1e-6, 2.), 816 (0., -8., -4.), 817 (1., 5., 2.), 818 *(to_float(start, end, step) for (start, end, step) in int_samples), 819 ) 820 821 large_samples = ( 822 (0, 10000, None), 823 ) 824 825 samples = int_samples + float_samples 826 if dtype not in (torch.int8, torch.uint8): 827 samples += large_samples 828 829 for start, end, step in samples: 830 if start is None: 831 assert step is None 832 # Pass end as positional arg 833 yield SampleInput(end, kwargs={"dtype": dtype, "device": device}) 834 # (Similar to) calling torch.arange(end=3) 835 yield SampleInput(0, kwargs={"end": end, "dtype": dtype, "device": device}) 836 elif step is None: 837 yield SampleInput(start, args=(end,), kwargs={"dtype": dtype, "device": device}) 838 else: 839 yield SampleInput(start, args=(end, step), kwargs={"dtype": dtype, "device": device}) 840 841 yield SampleInput(2) 842 yield SampleInput(1, args=(3, 1)) 843 844def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): 845 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) 846 847 shapes = ( 848 (M,), 849 (S, S) 850 ) 851 852 for shape in shapes: 853 yield SampleInput(input=shape, kwargs=dict(dtype=dtype, device=device, requires_grad=requires_grad)) 854 855def sample_inputs_normal(op, device, dtype, requires_grad, **kwargs): 856 857 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) 858 samples = ( 859 ((S, S), 0, 5), 860 ((S, S, S), -2, 0.5), 861 ) 862 for shape, mean, std in samples: 863 yield SampleInput(make_arg(shape), args=(mean, std)) 864 865def error_inputs_normal(op, device, **kwargs): 866 t = torch.zeros([10], device=device) 867 invalid_std = -1 868 yield ErrorInput( 869 SampleInput(t, args=(0, invalid_std)), 870 error_type=RuntimeError, 871 error_regex=fr"normal expects std >= 0.0, but found std {invalid_std}", 872 ) 873 874def sample_inputs_cauchy(op, device, dtype, requires_grad, **kwargs): 875 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) 876 samples = ( 877 ((M,), 0, 0.5), 878 ((S, S), 0, 1), 879 ((S, S, S), -2, 1), 880 ) 881 for shape, median, gamma in samples: 882 yield SampleInput(make_arg(shape), args=(median, gamma)) 883 884 885def error_inputs_cauchy(op, device, **kwargs): 886 t = torch.zeros([10], device=device) 887 invalid_scale = 0 888 yield ErrorInput( 889 SampleInput(t, args=(0, invalid_scale,)), 890 error_type=RuntimeError, 891 error_regex=fr"cauchy_ expects sigma > 0.0, but found sigma={invalid_scale}", 892 ) 893 894 895def sample_inputs_exponential(op, device, dtype, requires_grad, **kwargs): 896 897 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) 898 samples = ( 899 ((M,), 0.5), 900 ((S, S), 1), 901 ((S, S, S), 1.5), 902 ) 903 for shape, rate in samples: 904 yield SampleInput(make_arg(shape), args=(rate,)) 905 906 907def error_inputs_exponential(op, device, **kwargs): 908 t = torch.zeros([10], device=device) 909 invalid_rate = 0 910 yield ErrorInput( 911 SampleInput(t, args=(invalid_rate,)), 912 error_type=RuntimeError, 913 error_regex=fr"exponential_ expects lambda > 0.0, but found lambda={invalid_rate}", 914 ) 915 916 917def sample_inputs_geometric(op, device, dtype, requires_grad, **kwargs): 918 919 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) 920 samples = ( 921 ((M,), 0.2), 922 ((S, S), 0.5), 923 ((S, S, S), 0.8), 924 ) 925 for shape, rate in samples: 926 yield SampleInput(make_arg(shape), args=(rate,)) 927 928 929def error_inputs_geometric(op, device, **kwargs): 930 t = torch.zeros([10], device=device) 931 neg_prob = -1 932 yield ErrorInput( 933 SampleInput(t, args=(neg_prob,)), 934 error_type=RuntimeError, 935 error_regex=fr"geometric_ expects p to be in \(0, 1\), but got p={neg_prob}", 936 ) 937 938 939def sample_inputs_log_normal(op, device, dtype, requires_grad, **kwargs): 940 941 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) 942 samples = ( 943 ((M,), 0, 0.25), 944 ((S, S), 0.5, 1), 945 ((S, S, S), 0, 0.5), 946 ) 947 for shape, mean, std in samples: 948 yield SampleInput(make_arg(shape), args=(mean, std)) 949 950 951def error_inputs_log_normal(op, device, **kwargs): 952 t = torch.zeros([10], device=device) 953 invalid_std = 0 954 yield ErrorInput( 955 SampleInput(t, args=(0, invalid_std)), 956 error_type=RuntimeError, 957 error_regex=fr"log_normal_ expects std > 0.0, but found std={invalid_std}", 958 ) 959 960 961def sample_inputs_uniform(op, device, dtype, requires_grad, **kwargs): 962 963 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) 964 samples = ( 965 ((M,), -100, 100), 966 ((S, S), 0, 1), 967 ((S, S, S), 1, 2), 968 ) 969 for shape, hi, lo in samples: 970 yield SampleInput(make_arg(shape), args=(hi, lo)) 971 972def sample_inputs_ones_zeros(op, device, dtype, requires_grad, **kwargs): 973 # this is a bit messy, as we want the args to be tuples 974 # so if we pass size as a tuple, we have a tuple containing a tuple 975 sizes = ( 976 (M,), 977 (S, S), 978 ) 979 for size in sizes: 980 yield SampleInput(size, kwargs={'dtype': dtype, 'device': device}) 981 982def sample_inputs_full(op, device, dtype, requires_grad, **kwargs): 983 def get_val(dtype): 984 return make_tensor([], dtype=dtype, device="cpu").item() 985 986 sizes = ( 987 (M,), 988 (S, S), 989 ) 990 fill_values = [get_val(dtype), get_val(torch.int)] 991 992 for size, fill_value in product(sizes, fill_values): 993 yield SampleInput(size, fill_value, dtype=dtype, device=device) 994 995 996def error_inputs_uniform(op, device, **kwargs): 997 t = torch.zeros([10], device=device) 998 yield ErrorInput( 999 SampleInput(t, args=(3, -1)), 1000 error_type=RuntimeError, 1001 error_regex=r"uniform_ expects to return a \[from, to\) range, but found from=3 > to=-1", 1002 ) 1003 1004 1005def error_inputs_linspace(op, device, **kwargs): 1006 yield ErrorInput(SampleInput(0, args=(3, -1)), error_type=RuntimeError, error_regex='number of steps must be non-negative') 1007 yield ErrorInput( 1008 SampleInput(0, args=(3, 1.)), 1009 error_type=TypeError, 1010 error_regex="received an invalid combination of arguments - got \\(int, int, float", 1011 ) 1012 yield ErrorInput( 1013 SampleInput(torch.tensor([1, 1], device=device), args=(torch.tensor([3, 3], device=device), 1)), 1014 error_type=RuntimeError, 1015 error_regex="only supports 0-dimensional start and end tensors" 1016 ) 1017 1018 1019def sample_inputs_linspace(op, device, dtype, requires_grad, **kwargs): 1020 ends = (-3, 0, 1, 4, 50) 1021 starts = (-2., 0, 4.3, 50) 1022 nsteps = (0, 1, 50) 1023 # Extra case to replicate off-by-one issue on CUDA 1024 cases = list(product(starts, ends, nsteps)) + [(0, 7, 50)] 1025 for start, end, nstep in cases: 1026 if dtype == torch.uint8 and (end < 0 or start < 0): 1027 continue 1028 yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device}) 1029 1030 yield SampleInput(1, args=(3, 1)) 1031 1032 1033def sample_inputs_linspace_tensor_overload(op, device, dtype, requires_grad, **kwargs): 1034 ends = (-3, 0, 1, 4, 50) 1035 starts = (-2., 0, 4.3, 50) 1036 nsteps = (0, 1, 50) 1037 is_start_end_tensors = ((True, True), (True, False), (False, True)) 1038 make_arg = partial(torch.tensor, device=device, requires_grad=False) 1039 1040 # Extra case to replicate off-by-one issue on CUDA 1041 cases = list(product(starts, ends, nsteps, is_start_end_tensors)) + [(0, 7, 50, (True, True))] 1042 for start, end, nstep, (is_start_tensor, is_end_tensor) in cases: 1043 if dtype == torch.uint8 and (end < 0 or start < 0): 1044 continue 1045 1046 tensor_options = {"dtype": dtype, "device": device} 1047 if is_start_tensor: 1048 start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64) 1049 if is_end_tensor: 1050 end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64) 1051 1052 yield SampleInput(start, args=(end, nstep), kwargs=tensor_options) 1053 1054 yield SampleInput(1, args=(3, 1)) 1055 1056 1057def sample_inputs_logspace(op, device, dtype, requires_grad, **kwargs): 1058 ends = (-3, 0, 1.2, 2, 4) 1059 starts = (-2., 0, 1, 2, 4.3) 1060 nsteps = (0, 1, 2, 4) 1061 bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.) 1062 for start, end, nstep, base in product(starts, ends, nsteps, bases): 1063 if dtype == torch.uint8 and end < 0 or start < 0: 1064 continue 1065 if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point): 1066 # https://github.com/pytorch/pytorch/issues/82242 1067 continue 1068 if base is None: 1069 yield SampleInput(start, args=(end, nstep), kwargs={"dtype": dtype, "device": device}) 1070 else: 1071 yield SampleInput(start, args=(end, nstep, base), kwargs={"dtype": dtype, "device": device}) 1072 1073 yield SampleInput(1, args=(3, 1, 2.)) 1074 1075 1076def sample_inputs_logspace_tensor_overload(op, device, dtype, requires_grad, **kwargs): 1077 ends = (-3, 0, 1.2, 2, 4) 1078 starts = (-2., 0, 1, 2, 4.3) 1079 nsteps = (0, 1, 2, 4) 1080 bases = (2., 1.1) if dtype in (torch.int8, torch.uint8) else (None, 2., 3., 1.1, 5.) 1081 is_start_end_tensors = ((True, True), (True, False), (False, True)) 1082 make_arg = partial(torch.tensor, device=device) 1083 for start, end, nstep, base, (is_start_tensor, is_end_tensor) in product(starts, ends, nsteps, bases, is_start_end_tensors): 1084 if dtype == torch.uint8 and end < 0 or start < 0: 1085 continue 1086 if nstep == 1 and isinstance(start, float) and not (dtype.is_complex or dtype.is_floating_point): 1087 # https://github.com/pytorch/pytorch/issues/82242 1088 continue 1089 1090 tensor_options = {"dtype": dtype, "device": device} 1091 1092 if (is_start_tensor): 1093 start = make_arg(start, dtype=torch.float32 if isinstance(start, float) else torch.int64) 1094 if (is_end_tensor): 1095 end = make_arg(end, dtype=torch.float32 if isinstance(end, float) else torch.int64) 1096 1097 if base is None: 1098 yield SampleInput(start, args=(end, nstep), kwargs=tensor_options) 1099 else: 1100 yield SampleInput(start, args=(end, nstep, base), kwargs=tensor_options) 1101 1102 yield SampleInput(1, args=(3, 1, 2.)) 1103 1104 1105def sample_inputs_isclose(op, device, dtype, requires_grad, **kwargs): 1106 yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) 1107 1108 # Creates additional inputs to test the rtol, atol, and equal_nan params 1109 rtols = [0., 1e-7] 1110 atols = [0., 1e-7] 1111 equal_nans = [False, True] 1112 1113 products = product(rtols, atols, equal_nans) 1114 1115 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1116 for rtol, atol, equal_nan in products: 1117 lhs = make_arg((S, S), **op.lhs_make_tensor_kwargs) 1118 rhs = make_arg((S, S), **op.rhs_make_tensor_kwargs) 1119 1120 yield SampleInput(lhs, args=(rhs,), 1121 kwargs=dict(rtol=rtol, atol=atol, equal_nan=equal_nan)) 1122 1123 1124def error_inputs_isclose(op, device, **kwargs): 1125 make_float_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) 1126 1127 yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'rtol': -0.4}), 1128 error_type=RuntimeError, 1129 error_regex='rtol must be greater than or equal to zero') 1130 1131 yield ErrorInput(SampleInput(make_float_arg(()), args=(make_float_arg(()),), kwargs={'atol': -0.4}), 1132 error_type=RuntimeError, 1133 error_regex='atol must be greater than or equal to zero') 1134 1135 1136def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs): 1137 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1138 yield SampleInput(make_arg((1, 2))) 1139 yield SampleInput(make_arg((2,))) 1140 yield SampleInput(make_arg(())) 1141 1142 1143def sample_inputs_mm(op_info, device, dtype, requires_grad, **kwargs): 1144 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1145 1146 def make_arg_conj(size): 1147 return make_arg(size).conj().requires_grad_(requires_grad) 1148 1149 first_shape, second_shape = (S, M), (M, S) 1150 1151 yield SampleInput(make_arg(first_shape), args=(make_arg(second_shape),)) 1152 1153 if dtype.is_complex: 1154 yield SampleInput(make_arg(first_shape), args=(make_arg_conj(second_shape),)) 1155 1156 # Matmul of empty matrices 1157 yield SampleInput(make_arg((0, S)), args=(make_arg(S, M),)) 1158 yield SampleInput(make_arg((S, 0)), args=(make_arg(0, M),)) 1159 1160 1161def sample_inputs_addmm(op_info, device, dtype, requires_grad, **kwargs): 1162 alpha_val = kwargs.get('alpha', 2 + 3j if dtype.is_complex else 0.6) 1163 beta_val = kwargs.get('beta', 1 + 2j if dtype.is_complex else 0.2) 1164 tests_list = [ 1165 ((2, 3), (2, 2), (2, 3), False), 1166 ((3, 3), (3, 3), (3, 3), False), 1167 ] 1168 tests_with_lhs_broadcasting = [ 1169 ((1,), (2, 2), (2, 3), True), 1170 ((), (2, 2), (2, 3), True), 1171 ] 1172 test_cases = tests_list + tests_with_lhs_broadcasting # type: ignore[operator] 1173 1174 kwargs = dict(alpha=alpha_val, beta=beta_val) 1175 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 1176 for shape_a, shape_b, shape_c, broadcasts_input in test_cases: 1177 yield SampleInput( 1178 make_arg(shape_a), 1179 make_arg(shape_b), 1180 make_arg(shape_c), 1181 **kwargs, 1182 ).with_metadata(broadcasts_input=broadcasts_input) 1183 1184 if dtype.is_complex: 1185 shape = (3, 3) 1186 yield SampleInput( 1187 make_arg(shape), 1188 make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad), 1189 make_arg(shape), 1190 **kwargs, 1191 ) 1192 yield SampleInput( 1193 make_arg(shape), 1194 make_arg(shape), 1195 make_arg(shape, requires_grad=False).mH.requires_grad_(requires_grad), 1196 **kwargs, 1197 ) 1198 # addmm of empty matrices 1199 if dtype.is_floating_point: 1200 yield SampleInput(make_arg(S, M), make_arg(S, 0), make_arg(0, M), **kwargs) 1201 # empty matmul with broadcastable input 1202 yield SampleInput(make_arg(M), make_arg(S, 0), make_arg(0, M), **kwargs).with_metadata(broadcasts_input=True) 1203 1204def sample_inputs_sparse_sampled_addmm(op_info, device, dtype, requires_grad, **kwargs): 1205 alpha = 2 + 3j if dtype.is_complex else 0.6 1206 beta = 1 + 2j if dtype.is_complex else 0.2 1207 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1208 1209 # sparse.sampled_addmm performs: alpha * (A @ B) * sparse_ones_like(C) + beta * C 1210 for m, n, k in itertools.product([0, 5], repeat=3): 1211 yield SampleInput( 1212 torch.eye(m, n, device=device, dtype=dtype) 1213 .to_sparse_csr() 1214 .requires_grad_(requires_grad), 1215 make_arg((m, k)), 1216 make_arg((k, n)), 1217 alpha=alpha, 1218 beta=beta, 1219 ) 1220 1221def sample_inputs_sparse_mm_reduce(op_info, device, dtype, requires_grad, **kwargs): 1222 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1223 1224 reductions = ["sum", "mean", "amax", "amin"] 1225 for m, k, reduce in product([5, 7], [3, 11], reductions): 1226 yield SampleInput( 1227 torch.eye(m, m) 1228 .to(device=device, dtype=dtype) 1229 .to_sparse_csr() 1230 .requires_grad_(requires_grad), 1231 make_arg((m, k)), 1232 reduce, 1233 ) 1234 1235 1236def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs): 1237 make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) 1238 yield SampleInput(make_arg(S, M), make_arg(M)) 1239 1240def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs): 1241 make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) 1242 yield SampleInput(make_arg(M, S, M), make_arg(M, M, S)) 1243 1244def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs): 1245 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1246 1247 def make_arg_conj(size): 1248 return make_arg(size).conj().requires_grad_(requires_grad) 1249 1250 yield SampleInput(make_arg((S, )), make_arg((S, ))) 1251 if dtype.is_complex: 1252 # dot/vdot for (conj(input), conj(arg_tensor)) and (conj(input), arg_tensor) 1253 # is tested in test_conj_view (which tests operations with only conjugated input tensor 1254 # -- not conjugated arg tensors) 1255 yield SampleInput(make_arg((S, )), make_arg_conj((S, ))) 1256 1257 1258def error_inputs_dot_vdot(op_info, device, is_ref=False, **kwargs): 1259 make_input = partial(make_tensor, device=device, dtype=torch.float32) 1260 1261 if not is_ref: 1262 yield ErrorInput(SampleInput(make_input(1), args=(make_input(3, dtype=torch.float16),)), 1263 error_regex='dot : expected both vectors to have same dtype') 1264 yield ErrorInput(SampleInput(make_input(1, 1), args=(make_input(3),)), 1265 error_regex='1D tensors expected') 1266 yield ErrorInput(SampleInput(make_input(9), args=(make_input(3),)), 1267 error_regex='inconsistent tensor size') 1268 if device != "cpu" and not is_ref: 1269 yield ErrorInput(SampleInput(make_input(3), args=(make_input(3, device="cpu"),)), 1270 error_regex='Expected all tensors to be on the same device') 1271 1272 1273def sample_inputs_addmv(op_info, device, dtype, requires_grad, **kwargs): 1274 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 1275 1276 test_cases = (((S,), (S, M), (M,), 1, 1, False), 1277 ((S,), (S, M), (M,), 0.2, 0.6, False), 1278 ) 1279 1280 test_cases_with_broadcast = (((1,), (S, M), (M,), 1, 1, True), 1281 ((1,), (S, M), (M,), 0.2, 0.6, True), 1282 ((), (S, M), (M,), 1, 1, True), 1283 ((), (S, M), (M,), 0.2, 0.6, True), 1284 ) 1285 1286 cases = test_cases + test_cases_with_broadcast 1287 1288 # addmv performs: beta * M + alpha * (mat @ vec) 1289 for size, mat, vec, beta, alpha, broadcasts_input in cases: 1290 yield SampleInput(make_arg(size), args=(make_arg(mat), make_arg(vec)), 1291 kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=broadcasts_input) 1292 1293def sample_inputs_addbmm(op_info, device, dtype, requires_grad, **kwargs): 1294 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1295 1296 # input_shape, batch1_shape, batch2_shape, beta_val, alpha_val, is_broadcasting 1297 test_cases = [((S, M), (S, S, S), (S, S, M), 1, 1, False), 1298 ((1,), (S, S, S), (S, S, M), 1, 1, True), 1299 ((S, M), (S, S, S), (S, S, M), 0.6, 0.2, False), 1300 ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True), 1301 ((), (S, S, S), (S, S, M), 1, 1, True), 1302 ((), (S, S, S), (S, S, M), 0.6, 0.2, True), 1303 ] 1304 1305 for input_shape, batch1_shape, batch2_shape, beta, alpha, is_broadcasting in test_cases: 1306 if dtype.is_complex: 1307 beta_complex, alpha_complex = beta * (1 + 2j), alpha * (2 + 3j) 1308 yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)), 1309 kwargs=dict(beta=beta_complex, alpha=alpha_complex), broadcasts_input=is_broadcasting) 1310 yield SampleInput(make_arg(input_shape), args=(make_arg(batch1_shape), make_arg(batch2_shape)), 1311 kwargs=dict(beta=beta, alpha=alpha), broadcasts_input=is_broadcasting) 1312 1313def sample_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs): 1314 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1315 test_cases = [(((S, S), (S, S), (S, S)), False), 1316 (((S, S), (S, 1), (1, S)), False), 1317 (((1,), (S, S, 1), (1, S)), True), 1318 (((), (), ()), False), 1319 (((S, S), (), ()), True), 1320 (((), (S, S, 1), (1, S)), True) 1321 ] 1322 1323 for input_args, broadcasts_input in test_cases: 1324 # addcdiv should accept inputs with zero value 1325 # Currently, it throws ZeroDivisionError when the denominator is zero 1326 # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed 1327 args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg 1328 for arg in input_args) 1329 yield SampleInput(*args).with_metadata(broadcasts_input=broadcasts_input) 1330 1331 # addcdiv should accept inputs with zero value 1332 # Currently, it throws ZeroDivisionError when the denominator is zero 1333 # TODO: exclude_zeros can be removed after https://github.com/pytorch/pytorch/issues/73638 is fixed 1334 args = tuple(make_arg(arg, exclude_zero=True) if isinstance(arg, tuple) else arg 1335 for arg in input_args) 1336 yield SampleInput( 1337 *args, value=3.14 if dtype.is_floating_point or dtype.is_complex else 3 1338 ).with_metadata(broadcasts_input=broadcasts_input) 1339 1340def reference_inputs_addcmul_addcdiv(op_info, device, dtype, requires_grad, **kwargs): 1341 yield from sample_inputs_addcmul_addcdiv( 1342 op_info, device, dtype, requires_grad, **kwargs) 1343 1344 # type promotion cases 1345 supported_dtypes = op_info.supported_dtypes(device) 1346 make_arg = partial(make_tensor, device=device, requires_grad=requires_grad) 1347 1348 types = ( 1349 (torch.float64, torch.complex128), 1350 (torch.bfloat16, torch.float32), 1351 ) 1352 1353 values = ( 1354 None, 1355 True, False, 1356 3.14, 3, 1357 1.0, 1, 1358 0.0, 0, 1359 -3.14, -3, 1360 3.14 + 2.71j, 1361 ) 1362 1363 for (type2, type3), value in product(types, values): 1364 if (type2 not in supported_dtypes or 1365 type3 not in supported_dtypes): 1366 continue 1367 1368 # RuntimeError: value cannot be converted without overflow 1369 if (type(value) is complex and 1370 type2 is not torch.complex128): 1371 continue 1372 1373 arg1 = make_arg([5, 5], dtype=dtype) 1374 arg2 = make_arg([5, 5], dtype=type2) 1375 arg3 = make_arg([1, 5], dtype=type3) 1376 1377 # TypeError: addcdiv(): argument 'value' must be Number, not NoneType 1378 if value is not None: 1379 yield SampleInput(arg1, args=(arg2, arg3), kwargs=dict(value=value)) 1380 else: 1381 yield SampleInput(arg1, args=(arg2, arg3)) 1382 1383def sample_inputs_baddbmm(op_info, device, dtype, requires_grad, **kwargs): 1384 test_cases = [((S, S, M), (S, S, S), (S, S, M), 1, 1, False), 1385 ((1,), (S, S, S), (S, S, M), 1, 1, True), 1386 ((S, S, M), (S, S, S), (S, S, M), 0.6, 0.2, False), 1387 ((1,), (S, S, S), (S, S, M), 0.6, 0.2, True), 1388 ((), (S, S, S), (S, S, M), 1, 1, True), 1389 ((), (S, S, S), (S, S, M), 0.6, 0.2, True), 1390 ] 1391 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) 1392 for (input_shape, batch1_shape, batch2_shape, alpha, beta, broadcasts_input) in test_cases: 1393 yield SampleInput( 1394 make_arg(input_shape), 1395 make_arg(batch1_shape), 1396 make_arg(batch2_shape), 1397 beta=beta, 1398 alpha=alpha 1399 ).with_metadata(broadcasts_input=broadcasts_input) 1400 1401 if dtype.is_complex: 1402 yield SampleInput( 1403 make_arg(input_shape), 1404 make_arg(batch1_shape), 1405 make_arg(batch2_shape), 1406 beta=beta * (1 + 2j), 1407 alpha=alpha * (2 + 3j), 1408 ).with_metadata(broadcasts_input=broadcasts_input) 1409 1410 if dtype.is_complex: 1411 shapes = [(S, S, S), (S, M, S), (S, S, M)] 1412 args = tuple(make_arg(s) for s in shapes) 1413 yield SampleInput( 1414 args[0].transpose_(-1, 1), 1415 args[1].transpose(-1, 1).conj().requires_grad_(requires_grad), 1416 args[2].transpose(-1, 1).conj().requires_grad_(requires_grad), 1417 beta=beta * (1 + 2j), 1418 alpha=alpha * (2 + 3j), 1419 ) 1420 1421# TODO: add reduction kwargs 1422def sample_inputs_multilabel_soft_margin_loss(op_info, device, dtype, requires_grad, **kwargs): 1423 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1424 1425 shapes = ( 1426 (S,), 1427 (S, S), 1428 ) 1429 1430 for shape in shapes: 1431 # Produce one with weight and one without. 1432 yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), kwargs={}) 1433 yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), 1434 kwargs={'weight': _make_tensor(shape, requires_grad=False)}) 1435 1436def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs): 1437 make_arg = partial( 1438 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None 1439 ) 1440 yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M)) 1441 1442 yield SampleInput(make_arg(), make_arg(S), make_arg(M)).with_metadata(broadcasts_input=True) 1443 1444 if dtype.is_complex: 1445 alpha, beta = 0.1 + 0.3j, 0.4 + 0.6j 1446 elif dtype.is_floating_point: 1447 alpha, beta = 0.2, 0.6 1448 else: 1449 alpha, beta = 2, 3 1450 1451 yield SampleInput(make_arg(S, M), make_arg(S), make_arg(M), beta=beta, alpha=alpha) 1452 1453 yield SampleInput( 1454 make_arg(), 1455 make_arg(S), 1456 make_arg(M), 1457 beta=beta, 1458 alpha=alpha, 1459 ).with_metadata(broadcasts_input=True) 1460 1461 # These samples fail gradcheck 1462 if dtype.is_floating_point and not requires_grad: 1463 tensor_options = dict(device=device, dtype=dtype, requires_grad=requires_grad) 1464 yield SampleInput( 1465 torch.tensor([[math.nan]], **tensor_options), 1466 torch.tensor([0.0], **tensor_options), 1467 torch.tensor([0.0], **tensor_options), 1468 beta=0.0, 1469 alpha=0.0, 1470 ).with_metadata(broadcasts_input=True) 1471 1472 yield SampleInput( 1473 torch.tensor([[0.0]], **tensor_options), 1474 torch.tensor([math.nan], **tensor_options), 1475 torch.tensor([math.nan], **tensor_options), 1476 beta=0.0, 1477 alpha=0.0, 1478 ).with_metadata(broadcasts_input=True) 1479 1480def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs): 1481 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1482 1483 cases = ((), (S, S, S), (S,)) 1484 1485 for shape in cases: 1486 yield SampleInput(make_arg(shape)) 1487 1488def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs): 1489 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1490 make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) 1491 make_weight = partial(_make_tensor, requires_grad=False) 1492 1493 inputs = ( 1494 ((), make_target([], low=0, high=1), {}), 1495 ((S,), make_target([], low=0, high=S), {"p": 1}), 1496 ((S,), make_target([1], low=0, high=S), {"p": 2}), 1497 ((S, M), make_target([S], low=0, high=M), {"margin": 1.0}), 1498 ((S, M), make_target([S], low=0, high=M), {"margin": -3.14}), 1499 ((M, S), make_target([M], low=0, high=S), {"weight": None}), 1500 ((M, S), make_target([M], low=0, high=S), {"weight": make_weight([S], low=-10., high=10.)}), 1501 ((M, S), make_target([M], low=0, high=S), {"reduction": "none"}), 1502 ((M, S), make_target([M], low=0, high=S), {"reduction": "mean"}), 1503 ((M, S), make_target([M], low=0, high=S), {"reduction": "sum"}), 1504 ) 1505 1506 for input_shape, target, kwargs in inputs: 1507 yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs) 1508 1509 1510def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs): 1511 yield from sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs) 1512 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1513 make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) 1514 make_weight = partial(_make_tensor, requires_grad=False) 1515 1516 inputs = ( 1517 ((), make_target([], low=0, high=1)), 1518 ((S,), make_target([], low=0, high=S)), 1519 ((S,), make_target([1], low=0, high=S)), 1520 ((M, S), make_target([M], low=0, high=S)), 1521 ) 1522 ps = (1, 2) 1523 margins = (0, 7, -3.14) 1524 weights = (False, True) 1525 reductions = (None, "none", "mean", "sum") 1526 1527 for (input_shape, target), p, margin, weight, reduction in product(inputs, ps, margins, weights, reductions): 1528 input = _make_tensor(input_shape) 1529 weight_shape = [input.size(-1)] if input.ndim > 0 else [1] 1530 weight = make_weight(weight_shape, low=-10., high=10.) if weight else None 1531 kwargs = {"p": p, "margin": margin, "weight": weight} 1532 if reduction is not None: 1533 kwargs["reduction"] = reduction 1534 yield SampleInput(input, args=(target,), kwargs=kwargs) 1535 1536 1537def error_inputs_multi_margin_loss(op, device, **kwargs): 1538 make_input = partial(make_tensor, device=device, dtype=torch.float32) 1539 # invalid reduction 1540 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'reduction': 'abc'}), 1541 error_type=ValueError, error_regex='abc is not a valid value for reduction') 1542 # invalid input 1543 yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5,),), kwargs={}), 1544 error_type=RuntimeError, 1545 error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]') 1546 yield ErrorInput(SampleInput(make_input(0,), args=(make_input(5,),), kwargs={}), 1547 error_type=RuntimeError, 1548 error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]') 1549 # invalid target 1550 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={}), 1551 error_type=RuntimeError, error_regex=r'inconsistent target size, expected 5 but got \[5, 4\]') 1552 # invalid target dtype 1553 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={}), 1554 error_type=RuntimeError, error_regex='expected scalar type Long but found Float') 1555 # invalid weight 1556 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(())}), 1557 error_type=ValueError, error_regex='weight must be one-dimensional') 1558 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5, 4)}), 1559 error_type=ValueError, error_regex='weight must be one-dimensional') 1560 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5,)}), 1561 error_type=RuntimeError, error_regex=r'inconsistent weight size, expected 4 but got \[5\]') 1562 # invalid p 1563 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'p': 3}), 1564 error_type=ValueError, error_regex='only p == 1 and p == 2 supported') 1565 1566 1567def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs): 1568 inputs = ( 1569 ((), (0,), True), 1570 ((S, S), (1,), True), 1571 ((S, S), (1,), False), 1572 ((S, S), (-2,), False), 1573 ((S, S), (0, 1), False), 1574 ) 1575 # Test large inputs to check numerical stability 1576 lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128) else (None,) 1577 for low in lows: 1578 high = low * 2 if low is not None else None 1579 for shape, dim, keepdim in inputs: 1580 t = make_tensor(shape, dtype=dtype, device=device, 1581 low=low, high=high, 1582 requires_grad=requires_grad) 1583 yield SampleInput(t, dim, keepdim) 1584 1585def reference_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs): 1586 yield from sample_inputs_logsumexp(op, device, dtype, requires_grad, **kwargs) 1587 1588 # https://github.com/pytorch/pytorch/issues/91843 1589 t = torch.tensor([20, 30, 100], dtype=dtype, device=device, requires_grad=requires_grad) 1590 yield SampleInput(t, 0, False) 1591 1592 t = torch.tensor((), dtype=dtype, device=device, requires_grad=requires_grad) 1593 yield SampleInput(t, 0, False) 1594 1595 # tests masking 1596 # https://github.com/pytorch/pytorch/pull/91860#pullrequestreview-1241344073 1597 t = torch.tensor(float("inf")) 1598 yield SampleInput(t, 0, True) 1599 1600def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): 1601 inputs = [ 1602 ((), {}), 1603 ((S, S), {}), 1604 ((0, S, 0), {}), 1605 ((S,), {'dtype': dtype, 'device': device}), 1606 # Hard-code some dtypes/devices. We want to test cases where the 1607 # (dtype, device) is different from the input's (dtype, device) 1608 ((S,), {'dtype': torch.double}), 1609 ((S,), {'device': 'cpu'}), 1610 ((S,), {'dtype': torch.double, 'device': 'cpu'}), 1611 ] 1612 if torch.cuda.is_available(): 1613 inputs.append(((S,), {'device': 'cuda'})) 1614 1615 for shape, kwargs in inputs: 1616 t = make_tensor(shape, dtype=dtype, device=device, 1617 low=None, high=None, 1618 requires_grad=requires_grad) 1619 yield SampleInput(t, **kwargs) 1620 1621def reference_inputs_like_fns(op, device, dtype, requires_grad, **kwargs): 1622 yield from sample_inputs_like_fns(op, device, dtype, requires_grad, **kwargs) 1623 1624 # shape 1625 cases = ( 1626 (), (0,), (1, 0), (1, 1, 4, 5), (5, 3, 0, 1), (1, 4, 3, 1, 1) 1627 ) 1628 1629 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 1630 for shape in cases: 1631 yield SampleInput(make_arg(shape)) 1632 yield SampleInput(make_arg(shape).transpose(0, -1)) 1633 yield SampleInput(make_arg(shape, noncontiguous=True)) 1634 yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1)) 1635 1636def sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs): 1637 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1638 make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) 1639 1640 inputs = ( 1641 ([], make_target([], low=0, high=1), {}), 1642 ([S], make_target([S], low=0, high=S), {}), 1643 ([M, S], make_target([M, S], low=0, high=S), {}), 1644 ([M, S], make_target([M, S], low=0, high=S), {"reduction": "none"}), 1645 ([M, S], make_target([M, S], low=0, high=S), {"reduction": "mean"}), 1646 ([M, S], make_target([M, S], low=0, high=S), {"reduction": "sum"}), 1647 ) 1648 1649 for shape, target, kwargs in inputs: 1650 yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs) 1651 1652 1653def reference_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs): 1654 yield from sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs) 1655 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1656 make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False) 1657 make_target_tensor = partial(torch.tensor, device=device, dtype=torch.long, requires_grad=False) 1658 1659 inputs = ( 1660 # random tests including -1 target labels 1661 ([], make_target([], low=-1, high=1)), 1662 ([S], make_target([S], low=-1, high=S)), 1663 ([M, S], make_target([M, S], low=-1, high=S)), 1664 # repeated target labels and -1 (labels after the first -1 are ignored) 1665 ([], make_target_tensor(-1)), 1666 ([7], make_target_tensor([2, 0, 6, -1, 4, -1, 6])), 1667 ([4, 5], make_target_tensor([[4, -1, 0, -1, 2], [0, 0, 4, 1, 4], [-1, 3, -1, 1, 0], [4, 3, 2, 1, 0]])), 1668 ) 1669 reductions = (None, "none", "mean", "sum") 1670 1671 for (shape, target), reduction in product(inputs, reductions): 1672 kwargs = {} 1673 if reduction is not None: 1674 kwargs["reduction"] = reduction 1675 yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs) 1676 1677 1678def error_inputs_multilabel_margin_loss(op, device, **kwargs): 1679 make_input = partial(make_tensor, device=device, dtype=torch.float32) 1680 # invalid reduction 1681 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), 1682 error_type=ValueError, error_regex='abc is not a valid value for reduction') 1683 # invalid input 1684 yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5, 4),), kwargs={}), 1685 error_type=RuntimeError, 1686 error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]') 1687 yield ErrorInput(SampleInput(make_input(0,), args=(make_input(0,),), kwargs={}), 1688 error_type=RuntimeError, 1689 error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]') 1690 # invalid target 1691 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(4,),), kwargs={}), 1692 error_type=RuntimeError, 1693 error_regex=r'inconsistent target size: \[4\] for input of size: \[5, 4\]') 1694 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input((),),), kwargs={}), 1695 error_type=RuntimeError, 1696 error_regex=r'inconsistent target size: \[\] for input of size: \[5, 4\]') 1697 1698 1699def get_independent_tensor(tensor): 1700 return tensor.clone().requires_grad_(tensor.requires_grad) 1701 1702def sample_inputs_randint(self, device, dtype, requires_grad, **kwargs): 1703 low = 2 1704 high = 10 1705 1706 for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): 1707 sample.kwargs.setdefault('device', device) 1708 # With high 1709 yield SampleInput(high, sample.input.shape, *sample.args, **sample.kwargs) 1710 # With low and high 1711 yield SampleInput(low, high, sample.input.shape, *sample.args, **sample.kwargs) 1712 1713def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs): 1714 low = 2 1715 high = 10 1716 1717 for sample in sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): 1718 # With high 1719 yield SampleInput( 1720 sample.input, 1721 high, 1722 *sample.args, 1723 **sample.kwargs) 1724 # With low and high 1725 yield SampleInput( 1726 get_independent_tensor(sample.input), 1727 low, 1728 high, 1729 *sample.args, 1730 **sample.kwargs) 1731 1732def sample_inputs_margin_ranking_loss(op_info, device, dtype, requires_grad, **kwargs): 1733 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1734 1735 shapes = ( 1736 (), 1737 (S,), 1738 (S, S), 1739 (S, S, S), 1740 ) 1741 1742 margins = (0., 1.) 1743 reductions = ('sum', 'mean', 'none') 1744 1745 for shape in shapes: 1746 for margin, reduction in product(margins, reductions): 1747 kwargs = {'margin': margin, 'reduction': reduction} 1748 yield SampleInput(_make_tensor(shape), 1749 args=(_make_tensor(shape, requires_grad=False), 1750 _make_tensor(shape, requires_grad=False)), 1751 kwargs=kwargs) 1752 1753def reference_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs): 1754 yield from sample_inputs_margin_ranking_loss(op, device, dtype, requires_grad, **kwargs) 1755 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 1756 1757 for reduction in ('sum', 'mean', 'none'): 1758 if dtype.is_floating_point: # only supports ints and floats 1759 # NaN propagation 1760 inp1 = make_input((10, )) 1761 inp1[2] = float('nan') 1762 inp2 = make_input((10, )) 1763 inp2[4] = float('nan') 1764 target = make_input((10, )) 1765 inp2[9] = float('nan') 1766 yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) 1767 1768 # Inf handling 1769 inp1 = make_input((10, )) 1770 inp2[1] = float('inf') 1771 inp2 = make_input((10, )) 1772 inp2[4] = float('inf') 1773 target = make_input((10, )) 1774 inp2[7] = float('inf') 1775 yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) 1776 1777 # Broadcasting 1778 inp1 = make_input((5, 2)) 1779 inp2 = make_input((5, 1)) 1780 target = make_input((1, 2)) 1781 yield SampleInput(inp1, args=(inp2, target), kwargs={'reduction': reduction}) 1782 1783def error_inputs_margin_ranking_loss(op, device, **kwargs): 1784 make_input = partial(make_tensor, device=device, dtype=torch.float32) 1785 # invalid reduction value. 1786 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5, 4),), kwargs={'reduction': 'abc'}), 1787 error_type=ValueError, error_regex='is not a valid value') 1788 # invalid input shapes 1789 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4), make_input(5,),)), 1790 error_regex='margin_ranking_loss : All input tensors should') 1791 1792def sample_inputs_new_fns(self, device, dtype, requires_grad, *, is_strided=False, **kwargs): 1793 # input_shape, output_shape, strides, kwargs 1794 # lengths of output_shape and strides must be equal 1795 inputs = [ 1796 ((), (), (), {}), 1797 ((S, S), (2, 0), (3, 4), {}), 1798 ((0, S, 0), (3, 2, 2), (1, 2, 3), {}), 1799 ((S,), (2, 3), (7, 8), {'dtype': dtype, 'device': device}), 1800 # Hard-code some dtypes/devices. We want to test cases where the 1801 # (dtype, device) is different from the input's (dtype, device) 1802 ((S,), (10,), (S,), {'dtype': torch.double}), 1803 ((S,), (1, 1, 12), (S, L, M), {'device': 'cpu'}), 1804 ((S,), (2, 2, 2), (L, M, S), {'dtype': torch.double, 'device': 'cpu'}), 1805 ] 1806 if torch.cuda.is_available(): 1807 inputs.append(((S,), (7, 2), (3, 4), {'device': 'cuda'})) 1808 1809 for input_shape, output_shape, strides, kwargs in inputs: 1810 t = make_tensor(input_shape, dtype=dtype, device=device, 1811 low=None, high=None, 1812 requires_grad=requires_grad) 1813 if is_strided: 1814 yield SampleInput(t, output_shape, strides, **kwargs) 1815 else: 1816 yield SampleInput(t, output_shape, **kwargs) 1817 1818def sample_inputs_empty_strided(op, device, dtype, requires_grad=False, **kwargs): 1819 1820 inputs = [ 1821 ((), (), {'dtype': dtype, 'device': device}), 1822 ((S,), (4,), {'dtype': dtype, 'device': device}), 1823 ((S, S), (2, 1), {'dtype': dtype, 'device': device}), 1824 ((S, S, S), (2, 0, 1), {'dtype': dtype, 'device': device}), 1825 ] 1826 1827 for shape, strides, kwargs in inputs: 1828 yield SampleInput(shape, strides, requires_grad=requires_grad, **kwargs) 1829 1830def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs): 1831 # shape 1832 cases = ( 1833 (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1), 1834 ) 1835 1836 for case in cases: 1837 yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad) 1838 1839def sample_inputs_empty_permuted(op, device, dtype, requires_grad, **kwargs): 1840 # shape 1841 cases = ( 1842 (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1), 1843 ) 1844 1845 for case in cases: 1846 for layout in itertools.permutations(range(len(case))): 1847 yield SampleInput(case, layout, device=device, dtype=dtype, requires_grad=requires_grad) 1848 1849def error_inputs_empty_permuted(op_info, device, **kwargs): 1850 yield ErrorInput( 1851 SampleInput((2,), args=((0, 1),)), 1852 error_type=RuntimeError, 1853 error_regex="Number of dimensions in size does not match the length of the physical_layout" 1854 ) 1855 yield ErrorInput( 1856 SampleInput((2,), args=((3,),)), 1857 error_type=RuntimeError, 1858 error_regex="Dimension out of range" 1859 ) 1860 yield ErrorInput( 1861 SampleInput((2, 3), args=((0, 0),)), 1862 error_type=RuntimeError, 1863 error_regex="Duplicate dim not allowed" 1864 ) 1865 1866def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs): 1867 # Not including a scalar tensor in vals because meta tests start failing due to 1868 # lack of meta support for _local_scalar_dense 1869 # torch.tensor(2, device=device) 1870 vals = (-5, 0, 1) 1871 1872 for item in vals: 1873 yield SampleInput(item, device=device, dtype=dtype, requires_grad=requires_grad) 1874 1875def sample_inputs_eye(op, device, dtype, requires_grad, **kwargs): 1876 # only ints >= 0 are allowed for both arguments, unless m is omitted 1877 sizes = (None, 0, 1, 2, 3, 4, 7, L, M, S) 1878 1879 for n, m in product(sizes, sizes): 1880 if n is None: 1881 continue 1882 1883 # TODO: no layout 1884 _kwargs = {'device': device, 'dtype': dtype, 'requires_grad': requires_grad} 1885 if m is None: 1886 yield SampleInput(n, args=(), kwargs=_kwargs) 1887 else: 1888 yield SampleInput(n, args=(m,), kwargs=_kwargs) 1889 1890def error_inputs_eye(op_info, device, **kwargs): 1891 # TODO: no layout 1892 _kwargs = {'device': device, 'dtype': torch.float32} 1893 1894 yield ErrorInput( 1895 SampleInput(-1, args=(), kwargs=_kwargs), 1896 error_regex="n must be greater or equal to 0, got -1" 1897 ) 1898 1899 yield ErrorInput( 1900 SampleInput(-7, args=(42,), kwargs=_kwargs), 1901 error_regex="n must be greater or equal to 0, got -7" 1902 ) 1903 1904 yield ErrorInput( 1905 SampleInput(0, args=(-3,), kwargs=_kwargs), 1906 error_regex="m must be greater or equal to 0, got -3" 1907 ) 1908 1909 1910def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs): 1911 def get_val(dtype): 1912 return make_tensor([], dtype=dtype, device="cpu").item() 1913 1914 for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs): 1915 # The scalar we are passing to new_full must be the same dtype 1916 # as the one of the resulting tensor 1917 use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype 1918 yield SampleInput( 1919 sample.input, *sample.args, get_val(use_dtype), **sample.kwargs) 1920 1921def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs): 1922 def get_val(dtype): 1923 return make_tensor([], dtype=dtype, device="cpu").item() 1924 1925 inputs = [ 1926 ((), get_val(dtype), {}), 1927 ((S, S), get_val(dtype), {}), 1928 ((0, S, 0), get_val(dtype), {}), 1929 ((S,), get_val(dtype), {'dtype': dtype, 'device': device}), 1930 # Hard-code some dtypes/devices. We want to test cases where the 1931 # (dtype, device) is different from the input's (dtype, device) 1932 ((S,), get_val(torch.double), {'dtype': torch.double}), 1933 ((S,), get_val(dtype), {'device': 'cpu'}), 1934 ((S,), get_val(torch.double), {'dtype': torch.double, 'device': 'cpu'}), 1935 ] 1936 if torch.cuda.is_available(): 1937 inputs.append(((S,), get_val(dtype), {'device': 'cuda'})) 1938 1939 for shape, fill_value, kwargs in inputs: 1940 t = make_tensor(shape, dtype=dtype, device=device, 1941 low=None, high=None, 1942 requires_grad=requires_grad) 1943 yield SampleInput(t, fill_value, **kwargs) 1944 1945def sample_inputs_multinomial(self, device, dtype, requires_grad, **kwargs): 1946 cases = [ 1947 ([3], 3, {}), 1948 ([10], 3, {}), 1949 ([3, 10], 3, {}), 1950 ([3], 3, dict(replacement=False)), 1951 ([3], 3, dict(replacement=True)), 1952 ([3, 4], 4, dict(replacement=True)), 1953 ([3, 4], 4, dict(replacement=False)), 1954 ] 1955 1956 for shape, num_samples, kwargs in cases: 1957 t = make_tensor(shape, dtype=dtype, device=device, 1958 low=0, high=None, 1959 requires_grad=requires_grad) 1960 yield SampleInput(t, num_samples, **kwargs) 1961 1962def sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs): 1963 def get_value_or_make_tensor(value_or_shape): 1964 if isinstance(value_or_shape, list): 1965 return make_tensor(value_or_shape, dtype=dtype, device=device, 1966 low=0, high=None, 1967 requires_grad=requires_grad) 1968 return value_or_shape 1969 1970 for value_or_mean_shape, value_or_std_shape, kwargs in cases: 1971 mean = get_value_or_make_tensor(value_or_mean_shape) 1972 std = get_value_or_make_tensor(value_or_std_shape) 1973 yield SampleInput(mean, std, **kwargs) 1974 1975def sample_inputs_normal_tensor_first(self, device, dtype, requires_grad, **kwargs): 1976 # value_or_size, value_or_size, kwargs 1977 cases = [ 1978 ([], [], {}), 1979 ([3], [3], {}), 1980 ([3, 4, 2], [3, 4, 2], {}), 1981 ([2, 3], 1.1, {}), 1982 ([1, 2, 3], [5, 2, 3], {}), # broadcasting 1983 ] 1984 1985 return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs) 1986 1987def sample_inputs_normal_tensor_second(self, device, dtype, requires_grad, **kwargs): 1988 yield SampleInput(1.6, 0.3, [2, 3], dtype=dtype, device=device) 1989 yield SampleInput(1.6, 0.3, [2, 2, 2], dtype=dtype, layout=torch.strided, device=device) 1990 yield SampleInput(2.7, make_tensor([4, 3], dtype=dtype, device=device, low=0, high=None, requires_grad=requires_grad)) 1991 1992def sample_inputs_bernoulli(self, device, dtype, requires_grad, **kwargs): 1993 shapes = [ 1994 [3], 1995 [], 1996 [0, 3], 1997 [2, 3, 4], 1998 ] 1999 2000 for shape in shapes: 2001 t = make_tensor(shape, dtype=dtype, device=device, 2002 low=0, high=1, 2003 requires_grad=requires_grad) 2004 yield SampleInput(t) 2005 2006def error_inputs_bernoulli(op_info, device, **kwargs): 2007 # more than one element of the written-to tensor refers to a single memory location 2008 x = torch.rand((1,), device=device).expand((6,)) 2009 err_msg = 'unsupported operation' 2010 yield ErrorInput(SampleInput(torch.rand_like(x), kwargs={'out': x}), 2011 error_regex=err_msg) 2012 2013def sample_inputs_logcumsumexp(self, device, dtype, requires_grad, **kwargs): 2014 inputs = ( 2015 ((S, S, S), 0), 2016 ((S, S, S), 1), 2017 ((), 0), 2018 ) 2019 2020 for large_number in (True, False): 2021 for shape, dim in inputs: 2022 t = make_tensor(shape, dtype=dtype, device=device, 2023 low=None, high=None, 2024 requires_grad=requires_grad) 2025 2026 if large_number and t.dim() > 0: 2027 t[0] = 10000 2028 yield SampleInput(t, dim) 2029 2030def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs): 2031 yield SampleInput( 2032 make_tensor((S, S), dtype=dtype, device=device, 2033 low=None, high=None, 2034 requires_grad=requires_grad)) 2035 2036 2037def error_inputs_trace(op, device): 2038 yield ErrorInput(SampleInput(make_tensor((3, 4, 5), dtype=torch.float32, device=device)), error_regex="expected a matrix") 2039 2040 2041def sample_inputs_renorm(self, device, dtype, requires_grad, **kwargs): 2042 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 2043 cases = (((S, S, S), (2, 1, 0.5)), 2044 ((S, S, S), (2, -1, 0.5)), 2045 ((S, S, S), (1, 2, 3)), 2046 ((S, S, S), (float('inf'), 2, 0.5)), 2047 ) 2048 2049 for shape, args in cases: 2050 yield SampleInput(make_arg(shape), args=args) 2051 2052 2053def sample_inputs_transpose_swapdims(self, device, dtype, requires_grad, **kwargs): 2054 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 2055 2056 cases = (((1, 2, 3), (-1, -2)), 2057 ((1, 2, 3), (-1, 2)), 2058 ((1, 2, 3), (1, -2)), 2059 ((1, 2, 3), (1, 2)), 2060 ((), (0, 0)), 2061 ((1, ), (0, 0)), 2062 ((M, M), (0, 1)), 2063 ((S, S, S), (2, 0)), ) 2064 2065 for shape, args in cases: 2066 yield SampleInput(make_arg(shape), args=args) 2067 2068def _numpy_ref_transpose(a, dim0, dim1): 2069 if a.ndim <= 1: 2070 return a 2071 2072 return np.swapaxes(a, dim0, dim1) 2073 2074def sample_inputs_adjoint(self, device, dtype, requires_grad, **kwargs): 2075 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 2076 2077 shapes = ((1, 2, 3), (M, M), (S, S, S), (S, M, S), (M, S, M, S)) 2078 return (SampleInput(make_arg(shape)) for shape in shapes) 2079 2080def sample_inputs_T(self, device, dtype, requires_grad, **kwargs): 2081 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 2082 2083 shapes = ((M, M), (M, L)) 2084 return (SampleInput(make_arg(shape)) for shape in shapes) 2085 2086def error_inputs_T(self, device, has_ndims_error=False): 2087 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 2088 2089 # Deprecated behavior in regular PyTorch, but throws an error in primTorch: 2090 # https://github.com/pytorch/pytorch/issues/86968 2091 if has_ndims_error: 2092 # ndims == 1 2093 yield ErrorInput(SampleInput(make_arg(M)), 2094 error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 ' 2095 r'to reverse their shape is not supported\.')) 2096 2097 # ndims > 2 2098 yield ErrorInput(SampleInput(make_arg(M, S, L)), 2099 error_regex=(r'The use of `x\.T` on tensors of dimension other than 0 or 2 ' 2100 r'to reverse their shape is not supported\.')) 2101 2102 2103def sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad=False): 2104 """ 2105 This function produces two tensors of shape (*, m, k) and (*, n, k) with k <= min(m, n). 2106 Their matrix product could be used to generate tensor of shape (*, m, n) of rank k. 2107 """ 2108 2109 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2110 batches = [(), (2,)] 2111 size = [3, 4] 2112 for batch, m, n in product(batches, size, size): 2113 k = 2 2114 a = make_arg((*batch, m, k)) 2115 b = make_arg((*batch, n, k)) 2116 yield a, b 2117 2118 2119def sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): 2120 # Function that's well defined on the outputs for complex inputs 2121 def fn(usv): 2122 U, S, V = usv 2123 return U @ V.mH, S 2124 2125 for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): 2126 *batch, m, k = a.shape 2127 n = b.shape[-2] 2128 2129 # NOTE: since svd_lowrank relies on non rank-revealing SVD, 2130 # it inherits the problem of unstable behavior with repeated 2131 # singular values including zeros. 2132 # Since we want to avoid (repeated) zeros as singular values, 2133 # we can only use k for q. 2134 # This issues could be resolved with using a rank-revealing SVD 2135 # which does not include "zero" singular values. 2136 yield SampleInput(a, b, q=k, M=None).with_metadata(output_process_fn_grad=fn) 2137 2138 for (a, b) in sample_inputs_singular_matrix_factors(op_info, device, dtype, requires_grad): 2139 *batch, m, k = a.shape 2140 n = b.shape[-2] 2141 M = make_tensor((*batch, m, n), dtype=dtype, device=device, requires_grad=requires_grad) 2142 yield SampleInput(a, b, q=k, M=M).with_metadata(output_process_fn_grad=fn) 2143 2144def chunk_iter(iterable, size): 2145 it = iter(iterable) 2146 while True: 2147 chunk = tuple(islice(it, size)) 2148 if not chunk: 2149 break 2150 yield chunk 2151 2152def sample_inputs_pca_lowrank(op_info, device, dtype, requires_grad=False, **kwargs): 2153 # we reuse samples from svd_lowrank which come in group of two with 2154 # kwarg['M'] = None and with kwarg['M'] = <some tensor> 2155 samples = sample_inputs_svd_lowrank(op_info, device, dtype, requires_grad, **kwargs) 2156 for s1, s2 in chunk_iter(samples, 2): 2157 del s1.kwargs['M'] 2158 del s2.kwargs['M'] 2159 s1.kwargs['center'] = False 2160 s2.kwargs['center'] = True 2161 yield s1 2162 yield s2 2163 2164def np_sinc_with_fp16_as_fp32(x): 2165 # Wraps numpy's sinc function so that fp16 values are promoted to fp32 2166 # before sinc is invoked. Context: numpy's sinc returns NaN when evaluated 2167 # at 0 for fp16. 2168 if x.dtype == np.float16: 2169 return np.sinc(x.astype(np.float32)) 2170 else: 2171 return np.sinc(x) 2172 2173def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad, **kwargs): 2174 test_cases = ( 2175 ((S, 1, 1), (S, S, S)), 2176 ((S, 1, S), (S, S, S)), 2177 ((S, 1), (S, S, S)), 2178 ((1,), (S, S, S)), 2179 ((1, S), (1, 1, S)), 2180 ((), ()), 2181 ((), (1, 3, 2)), 2182 ) 2183 2184 return ( 2185 SampleInput( 2186 make_tensor(size, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad), 2187 shape, 2188 ) for size, shape in test_cases) 2189 2190def sample_inputs_broadcast_tensors(op_info, device, dtype, requires_grad, **kwargs): 2191 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 2192 test_cases: Tuple[tuple] = (((3,), (1, 2, 1), (1, 1), (5, 1, 1),),) 2193 2194 for shape, *other_shapes in test_cases: 2195 yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes)) 2196 2197def reference_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs): 2198 yield from sample_inputs_broadcast_tensors(op, device, dtype, requires_grad, **kwargs) 2199 2200 m = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 2201 n = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True) 2202 2203 cases = ( 2204 ((), (1, 1), (1, 1, 7, 1), (3, 1, 1)), 2205 ((3, 5, 6), (1, 3, 5, 6), (1, 1, 1, 1, 6), (8, 3, 5, 6)) 2206 ) 2207 2208 for a, b, c, d in cases: 2209 yield SampleInput(m(a), args=(m(b), m(c), m(d))) 2210 yield SampleInput(n(a), args=(n(b), n(c), n(d))) 2211 2212def sample_inputs_block_diag(op_info, device, dtype, requires_grad, **kwargs): 2213 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 2214 test_cases: Tuple[tuple] = ( 2215 ((1, S), (2, S), (3, S),), 2216 ((S, 1), (S, 2), (S, 3),), 2217 ((1,), (2,), (3,),), 2218 ((2, S), (S,)) 2219 ) 2220 2221 for shape, *other_shapes in test_cases: 2222 yield SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes)) 2223 # We also want to test mixed complex-non-complex inputs to block_diag 2224 if dtype == torch.complex32 or dtype == torch.complex64: 2225 non_complex_dtype = torch.float32 if dtype == torch.complex32 else torch.float64 2226 make_arg_non_complex = partial(make_tensor, dtype=non_complex_dtype, device=device, requires_grad=requires_grad) 2227 yield SampleInput(make_arg_non_complex(shape), args=tuple(make_arg(s) for s in other_shapes)) 2228 2229def sample_inputs_cdist(op_info, device, dtype, requires_grad, **kwargs): 2230 small_S = 2 2231 test_cases = ( 2232 ((S, S, 2), (S, S + 1, 2)), 2233 ((S, S), (S, S)), 2234 ((S, S, S), (S, S, S)), 2235 ((3, 5), (3, 5)), 2236 ((2, 3, 5), (2, 3, 5)), 2237 ((1, 2, 3), (1, 2, 3)), 2238 ((1, 1), (S, 1)), 2239 ((0, 5), (4, 5)), 2240 ((4, 5), (0, 5)), 2241 ((0, 4, 5), (3, 5)), 2242 ((4, 5), (0, 3, 5)), 2243 ((0, 4, 5), (1, 3, 5)), 2244 ((1, 4, 5), (0, 3, 5)), 2245 # Using S here would make this one test take 9s 2246 ((small_S, small_S, small_S + 1, 2), (small_S, small_S, small_S + 2, 2)), 2247 ((small_S, 1, 1, small_S), (1, small_S, small_S)), 2248 ((1, 1, small_S), (small_S, 1, small_S, small_S)), 2249 ) 2250 2251 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2252 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 2253 # FIXME add an override for JIT and revert 0. back to 0 2254 # since it's accepted by eager 2255 for p in [0., 1., 2., 3., 0.5, 1.5, 2.5, float("inf")]: 2256 for t1_size, t2_size in test_cases: 2257 # The args should never be non-contiguous as this is not supported in the backward 2258 yield SampleInput(make_arg(t1_size), make_arg(t2_size), p, cm) 2259 2260def _fill_np(a, value): 2261 a = a.copy() 2262 a.fill(value) 2263 return a 2264 2265def _fill_sample_kwargs(device, dtype, input): 2266 if dtype is torch.bool: 2267 value = True 2268 else: 2269 value = 3 2270 2271 return ({'value': value}, {'value': value}) 2272 2273def sample_inputs_comparison_ops(op, device, dtype, requires_grad, **kwargs): 2274 yield from sample_inputs_elementwise_binary(op, device, dtype, requires_grad, **kwargs) 2275 2276 # Adds a sample input where both tensors have the same values 2277 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2278 2279 lhs = make_arg((S, S)) 2280 yield SampleInput(lhs, args=(lhs.clone(),)) 2281 2282def sample_inputs_stack(op_info, device, dtype, requires_grad, **kwargs): 2283 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2284 2285 # shape x number of tensors 2286 cases = ( 2287 ((3, 4), 1), 2288 ((1, 2, 1, 4), 3), 2289 ((0, 1, 0), 2),) 2290 2291 for shape, num_tensors in cases: 2292 tensors = [] 2293 for _ in range(num_tensors): 2294 tensors.append(make_arg(shape)) 2295 for dim in range(-1, len(shape) - 1): 2296 yield SampleInput(tensors, args=(dim,)) 2297 2298 2299def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs): 2300 # 1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. 2301 # If all input tensors have the same ndims, we support both negative and non-negative dim. 2302 # 2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. 2303 # No requirements for (wrapped_dim, ...)-th dimension. 2304 # 3. Expect positive num_chunks 2305 # 4. Expect non-empty input tensor list and each input tensor should have at least 1 element 2306 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2307 same_ndim_cases = ( 2308 ( 2309 [ 2310 torch.Size([1, 2, 3]), 2311 torch.Size([1, 2, 3]), 2312 ], -1, 5 2313 ), 2314 ( 2315 [ 2316 torch.Size([1, 2, 129]), 2317 torch.Size([1, 2, 297]), 2318 ], -1, 5 2319 ), 2320 ( 2321 [ 2322 torch.Size([1, 2, 3]), 2323 torch.Size([1, 2, 3]), 2324 ], 1, 5 2325 ), 2326 ( 2327 [ 2328 torch.Size([3, 3, 2, 1]), 2329 torch.Size([1, 4, 2, 2]), 2330 torch.Size([2, 1, 3, 3]), 2331 ], 0, 2 2332 ), 2333 ) 2334 for sizes, dim, num_chunks in same_ndim_cases: 2335 tensors = [] 2336 for size in sizes: 2337 tensors.append(make_arg(size)) 2338 yield SampleInput(tensors, args=(dim, num_chunks)) 2339 2340 different_ndim_case = [ 2341 torch.Size([2, 3, 3]), 2342 torch.Size([2, 3, 1, 2]), 2343 torch.Size([2, 3]), 2344 torch.Size([2, 3, 2]), 2345 torch.Size([2, 3, 271]), 2346 ] 2347 max_dim, num_chunks = 2, 3 2348 for dim in range(max_dim): 2349 tensors = [] 2350 for size in different_ndim_case: 2351 tensors.append(make_arg(size)) 2352 yield SampleInput(tensors, args=(dim, num_chunks)) 2353 2354 2355def error_inputs_chunk_cat(op_info, device, **kwargs): 2356 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 2357 2358 # input tensors have different ndims but dim is negative 2359 sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], -1, 3 2360 tensors = [make_arg(size) for size in sizes] 2361 yield ErrorInput( 2362 SampleInput(tensors, args=(dim, num_chunks)), 2363 error_regex='_chunk_cat expects non-negative dim when input tensors have different ndims', 2364 ) 2365 2366 # input tensors have different ndims but dim >= ndim of some input tensors 2367 sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], 1, 3 2368 tensors = [make_arg(size) for size in sizes] 2369 yield ErrorInput( 2370 SampleInput(tensors, args=(dim, num_chunks)), 2371 error_regex='_chunk_cat expects dim < ndim for all input tensors', 2372 ) 2373 2374 # some tensors have different sizes for 0, ..., dim-1 dimensions. 2375 sizes, dim, num_chunks = [torch.Size([2, 3, 4]), torch.Size([4, 3])], 1, 3 2376 tensors = [make_arg(size) for size in sizes] 2377 yield ErrorInput( 2378 SampleInput(tensors, args=(dim, num_chunks)), 2379 error_regex='_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors', 2380 ) 2381 2382 # negative num_chunks 2383 sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, -1 2384 tensors = [make_arg(size) for size in sizes] 2385 yield ErrorInput( 2386 SampleInput(tensors, args=(dim, num_chunks)), 2387 error_regex='_chunk_cat expects positive num_chunks', 2388 ) 2389 2390 # zero as num_chunks 2391 sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, 0 2392 tensors = [make_arg(size) for size in sizes] 2393 yield ErrorInput( 2394 SampleInput(tensors, args=(dim, num_chunks)), 2395 error_regex='_chunk_cat expects positive num_chunks', 2396 ) 2397 2398 # empty input tensor list 2399 dim, num_chunks = 0, 1 2400 yield ErrorInput( 2401 SampleInput([], args=(dim, num_chunks)), 2402 error_regex='_chunk_cat expects a non-empty input tensor list', 2403 ) 2404 2405 # empty input tensor with 0 elements 2406 sizes, dim, num_chunks = [torch.Size([0,]), torch.Size([3,])], 0, 1 2407 tensors = [make_arg(size) for size in sizes] 2408 yield ErrorInput( 2409 SampleInput(tensors, args=(dim, num_chunks)), 2410 error_regex='_chunk_cat expects non-empty tensor', 2411 ) 2412 2413 2414def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs): 2415 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2416 2417 cases: Tuple[tuple, tuple, dict] = ( # type: ignore[assignment] 2418 ((S, S), (S, S), {'dim': -1}), 2419 ((S, S), (S, S), {'dim': 1}), 2420 ((M, S), (S, S), {'dim': 0}), # different shapes 2421 ((1, 2, 3), (1, 2, 3), {'dim': -2}), 2422 ((0,), (0,), {'dim': 0}), # empty tensor 2423 ((0,), (S, S), {'dim': 1}), # empty tensor with unempty and dim=1 (special case for legacy_cat_wrap_dim) 2424 ((0, S), (S, S), {'dim': 0}), 2425 ((1,), (1,), {}) # dim not passed, fallback to default 2426 ) 2427 2428 for input_shape1, input_shape2, kwargs in cases: 2429 yield SampleInput([make_arg(input_shape1), make_arg(input_shape2)], kwargs=kwargs) 2430 2431 # from coat_lite_mini 2432 yield SampleInput([make_arg((2, 2, 2, 2), memory_format=torch.channels_last)], args=(1,),) 2433 2434def error_inputs_cat(op_info, device, **kwargs): 2435 2436 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 2437 2438 # error inputs for more than one element of the written-to tensor refer to a single memory location 2439 yield ErrorInput(SampleInput([make_arg((S, S)), make_arg((S, S))], 2440 kwargs={'out': make_arg((1, S)).expand((2 * S, S))}), 2441 error_regex='unsupported operation') 2442 2443 # error inputs for empty tensors 2444 yield ErrorInput(SampleInput([], kwargs={'dim': 1}), 2445 error_regex='non-empty list of Tensors') 2446 2447 # error inputs for different sizes 2448 yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}), 2449 error_regex='Sizes of tensors must match except in dimension') 2450 yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S, S, L, L))], kwargs={'dim': 1}), 2451 error_regex='Sizes of tensors must match except in dimension') 2452 2453 # error inputs for different dimensions 2454 yield ErrorInput(SampleInput([make_arg((S - 1, 0)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}), 2455 error_regex='Tensors must have same number of dimensions') 2456 yield ErrorInput(SampleInput([make_arg((S, 0, L - 1, L)), make_arg((S - 1, 0))], kwargs={'dim': 1}), 2457 error_regex='Tensors must have same number of dimensions') 2458 2459 # error inputs for same memory locations 2460 x = torch.zeros((0), device=device) 2461 y = torch.randn((4, 6), device=device) 2462 2463 err_msg = "the written-to tensor refer to a single memory location" 2464 2465 yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': x}), 2466 error_regex=err_msg) 2467 yield ErrorInput(SampleInput((x, y), kwargs={'dim': 0, 'out': y}), 2468 error_regex=err_msg) 2469 2470 z = torch.zeros((4, 6), device=device) 2471 yield ErrorInput(SampleInput((y, z), kwargs={'out': z[:2, :]}), 2472 error_regex=err_msg) 2473 2474 # error inputs for different devices 2475 if torch.device(device).type == 'cuda': 2476 x_cuda = make_tensor((3, 3), device=device, dtype=torch.float32) 2477 y_cpu = make_tensor((3, 3), device='cpu', dtype=torch.float32) 2478 yield ErrorInput(SampleInput((x_cuda, y_cpu)), 2479 error_regex='Expected all tensors to be on the same device') 2480 2481 # error inputs for different input sizes for more than 2 tensors 2482 yield ErrorInput(SampleInput([make_arg((L, 1)), make_arg((L, 1, 1)), make_arg((L, 1, 1))]), 2483 error_regex='Tensors must have same number of dimensions') 2484 2485 yield ErrorInput(SampleInput([make_arg((S, 1, M)), make_arg((S, 1, 1)), make_arg((S, M, 1))], 2486 kwargs={'dim': 1}), 2487 error_regex='Sizes of tensors must match') 2488 2489 # error inputs for None input 2490 yield ErrorInput(SampleInput((make_arg((S, 1, 1)), None)), error_type=TypeError, 2491 error_regex='got None') 2492 2493 # error inputs for zero-dimensional tensors 2494 yield ErrorInput(SampleInput([make_arg(()), make_arg(())]), 2495 error_regex='zero-dimensional.*cannot be concatenated') 2496 2497 # error inputs for different dtype of out tensors 2498 d = make_tensor((2, 3), device=device, dtype=torch.double) 2499 x = make_tensor((2, 3), device=device, dtype=torch.float32) 2500 yield ErrorInput(SampleInput(x, kwargs={'out': d}), error_type=TypeError, 2501 error_regex='invalid combination of arguments') 2502 2503def reference_inputs_cat(op, device, dtype, requires_grad, **kwargs): 2504 yield from sample_inputs_cat_concat(op, device, dtype, requires_grad, **kwargs) 2505 2506 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2507 2508 # Noncontiguous type promoting tensors 2509 a = make_arg((3, 4, 2)) 2510 b = make_arg((3, 2, 2), noncontiguous=True, dtype=torch.double) 2511 c = make_arg((3, 3, 2), dtype=torch.float16).permute(1, 0, 2) 2512 2513 yield SampleInput((a, b, c), kwargs={'dim': 1}) 2514 2515 # Special 1D tensor with dim length of 0 case 2516 a = make_arg((0,)) 2517 b = make_arg((3, 2, 2)) 2518 2519 yield SampleInput((a, b, a)) 2520 yield SampleInput((a, a, a)) 2521 2522def _elementwise_type_promo_np(*args, type_promotion_kind): 2523 def _maybe_torch(x): 2524 if isinstance(x, np.ndarray): 2525 return torch.from_numpy(x) 2526 return x 2527 2528 flattened = pytree.arg_tree_leaves(*args) 2529 transformed = tuple(_maybe_torch(a) for a in flattened) 2530 result_dtype, _ = prims.utils.elementwise_dtypes( 2531 *transformed, 2532 type_promotion_kind=type_promotion_kind) 2533 return torch_to_numpy_dtype_dict[result_dtype] 2534 2535def _cat_np(input_seq, dim=0): 2536 inputs = tuple(a for a in input_seq if not (a.ndim == 1 and a.size == 0)) 2537 2538 if len(inputs) == 0: 2539 np_dtype = _elementwise_type_promo_np( 2540 input_seq, 2541 type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH) 2542 return np.empty(0, dtype=np_dtype) 2543 2544 return np.concatenate(inputs, axis=dim) 2545 2546def _floor_divide_np(a, b): 2547 dtype = _elementwise_type_promo_np( 2548 a, 2549 b, 2550 type_promotion_kind=prims.utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) 2551 if isinstance(a, np.ndarray): 2552 a = a.astype(dtype) 2553 if isinstance(b, np.ndarray): 2554 b = b.astype(dtype) 2555 return np.floor_divide(a, b) 2556 2557def sample_inputs_hstack_dstack_vstack(op_info, device, dtype, requires_grad, **kwargs): 2558 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 2559 tensor_shapes = ( 2560 # First Tensor being 1-D is special 2561 # case for hstack 2562 ((S,), (S,), (S,)), 2563 ((S, S), (S, S), (S, S)), 2564 ) 2565 for s1, s2, s3 in tensor_shapes: 2566 tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3)) 2567 yield SampleInput(tensors) 2568 2569def error_inputs_hstack_dstack_vstack(op, device): 2570 make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False) 2571 tensor_shapes = ( 2572 ((S,), (S, S, S, S), (S,)), 2573 ) 2574 for s1, s2, s3 in tensor_shapes: 2575 tensors = (make_arg(s1,), make_arg(s2,), make_arg(s3)) 2576 # Different dimension tensor 2577 yield ErrorInput(SampleInput(tensors), error_regex="Tensors must have same number of dimensions") 2578 2579 # empty tensor list 2580 yield ErrorInput(SampleInput(()), error_regex="expects a non-empty TensorList") 2581 2582def sample_inputs_unbind(op_info, device, dtype, requires_grad, **kwargs): 2583 # Note: we don't do any tests where we unbind along 0-length dims 2584 # because in that case unbind returns and empty tuple, and that breaks 2585 # some assumptions in some backward tests in test_ops.py 2586 shape_dims = (((S,), 0), 2587 ((S, S), 0), 2588 ((S, S), 1), 2589 ((S, S), -1), 2590 ((S, 0, S), 0), 2591 ((S, S, S), 1), 2592 ) 2593 for shape, dim in shape_dims: 2594 yield SampleInput(make_tensor(shape, dtype=dtype, device=device, 2595 requires_grad=requires_grad), 2596 args=(dim,)) 2597 2598def error_inputs_unbind(op_info, device): 2599 make_arg = partial(make_tensor, dtype=torch.int32, device=device, requires_grad=False) 2600 yield ErrorInput(SampleInput(make_arg(()), args=(0,)), error_type=IndexError, 2601 error_regex="Dimension specified as 0 but tensor has no dimensions") 2602 yield ErrorInput(SampleInput(make_arg((2,)), args=(2,)), error_type=IndexError, 2603 error_regex="Dimension out of range") 2604 2605def reference_unbind(t, dim): 2606 """A numpy implementation of torch.unbind""" 2607 return tuple(s.squeeze(dim) for s in np.split(t, t.shape[dim], dim)) 2608 2609def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs): 2610 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) 2611 yield SampleInput( 2612 make_arg((M, S)), 2613 0, 2614 gather_variable((S, S), 1, M, True, device=device)) 2615 yield SampleInput( 2616 make_arg((M, S)), 2617 1, 2618 gather_variable((M, S // 2), 0, S, True, device=device)) 2619 # Empty index tensor case, see: https://github.com/pytorch/pytorch/pull/65006 2620 yield SampleInput( 2621 make_arg((S,)), 2622 0, 2623 torch.tensor([], dtype=torch.uint8, device=device)) 2624 yield SampleInput( 2625 make_arg((S,)), 2626 0, 2627 torch.tensor([[], []], dtype=torch.uint8, device=device)) 2628 # 0D tensor case 2629 yield SampleInput( 2630 make_arg(()), 2631 0, 2632 torch.tensor([0], dtype=torch.int64, device=device)) 2633 yield SampleInput( 2634 make_arg(()), 2635 0, 2636 torch.tensor(0, dtype=torch.int64, device=device)) 2637 2638def _fill_indices(idx, dim, dim_size, elems_per_row, m, n, o): 2639 for i in range(1 if dim == 0 else m): 2640 for j in range(1 if dim == 1 else n): 2641 for k in range(1 if dim == 2 else o): 2642 ii = [i, j, k] 2643 ii[dim] = slice(0, idx.size(dim) + 1) 2644 idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row] 2645 2646def error_inputs_gather(op_info, device, **kwargs): 2647 # src is [1, 2] 2648 # [3, 4] 2649 src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) 2650 2651 # idx is [0, 0] 2652 # [1, 0] 2653 idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) 2654 2655 # Index should be smaller than self except on dimension 1 2656 bad_src = make_tensor((1, 1), device=device, dtype=torch.float32) 2657 yield ErrorInput(SampleInput(bad_src, args=(1, idx,)), 2658 error_regex="Size does not match at dimension 0") 2659 2660 # Index must have long dtype 2661 bad_idx = idx.to(torch.int32) 2662 yield ErrorInput(SampleInput(src, args=(1, bad_idx)), 2663 error_regex="Expected dtype int64 for index") 2664 2665 # TODO: FIXME 2666 # out.dtype must match src.dtype 2667 # Creates new src & idx since SampleInputs can't share tensors 2668 src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) 2669 idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) 2670 out = torch.empty((2, 2), device=device, dtype=torch.float64) 2671 yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}), 2672 error_regex="Expected out tensor to have dtype") 2673 2674 # src and index tensors must have the same # of dimensions 2675 # idx too few dimensions 2676 src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) 2677 idx = torch.tensor((0, 0), device=device, dtype=torch.long) 2678 yield ErrorInput(SampleInput(src, args=(1, idx)), 2679 error_regex="Index tensor must have the same number of dimensions") 2680 2681 # src too few dimensions 2682 src = torch.tensor((1, 2), device=device, dtype=torch.float32) 2683 idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) 2684 yield ErrorInput(SampleInput(src, args=(0, idx)), 2685 error_regex="Index tensor must have the same number of dimensions") 2686 2687 # index out of bounds 2688 # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices 2689 if torch.device(device).type == 'cpu': 2690 src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) 2691 idx = torch.tensor(((0, 23), (1, 0)), device=device, dtype=torch.long) 2692 yield ErrorInput(SampleInput(src, args=(1, idx,)), 2693 error_regex="index 23 is out of bounds for dimension") 2694 2695 x = torch.rand((1,), device=device).expand((3,)) 2696 src = torch.rand((6,), device=device) 2697 ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) 2698 2699 yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=x)), 2700 error_type=RuntimeError, 2701 error_regex='unsupported operation') 2702 2703 yield ErrorInput(SampleInput(src, args=(0, ind,), kwargs=dict(out=src)), 2704 error_type=RuntimeError, 2705 error_regex='unsupported operation') 2706 2707 yield ErrorInput(SampleInput(ind.clone(), args=(0, ind[1:],), kwargs=dict(out=ind[:1])), 2708 error_type=RuntimeError, 2709 error_regex='unsupported operation') 2710 2711def error_inputs_take(op_info, device, **kwargs): 2712 x = torch.rand((1,), device=device).expand((3,)) 2713 src = torch.rand((6,), device=device) 2714 ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) 2715 2716 yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=x)), 2717 error_type=RuntimeError, 2718 error_regex='unsupported operation') 2719 2720 yield ErrorInput(SampleInput(src, args=(ind,), kwargs=dict(out=src)), 2721 error_type=RuntimeError, 2722 error_regex='unsupported operation') 2723 2724 yield ErrorInput(SampleInput(ind.clone(), args=(ind[1:],), kwargs=dict(out=ind[:-1])), 2725 error_type=RuntimeError, 2726 error_regex='unsupported operation') 2727 2728# Error inputs for scatter 2729def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs): 2730 # Error when self.dtype != src.dtype (and src is not a scalar) 2731 src = make_tensor((2, 5), device=device, dtype=torch.float32) 2732 idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long) 2733 dst = torch.zeros((3, 5), device=device, dtype=torch.double) 2734 yield ErrorInput(SampleInput(dst, args=(0, idx, src)), 2735 error_regex="Expected self.dtype to be equal to src.dtype") 2736 2737 # Index dtype must be long 2738 src = make_tensor((2, 5), device=device, dtype=torch.float32) 2739 idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.int32) 2740 dst = torch.zeros((3, 5), device=device, dtype=torch.float32) 2741 yield ErrorInput(SampleInput(dst, args=(0, idx, src)), 2742 error_regex="Expected dtype int64 for index") 2743 2744 # Index and destination must have the same number of dimensions 2745 src = make_tensor((2, 5), device=device, dtype=torch.float32) 2746 idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long) 2747 dst = torch.zeros((3, 5, 3), device=device, dtype=torch.float32) 2748 yield ErrorInput(SampleInput(dst, args=(0, idx, src)), 2749 error_regex="Index tensor must have the same number of dimensions as self tensor") 2750 2751 # Index and src must have the same number of dimensions when src is not a scalar 2752 src = make_tensor((2, 5, 2), device=device, dtype=torch.float32) 2753 idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long) 2754 dst = torch.zeros((3, 5), device=device, dtype=torch.float32) 2755 yield ErrorInput(SampleInput(dst, args=(0, idx, src)), 2756 error_regex="Index tensor must have the same number of dimensions as src tensor") 2757 2758 # Index out of bounds 2759 # NOTE: this ErrorInput is guarded because bounds checking does not occur on CUDA devices 2760 if torch.device(device).type == 'cpu': 2761 src = make_tensor((2, 5), device=device, dtype=torch.float32) 2762 idx = torch.tensor(((34, 1), (1, 2)), device=device, dtype=torch.long) 2763 dst = torch.zeros((3, 5), device=device, dtype=torch.float32) 2764 yield ErrorInput(SampleInput(dst, args=(0, idx, src)), 2765 error_regex="index 34 is out of bounds for dimension 0 with size 3") 2766 2767def error_inputs_renorm(op_info, device, **kwargs): 2768 zero_d = torch.randn((), device=device) 2769 yield ErrorInput(SampleInput(zero_d, args=(0.5, 0, 1.0)), error_type=RuntimeError, 2770 error_regex="needs at least 2 dimensions, got 0 dimensions") 2771 2772 2773def error_inputs_ormqr(op_info, device, **kwargs): 2774 zero_d = torch.randn((), device=device) 2775 yield ErrorInput(SampleInput(zero_d, args=(zero_d, zero_d)), error_type=RuntimeError, 2776 error_regex="input must have at least 2 dimensions") 2777 2778 # https://github.com/pytorch/pytorch/issues/85218 2779 tensor_0 = torch.full((5, 0,), 1, device=device) 2780 tensor_1 = torch.full((5,), 1, device=device) 2781 tensor_2 = torch.full((5, 5,), 1, device=device) 2782 bool_3 = True 2783 bool_4 = True 2784 yield ErrorInput(SampleInput(tensor_0, args=(tensor_1, tensor_2, bool_3, bool_4)), error_type=RuntimeError, 2785 error_regex=r"tau.shape\[-1\] must be less than or equal to input.shape\[-1\]") 2786 2787 2788def error_inputs_diag(op_info, device, **kwargs): 2789 zero_d = torch.randn((), device=device) 2790 yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError, 2791 error_regex="1D or 2D") 2792 zero_d = torch.randn(1, 1, 1, device=device) 2793 yield ErrorInput(SampleInput(zero_d, args=(0,)), error_type=RuntimeError, 2794 error_regex="1D or 2D") 2795 2796def error_inputs_embedding(op_info, device, **kwargs): 2797 indices = torch.rand(2, 2, device=device).long() 2798 weights = [ 2799 torch.tensor(1.0, device=device), 2800 torch.tensor(1.0, device=device).reshape(1, 1, 1), 2801 ] 2802 2803 for weight in weights: 2804 yield ErrorInput(SampleInput(weight, args=(indices,)), error_type=RuntimeError, 2805 error_regex="'weight' must be 2-D") 2806 2807 2808def error_inputs_t(op_info, device, **kwargs): 2809 yield ErrorInput( 2810 SampleInput(torch.randn(2, 3, 4, 5, device=device)), 2811 error_regex="expects a tensor with <= 2", 2812 ) 2813 2814 2815def error_inputs_multinomial(op_info, device, **kwargs): 2816 x = torch.empty(1, 2, 3, dtype=torch.double, device=device) 2817 yield ErrorInput(SampleInput(x, args=(2,)), 2818 error_regex="prob_dist must be 1 or 2 dim") 2819 2820 x = torch.empty(1, 2, dtype=torch.long, device=device) 2821 yield ErrorInput(SampleInput(x, args=(2,)), 2822 error_regex="multinomial only supports floating-point dtypes for input") 2823 2824 x = torch.empty(1, 2, dtype=torch.double, device=device) 2825 y = torch.empty(1, 2, dtype=torch.double, device=device) 2826 yield ErrorInput(SampleInput(x, args=(2,), kwargs=dict(out=y)), 2827 error_regex="multinomial expects Long tensor out") 2828 2829 x = torch.empty(2, dtype=torch.double, device=device) 2830 yield ErrorInput(SampleInput(x, args=(0,)), 2831 error_regex="cannot sample n_sample <= 0 samples") 2832 2833 x = torch.empty(2, dtype=torch.double, device=device) 2834 yield ErrorInput(SampleInput(x, args=(-1,)), 2835 error_regex="cannot sample n_sample <= 0 samples") 2836 2837 x = torch.empty(2, dtype=torch.double, device=device) 2838 yield ErrorInput(SampleInput(x, args=(3, False,)), 2839 error_regex="cannot sample n_sample > prob_dist") 2840 2841 x = torch.empty(16777217, dtype=torch.double, device=device) 2842 yield ErrorInput(SampleInput(x, args=(3,)), 2843 error_regex="number of categories cannot exceed") 2844 2845 inputs = ((1., -1., 1.), (1., inf, 1.), (1., -inf, 1.), (1., 1., nan)) 2846 2847 err_msg1 = "probability tensor contains either `inf`, `nan` or element < 0" 2848 err_msg2 = "invalid multinomial distribution" 2849 2850 rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,) 2851 2852 if torch.device(device).type == 'cpu': 2853 for rep in rep_arg: 2854 kwargs = {'num_samples': 2, 'replacement': rep} 2855 2856 for shape in inputs: 2857 # error case when input tensor contains `inf`, `nan` or negative element 2858 yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs), 2859 error_regex=err_msg1 if rep is False else err_msg2) 2860 2861 # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input 2862 x = torch.zeros(3, device=device) 2863 yield ErrorInput(SampleInput(x, kwargs=kwargs), 2864 error_regex=err_msg2) 2865 2866 # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input 2867 x = torch.zeros(3, 3, device=device) 2868 yield ErrorInput(SampleInput(x, kwargs=kwargs), 2869 error_regex=err_msg2) 2870 2871 # error case for the invalid multinomial distribution 2872 x[1, :] = 1 2873 yield ErrorInput(SampleInput(x, kwargs=kwargs), 2874 error_regex=err_msg2) 2875 2876def error_inputs_gradient(op_info, device, **kwargs): 2877 for dtype in [torch.long, torch.float32, torch.complex64]: 2878 t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device, dtype=dtype) 2879 2880 dim = (1, 0) 2881 spacing = [0.1] 2882 yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)), 2883 error_type=RuntimeError, 2884 error_regex='torch.gradient expected spacing to be unspecified, a scalar ') 2885 2886 yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=3)), 2887 error_type=RuntimeError, 2888 error_regex='torch.gradient only supports edge_order=1 and edge_order=2.') 2889 2890 dim = (1, 1) 2891 spacing = 0.1 2892 yield ErrorInput(SampleInput(t, kwargs=dict(spacing=spacing, dim=dim, edge_order=1)), 2893 error_type=RuntimeError, 2894 error_regex='dim 1 appears multiple times in the list of dims') 2895 2896 dim = (0, 1) 2897 coordinates = [torch.tensor([1, 2, 4], device='cpu'), torch.tensor([1, 2, 4], device='meta')] 2898 yield ErrorInput(SampleInput(t, kwargs=dict(spacing=coordinates, dim=dim, edge_order=1)), 2899 error_type=RuntimeError, 2900 error_regex='torch.gradient expected each tensor to be on the same device,') 2901 2902 yield ErrorInput(SampleInput(t, kwargs=dict(dim=3)), 2903 error_type=IndexError, error_regex='') 2904 2905 t = torch.tensor([[1], [2], [3]]) 2906 yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=1)), 2907 error_type=RuntimeError, 2908 error_regex='torch.gradient expected each dimension size to be at least') 2909 2910 t = torch.tensor([[1, 2], [3, 4]]) 2911 yield ErrorInput(SampleInput(t, kwargs=dict(edge_order=2)), 2912 error_type=RuntimeError, 2913 error_regex='torch.gradient expected each dimension size to be at least') 2914 2915def sample_inputs_rrelu(op_info, device, dtype, requires_grad, **kwargs): 2916 yield from sample_inputs_elementwise_unary( 2917 op_info, device, dtype, requires_grad, op_kwargs=dict(lower=0., upper=1., training=True)) 2918 2919 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 2920 yield SampleInput(make_arg(S)) 2921 yield SampleInput(make_arg(S), training=False) 2922 2923def error_inputs_rrelu(op_info, device, **kwargs): 2924 input = make_tensor((S, S), device=device, dtype=torch.float32) 2925 yield ErrorInput(SampleInput(input, kwargs={'lower': 0.3, 'upper': 0.1}), 2926 error_regex='Lower bound should be less than or equal to the upper bound') 2927 2928def error_inputs_masked_select(op_info, device, **kwargs): 2929 x = torch.rand((1,), device=device).expand((3,)) 2930 y = torch.rand((6,), device=device) 2931 mask = torch.tensor([True, False, True, True, False, False], device=device) 2932 2933 yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=x)), 2934 error_type=RuntimeError, 2935 error_regex='unsupported operation') 2936 2937 yield ErrorInput(SampleInput(y, args=(mask,), kwargs=dict(out=y)), 2938 error_type=RuntimeError, 2939 error_regex='unsupported operation') 2940 2941 yield ErrorInput(SampleInput(mask.clone(), args=(mask,), kwargs=dict(out=mask)), 2942 error_type=RuntimeError, 2943 error_regex='unsupported operation') 2944 2945def error_inputs_median(op_info, device, **kwargs): 2946 x = torch.tensor([[[[[[[[[[[[[[[[[[[[[[[[[nan], 2947 [nan]]]]]]]]]]]]]]]]]]]]]]]]], device=device) 2948 if device == 'cuda': 2949 yield ErrorInput(SampleInput(x, kwargs=dict(dim=(-1))), 2950 error_type=RuntimeError, 2951 error_regex='CUDA Tensors cannot have more than 25 dimensions') 2952 else: 2953 return 2954 2955 2956def error_inputs_index_select(op_info, device, **kwargs): 2957 x = torch.rand((1, 6), device=device).expand((2, 6)) 2958 y = torch.rand((3, 6), device=device) 2959 ind = torch.tensor([0, 1], dtype=torch.int64, device=device) 2960 2961 yield ErrorInput(SampleInput(y, args=(1, ind,), kwargs=dict(out=x)), 2962 error_type=RuntimeError, 2963 error_regex='unsupported operation') 2964 2965def error_inputs_index_add(op_info, device, **kwargs): 2966 result = torch.tensor([[1., 2.], [4., 5.], [7., 8.]]) 2967 source = torch.tensor([2., 4.]) 2968 2969 yield ErrorInput(SampleInput(result, args=(0, torch.tensor([0, 2]), source)), 2970 error_type=RuntimeError, 2971 error_regex=r'source tensor shape must match self tensor shape, ' 2972 r'excluding the specified dimension. Got self.shape = \[3, 2\] source.shape = \[2\]') 2973 2974def error_inputs_logcumsumexp(op_info, device, **kwargs): 2975 dim = 3 2976 srcs = [torch.randn(5, 2, device=device), torch.randn(0, 2, device=device)] 2977 for src in srcs: 2978 yield ErrorInput(SampleInput(src, args=(dim,)), 2979 error_type=IndexError, 2980 error_regex='Dimension out of range') 2981 2982def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs): 2983 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) 2984 yield SampleInput( 2985 make_arg((S, S)), gather_variable((S, S), 1, S, True, device=device), 0) 2986 2987 # `indices` broadcast 2988 yield SampleInput( 2989 make_arg((S, S)), gather_variable((1, S // 2), 0, S, True, device=device), 1) 2990 2991 # `self` broadcast 2992 yield SampleInput( 2993 make_arg((1, S)), gather_variable((S, S // 2), 0, S, True, device=device), 1) 2994 2995 # without `dim` arg 2996 yield SampleInput( 2997 make_arg((S, S)), gather_variable((S, S // 2), 0, S, True, device=device)) 2998 2999 3000def error_inputs_aminmax_amax_amin(op_info, device, is_ref=False, **kwargs): 3001 3002 # Error Inputs for zero-dim tensors, when 'dim' arg is not provided. 3003 shape = (S, 0, S) 3004 err_msg_amax_amin = "reduction" 3005 err_msg_aminmax = "cannot compute aminmax over an empty dimension as the operation has no identity" 3006 if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: 3007 yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_amax_amin) 3008 elif op_info.name in ['aminmax']: 3009 yield ErrorInput(SampleInput(torch.rand(shape, device=device)), error_regex=err_msg_aminmax) 3010 3011 # Error Inputs for tensors with more than 64 dimension 3012 sizes = [1] * 65 3013 err_msg1 = "only tensors with up to 64 dims are supported" 3014 yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': -1}), 3015 error_regex=err_msg1) 3016 yield ErrorInput(SampleInput(torch.randn(sizes, device=device), kwargs={'dim': 64}), 3017 error_regex=err_msg1) 3018 3019 # Error Inputs for repeated 'dim' 3020 if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: 3021 dims = [(0, 0), (0, -4)] 3022 err_msg2 = "in the list of dims" 3023 x = torch.randn(S, S, S, S, device=device) 3024 for dim in dims: 3025 yield ErrorInput(SampleInput(x, kwargs={'dim': dim}), error_regex=err_msg2) 3026 3027 # Error Input for illegal dtype 3028 input5 = torch.randn(L, L, dtype=torch.float32, device=device) 3029 max_values = torch.empty(L, dtype=torch.float32, device=device) 3030 min_values = torch.empty(L, dtype=torch.double, device=device) 3031 illegal_values = torch.empty(L, dtype=torch.int, device=device) 3032 3033 # Unlike regular PyTorch, amax and amin refs don't require input and out 3034 # dtypes to match exactly: 3035 # https://github.com/pytorch/pytorch/pull/87765#pullrequestreview-1162023824 3036 if is_ref: 3037 err_msg_amax_amin2 = ("Attempting to cast from torch.float32 to out tensor with dtype " 3038 "torch.int32, but this can't be cast because it is not safe!") 3039 else: 3040 err_msg_amax_amin2 = ("Expected the dtype for input and out to match, but got Float " 3041 "for input's dtype and Int for out's dtype.") 3042 err_msg_aminmax2 = "Expected out tensor to have dtype float, but got double instead" 3043 3044 if op_info.name in ['amax', 'amin', '_refs.amax', '_refs.amin']: 3045 yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': illegal_values}), 3046 error_regex=err_msg_amax_amin2) 3047 elif op_info.name in ['aminmax']: 3048 yield ErrorInput(SampleInput(input5, kwargs={'dim': 0, 'out': (max_values, min_values)}), 3049 error_regex=err_msg_aminmax2) 3050 3051 # Error Inputs for functions to raise an error on specified zero'd dimension as reduction dim 3052 err_msg3 = "reduction" 3053 # FIXME: eager and ref impl throw different types of errors 3054 error_type = IndexError if 'refs' not in op_info.name else RuntimeError 3055 yield ErrorInput(SampleInput(torch.rand(shape, device=device), kwargs={'dim': 1}), 3056 error_type=error_type, error_regex=err_msg3) 3057 3058def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs): 3059 test_cases: Tuple[tuple, dict] = ( # type: ignore[assignment] 3060 ((S, S, S), {}), 3061 ((S, S, S), {'dim': 1}), 3062 ((S, S, S), {'dim': 1, 'keepdim': True}), 3063 ((), {'dim': 0}), 3064 ((), {}), 3065 ((), {'dim': 0, 'keepdim': True}), 3066 ((S, 0, S), {'dim': 0}), 3067 ) 3068 3069 for shape, kwargs in test_cases: 3070 yield SampleInput( 3071 make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad), 3072 **kwargs) 3073 3074def error_inputs_diff(op_info, device, **kwargs): 3075 t = torch.rand((1, 3), device=device) 3076 n = -1 3077 yield ErrorInput(SampleInput(t, args=(n, ), kwargs=kwargs), 3078 error_type=RuntimeError, 3079 error_regex=f'order must be non-negative but got {n}') 3080 3081def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs): 3082 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 3083 3084 test_cases = ( 3085 ((1,), 0, None, None), 3086 ((S,), 0, None, None), 3087 ((S, 1), 0, None, None), 3088 ((S, 1), 1, None, None), 3089 ((S, S), 0, None, None), 3090 ((S, S), 1, None, None), 3091 ((S, S), 0, (1, S), (2, S)), 3092 ((S, S), 0, None, (2, S)), 3093 ((XS, XS, XS), 1, None, None), 3094 ((XS, XS, XS), 2, None, None), 3095 ((XS, XS, XS), 1, (XS, 1, XS), (XS, 1, XS)), 3096 ((XS, XS, XS), 2, (XS, XS, 1), (XS, XS, 1)), 3097 ((XS, XS, XS), 2, (XS, XS, XS), (XS, XS, XS)),) 3098 3099 sample_inputs = [] 3100 for size, dim, size_prepend, size_append in test_cases: 3101 prepend_size = 0 if (size_prepend is None) else size_prepend[dim] 3102 append_size = 0 if (size_append is None) else size_append[dim] 3103 dim_size = size[dim] + prepend_size + append_size 3104 for n in range(dim_size): 3105 input_tensor = make_arg(size) 3106 prepend = make_arg(size_prepend) if size_prepend else None 3107 append = make_arg(size_append) if size_append else None 3108 yield SampleInput(input_tensor, n, dim, prepend, append) 3109 3110 # add some samples with n > dim_size 3111 yield SampleInput(make_arg((XS, XS, XS)), S + 1, 1) 3112 yield SampleInput(make_arg((XS, XS, XS)), S * 3 + 2, 2, make_arg((XS, XS, XS)), make_arg((XS, XS, XS))) 3113 3114def sample_inputs_histogram(op_info, device, dtype, requires_grad, **kwargs): 3115 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 3116 3117 sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) 3118 3119 for size, bin_ct, weighted, density in product(sizes, range(1, 5), [False, True], [False, True]): 3120 input_tensor = make_arg(size) 3121 weight_tensor = make_arg(size) if weighted else None 3122 3123 yield SampleInput(input_tensor, bin_ct, 3124 weight=weight_tensor, density=density) 3125 3126 bins_tensor = make_arg((bin_ct + 1,)) 3127 sorted_bins, bins_indices = torch.sort(bins_tensor) 3128 yield SampleInput(input_tensor, sorted_bins, 3129 weight=weight_tensor, density=density) 3130 3131def sample_inputs_histogramdd(op_info, device, dtype, requires_grad, **kwargs): 3132 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 3133 3134 sizes = ((S, S), (S, S, S), (S, 1, S), (S, 0, S)) 3135 bin_ct_patterns = ((1, 1, 1, 1, 1), (2, 3, 2, 3, 2), (3, 2, 3, 2, 3)) 3136 3137 for size, bin_ct_pattern, weighted, density in product(sizes, bin_ct_patterns, [False, True], [False, True]): 3138 input_tensor = make_arg(size) 3139 bin_ct = bin_ct_pattern[:size[-1]] 3140 weight_tensor = make_arg(size[:-1]) if weighted else None 3141 3142 yield SampleInput(input_tensor, bin_ct, 3143 weight=weight_tensor, density=density) 3144 3145 bins_tensor = [make_arg(ct + 1) for ct in bin_ct] 3146 yield SampleInput(input_tensor, bins_tensor, 3147 weight=weight_tensor, density=density) 3148 3149def error_inputs_histogramdd(opinfo, device, **kwargs): 3150 invalid_bins = [1, 1, 1, 1, 1] 3151 make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False) 3152 msg = "histogramdd: The size of bins must be equal to the innermost dimension of the input." 3153 yield ErrorInput(SampleInput(make_arg(5, 6), invalid_bins), error_regex=msg) 3154 3155def sample_inputs_histc(op_info, device, dtype, requires_grad, **kwargs): 3156 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 3157 3158 sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) 3159 3160 for size, min, max in product(sizes, [0, -10], [0, 10]): 3161 # construct sample input omitting bins arg 3162 yield SampleInput(make_arg(size), min=min, max=max) 3163 3164 # construct sample inputs with a few different bins values 3165 for bins in [1, 3, 10]: 3166 yield SampleInput(make_arg(size), bins=bins, min=min, max=max) 3167 3168def sample_inputs_bincount(op_info, device, dtype, requires_grad, **kwargs): 3169 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 3170 3171 for size, weighted in product((S, M), [False, True]): 3172 input_tensor = torch.randint(0, size, (size,), dtype=dtype, device=device) 3173 weight_tensor = make_arg((size,)) if weighted else None 3174 3175 max_val = int(input_tensor.max().item()) 3176 3177 for minlength in [0, max_val // 2, max_val, 2 * max_val]: 3178 yield SampleInput( 3179 input_tensor, weights=weight_tensor, minlength=minlength) 3180 3181def sample_inputs_bucketize(op_info, device, dtype, requires_grad, reference_inputs_mode=False, **kwargs): 3182 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 3183 3184 sizes = (((), S), ((S,), S), ((S, S), S), ((S, S, S), S), ((S, 1, S), S), ((S, 0, S), S)) 3185 3186 if reference_inputs_mode: 3187 sizes += (((256,), 128), ((128,), 256), ((32, 32), 11), ((32, 4, 32), 33)) 3188 3189 for (input_shape, nb), out_int32, right in product(sizes, [False, True], [False, True]): 3190 input_tensor = make_arg(input_shape) 3191 boundaries = make_arg(nb).msort() 3192 3193 yield SampleInput(input_tensor, boundaries, 3194 out_int32=out_int32, right=right) 3195 3196reference_inputs_bucketize = partial(sample_inputs_bucketize, reference_inputs_mode=True) 3197 3198def error_inputs_bucketize(opinfo, device, **kwargs): 3199 make_arg = partial(make_tensor, dtype=torch.float, device=device, requires_grad=False) 3200 yield ErrorInput(SampleInput(make_arg((S, S, S)), make_arg((S, S))), 3201 error_regex="boundaries tensor must be 1 dimension") 3202 3203def sample_inputs_searchsorted(op_info, device, dtype, requires_grad, **kwargs): 3204 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 3205 3206 # (unsorted tensor size, (input sizes,), is_scalar) 3207 sizes = ( 3208 ((0,), ((0,),), False), 3209 ((M,), ((), (M,), (M, M)), False), 3210 ((0, 0), ((0, 0),), False), 3211 ((M, M), ((M, M),), False), 3212 ((0, 0, 0), ((0, 0, 0),), False), 3213 ((M, M, M), ((M, M, M),), False), 3214 ((L,), ((),), True), 3215 ) 3216 3217 for (size, input_sizes, is_scalar), noncontiguous, out_int32, right in product( 3218 sizes, [False, True], [False, True], [False, True] 3219 ): 3220 unsorted_tensor = make_arg(size, noncontiguous=noncontiguous) 3221 for input_size in input_sizes: 3222 input = make_arg(input_size, noncontiguous=noncontiguous) 3223 if is_scalar: 3224 input = input.item() 3225 if np.prod(size) == 0: 3226 boundary_tensor = unsorted_tensor 3227 sorter = make_tensor(size, dtype=torch.int64, device=device, noncontiguous=noncontiguous) 3228 else: 3229 boundary_tensor, sorter = torch.sort(unsorted_tensor) 3230 side = "right" if right else "left" 3231 3232 yield SampleInput(boundary_tensor, input, out_int32=out_int32, right=right) 3233 yield SampleInput(boundary_tensor, input, out_int32=out_int32, side=side) 3234 3235 yield SampleInput(unsorted_tensor, input, out_int32=out_int32, right=right, sorter=sorter) 3236 yield SampleInput(unsorted_tensor, input, out_int32=out_int32, side=side, sorter=sorter) 3237 3238def sample_inputs_gradient(op_info, device, dtype, requires_grad, **kwargs): 3239 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) 3240 test_cases_float = ( 3241 ((S,), None, None, 1), 3242 ((S,), 2., None, 1), 3243 ((S, S), None, None, 2), 3244 ((S, S), [2.0, 2.1], None, 1), 3245 ((S, S), [2.0, 2.1], (0, 1), 1), 3246 ((4, 4, 4), [2., 1.], (0, 1), 2), 3247 ) 3248 for size, spacing, dim, edge_order in test_cases_float: 3249 t = make_arg(size) 3250 yield SampleInput(t, dim=dim, spacing=spacing, edge_order=edge_order) 3251 3252 test_cases_tensor = ( 3253 ((3, 3, 3), ((1.1, 2.0, 3.5), (4.0, 2, 6.0)), (0, -1), 1), 3254 ((3, 3, 3), ((1.0, 3.0, 2.0), (8.0, 6.0, 1.0)), (0, 1), 2), 3255 ) 3256 for size, coordinates, dim, edge_order in test_cases_tensor: 3257 t = make_arg(size) 3258 coordinates_tensor_list = [] 3259 for coords in coordinates: 3260 # `coords` will always contain floating point values and Python 3.10 does not support this 3261 # implicit conversion to an integer using `__int__` 3262 # TODO: this can be simplified after https://github.com/pytorch/pytorch/issues/69316 is fixed 3263 a = torch.tensor(coords, device=device) 3264 coordinates_tensor_list.append(a.to(dtype)) 3265 yield SampleInput(t, dim=dim, spacing=coordinates_tensor_list, edge_order=edge_order) 3266 3267def sample_inputs_getitem(op_info, device, dtype, requires_grad, **kwargs): 3268 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 3269 test_args = [ 3270 ([1, 2],), 3271 (slice(0, 3),), 3272 ([slice(0, 3), 1],), 3273 ([[0, 2, 3], [1, 3, 3], [0, 0, 2]],), 3274 ([[0, 0, 3], [1, 1, 3], [0, 0, 2]],), 3275 ([slice(None), slice(None), [0, 3]],), 3276 ([slice(None), [0, 3], slice(None)],), 3277 ([[0, 3], slice(None), slice(None)],), 3278 ([[0, 3], [1, 2], slice(None)],), 3279 ([[0, 3], ],), 3280 ([[0, 3], slice(None)],), 3281 ([[0, 3], Ellipsis],), 3282 ([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])],), 3283 (index_variable(2, S, device=device),), 3284 (mask_not_all_zeros((S,)),), 3285 ] 3286 3287 for args in test_args: 3288 yield SampleInput(make_arg((S, S, S)), args=args) 3289 3290 yield SampleInput(make_arg((S, S, S, S)), args=([slice(None), [0, 1], slice(None), [0, 1]],)) 3291 3292def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs): 3293 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 3294 3295 for accumulate in [False, True]: 3296 # Test with indices arg 3297 yield SampleInput( 3298 make_arg((S, S,)), 3299 (index_variable(2, S, device=device),), 3300 make_arg((2, S)), 3301 accumulate=accumulate) 3302 3303 # Test with mask arg 3304 mask = torch.zeros(S, dtype=torch.bool) if accumulate else mask_not_all_zeros((S,)) 3305 yield SampleInput( 3306 make_arg((S, S)), (mask, ), make_arg((S,)), accumulate=accumulate) 3307 3308def sample_inputs_sort(op_info, device, dtype, requires_grad, **kwargs): 3309 def small_3d_unique(): 3310 res = torch.randperm(S * S * S, dtype=torch.int64, device=device).view(S, S, S) 3311 res = res.to(dtype).requires_grad_(requires_grad) 3312 return res 3313 3314 def large_1d_unique(): 3315 res = torch.randperm(L * L * L, dtype=torch.int64, device=device) 3316 res = res.to(dtype).requires_grad_(requires_grad) 3317 return res 3318 3319 # Test case for large tensor. 3320 yield SampleInput(large_1d_unique()) 3321 3322 # Test cases for small 3d tensors. 3323 # Imitates legacy tests from test/test_torch.py 3324 dims = range(-3, 3) 3325 flag = [True, False] 3326 for dim, descending, stable in product(dims, flag, flag): 3327 # default schema without stable sort 3328 yield SampleInput(small_3d_unique(), dim, descending) 3329 # schema with stable sort, no CUDA support yet 3330 if torch.device(device).type == 'cpu': 3331 yield SampleInput( 3332 small_3d_unique(), dim=dim, descending=descending, stable=stable) 3333 3334 # Test cases for scalar tensor 3335 tensor_opt = dict(dtype=dtype, device=device, requires_grad=requires_grad) 3336 yield SampleInput(torch.tensor(1, **tensor_opt)) 3337 yield SampleInput(torch.tensor(1, **tensor_opt), 0) 3338 yield SampleInput(torch.tensor(1, **tensor_opt), 0, True) 3339 3340 # Test cases for empty tensor 3341 yield SampleInput(torch.tensor((), **tensor_opt)) 3342 yield SampleInput(torch.tensor((), **tensor_opt), 0) 3343 yield SampleInput(torch.tensor((), **tensor_opt), 0, True) 3344 3345 # Test cases for stable sort 3346 yield SampleInput(small_3d_unique(), stable=True) 3347 yield SampleInput(small_3d_unique(), dim=0, stable=True) 3348 yield SampleInput(small_3d_unique(), dim=0, descending=True, stable=True) 3349 3350def sample_inputs_threshold(op_info, device, dtype, requires_grad, **kwargs): 3351 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 3352 sizes = ((), (S,), (S, S), (S, S, S)) 3353 for x_size in sizes: 3354 # threshold and values args must be numbers 3355 yield SampleInput(make_arg(x_size), make_arg(()).item(), make_arg(()).item()) 3356 3357def sample_inputs_unique(op_info, device, dtype, requires_grad, **kwargs): 3358 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3359 sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) 3360 3361 for shape, sorted, return_inverse, return_counts, dim in \ 3362 product(sizes, [False, True], [False, True], [False, True], [None, -2, -1, 0, 1, 2]): 3363 # torch.unique cannot be called if the input tensor has a zero dimension which isn't the selected dim 3364 if 0 in shape and shape.index(0) is not dim: 3365 continue 3366 3367 # skip invalid dim args 3368 if dim is not None and (dim < -len(shape) or dim >= len(shape)): 3369 continue 3370 3371 kwargs = dict(sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim) 3372 3373 # construct a test case with only one distinct value 3374 input_t = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) 3375 yield SampleInput(input_t, **kwargs) 3376 3377 # construct a test case with mixed 0s and 1s 3378 input_t = make_arg(shape, dtype=torch.bool, requires_grad=False)\ 3379 .to(dtype).requires_grad_(requires_grad) 3380 yield SampleInput(input_t, **kwargs) 3381 3382 # construct a test case with many different values 3383 yield SampleInput(make_arg(shape), **kwargs) 3384 3385def sample_inputs_unique_consecutive(*args, **kwargs): 3386 for sample_input in sample_inputs_unique(*args, **kwargs): 3387 if not sample_input.kwargs["sorted"]: 3388 sample_input.kwargs.pop("sorted") 3389 yield sample_input 3390 3391def sample_inputs_adaptive_avg_pool1d(op_info, device, dtype, requires_grad, **kwargs): 3392 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3393 3394 # Ordered as (input shape, output size) 3395 cases = ( 3396 ((0, 8, 8), (5,)), 3397 ((3, 8, 8), 5), 3398 ((3, 8, 8), 1) 3399 ) 3400 3401 for input_shape, output_size in cases: 3402 # Batched 3403 yield SampleInput(make_arg(input_shape), args=(output_size,)) 3404 # Unbatched 3405 yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) 3406 3407 3408def error_inputs_adaptive_avg_pool1d(opinfo, device, **kwargs): 3409 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 3410 3411 # error inputs for empty output 3412 yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()), 3413 error_regex="'output_size' should contain one int") 3414 3415 # error inputs for output_size lesser than 0 3416 yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)), 3417 error_regex="elements of output_size must be greater than or equal to 0") 3418 3419 3420def sample_inputs_adaptive_avg_pool2d(op_info, device, dtype, requires_grad, **kwargs): 3421 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3422 3423 # Ordered as (input shape, output size) 3424 cases = ( 3425 ((1, 8, 8, 8), (5, 7)), 3426 ((2, 8, 8, 8), (None, 7)), 3427 ((1, 8, 4, 3), (5, None)), 3428 ((1, 8, 4, 3), (None, None)), 3429 ((1, 8, 4, 3), (5)), 3430 ) 3431 3432 for input_shape, output_size in cases: 3433 # Batched 3434 yield SampleInput(make_arg(input_shape), args=(output_size,)) 3435 # Unbatched 3436 yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) 3437 3438 3439def error_inputs_adaptive_avg_pool2d(opinfo, device, **kwargs): 3440 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 3441 3442 # error inputs for incorrect input dimension 3443 yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)), 3444 error_type=ValueError, error_regex="Input dimension should be at least 3") 3445 3446 # error inputs for empty output 3447 yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), 3448 error_regex="output_size must be 2") 3449 3450 # error inputs for output_size lesser than 0 3451 yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)), 3452 error_regex="elements of output_size must be greater than or equal to 0") 3453 3454 3455def sample_inputs_adaptive_avg_pool3d(op_info, device, dtype, requires_grad, **kwargs): 3456 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3457 3458 # Ordered as (input shape, output size) 3459 cases = ( 3460 ((0, 8, 8, 8, 8), (5, 7, 4)), 3461 ((1, 8, 4, 3, 7), (None, None, None)), 3462 ((1, 8, 4, 3, 7), (1, 1, 1)), 3463 ((3, 3, 8, 8, 6), (5, 7, None)), 3464 ((1, 3, 8, 8, 6), (5, None, 2)), 3465 ((3, 3, 8, 8, 6), (None, 3, 2)), 3466 ) 3467 3468 for input_shape, output_size in cases: 3469 # Batched 3470 yield SampleInput(make_arg(input_shape), args=(output_size,)) 3471 # Unbatched 3472 yield SampleInput(make_arg(input_shape[1:]), args=(output_size,)) 3473 3474 3475def error_inputs_adaptive_avg_pool3d(opinfo, device, **kwargs): 3476 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 3477 3478 # error inputs for incorrect input dimension 3479 yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)), 3480 error_type=ValueError, error_regex="Input dimension should be at least 4") 3481 3482 # error inputs for empty output 3483 yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), 3484 error_regex="output_size must be 3") 3485 3486 # error inputs for output_size lesser than 0 3487 yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)), 3488 error_regex="elements of output_size must be greater than or equal to 0") 3489 3490 3491def sample_inputs_adaptive_max_pool1d(op_info, device, dtype, requires_grad, **kwargs): 3492 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3493 3494 # Ordered as (input shape, output size) 3495 cases = ( 3496 # ((0, 8, 8), (5,)), 3497 # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] 3498 ((3, 4, 4), 3), 3499 ((3, 4, 4), 1) 3500 ) 3501 3502 for shapes, return_idx in product(cases, (True, False)): 3503 # Batched 3504 yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) 3505 # Unbatched 3506 yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) 3507 3508 3509def error_inputs_adaptive_max_pool1d(opinfo, device, **kwargs): 3510 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 3511 3512 # error inputs for empty output 3513 yield ErrorInput(SampleInput(make_arg((1, 2, 3)), output_size=()), 3514 error_regex="'output_size' should contain one int") 3515 3516 # error inputs for output_size lesser than 0 3517 yield ErrorInput(SampleInput(make_arg((1, 1, 1)), output_size=(-1,)), 3518 error_regex="Trying to create tensor with negative dimension") 3519 3520def sample_inputs_adaptive_max_pool2d(op_info, device, dtype, requires_grad, **kwargs): 3521 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3522 3523 # Ordered as (input shape, output size) 3524 cases = ( 3525 # ((0, 8, 8, 8), (5, 7)), 3526 # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] 3527 ((1, 4, 4, 4), (2, 3)), 3528 ((2, 4, 4, 4), (None, 3)), 3529 ((2, 4, 4, 4), (1, 1)), 3530 ((1, 4, 4, 3), (3, None)), 3531 ((1, 4, 4, 3), (None, None)), 3532 ((1, 4, 4, 3), (3)), 3533 ) 3534 3535 for shapes, return_idx in product(cases, (True, False)): 3536 # Batched 3537 yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) 3538 # Unbatched 3539 yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) 3540 3541def error_inputs_adaptive_max_pool2d(opinfo, device, **kwargs): 3542 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 3543 3544 # error inputs for incorrect input dimension 3545 yield ErrorInput(SampleInput(make_arg((2, 2)), output_size=(2, 2)), 3546 error_type=ValueError, error_regex="Input dimension should be at least 3") 3547 3548 # error inputs for empty output 3549 yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), 3550 error_regex="internal error") 3551 3552 # error inputs for output_size lesser than 0 3553 yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1)), output_size=(-1, 0)), 3554 error_regex="Trying to create tensor with negative dimension") 3555 3556 3557def sample_inputs_adaptive_max_pool3d(op_info, device, dtype, requires_grad, **kwargs): 3558 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3559 3560 # Ordered as (input shape, output size) 3561 cases = ( 3562 # ((0, 8, 8, 8, 8), (5, 7, 4)), 3563 # 0 batch size doesn't work, cannot reshape tensor of 0 elements into shape [0, 8, -1] 3564 ((1, 4, 4, 3, 5), (None, None, None)), 3565 ((1, 4, 4, 3, 5), (1, 1, 1)), 3566 ((3, 3, 4, 4, 6), (2, 3, None)), 3567 ((1, 3, 4, 4, 6), (3, None, 2)), 3568 ((3, 3, 4, 4, 6), (None, 3, 2)), 3569 ) 3570 3571 for shapes, return_idx in product(cases, (True, False)): 3572 # Batched 3573 yield SampleInput(make_arg(shapes[0]), args=(shapes[1], return_idx)) 3574 # Unbatched 3575 yield SampleInput(make_arg(shapes[0][1:]), args=(shapes[1], return_idx)) 3576 3577def error_inputs_adaptive_max_pool3d(opinfo, device, **kwargs): 3578 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 3579 3580 # error inputs for incorrect input dimension 3581 yield ErrorInput(SampleInput(make_arg((2, 2, 2)), output_size=(2, 2, 2)), 3582 error_type=ValueError, error_regex="Input dimension should be at least 4") 3583 3584 # error inputs for empty output 3585 yield ErrorInput(SampleInput(make_arg((1, 2, 3, 4)), output_size=()), 3586 error_regex="internal error") 3587 3588 # error inputs for output_size lesser than 0 3589 yield ErrorInput(SampleInput(make_arg((1, 1, 1, 1, 1)), output_size=(-1, 0, 2)), 3590 error_regex="Trying to create tensor with negative dimension") 3591 3592 3593class _TestParamsMaxPoolBase: 3594 3595 def __init__(self) -> None: 3596 self.kwargs = { 3597 'kernel_size': [3], 3598 'stride': [2, None], 3599 'ceil_mode': [True, False], 3600 'padding': [0, 1], 3601 'dilation': [1], 3602 'return_indices': [True, False] 3603 } 3604 3605 self.shapes = [ 3606 [1, 2, None], # batch 3607 [2], # channels 3608 [3, 6] # signal 3609 ] 3610 3611 def _gen_shape(self): 3612 for shape in product(*self.shapes): 3613 # shape[0] is None indicates missing batch dimension 3614 if shape[0] is None: 3615 shape = shape[1:] 3616 3617 yield shape, torch.contiguous_format 3618 # only 2d (N, C, H, W) rank 4 tensors support channels_last memory format 3619 if len(self.shapes) == 4 and len(shape) == 4: 3620 yield shape, torch.channels_last 3621 3622 def _gen_kwargs(self): 3623 keys = self.kwargs.keys() 3624 for values in product(*self.kwargs.values()): 3625 yield dict(zip(keys, values)) 3626 3627 def gen_input_params(self): 3628 yield from product(self._gen_shape(), self._gen_kwargs()) 3629 3630class _TestParamsMaxPool1d(_TestParamsMaxPoolBase): 3631 3632 def __init__(self) -> None: 3633 super().__init__() 3634 self.kwargs['kernel_size'] += [(3,)] 3635 self.kwargs['stride'] += [(2,)] 3636 self.kwargs['padding'] += [(1,)] 3637 self.kwargs['dilation'] += [(1,)] 3638 3639class _TestParamsMaxPool2d(_TestParamsMaxPoolBase): 3640 3641 def __init__(self) -> None: 3642 super().__init__() 3643 self.kwargs['kernel_size'] += [(3, 2)] 3644 self.kwargs['stride'] += [(2, 1)] 3645 self.kwargs['padding'] += [(1, 1)] 3646 self.kwargs['dilation'] += [(1, 2)] 3647 3648 self.shapes.append([6]) 3649 3650class _TestParamsMaxPool3d(_TestParamsMaxPoolBase): 3651 3652 def __init__(self) -> None: 3653 super().__init__() 3654 self.kwargs['kernel_size'] += [(3, 2, 3)] 3655 self.kwargs['stride'] += [(2, 1, 2)] 3656 self.kwargs['dilation'] += [(1, 2, 1)] 3657 3658 self.shapes.append([6]) 3659 self.shapes.append([5]) 3660 3661def sample_inputs_max_pool(op_info, device, dtype, requires_grad, **kwargs): 3662 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 3663 3664 params_generator_type_dict = { 3665 'nn.functional.max_pool1d': _TestParamsMaxPool1d, 3666 'nn.functional.max_pool2d': _TestParamsMaxPool2d, 3667 'nn.functional.max_pool3d': _TestParamsMaxPool3d, 3668 'max_pool2d_with_indices_backward': _TestParamsMaxPool2d, 3669 } 3670 3671 params_generator = params_generator_type_dict[op_info.name]() 3672 for (shape, memory_format), kwargs in params_generator.gen_input_params(): 3673 arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) 3674 yield SampleInput(arg, kwargs=kwargs) 3675 3676def max_pool2d_backward(*args, kernel_size=(), stride=(), padding=(0,), dilation=(1,), ceil_mode=False, **kwargs): 3677 out, indices = torch.nn.functional.max_pool2d_with_indices( 3678 *args, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, return_indices=True) 3679 grad_out = torch.ones_like(out) 3680 if stride is None: 3681 stride = kernel_size 3682 out_b = torch.ops.aten.max_pool2d_with_indices_backward.default( 3683 grad_out, *args, kernel_size, stride, padding, dilation, ceil_mode, indices) 3684 return out_b 3685 3686def error_inputs_max_pool1d(op_info, device, **kwargs): 3687 # Toggle requires_grad because `max_pool1d` has different path 3688 # based on whether `requires_grad` is set or not. 3689 for requires_grad in (True, False): 3690 make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=requires_grad) 3691 # error inputs when pad is negative 3692 x = make_arg((0, 1, 49)) 3693 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), 3694 error_regex='pad must be non-negative') 3695 3696 # error inputs when pad > kernel_size / 2 3697 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), 3698 error_regex='pad should be at most half of effective kernel size') 3699 3700 # error inputs when pad > ((kernel_size - 1) * dilation + 1) / 2, when dilation is not default 3701 yield ErrorInput(SampleInput(x, 3702 kwargs={'kernel_size': 3, 'dilation': 2, 'stride': 1, 'padding': 3, 'return_indices': True}), 3703 error_regex='pad should be at most half of effective kernel size') 3704 3705 # error inputs for input tensor 3706 error_msg = r'Expected 2D or 3D \(batch mode\) tensor with optional 0 dim batch size for input' 3707 yield ErrorInput(SampleInput(make_arg((), requires_grad=requires_grad), kwargs={'kernel_size': 1}), 3708 error_regex=error_msg) 3709 3710 # error inputs for empty input 3711 yield ErrorInput(SampleInput(torch.tensor([], device=device, requires_grad=requires_grad), 3712 kwargs={'kernel_size': 1}), 3713 error_regex=error_msg) 3714 3715 # error: unbatched input with 0 sized non-batch dims. 3716 yield ErrorInput(SampleInput(make_arg((0, 10), requires_grad=requires_grad), 3717 kwargs={'kernel_size': 1}), 3718 error_regex=error_msg) 3719 3720 # error: batched input with 0 sized non-batch dims. 3721 yield ErrorInput(SampleInput(make_arg((1, 10, 0), requires_grad=requires_grad), 3722 kwargs={'kernel_size': 1}), 3723 error_regex=error_msg) 3724 3725 # error inputs for empty input with stride=0 3726 error_msg = 'stride must be greater than zero, but got 0' 3727 yield ErrorInput(SampleInput(make_arg((3, 3, 3)), kwargs={'kernel_size': 1, 'stride': 0}), 3728 error_regex=error_msg) 3729 3730 # error inputs for empty input with dilation=0 3731 error_msg = 'dilation must be greater than zero, but got 0' 3732 yield ErrorInput(SampleInput(make_arg((3, 3, 3)), 3733 kwargs={'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 0}), 3734 error_regex=error_msg) 3735 3736 # error inputs for invalid output size 3737 error_msg = 'Invalid computed output size: -2' 3738 yield ErrorInput(SampleInput(make_arg((2, 2, 2)), 3739 kwargs={'kernel_size': 5, 'stride': 1, 'padding': 0, 'dilation': 1}), 3740 error_regex=error_msg) 3741 3742 # error inputs when kernel_size=0 3743 error_msg = 'kernel_size must be greater than zero' 3744 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 0}), 3745 error_regex=error_msg) 3746 3747 # error inputs for strides > 0 3748 error_msg = 'stride must be greater than zero' 3749 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 0}), 3750 error_regex=error_msg) 3751 3752 3753def error_inputs_max_pool2d(op_info, device, **kwargs): 3754 make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) 3755 # error inputs when pad is negative 3756 x = make_arg((0, 1, 49)) 3757 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), 3758 error_regex='pad must be non-negative') 3759 # 2-dimensional kernel 3760 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1, 'return_indices': True}), 3761 error_regex='pad must be non-negative') 3762 3763 # error inputs when pad > kernel_size / 2 (kernel_size : int) 3764 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), 3765 error_regex='pad should be at most half of effective kernel size') 3766 3767 # error inputs when pad > kernel_size / 2 (kernel_size : tuple) 3768 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4, 'return_indices': True}), 3769 error_regex='pad should be at most half of effective kernel size') 3770 3771 # error: unbatched input with 0 sized non-batch dims. 3772 err_msg = r'Expected 3D or 4D \(batch mode\) tensor with optional 0 dim batch size for input' 3773 yield ErrorInput(SampleInput(make_arg((1, 0, 10)), 3774 kwargs={'kernel_size': 1}), 3775 error_regex=err_msg) 3776 3777 # error: batched input with 0 sized non-batch dims. 3778 yield ErrorInput(SampleInput(make_arg((2, 1, 10, 0)), 3779 kwargs={'kernel_size': 1}), 3780 error_regex=err_msg) 3781 3782 3783def error_inputs_max_pool3d(op_info, device, **kwargs): 3784 make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) 3785 # error inputs when pad is negative 3786 x = make_arg((0, 1, 49, 50)) 3787 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1, 'return_indices': True}), 3788 error_regex='pad must be non-negative') 3789 # 3-dimensional kernel 3790 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 3791 'padding': -1, 'return_indices': True}), 3792 error_regex='pad must be non-negative') 3793 3794 # error inputs when pad > kernel_size / 2 (kernel_size: int) 3795 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4, 'return_indices': True}), 3796 error_regex='pad should be at most half of effective kernel size') 3797 3798 # error inputs when pad > kernel_size / 2 (kernel_size: tuple) 3799 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 3800 'padding': 4, 'return_indices': True}), 3801 error_regex='pad should be at most half of effective kernel size') 3802 3803 # error: unbatched input with 0 sized non-batch dims. 3804 err_msg = r'Expected input\'s non-batch dimensions to have positive length' 3805 yield ErrorInput(SampleInput(make_arg((0, 1, 2, 10)), 3806 kwargs={'kernel_size': 1}), 3807 error_regex=err_msg) 3808 3809 # error: batched inputs with 0 sized non-batch dims. 3810 yield ErrorInput(SampleInput(make_arg((2, 1, 0, 1, 2)), 3811 kwargs={'kernel_size': 1}), 3812 error_regex=err_msg) 3813 3814 3815def sample_inputs_normalize(self, device, dtype, requires_grad, **kwargs): 3816 make_arg = partial(make_tensor, low=-1, high=1, device=device, dtype=dtype, requires_grad=requires_grad) 3817 3818 cases: Tuple[Tuple[int], dict] = ( # type: ignore[assignment] 3819 ((2, 1, 4, 5), {'p': 1., 'dim': 2}), 3820 ((2, 3, 4, 5), {'p': 2., 'dim': 1}), 3821 ((1, 2, 4, 5), {'p': 0.5, 'dim': 0}), 3822 ((1, 3, 4, 5), {'p': -1., 'dim': 1}), 3823 ((1, 3, 4, 5), {'p': 0., 'dim': -1}), 3824 ((), {'p': 1.2, 'dim': 0}), 3825 ((2, 3, 4, 5), {}), 3826 ((2, 3, 4, 5), {'eps': 1e-4})) 3827 3828 for input_shape, kwargs in cases: 3829 yield SampleInput(make_arg(input_shape), kwargs=kwargs) 3830 3831 3832def complex_conv(fn, input_size, weight, grad_output, stride, padding, dilation, groups): 3833 # conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0)) 3834 # a = conv(Wr, xr, br), 3835 # b = conv(Wi, xi, 0), 3836 # c = conv(Wr + Wi, xr + xi, br + bi) 3837 # conv(W, x, b) = a - b + i(c - a - b) 3838 3839 grad_output_ = torch.view_as_real(grad_output) 3840 grad_output_r = grad_output_[..., 0] 3841 grad_output_i = grad_output_[..., 1] 3842 3843 weight_ = torch.view_as_real(weight) 3844 weight_r = weight_[..., 0] 3845 weight_i = weight_[..., 1] 3846 3847 a = fn(input_size, weight_r, grad_output_r, stride, padding, dilation, groups) 3848 b = fn(input_size, weight_i, grad_output_i, stride, padding, dilation, groups) 3849 c = fn(input_size, weight_r + weight_i, grad_output_r + grad_output_i, stride, padding, dilation, groups) 3850 3851 return (a - b) + 1j * (c - a - b) 3852 3853 3854def conv_transpose_ref(input, weight, bias, stride=1, padding=0, 3855 output_padding=0, dilation=1, groups=1, 3856 fn=None): 3857 # Derivative of `conv` is `conv_transpose`. 3858 # To verify the correctness of `conv_transpose`, 3859 # we rely `torch.nn.grad` implementation (which is tested in test_nn.py) 3860 # for floating dtypes. 3861 3862 assert fn is not None 3863 3864 grad_fn_map = {torch.nn.functional.conv_transpose1d: torch.nn.grad.conv1d_input, 3865 torch.nn.functional.conv_transpose2d: torch.nn.grad.conv2d_input, 3866 torch.nn.functional.conv_transpose3d: torch.nn.grad.conv3d_input} 3867 batched_dim_map = {torch.nn.functional.conv_transpose1d: 3, 3868 torch.nn.functional.conv_transpose2d: 4, 3869 torch.nn.functional.conv_transpose3d: 5} 3870 3871 # Input for `ref` is ndarray. 3872 input, weight = torch.from_numpy(input), torch.from_numpy(weight) 3873 3874 is_batched = len(input.shape) == batched_dim_map[fn] 3875 if not is_batched: 3876 input = input.unsqueeze(0) 3877 3878 if bias is not None: 3879 bias = torch.from_numpy(bias) 3880 unsqueeze_dims = input.ndim - 2 3881 for _ in range(unsqueeze_dims): 3882 bias = bias.unsqueeze(1) 3883 3884 grad_output = input 3885 # Get the input shape for grad_fn. 3886 conv_transpose_output = fn(grad_output.to('meta'), weight.to('meta'), None, 3887 stride=stride, padding=padding, output_padding=output_padding, 3888 groups=groups, dilation=dilation) 3889 input_size = conv_transpose_output.shape 3890 3891 grad_fn = grad_fn_map[fn] 3892 if weight.dtype.is_complex: 3893 out = complex_conv(grad_fn, input_size, weight, grad_output, stride, padding, dilation, groups) 3894 else: # Floating 3895 out = grad_fn(input_size, weight, grad_output, stride, padding, dilation, groups) 3896 3897 if bias is not None: 3898 out = out + bias 3899 3900 return out.squeeze(0) if not is_batched else out 3901 3902 3903def sample_inputs_conv_transpose1d(op_info, device, dtype, requires_grad, **kwargs): 3904 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3905 3906 # Ordered as shapes for input, weight, bias 3907 # and a dict of values of (stride, padding, output_padding, groups, dilation) 3908 cases: Tuple[Tuple[int], Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment] 3909 ((1, 3, 4), (3, 3, 3), (3,), 3910 {'stride': (2,), 'padding': 2, 'output_padding': (1,), 'groups': 1}), 3911 ((2, 2, 4), (2, 2, 4), (4,), 3912 {'stride': (3,), 'padding': (1,), 'output_padding': (2,), 'groups': 2, 'dilation': (4,)}), 3913 ((1, 1, 4), (1, 1, 4), (1,), 3914 {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2,)}), 3915 ((1, 1, 4), (1, 2, 3), None, 3916 {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), 3917 ((1, 4, 5), (4, 8, 3), None, 3918 {}) 3919 ) 3920 3921 for input_shape, weight, bias, kwargs in cases: 3922 # Batched 3923 yield SampleInput(make_arg(input_shape), args=( 3924 make_arg(weight), 3925 make_arg(bias) if bias is not None else bias 3926 ), kwargs=kwargs) 3927 # Unbatched 3928 yield SampleInput(make_arg(input_shape[1:]), args=( 3929 make_arg(weight), 3930 make_arg(bias) if bias is not None else bias 3931 ), kwargs=kwargs) 3932 3933 3934def sample_inputs_conv_transpose2d(op_info, device, dtype, requires_grad, **kwargs): 3935 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3936 3937 # Ordered as shapes for input, weight, bias 3938 # and a dict of values of (stride, padding, output_padding, groups, dilation) 3939 cases: Tuple[Tuple[int], Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment] 3940 ((1, 3, 4, 4), (3, 3, 3, 3), (3,), 3941 {'stride': (2, 2), 'padding': 2, 'output_padding': (1, 1), 'groups': 1}), 3942 ((2, 2, 4, 4), (2, 2, 4, 5), (4,), 3943 {'stride': (3, 2), 'padding': (1, 2), 'output_padding': (2, 3), 'groups': 2, 'dilation': (4, 4)}), 3944 ((1, 1, 4, 5), (1, 1, 4, 3), (1,), 3945 {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3)}), 3946 ((1, 1, 4, 3), (1, 2, 3, 4), None, 3947 {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), 3948 ((2, 4, 4, 4), (4, 1, 3, 3), None, {'groups': 4}), 3949 ((1, 2, 5, 5), (2, 4, 3, 3), None, {}) 3950 ) 3951 3952 for input_shape, weight, bias, kwargs in cases: 3953 # Batched 3954 yield SampleInput(make_arg(input_shape), args=( 3955 make_arg(weight), 3956 make_arg(bias) if bias is not None else bias 3957 ), kwargs=kwargs) 3958 # Unbatched 3959 yield SampleInput(make_arg(input_shape[1:]), args=( 3960 make_arg(weight), 3961 make_arg(bias) if bias is not None else bias 3962 ), kwargs=kwargs) 3963 3964def sample_inputs_conv_transpose3d(op_info, device, dtype, requires_grad, **kwargs): 3965 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3966 3967 # Ordered as shapes for input, weight, bias 3968 # and a dict of values of (stride, padding, output_padding, groups, dilation) 3969 cases: Tuple[Tuple[int], Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment] 3970 ((1, 3, 4, 4, 4), (3, 3, 3, 3, 3), (3,), 3971 {'stride': (2, 2, 2), 'padding': 2, 'output_padding': (1, 1, 1), 'groups': 1}), 3972 ((2, 2, 4, 4, 4), (2, 2, 4, 5, 6), (4,), 3973 {'stride': (3, 2, 1), 'padding': (1, 2, 3), 'output_padding': (2, 3, 1), 'groups': 2, 'dilation': (4, 4, 4)}), 3974 ((1, 1, 4, 5, 2), (1, 1, 4, 3, 1), (1,), 3975 {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1, 'dilation': (2, 3, 2)}), 3976 ((1, 1, 4, 3, 4), (1, 2, 3, 4, 5), None, 3977 {'stride': 2, 'padding': 1, 'output_padding': 1, 'groups': 1}), 3978 ((1, 4, 5, 5, 5), (4, 8, 3, 3, 3), None, 3979 {}) 3980 ) 3981 3982 for input_shape, weight, bias, kwargs in cases: 3983 # Batched 3984 yield SampleInput(make_arg(input_shape), args=( 3985 make_arg(weight), 3986 make_arg(bias) if bias is not None else bias 3987 ), kwargs=kwargs) 3988 # Unbatched 3989 yield SampleInput(make_arg(input_shape[1:]), args=( 3990 make_arg(weight), 3991 make_arg(bias) if bias is not None else bias 3992 ), kwargs=kwargs) 3993 3994 3995def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): 3996 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 3997 3998 # Ordered as shapes for input, weight, bias, 3999 # and a dict of values of (stride, padding, dilation, groups) 4000 cases: Tuple = ( 4001 ((1, 3, 4), (3, 3, 3), (3,), {'stride': (2,), 'padding': 2, 'groups': 1}), 4002 ((2, 4, 8), (2, 2, 3), (2,), {'stride': 3, 'padding': 1, 'groups': 2, 'dilation': 2}), 4003 ((1, 4, 5), (1, 4, 3), None, {'stride': (2,), 'padding': 'valid'}), 4004 ((2, 2, 4), (2, 1, 4), (2,), {'stride': (1,), 'padding': 'same', 'groups': 2, 'dilation': (2,)}), 4005 # With defaults 4006 ((1, 4, 5), (3, 4, 3), None, {}), 4007 ) 4008 4009 for input_shape, weight, bias, kwargs in cases: 4010 # Batched 4011 yield SampleInput(make_arg(input_shape), args=( 4012 make_arg(weight), 4013 make_arg(bias) if bias is not None else bias 4014 ), kwargs=kwargs) 4015 # Unbatched 4016 yield SampleInput(make_arg(input_shape[1:]), args=( 4017 make_arg(weight), 4018 make_arg(bias) if bias is not None else bias 4019 ), kwargs=kwargs) 4020 4021 4022def error_inputs_conv1d(opinfo, device, **kwargs): 4023 make_arg = partial(make_tensor, device=device, dtype=torch.float64) 4024 make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) 4025 make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) 4026 4027 # error inputs for different dtypes of input tensor and bias 4028 yield ErrorInput( 4029 SampleInput(make_int_arg((1, 1, 4)), args=(make_int_arg((1, 1, 2)), make_arg((1,)))), 4030 error_regex="should be the same") 4031 4032 # error inputs for different dtypes of input tensor and bias 4033 yield ErrorInput( 4034 SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_complex_arg((1,)))), 4035 error_regex="should be the same") 4036 4037 # error inputs for negative strides 4038 yield ErrorInput( 4039 SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), 4040 kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") 4041 4042 # error inputs for negative padding 4043 yield ErrorInput( 4044 SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2, 2)), make_arg((1,))), 4045 kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") 4046 4047 # error inputs for negative dilation 4048 yield ErrorInput( 4049 SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 2)), make_arg((1,))), 4050 kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") 4051 4052 # FIXME: https://github.com/pytorch/pytorch/issues/85656 4053 # error inputs for bias shape not equal to the output channels 4054 # yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 1, 3)), make_arg((2,)))), 4055 # error_regex="expected bias to be 1-dimensional with 1 elements") 4056 4057 # error inputs for input.ndim != weight.ndim 4058 yield ErrorInput(SampleInput(make_arg((1, 1, 4)), args=(make_arg((1, 2)), make_arg((1,)))), 4059 error_regex="weight should have at least three dimensions") 4060 4061 # error inputs for the weight[0] are less than the number of groups 4062 yield ErrorInput( 4063 SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), 4064 kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") 4065 4066 # error inputs for the weight[0] are less than the number of groups 4067 yield ErrorInput( 4068 SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), 4069 kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") 4070 4071 # error inputs for invalid groups 4072 yield ErrorInput( 4073 SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), 4074 kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") 4075 4076 # error inputs for invalid groups 4077 yield ErrorInput( 4078 SampleInput(make_arg((2, 2, 4)), args=(make_arg((2, 2, 2)), make_arg((2,))), 4079 kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") 4080 4081 4082def error_inputs_conv2d(opinfo, device, **kwargs): 4083 make_arg = partial(make_tensor, device=device, dtype=torch.float64) 4084 make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) 4085 make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) 4086 4087 # error inputs for different dtypes of input tensor and bias 4088 yield ErrorInput( 4089 SampleInput(make_int_arg((2, 4, 4)), args=(make_int_arg((3, 2, 3, 3)), make_arg((3,)))), 4090 error_regex="should be the same") 4091 4092 # error inputs for different dtypes of input tensor and bias 4093 yield ErrorInput( 4094 SampleInput(make_arg((2, 4, 4)), args=(make_arg((3, 2, 3, 3)), make_complex_arg((3,)))), 4095 error_regex="should be the same") 4096 4097 # error inputs for negative strides 4098 yield ErrorInput( 4099 SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 2, 2, 3)), make_arg((1,))), 4100 kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") 4101 4102 # error inputs for negative padding 4103 yield ErrorInput( 4104 SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2, 4)), make_arg((1,))), 4105 kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") 4106 4107 # error inputs for negative dilation 4108 yield ErrorInput( 4109 SampleInput(make_arg((1, 1, 4, 2)), args=(make_arg((1, 1, 2, 5)), make_arg((1,))), 4110 kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") 4111 4112 # FIXME: https://github.com/pytorch/pytorch/issues/85656 4113 # error inputs for bias shape not equal to the output channels 4114 # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4)), args=(make_arg((1, 1, 3, 2)), make_arg((2,)))), 4115 # error_regex="expected bias to be 1-dimensional with 1 elements") 4116 4117 # error inputs for input.ndim != weight.ndim 4118 yield ErrorInput( 4119 SampleInput(make_arg((1, 1, 4, 3)), args=(make_arg((1, 2, 2)), make_arg((1,))), 4120 kwargs={'padding': 'same'}), error_regex="Expected 3-dimensional input for 3-dimensional weight") 4121 4122 # error inputs for the weight[0] are less than the number of groups 4123 yield ErrorInput( 4124 SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), 4125 kwargs={'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") 4126 4127 # error inputs for groups the weight[0] are less than the number of groups 4128 yield ErrorInput( 4129 SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 1, 3)), make_arg((2,))), 4130 kwargs={'padding': 'same', 'groups': 3}), error_regex="expected weight to be at least 3 at dimension 0") 4131 4132 # error inputs for invalid groups 4133 yield ErrorInput( 4134 SampleInput(make_arg((2, 2, 4, 5)), args=(make_arg((2, 2, 1, 4)), make_arg((2,))), 4135 kwargs={'padding': 'same', 'groups': -1}), error_regex="non-positive groups is not supported") 4136 4137 # error inputs for invalid groups 4138 yield ErrorInput( 4139 SampleInput(make_arg((2, 2, 4, 3)), args=(make_arg((2, 2, 4, 3)), make_arg((2,))), 4140 kwargs={'padding': 'same', 'groups': 0}), error_regex="non-positive groups is not supported") 4141 4142 4143def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=False, **kwargs): 4144 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4145 4146 # Ordered as shapes for input, weight, bias 4147 # and a dict of values of (stride, padding, groups, dilation) 4148 cases: Tuple = ( 4149 ((1, 3, 4, 4), (3, 3, 3, 3), (3,), 4150 {'stride': (2, 2), 'padding': 2, 'groups': 1}), 4151 ((2, 4, 8, 8), (2, 2, 3, 3), (2,), 4152 {'stride': (3, 2), 'padding': (2, 1), 'groups': 2, 'dilation': (4, 4)}), 4153 ((1, 4, 5, 5), (1, 4, 2, 3), (1,), 4154 {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}), 4155 ((1, 4, 5, 5), (1, 4, 2, 3), (1,), 4156 {'stride': 2, 'padding': 1, 'groups': 1, 'dilation': (2, 3)}), 4157 ((1, 2, 4, 3), (4, 2, 3, 4), None, 4158 {'stride': 2, 'padding': 1, 'groups': 1}), 4159 ((1, 4, 5, 5), (1, 4, 2, 3), (1,), 4160 {'stride': 2, 'padding': "valid"}), 4161 ((1, 4, 5, 5), (1, 4, 2, 3), (1,), 4162 {'stride': 1, 'padding': "same", 'dilation': 3}), 4163 # Below are the group related samples from common_nn.py 4164 ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4}), 4165 ((2, 4, 6, 6), (8, 1, 3, 3), (8,), {'groups': 4}), 4166 ((2, 4, 6, 6), (8, 1, 3, 3), None, {'groups': 4}), 4167 ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'stride': (3, 2)}), 4168 ((2, 4, 6, 6), (4, 1, 3, 3), (4,), {'groups': 4, 'padding': (1, 1)}), 4169 ((2, 4, 5, 5), (4, 1, 2, 2), (4,), {'groups': 4, 'dilation': (2, 2)}), 4170 ((2, 4, 6, 5), (6, 2, 3, 2), (6,), {'groups': 2}), 4171 # With defaults 4172 ((1, 4, 5, 5), (3, 4, 3, 3), None, {}), 4173 ) 4174 4175 for input_shape, weight, bias, kwargs in cases: 4176 # Batched 4177 yield SampleInput(make_arg(input_shape), args=( 4178 make_arg(weight), 4179 make_arg(bias) if bias is not None else bias 4180 ), kwargs=kwargs) 4181 # Unbatched 4182 yield SampleInput(make_arg(input_shape[1:]), args=( 4183 make_arg(weight), 4184 make_arg(bias) if bias is not None else bias 4185 ), kwargs=kwargs) 4186 4187 4188def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs): 4189 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4190 4191 # Ordered as shapes for input, weight, bias 4192 # and dict of values of (stride, padding, dilation, groups) 4193 cases: Tuple = ( 4194 ((1, 1, 4, 4, 4), (1, 1, 1, 1, 1), (1,), {'padding': 'same'}), 4195 ((1, 1, 4, 4, 4), (1, 1, 4, 4, 4), (1,), {'stride': (2, 2, 2)}), 4196 ((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1,), {'dilation': 2}), 4197 ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), 4198 ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same'}), 4199 ((1, 1, 10, 11, 12), (1, 1, 1, 2, 5), None, {'padding': 'same', 'dilation': 2}), 4200 ((1, 1, 10, 11, 12), (1, 1, 4, 4, 4), None, {'padding': 'same', 'dilation': 3}), 4201 ((1, 1, 1, 1, 10), (1, 1, 1, 1, 4), None, {'padding': 'valid'}), 4202 ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'groups': 3}), 4203 ((3, 9, 3, 1, 9), (3, 3, 3, 1, 9), (3,), {'stride': (2, 2, 2), 'dilation': 1, 'groups': 3}), 4204 ) 4205 4206 for input_shape, weight, bias, kwargs in cases: 4207 # Batched 4208 yield SampleInput(make_arg(input_shape), args=( 4209 make_arg(weight), 4210 make_arg(bias) if bias is not None else bias 4211 ), kwargs=kwargs) 4212 # Unbatched 4213 yield SampleInput(make_arg(input_shape[1:]), args=( 4214 make_arg(weight), 4215 make_arg(bias) if bias is not None else bias 4216 ), kwargs=kwargs) 4217 4218 4219def error_inputs_conv3d(opinfo, device, **kwargs): 4220 make_arg = partial(make_tensor, device=device, dtype=torch.float64) 4221 make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) 4222 make_complex_arg = partial(make_tensor, device=device, dtype=torch.complex128) 4223 4224 # error inputs for different dtypes of input tensor and bias 4225 yield ErrorInput( 4226 SampleInput(make_int_arg((1, 1, 4, 4, 4)), args=(make_int_arg((1, 1, 2, 2, 2)), make_arg((1,)))), 4227 error_regex="should be the same") 4228 4229 # error inputs for different dtypes of input tensor and bias 4230 yield ErrorInput( 4231 SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_complex_arg((1,)))), 4232 error_regex="should be the same") 4233 4234 # error inputs for negative strides 4235 yield ErrorInput( 4236 SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), 4237 kwargs={'stride': (-1,)}), error_regex="non-positive stride is not supported") 4238 4239 # error inputs for negative padding 4240 yield ErrorInput( 4241 SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), 4242 kwargs={'padding': (-1,)}), error_regex="negative padding is not supported") 4243 4244 # error inputs for negative dilation 4245 yield ErrorInput( 4246 SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 2, 2, 2)), make_arg((1,))), 4247 kwargs={'dilation': (-1,)}), error_regex="dilation should be greater than zero") 4248 4249 # FIXME: https://github.com/pytorch/pytorch/issues/85656 4250 # error inputs for bias shape not equal to the output channels 4251 # yield ErrorInput(SampleInput(make_arg((1, 1, 4, 4, 4)), args=(make_arg((1, 1, 3, 3, 3)), make_arg((2,)))), 4252 # error_regex="expected bias to be 1-dimensional with 1 elements") 4253 4254 # error inputs for input.ndim != weight.ndim 4255 yield ErrorInput( 4256 SampleInput(make_arg((1, 1, 3, 4, 5)), args=(make_arg((1, 1, 4, 3)), make_arg((1,))), 4257 kwargs={'padding': 'same'}), error_regex="Expected 4-dimensional input for 4-dimensional weight") 4258 4259 # error inputs for the weight[0] are less than the number of groups 4260 yield ErrorInput( 4261 SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), 4262 make_arg((2,))), kwargs={'groups': 3}), 4263 error_regex="expected weight to be at least 3 at dimension 0") 4264 4265 # error inputs for the weight[0] are less than the number of groups 4266 yield ErrorInput( 4267 SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), 4268 make_arg((2,))), kwargs={'padding': 'same', 'groups': 3}), 4269 error_regex="expected weight to be at least 3 at dimension 0") 4270 4271 # error inputs for invalid groups 4272 yield ErrorInput( 4273 SampleInput(make_arg((2, 2, 3, 4, 5)), args=(make_arg((2, 2, 4, 3, 3)), 4274 make_arg((2,))), kwargs={'padding': 'same', 'groups': 0}), 4275 error_regex="non-positive groups is not supported") 4276 4277 # error inputs for padding='same' not supported by strided convolutions 4278 yield ErrorInput( 4279 SampleInput(make_arg((18, 27, 9, 1, 9)), args=(make_arg((9, 9, 9, 1, 9)), 4280 make_arg((9,))), kwargs={'stride': 2, 'padding': 'same', 'groups': 3}), 4281 error_regex="padding='same' is not supported for strided convolutions") 4282 4283 4284def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs): 4285 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4286 4287 # Ordered as input shape, num groups, and kwargs for eps 4288 cases: Tuple[Tuple[int], int, float] = ( # type: ignore[assignment] 4289 ((1, 6, 3), 2, {'eps' : 0.5}), 4290 ((2, 6, 3), 2, {'eps' : -0.5}), 4291 ((1, 3), 1, {'eps' : 1e-5}), 4292 ((0, 2), 1, {'eps' : 1e-5}), 4293 ((S, S, S), 1, {'eps' : 0.5}), 4294 ) 4295 4296 # num_channels is inferred to be input.shape[1] dimension 4297 for input_shape, num_groups, kwargs in cases: 4298 # Shape of weight and bias should be the same as num_channels 4299 channels = input_shape[1] if len(input_shape) > 1 else 0 4300 weight_tensor = make_arg(channels) 4301 bias_tensor = make_arg(channels) 4302 4303 # Checking for permutations of weights and biases as `None` 4304 weights = [weight_tensor, None] 4305 biases = [bias_tensor, None] 4306 for weight, bias in itertools.product(weights, biases): 4307 kwargs = { 4308 'weight': weight, 4309 'bias': bias, 4310 **kwargs 4311 } 4312 yield SampleInput(make_arg(input_shape), num_groups, **kwargs) 4313 4314 # Without any optional args 4315 yield SampleInput(make_arg((1, 2)), args=(1,)) 4316 4317def reference_inputs_group_norm(op_info, device, dtype, requires_grad, **kwargs): 4318 yield from sample_inputs_group_norm( 4319 op_info, device, dtype, requires_grad, **kwargs) 4320 4321 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4322 4323 # Ordered as input shape, num groups, and kwargs for eps 4324 cases: Tuple[Tuple[int], int, float] = ( # type: ignore[assignment] 4325 ((20, 6, 10, 10), 3, {'eps' : 1e-5}), 4326 # equivalent with InstanceNorm 4327 # GroupNorm(C, num_groups=C) == InstanceNorm(num_features=C) 4328 ((20, 6, 10, 10), 6, {'eps' : 1e-5}), 4329 # equivalent with LayerNorm 4330 # GroupNorm(C, num_groups=1, affine=False) == LayerNorm(normalized_shape=[C, H, W], elementwise_affine=False) 4331 ((20, 6, 10, 10), 1, {'eps' : 1e-5}), 4332 ) 4333 4334 # num_channels is inferred to be input.shape[1] dimension 4335 for input_shape, num_groups, kwargs in cases: 4336 # Shape of weight and bias should be the same as num_channels 4337 channels = input_shape[1] if len(input_shape) > 1 else 0 4338 input_tensor = make_arg(input_shape) 4339 weight_tensor = make_arg(channels) 4340 bias_tensor = make_arg(channels) 4341 4342 # Checking for permutations of weights and biases as `None` 4343 weights = [weight_tensor, None] 4344 biases = [bias_tensor, None] 4345 for weight, bias in itertools.product(weights, biases): 4346 kwargs = { 4347 'weight': weight, 4348 'bias': bias, 4349 **kwargs 4350 } 4351 yield SampleInput(input_tensor, num_groups, **kwargs) 4352 4353 4354def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs): 4355 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4356 make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 4357 4358 # Ordered as: input shape, kwargs for momentum, eps 4359 cases: Tuple[Tuple[int], dict] = ( # type: ignore[assignment] 4360 ((S, S, S), {'momentum': 0.5, 'eps': 0.6}), 4361 ((S, S, S), {'momentum': 0.5, 'eps': 0.6, 'use_input_stats': True}), 4362 ((3, 2, 4), {'momentum': -1.2}), 4363 ((3, 2, 4), {'momentum': 0.0}), 4364 ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}), 4365 ((3, 2, 3, 4), {'momentum': -1.0, 'eps': 0.5}), 4366 ) 4367 4368 for input_shape, kwargs in cases: 4369 # args: running mean, running var, weight and bias should necessarily be of shape: (channels,) 4370 channels = input_shape[1] 4371 weight = make_arg(channels) 4372 bias = make_arg(channels) 4373 running_mean = make_arg_without_requires_grad(channels, low=0) 4374 running_var = make_arg_without_requires_grad(channels, low=0) 4375 new_kwargs = { 4376 'running_mean': running_mean, 4377 'running_var': running_var, 4378 'weight': weight, 4379 'bias': bias, 4380 **kwargs 4381 } 4382 4383 yield SampleInput( 4384 make_arg(input_shape), 4385 args=(), 4386 kwargs=new_kwargs 4387 ) 4388 4389 # Checking for permutations of weights and biases as `None` 4390 # instance_norm assumes that if there's a bias, there's a weight 4391 weights = [channels, None] 4392 biases = [None, None] 4393 4394 for weight_channels, bias_channels in zip(weights, biases): 4395 running_mean = make_arg_without_requires_grad(channels, low=0) 4396 running_var = make_arg_without_requires_grad(channels, low=0) 4397 yield SampleInput( 4398 make_arg(input_shape), 4399 args=(), 4400 kwargs={ 4401 'running_mean': running_mean, 4402 'running_var': running_var, 4403 'weight': make_arg(weight_channels) if weight_channels is not None else None, 4404 'bias': make_arg(bias_channels) if bias_channels is not None else None 4405 } 4406 ) 4407 4408 # Test case for no optional kwargs 4409 yield SampleInput(make_arg((1, 2, 3)), kwargs={}) 4410 4411def sample_inputs_safe_softmax(opinfo, device, dtype, requires_grad, **kwargs): 4412 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) 4413 4414 def make_bool_mask(*shape): 4415 return torch.randint(0, 2, shape, device=device, dtype=torch.bool) 4416 4417 def mask_two_rows(rows, cols): 4418 mask_two_rows = torch.ones((rows, cols), dtype=torch.bool, device=device) 4419 mask_two_rows[rows - 1] = False 4420 mask_two_rows[rows - 3] = False 4421 return mask_two_rows 4422 4423 def convert_to_float_mask(mask: torch.Tensor) -> torch.Tensor: 4424 return torch.where(~mask, float('-inf'), 0.0) 4425 4426 def with_requires_grad(tensor): 4427 return tensor.requires_grad_(requires_grad) 4428 4429 def generate_input_from_mask(mask_shape, dim): 4430 mask = make_bool_mask(*mask_shape) 4431 input_tensor = make_arg(mask_shape) 4432 masked_input = input_tensor + convert_to_float_mask(mask) 4433 return SampleInput(with_requires_grad(masked_input), kwargs={'dim': dim}) 4434 4435 samples = [ 4436 # Basic 3D tensor with mask 4437 generate_input_from_mask((2, 3, 4), dim=1), 4438 # 2D tensor with mask, testing different dim 4439 generate_input_from_mask((5, 5), dim=0), 4440 # 4D tensor, testing with a different dim 4441 generate_input_from_mask((2, 3, 4, 5), dim=2), 4442 # Edge case: 1D tensor 4443 generate_input_from_mask((10,), dim=0), 4444 # Edge case: tensor with one dimension of size 1 4445 generate_input_from_mask((1, 5, 5), dim=1), 4446 # Testing with all elements masked 4447 SampleInput( 4448 with_requires_grad( 4449 make_arg((3, 3)) 4450 + convert_to_float_mask( 4451 torch.zeros((3, 3), dtype=torch.bool, device=device) 4452 ) 4453 ), 4454 kwargs={"dim": 1}, 4455 ), 4456 # Testing with no elements masked 4457 SampleInput( 4458 with_requires_grad( 4459 make_arg((3, 3)) 4460 + convert_to_float_mask( 4461 torch.ones((3, 3), dtype=torch.bool, device=device) 4462 ) 4463 ), 4464 kwargs={"dim": 1}, 4465 ), 4466 # Testing with two rows masked 4467 SampleInput( 4468 with_requires_grad( 4469 make_arg((6, 3)) + convert_to_float_mask(mask_two_rows(6, 3)) 4470 ), 4471 kwargs={"dim": 1}, 4472 ), 4473 ] 4474 yield from samples 4475 4476def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs): 4477 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4478 4479 # Ordered as input shape, normalized_shape and a kwarg dict for eps 4480 cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment] 4481 ((1, 2, 3), (1, 2, 3), {'eps': 0.5}), 4482 ((2, 2, 3), (2, 3), {'eps': -0.5}), 4483 ((1,), (1,), {}), 4484 ((1, 2), (2,), {}), 4485 ((0, 1), (1,), {}), 4486 ) 4487 4488 for input_shape, normalized_shape, kwargs in cases: 4489 # Shape of weight and bias should be the same as normalized_shape 4490 weight = make_arg(normalized_shape) 4491 bias = make_arg(normalized_shape) 4492 yield SampleInput( 4493 make_arg(input_shape), 4494 args=(normalized_shape, weight, bias), 4495 kwargs=kwargs 4496 ) 4497 # Without any optional args 4498 yield SampleInput(make_arg((1, 2)), args=((2,),)) 4499 4500 # TODO: @krshrimali, once to_numpy method in SampleInput class is modified to take None inputs, 4501 # enable these inputs; see https://github.com/pytorch/pytorch/pull/63276#discussion_r691950400 4502 4503 # With weight and a `None` bias 4504 # yield SampleInput(make_arg((1, 2)), args=((2,), make_arg((2,)), None)) 4505 4506 # With `None` weight and bias (tests failing for this, see the link above) 4507 # yield SampleInput(make_arg((1, 2)), args=((2,), None, make_arg((2,)))) 4508 4509 4510def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwargs): 4511 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4512 4513 # Ordered as input shape, normalized_shape, eps 4514 cases: Tuple[Tuple[int], Tuple[int], float] = ( # type: ignore[assignment] 4515 ((1, 2, 3), (1, 2, 3), 0.5), 4516 ((2, 2, 3), (2, 3), -0.5), 4517 ((1,), (1,), 1e-5), 4518 ((1, 2), (2,), 1e-5), 4519 ((0, 1), (1,), 1e-5), 4520 ) 4521 4522 for input_shape, normalized_shape, eps in cases: 4523 # Shape of weight and bias should be the same as normalized_shape 4524 weight = make_arg(normalized_shape) 4525 bias = make_arg(normalized_shape) 4526 yield SampleInput( 4527 make_arg(input_shape), 4528 args=(normalized_shape, weight, bias, eps), 4529 ) 4530 yield SampleInput( 4531 make_arg(input_shape), 4532 args=(normalized_shape, None, bias, eps), 4533 ) 4534 yield SampleInput( 4535 make_arg(input_shape), 4536 args=(normalized_shape, weight, None, eps), 4537 ) 4538 yield SampleInput( 4539 make_arg(input_shape), 4540 args=(normalized_shape, None, None, eps), 4541 ) 4542 4543def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs): 4544 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4545 4546 # Ordered as input shape, normalized_shape and a kwarg dict for eps 4547 cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment] 4548 ((1, 2, 3), (1, 2, 3), {'eps': 0.5}), 4549 ((2, 2, 3), (2, 3), {'eps': -0.5}), 4550 ((1,), (1,), {}), 4551 ((1, 2), (2,), {}), 4552 ((0, 1), (1,), {}), 4553 ) 4554 4555 for input_shape, normalized_shape, kwargs in cases: 4556 # Shape of weight and bias should be the same as normalized_shape 4557 weight = make_arg(normalized_shape) 4558 yield SampleInput( 4559 make_arg(input_shape), 4560 args=(normalized_shape, weight), 4561 kwargs=kwargs 4562 ) 4563 # Without any optional args 4564 yield SampleInput(make_arg((1, 2)), args=((2,),)) 4565 4566def error_inputs_group_norm(opinfo, device, **kwargs): 4567 make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) 4568 4569 # check that input has minimum number of dimensions 4570 err_msg1 = "Expected at least 2 dimensions for input tensor but received" 4571 s1 = SampleInput(make_arg(1), args=(1,)) 4572 yield ErrorInput(s1, error_regex=err_msg1) 4573 4574 # check that the channels dimension is compatible with number of groups 4575 err_msg2 = "Expected number of channels in input to be divisible by num_groups, but got input of shape" 4576 s2 = SampleInput(make_arg((2, 7, 4)), args=(2,)) 4577 yield ErrorInput(s2, error_regex=err_msg2) 4578 4579def error_inputs_native_layer_norm(opinfo, device, **kwargs): 4580 make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) 4581 input_shape = (1, 2, 3) 4582 4583 err_msg1 = "Expected normalized_shape to be at least 1-dimensional" 4584 s1 = SampleInput( 4585 make_arg(input_shape), args=((), None, None, 1e-5) 4586 ) 4587 yield ErrorInput(s1, error_regex=err_msg1) 4588 4589 normalized_shape = (1, 2, 3) 4590 weight = make_arg((1, 2)) 4591 err_msg2 = "Expected weight to be of same shape as normalized_shape" 4592 s2 = SampleInput( 4593 make_arg(input_shape), args=(normalized_shape, weight, None, 1e-5) 4594 ) 4595 yield ErrorInput(s2, error_regex=err_msg2) 4596 4597 bias = make_arg((1, 2)) 4598 err_msg3 = "Expected bias to be of same shape as normalized_shape" 4599 s3 = SampleInput( 4600 make_arg(input_shape), args=(normalized_shape, None, bias, 1e-5) 4601 ) 4602 yield ErrorInput(s3, error_regex=err_msg3) 4603 4604 err_msg4 = "Given normalized_shape=" 4605 s4 = SampleInput( 4606 make_arg((2, 2, 3)), args=((2, 2), None, None, 1e-5) 4607 ) 4608 yield ErrorInput(s4, error_regex=err_msg4) 4609 4610def error_inputs_rms_norm(opinfo, device, **kwargs): 4611 make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) 4612 input_shape = (1, 2, 3) 4613 4614 err_msg1 = "Expected normalized_shape to be at least 1-dimensional" 4615 s1 = SampleInput( 4616 make_arg(input_shape), args=((), None, 1e-5) 4617 ) 4618 yield ErrorInput(s1, error_regex=err_msg1) 4619 4620 normalized_shape = (1, 2, 3) 4621 weight = make_arg((1, 2)) 4622 err_msg2 = "Expected weight to be of same shape as normalized_shape" 4623 s2 = SampleInput( 4624 make_arg(input_shape), args=(normalized_shape, weight, 1e-5) 4625 ) 4626 yield ErrorInput(s2, error_regex=err_msg2) 4627 4628 4629 err_msg4 = "Given normalized_shape=" 4630 s4 = SampleInput( 4631 make_arg((2, 2, 3)), args=((2, 2), None, 1e-5) 4632 ) 4633 yield ErrorInput(s4, error_regex=err_msg4) 4634 4635 4636def sample_inputs_local_response_norm(opinfo, device, dtype, requires_grad, **kwargs): 4637 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4638 4639 # Ordered as input shape, size and a kwarg dict for alpha, beta, and k 4640 cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment] 4641 ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), 4642 ((1, 6, 3), 2, {'beta': 0.5, 'k': 1.25}), 4643 ((1, 6, 3), 2, {'alpha': 3e-05, 'k': 1.25}), 4644 ((1, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5}), 4645 ((1, 6, 3), 2, {'alpha': 3e-05}), 4646 ((1, 6, 3), 2, {'beta': 0.5}), 4647 ((1, 6, 3), 2, {'k': 1.25}), 4648 ((1, 6, 3), 2, {}), 4649 ((2, 6, 3), 2, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), 4650 ((1, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), 4651 ((0, 1, 2), 1, {'alpha': 3e-05, 'beta': 0.5, 'k': 1.25}), 4652 ) 4653 4654 for input_shape, size, kwargs in cases: 4655 yield SampleInput(make_arg(input_shape), args=(size,), kwargs=kwargs) 4656 4657def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs): 4658 N = 5 4659 # make sure we are testing -3 -> 3 range. default is -10 -> 10 so maybe unnecessary ? 4660 make_arg = partial(make_tensor, device=device, dtype=dtype, 4661 requires_grad=requires_grad, low=-5, high=5) 4662 return (SampleInput(make_arg((N * 2, N * 2))) for _ in range(1, N)) 4663 4664def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs): 4665 features_options = [[3, 4], [8, 8]] 4666 batch_options: List[List[int]] = [ 4667 [], # no batch 4668 [0], 4669 [8], 4670 [2, 3], 4671 ] 4672 create_tensor = partial(make_tensor, device=device, dtype=dtype, 4673 requires_grad=requires_grad, low=-2, high=2) 4674 4675 for has_bias, (in_feat, out_feat), batch_shape in \ 4676 itertools.product([True, False], features_options, batch_options): 4677 input_tensor = create_tensor(batch_shape + [in_feat]) 4678 weight = create_tensor([out_feat, in_feat]) 4679 if not has_bias: 4680 yield SampleInput(input_tensor, weight) 4681 continue 4682 4683 bias = create_tensor([out_feat]) 4684 yield SampleInput(input_tensor, weight, bias) 4685 4686 # 5D tensor, used to crash on MPS, see https://github.com/pytorch/pytorch/issues/114942 4687 yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2)) 4688 yield SampleInput(create_tensor(2, 1, 2, 1, 2), create_tensor(4, 2), create_tensor(4)) 4689 4690def sample_inputs_bilinear(self, device, dtype, requires_grad, **kwargs): 4691 features_options = [[3, 4, 5], [8, 8, 8]] 4692 batch_options: List[List[int]] = [ 4693 [], # no batch 4694 [0], 4695 [8], 4696 [2, 3], 4697 ] 4698 create_tensor = partial(make_tensor, device=device, dtype=dtype, 4699 requires_grad=requires_grad, low=-2, high=2) 4700 4701 for has_bias, (in_feat1, in_feat2, out_feat), batch_shape in \ 4702 itertools.product([True, False], features_options, batch_options): 4703 input_tensor1 = create_tensor(batch_shape + [in_feat1]) 4704 input_tensor2 = create_tensor(batch_shape + [in_feat2]) 4705 weight = create_tensor([out_feat, in_feat1, in_feat2]) 4706 if not has_bias: 4707 yield SampleInput(input_tensor1, input_tensor2, weight) 4708 continue 4709 bias = create_tensor([out_feat]) 4710 yield SampleInput(input_tensor1, input_tensor2, weight, bias) 4711 4712def sample_inputs_glu(self, device, dtype, requires_grad, **kwargs): 4713 features_options = [[2], [2, 4], [8, 8], [3, 6, 8], [1, 4, 6, 7]] 4714 batch_options: List[List[int]] = [ 4715 [], # no batch 4716 [0], 4717 [8], 4718 [2, 3], 4719 ] 4720 create_tensor = partial(make_tensor, device=device, dtype=dtype, 4721 requires_grad=requires_grad, low=-2, high=2) 4722 4723 for features, batch_shape in itertools.product(features_options, batch_options): 4724 ndim = len(features) + len(batch_shape) 4725 for dim in range(ndim): 4726 input_tensor = create_tensor(batch_shape + features) 4727 dim_size = input_tensor.size(dim) 4728 if dim_size > 0 and dim_size % 2 == 0: 4729 yield SampleInput(input_tensor, dim) 4730 4731def sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs): 4732 N, C = 2, 3 4733 D = 4 4734 S = 3 4735 L = 5 4736 4737 align_corners_options: Tuple[Any, ...] = (None,) 4738 if mode in ('linear', 'bilinear', 'bicubic', 'trilinear'): 4739 align_corners_options = (True, False, None) 4740 ranks_for_mode = { 4741 'nearest': [1, 2, 3], 4742 'nearest-exact': [1, 2, 3], 4743 'linear': [1], 4744 'bilinear': [2], 4745 'bicubic': [2], 4746 'trilinear': [3], 4747 'area': [1, 2, 3] 4748 } 4749 4750 def shape(size, rank, with_batch_channel=True): 4751 if with_batch_channel: 4752 return tuple([N, C] + ([size] * rank)) 4753 return tuple([size] * rank) 4754 4755 if mode in ('bilinear', 'bicubic') and dtype == torch.uint8: 4756 make_arg = partial( 4757 make_tensor, 4758 device=device, 4759 dtype=dtype, 4760 requires_grad=requires_grad, 4761 # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype 4762 high=256 if dtype == torch.uint8 else None, 4763 ) 4764 # provide few samples for a more close to typical image processing usage 4765 rank = 2 4766 for memory_format in [torch.contiguous_format, torch.channels_last]: 4767 yield SampleInput( 4768 make_arg(shape(270, rank), memory_format=memory_format), 4769 shape(130, rank, False), 4770 scale_factor=None, 4771 mode=mode, 4772 align_corners=False, 4773 ) 4774 4775 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4776 4777 for align_corners in align_corners_options: 4778 for rank in ranks_for_mode[mode]: 4779 yield SampleInput( 4780 make_arg(shape(D, rank)), 4781 shape(S, rank, False), 4782 scale_factor=None, 4783 mode=mode, 4784 align_corners=align_corners, 4785 ) 4786 yield SampleInput( 4787 make_arg(shape(D, rank)), 4788 shape(L, rank, False), 4789 scale_factor=None, 4790 mode=mode, 4791 align_corners=align_corners, 4792 ) 4793 for recompute_scale_factor in [False, True]: 4794 for scale_factor in [1.7, 0.6]: 4795 yield SampleInput( 4796 make_arg(shape(D, rank)), 4797 size=None, 4798 scale_factor=scale_factor, 4799 mode=mode, 4800 align_corners=align_corners, 4801 recompute_scale_factor=recompute_scale_factor, 4802 ) 4803 4804def reference_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs): 4805 yield from sample_inputs_interpolate(mode, self, device, dtype, requires_grad, **kwargs) 4806 4807 if mode in ('bilinear', 'bicubic'): 4808 make_arg = partial( 4809 make_tensor, 4810 device=device, 4811 dtype=dtype, 4812 requires_grad=requires_grad, 4813 # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype 4814 high=256 if dtype == torch.uint8 else None, 4815 ) 4816 # provide few samples for more typical image processing usage 4817 for memory_format in [torch.contiguous_format, torch.channels_last]: 4818 for aa in [True, False]: 4819 yield SampleInput( 4820 make_arg((2, 3, 345, 456), memory_format=memory_format), 4821 (270, 270), 4822 scale_factor=None, 4823 mode=mode, 4824 align_corners=False, 4825 antialias=aa, 4826 ) 4827 4828def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): 4829 N, C = 2, 3 4830 D = 4 4831 S = 3 4832 L = 5 4833 4834 ranks_for_mode = { 4835 'nearest': [1, 2, 3], 4836 'bilinear': [2], 4837 } 4838 4839 def shape(size, rank, with_batch_channel=True): 4840 if with_batch_channel: 4841 return torch.Size([N, C] + ([size] * rank)) 4842 return torch.Size([size] * rank) 4843 4844 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4845 4846 for rank in ranks_for_mode[mode]: 4847 yield SampleInput(make_arg(shape(D, rank)), size=shape(S, rank, False)) 4848 yield SampleInput(make_arg(shape(D, rank)), size=shape(L, rank, False)) 4849 yield SampleInput(make_arg(shape(D, rank)), scale_factor=1.7) 4850 yield SampleInput(make_arg(shape(D, rank)), scale_factor=0.6) 4851 4852def reference_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): 4853 yield from sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs) 4854 4855 if mode in ('bilinear', ): 4856 make_arg = partial( 4857 make_tensor, 4858 device=device, 4859 dtype=dtype, 4860 requires_grad=requires_grad, 4861 # we pick more realistic upper bound 256 instead of default 10 for uint8 dtype 4862 high=256 if dtype == torch.uint8 else None, 4863 ) 4864 # provide a single sample for more typical image processing usage 4865 for memory_format in [torch.contiguous_format, torch.channels_last]: 4866 yield SampleInput( 4867 make_arg((2, 3, 345, 456), memory_format=memory_format), 4868 (270, 270), 4869 ) 4870 4871def sample_inputs_upsample_aa(mode, self, device, dtype, requires_grad, **kwargs): 4872 N = 6 4873 C = 3 4874 H = 10 4875 W = 20 4876 S = 3 4877 L = 5 4878 4879 input_tensor = make_tensor(torch.Size([N, C, H, W]), device=device, dtype=dtype, requires_grad=requires_grad) 4880 4881 yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scale_factors=None) 4882 yield SampleInput(input_tensor, output_size=torch.Size([L, L]), align_corners=False, scale_factors=None) 4883 yield SampleInput(input_tensor, output_size=None, align_corners=False, scale_factors=[1.7, 0.9]) 4884 yield SampleInput(input_tensor, output_size=None, align_corners=True, scale_factors=[0.8, 1.0]) 4885 4886 yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=None, scales_w=None) 4887 yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=False, scales_h=1.7, scales_w=0.9) 4888 yield SampleInput(input_tensor, output_size=torch.Size([S, S]), align_corners=True, scales_h=1.7, scales_w=0.9) 4889 4890def sample_inputs_gelu(self, device, dtype, requires_grad, **kwargs): 4891 N = 5 4892 for _ in range(1, N): 4893 for approximate in ['none', 'tanh']: 4894 yield SampleInput( 4895 make_tensor((N * 2, N * 2), device=device, dtype=dtype, 4896 requires_grad=requires_grad, low=-3, high=3), 4897 approximate=approximate) 4898 4899 4900def error_inputs_gelu(op, device, **kwargs): 4901 # Tests that gelu errors out when passed an approximation we don't know. 4902 yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device), kwargs={"approximate": "asdf"}), 4903 error_regex="approximate argument must be either") 4904 4905 4906def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs): 4907 inputs = [] 4908 args_for_reduction_with_dim = ( 4909 ((S, S, S), (1,),), 4910 ((S, S, S), (1, True, ),), 4911 ((), (0,),), 4912 ((), (0, True,),), 4913 ) 4914 return ((SampleInput(make_tensor(input_tensor, dtype=dtype, device=device, 4915 low=None, high=None, 4916 requires_grad=requires_grad), 4917 *args)) 4918 for input_tensor, args in args_for_reduction_with_dim) 4919 4920def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad, **kwargs): 4921 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) 4922 yield SampleInput(make_arg((S, S, S))) 4923 yield SampleInput(make_arg(())) 4924 4925def _generate_nan_reduction_inputs(device, dtype, requires_grad, **kwargs): 4926 yield from _generate_reduction_inputs(device, dtype, requires_grad) 4927 # NaN only exists for floating point numbers 4928 if dtype.is_complex or dtype.is_floating_point: 4929 yield torch.tensor([2, torch.nan, -1], device=device, dtype=dtype, requires_grad=requires_grad) 4930 yield torch.tensor([[torch.nan, 2], [0, 1]], device=device, dtype=dtype, requires_grad=requires_grad) 4931 4932def sample_inputs_nan_reduction(supports_multiple_dims): 4933 # Generates sample inputs for reduction ops that contain the input tensor 4934 # and dim and keepdim kwargs. If a reduction op needs to test additional 4935 # args/kwargs then create a separate sample_inputs function 4936 def fn(op_info, device, dtype, requires_grad, **kwargs): 4937 for t in _generate_nan_reduction_inputs(device, dtype, requires_grad): 4938 # Add case without dim and keepdim kwargs 4939 yield SampleInput(t.clone().requires_grad_(requires_grad)) 4940 for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims): 4941 yield SampleInput(t.clone().requires_grad_(requires_grad), **kwargs) 4942 4943 return fn 4944 4945def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad, **kwargs): 4946 test_quantiles = (0.5, make_tensor((2,), dtype=dtype, device=device, low=0, high=1, requires_grad=requires_grad)) 4947 test_interpolations = ['linear', 'midpoint'] 4948 4949 for quantiles in test_quantiles: 4950 for t in _generate_reduction_inputs(device, dtype, requires_grad): 4951 # Add case without dim and keepdim kwargs 4952 input = t.clone().requires_grad_(requires_grad) 4953 yield SampleInput(input, quantiles) 4954 for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims=False): 4955 # Interpolation kwarg for now is only supported when providing both dim and keepdim 4956 kwargs.setdefault('dim', 0) 4957 kwargs.setdefault('keepdim', False) 4958 for interpolation in test_interpolations: 4959 kwargs['interpolation'] = interpolation 4960 input = t.clone().requires_grad_(requires_grad) 4961 yield SampleInput(input, quantiles, **kwargs) 4962 4963def sample_inputs_reduction_count_nonzero(*args, **kwargs): 4964 """Sample inputs for count_nonzero""" 4965 # count_nonzero does not support keepdim yet 4966 for sample in sample_inputs_reduction(*args, **kwargs): 4967 sample.kwargs.pop('keepdim', None) 4968 yield sample 4969 4970def sample_inputs_leaky_relu(op_info, device, dtype, requires_grad, **kwargs): 4971 N = 10 4972 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4973 return (SampleInput(make_arg((N, N))) for _ in range(1, N)) 4974 4975def sample_inputs_fractional_max_pool2d(op_info, device, dtype, requires_grad, **kwargs): 4976 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 4977 4978 # Order: input_shape, kernel_size 4979 cases = (((1, 3, 9, 9), 3), 4980 ((1, 3, 9, 9), (4, 4)), 4981 ((1, 3, 9, 9), (6, 6)), 4982 ((2, 3, 9, 9), (3, 3)), 4983 ((1, 1, 4, 4), (2, 2)), 4984 ((1, 2, 6, 6), (4, 4))) 4985 4986 for input_shape, kernel_size in cases: 4987 for return_indices in [False, True]: 4988 # test case passing a single output size 4989 yield SampleInput( 4990 make_arg(input_shape), 4991 kernel_size, 4992 output_size=2, 4993 return_indices=return_indices, 4994 ) 4995 4996 # test case passing a tuple output size 4997 yield SampleInput( 4998 make_arg(input_shape), 4999 kernel_size, 5000 output_size=(2, 3), 5001 return_indices=return_indices, 5002 ) 5003 5004 # test case passing an output ratio 5005 yield SampleInput( 5006 make_arg(input_shape), 5007 kernel_size, 5008 output_ratio=(0.5, 0.5), 5009 return_indices=return_indices, 5010 ) 5011 5012def sample_inputs_fractional_max_pool3d(op_info, device, dtype, requires_grad, **kwargs): 5013 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5014 5015 # Order: input_shape, kernel_size 5016 cases = (((2, 3, 5, 5, 5), (2, 2, 2)), 5017 ((1, 2, 6, 5, 4), 2), 5018 ((1, 2, 5, 6, 5), (2, 3, 2)), 5019 ((1, 2, 6, 6, 6), (2, 3, 2)), 5020 ((1, 1, 7, 6, 7), (2, 3, 4)), 5021 ((1, 1, 4, 5, 4), (2, 2, 1)), 5022 ((1, 1, 8, 7, 6), (4, 3, 2)), 5023 ((0, 1, 4, 5, 4), (2, 2, 1))) 5024 5025 for input_shape, kernel_size in cases: 5026 for return_indices in [False, True]: 5027 # test case passing a single output size 5028 yield SampleInput( 5029 make_arg(input_shape), 5030 kernel_size, 5031 output_size=2, 5032 return_indices=return_indices, 5033 ) 5034 5035 # test case passing a tuple output size 5036 yield SampleInput( 5037 make_arg(input_shape), 5038 kernel_size, 5039 output_size=(2, 3, 2), 5040 return_indices=return_indices, 5041 ) 5042 5043 # test case passing an output ratio 5044 yield SampleInput( 5045 make_arg(input_shape), 5046 kernel_size, 5047 output_ratio=(0.5, 0.5, 0.5), 5048 return_indices=return_indices, 5049 ) 5050 5051def sample_inputs_avgpool2d(op_info, device, dtype, requires_grad, **kwargs): 5052 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5053 5054 # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override 5055 cases = (((1, 3, 9, 9), 3, 1, 1, True, False, 2), 5056 ((1, 3, 9, 9), (4, 4), (2, 3), 1, True, False, 2), 5057 ((1, 3, 9, 9), (6, 6), (3, 3), (2, 3), True, True, 2), 5058 ((2, 3, 9, 9), (3, 3), (1, 1), (1, ), True, False, 2), 5059 ((1, 1, 4, 4), (2, 2), (), (0, ), False, True, -2), 5060 ((1, 2, 6, 6), (4, 4), (2, 2), (2, ), True, True, None)) 5061 5062 for input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override in cases: 5063 yield SampleInput(make_arg(input_shape), 5064 args=(kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)) 5065 # Case with just input_shape and kernel_size 5066 yield SampleInput(make_arg((1, 3, 9, 9)), args=((3, 3))) 5067 5068def sample_inputs_avgpool1d(op_info, device, dtype, requires_grad, **kwargs): 5069 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5070 5071 # Order: input_shape, kernel_size, kwargs 5072 cases: List[Tuple[Tuple[int, ...], Union[int, Tuple[int, ...]], Dict]] = [ 5073 ((2, 3, 9), (3,), {}), 5074 ((1, 3, 9), 3, dict(stride=1, padding=1, ceil_mode=True, count_include_pad=False)), 5075 ((1, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=True, count_include_pad=True)), 5076 ((2, 3, 9), (3,), dict(stride=(1,), padding=(1,), ceil_mode=False, count_include_pad=True)), 5077 ((0, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=False, count_include_pad=True)), 5078 ((1, 2, 9), (7,), dict(stride=(3,), padding=(2,), ceil_mode=False)), 5079 ((1, 2, 9), (7,), dict(stride=(3,), padding=(3,), ceil_mode=True)), 5080 ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=False)), 5081 ((1, 2, 9), (7,), dict(stride=(3,), ceil_mode=True)), 5082 ] 5083 5084 for input_shape, kernel_size, kwargs in cases: 5085 yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs) 5086 5087def sample_inputs_avgpool3d(op_info, device, dtype, requires_grad, **kwargs): 5088 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5089 5090 # Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override 5091 cases: List[Tuple[Tuple[int, ...], Union[int, Tuple[int, ...]], Dict]] = [ 5092 ((2, 3, 3, 4, 4), (2, 2, 2), {}), 5093 ((1, 2, 4, 4, 4), 2, dict(stride=1, padding=1, ceil_mode=True, 5094 count_include_pad=False, divisor_override=2)), 5095 ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=True, 5096 count_include_pad=True, divisor_override=2)), 5097 ((1, 2, 5, 5, 5), (2, 3, 4), dict(stride=(1, 2, 2), padding=(0, 1, 2), ceil_mode=False)), 5098 ((1, 1, 7, 5, 7), (6, 3, 4), dict(stride=(2, 3, 2), padding=(3, 1, 0), ceil_mode=False, 5099 count_include_pad=False, divisor_override=2)), 5100 ((1, 1, 4, 5, 4), (2, 2, 3), dict(stride=(2, 2, 1), padding=0, ceil_mode=False, 5101 count_include_pad=True, divisor_override=-2)), 5102 ((1, 1, 6, 5, 6), (4, 5, 6), dict(stride=(2, 3, 2), padding=2, ceil_mode=True, 5103 count_include_pad=True, divisor_override=None)), 5104 ((0, 1, 4, 5, 4), (2, 3, 1), dict(stride=(2, 1, 2), padding=0, ceil_mode=False, 5105 count_include_pad=True, divisor_override=None)), 5106 ] 5107 5108 for input_shape, kernel_size, kwargs in cases: 5109 yield SampleInput(make_arg(input_shape), args=(kernel_size,), kwargs=kwargs) 5110 5111def error_inputs_avg_pool1d(op_info, device, **kwargs): 5112 # error inputs when pad is negative 5113 x = torch.rand([0, 1, 49], dtype=torch.float32) 5114 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), 5115 error_regex='pad must be non-negative') 5116 5117 # error inputs when pad > kernel_size / 2 5118 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), 5119 error_regex='pad should be at most half of effective kernel size') 5120 5121def error_inputs_avg_pool2d(op_info, device, **kwargs): 5122 # error inputs when pad is negative 5123 x = torch.rand([0, 1, 49], dtype=torch.float32) 5124 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), 5125 error_regex='pad must be non-negative') 5126 # 2-dimensional kernel 5127 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': -1}), 5128 error_regex='pad must be non-negative') 5129 5130 # error inputs when pad > kernel_size / 2 5131 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), 5132 error_regex='pad should be at most half of effective kernel size') 5133 # 2-dimensional kernel 5134 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2), 'stride': 50, 'padding': 4}), 5135 error_regex='pad should be at most half of effective kernel size') 5136 5137 # error inputs for zero divisor 5138 x = torch.zeros(3, 3, 3) 5139 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2), 'divisor_override': 0}), 5140 error_regex='divisor must be not zero') 5141 5142def error_inputs_avg_pool3d(op_info, device, **kwargs): 5143 # error inputs when pad is negative 5144 x = torch.rand([0, 1, 49, 50], dtype=torch.float32) 5145 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': -1}), 5146 error_regex='pad must be non-negative') 5147 # 3-dimensional kernel 5148 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': -1}), 5149 error_regex='pad must be non-negative') 5150 5151 # error inputs when pad > kernel_size / 2 5152 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 4}), 5153 error_regex='pad should be at most half of effective kernel size') 5154 # 3-dimensional kernel 5155 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (3, 2, 2), 'stride': 50, 'padding': 4}), 5156 error_regex='pad should be at most half of effective kernel size') 5157 5158 # error inputs for zero divisor 5159 x = torch.zeros(3, 3, 3, 3) 5160 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': (2, 2, 2), 'divisor_override': 0}), 5161 error_regex='divisor must be not zero') 5162 5163 # error inputs for invalid input dimension 5164 x = torch.rand([0, 1, 49], dtype=torch.float32) 5165 yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 50, 'padding': 0}), 5166 error_regex='non-empty 4D or 5D') 5167 5168 5169def sample_inputs_to(op_info, device, dtype, requires_grad, **kwargs): 5170 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5171 # test_multiple_devices_to_cuda would fail if we use a different device than given 5172 devices = [device] 5173 if torch.device(device).type == 'cpu': 5174 devices = [torch.device('cpu'), torch.device('cuda:0')] if torch.cuda.is_available() else devices 5175 memory_formats = [torch.preserve_format, torch.channels_last] 5176 5177 # TODO: can't switch `to.device` overload to use positional arguments 5178 # https://github.com/pytorch/pytorch/issues/84265 5179 # to.device overload 5180 for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats): 5181 kwargs = { 5182 "memory_format": mem_f, 5183 } 5184 yield SampleInput(make_arg((S, S, S, S)), args=(device, torch.float64, nb, cp), kwargs=kwargs) 5185 5186 # to.dtype overload 5187 for nb, cp, mem_f in product([True, False], [True, False], memory_formats): 5188 kwargs = { 5189 "memory_format": mem_f, 5190 } 5191 yield SampleInput(make_arg((S, S, S, S)), args=(torch.float64, nb, cp), kwargs=kwargs) 5192 5193 # to.other overload 5194 for device, nb, cp, mem_f in product(devices, [True, False], [True, False], memory_formats): 5195 kwargs = { 5196 "memory_format": mem_f, 5197 } 5198 other = make_arg((S, S, S, S), dtype=torch.float64, device=device) 5199 yield SampleInput(make_arg((S, S, S, S)), args=(other, nb, cp), kwargs=kwargs) 5200 5201 5202def sample_inputs_topk(op_info, device, dtype, requires_grad, **kwargs): 5203 def get_tensor_input(size): 5204 return make_tensor(size, dtype=dtype, device=device, requires_grad=requires_grad) 5205 5206 yield SampleInput(get_tensor_input((S, M, S)), 3) 5207 yield SampleInput(get_tensor_input((S, M, S)), 3, 1) 5208 yield SampleInput(get_tensor_input((S, M, S)), 3, -2) 5209 yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True) 5210 yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True) 5211 yield SampleInput(get_tensor_input((S, M, S)), 3, 1, True, True) 5212 yield SampleInput(get_tensor_input((S, M, S)), 3, -2, True, True) 5213 5214 yield SampleInput(get_tensor_input(()), 1) 5215 yield SampleInput(get_tensor_input(()), 1, 0) 5216 yield SampleInput(get_tensor_input(()), 1, -1) 5217 yield SampleInput(get_tensor_input(()), 1, 0, True) 5218 yield SampleInput(get_tensor_input(()), 1, -1, True) 5219 yield SampleInput(get_tensor_input(()), 1, 0, True, True) 5220 yield SampleInput(get_tensor_input(()), 1, -1, True, True) 5221 5222def sample_inputs_outer(op_info, device, dtype, requires_grad, **kwargs): 5223 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5224 yield SampleInput(make_arg(S), make_arg(M)) 5225 5226def sample_inputs_dist(op_info, device, dtype, requires_grad, **kwargs): 5227 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5228 sizes = ((S, S, S), (S,), (S, 1, S), (), (S, S)) 5229 ps = (2, 4) 5230 5231 for size_x, size_y, p in product(sizes, sizes, ps): 5232 yield SampleInput(make_arg(size_x), args=(make_arg(size_y), p)) 5233 5234# Missing to test the nondeterminism of the operation 5235# https://github.com/pytorch/pytorch/issues/53352 5236def sample_inputs_index(op_info, device, dtype, requires_grad, reference=False, **kwargs): 5237 # target.index_select(dim, idx) 5238 select = "index_select" in op_info.name 5239 # target.index_add(dim, idx, source, *, alpha=1) 5240 add = "index_add" in op_info.name 5241 # target.index_copy(dim, idx, source) 5242 copy = "index_copy" in op_info.name 5243 # target.index_fill(dim, idx, value) 5244 fill = "index_fill" in op_info.name 5245 5246 # Extended reference inputs. We generate that exercise atomic adds / writing 5247 # several times to one location 5248 if reference: 5249 make_arg = partial(torch.ones, device=device, dtype=dtype, requires_grad=requires_grad) 5250 make_idx = partial(torch.zeros, device=device, dtype=torch.int64) 5251 else: 5252 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5253 # idx They need to be different for copy and add to be deterministic 5254 if copy or add: 5255 make_idx = partial(torch.randperm, device=device, dtype=torch.int64) 5256 else: 5257 def make_idx(n): 5258 return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=n) 5259 5260 shapes = [(), (1,), (S, S)] 5261 # extra parameter for add 5262 if add: 5263 if dtype == torch.bool: 5264 alphas = (True, False) 5265 else: 5266 alphas = (-1, 0, 2) 5267 else: 5268 alphas = (None,) 5269 5270 if fill: 5271 # A weird number to catch errors. 5272 # The former one tests `index_fill.int_Scalar`, and the latter one tests `index_fill.int_Tensor`. 5273 values = (make_arg((1,)).item(), make_arg(())) 5274 else: 5275 values = (None,) 5276 5277 for shape, alpha, value in product(shapes, alphas, values): 5278 t = make_arg(shape) 5279 args = [] 5280 5281 # dim. We handle the scalar case 5282 dim = -1 if t.ndim == 2 else 0 5283 args.append(dim) 5284 5285 idx = make_idx(t.shape[dim] if t.ndim != 0 else 1) 5286 args.append(idx) 5287 5288 # source 5289 if copy or add: 5290 args.append(make_arg(shape)) 5291 elif fill: 5292 args.append(value) 5293 5294 args = tuple(args) 5295 kwargs = {} if alpha is None else {"alpha": alpha} 5296 5297 yield SampleInput(t, args=args, kwargs=kwargs) 5298 5299def sample_inputs_index_reduce(op_info, device, dtype, requires_grad, **kwargs): 5300 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5301 5302 def make_idx(n, m): 5303 return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m) 5304 5305 shapes = [((), ()), ((1,), (1,)), ((S, S), (S, M)), ((S, S, S), (S, M, S))] 5306 include_selfs = (True, False) 5307 reduce = op_info.variant_test_name 5308 assert reduce in ('prod', 'mean', 'amin', 'amax') 5309 5310 for shape, include_self in product(shapes, include_selfs): 5311 self_shape, src_shape = shape 5312 # dim. We handle the scalar case 5313 dim = 1 if len(self_shape) >= 2 else 0 5314 idx = make_idx(src_shape[dim] if len(src_shape) != 0 else 1, 5315 self_shape[dim] if len(self_shape) != 0 else 1) 5316 args = (dim, idx, make_arg(src_shape), reduce) 5317 yield SampleInput(make_arg(self_shape), 5318 args=args, 5319 kwargs={'include_self' : include_self}) 5320 5321 # Sample inputs to test edge cases for backward 5322 if requires_grad and reduce == 'prod': 5323 # Check that gradients are propagated correctly for prod when zeros in self/src are reduced 5324 # This sample tests gradients for the following cases 5325 # (a) 1 zero reduced (from source (self[0, 1]), from self (self[0, 0])) 5326 # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0], self[1, 1]) 5327 # (c) no zeros reduced (self[2, 1], self[2, 2]) 5328 # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py 5329 # test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad 5330 input = torch.tensor([[0, 13], [0, 0], [15, 19]], dtype=dtype, device=device, requires_grad=requires_grad) 5331 src = torch.tensor([[2, 0], [0, 0], [2, 3], [2, 2]], dtype=dtype, device=device, requires_grad=requires_grad) 5332 idx = torch.tensor([0, 1, 2, 0], dtype=torch.long, device=device) 5333 5334 yield SampleInput(input, 5335 args=(0, idx, src, reduce), 5336 kwargs={'include_self': True}) 5337 5338def sample_inputs__unsafe_masked_index(op_info, device, dtype, requires_grad, **kwargs): 5339 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5340 5341 def make_idx(n, m, dim, d): 5342 view_shape = [1] * dim 5343 view_shape[d] = n 5344 return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) 5345 5346 cases = [ 5347 ((S, S), S, M), 5348 ((S, S), M, S), 5349 ((S, S, S), S, M), 5350 ] 5351 5352 fill_value = make_tensor([], dtype=dtype, device="cpu").item() 5353 5354 for c in cases: 5355 self_shape, high, idx_size = c 5356 dim = len(self_shape) 5357 indices = [make_idx(idx_size, high, dim, d) for d in range(dim)] 5358 masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] 5359 mask = functools.reduce(torch.logical_and, masks) 5360 yield SampleInput(make_arg(self_shape), mask, indices, fill_value) 5361 5362 masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] 5363 mask = functools.reduce(torch.logical_and, masks) 5364 yield SampleInput(make_arg(self_shape), mask, indices, fill_value) 5365 5366def sample_inputs__unsafe_masked_index_put_accumulate(op_info, device, dtype, requires_grad, **kwargs): 5367 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5368 5369 def make_idx(n, m, dim, d): 5370 view_shape = [1] * dim 5371 view_shape[d] = n 5372 return make_tensor((n,), device=device, dtype=torch.int64, low=0, high=m).view(view_shape) 5373 5374 cases = [ 5375 ((S, S), S, (M, M)), 5376 ((S, S), M, (S, S + 1)), 5377 ((S, S, S), S, (M, M - 1, M + 1)), 5378 ] 5379 5380 fill_value = make_tensor([], dtype=dtype, device="cpu").item() 5381 5382 for c in cases: 5383 self_shape, high, idx_sizes = c 5384 dim = len(self_shape) 5385 indices = [make_idx(idx_sizes[d], high, dim, d) for d in range(dim)] 5386 masks = [torch.logical_and(idx >= 0, idx < self_shape[i]) for i, idx in enumerate(indices) if idx is not None] 5387 mask = functools.reduce(torch.logical_and, masks) 5388 values = make_arg(idx_sizes) 5389 yield SampleInput(make_arg(self_shape), mask, indices, values) 5390 5391 masks = [torch.logical_and(idx >= 1, idx < self_shape[i] - 1) for i, idx in enumerate(indices) if idx is not None] 5392 mask = functools.reduce(torch.logical_and, masks) 5393 yield SampleInput(make_arg(self_shape), mask, indices, values) 5394 5395 5396def sample_inputs_mode(op_info, device, dtype, requires_grad, **kwargs): 5397 args = ( 5398 ((S, S, S), (),), 5399 ((S, S, S), (1, ),), 5400 ((S, S, S), (1, True, ),), 5401 ((), (),), 5402 ((), (0,),), 5403 ((), (0, True,),), 5404 # Non-fused mode kernel on CUDA 5405 ((3000,), ()), 5406 ) 5407 make_arg = partial(make_tensor, dtype=dtype, device=device, 5408 requires_grad=requires_grad, low=None, high=None) 5409 return (SampleInput(make_arg(input_tensor), *args) 5410 for input_tensor, args in args) 5411 5412# Missing to test the nondeterminism of the operation 5413# https://github.com/pytorch/pytorch/issues/53352 5414def sample_inputs_put(op_info, device, dtype, requires_grad, **kwargs): 5415 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 5416 make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False) 5417 5418 S = 3 5419 5420 # Generic inputs 5421 idx = torch.randperm(S * S, device=device, dtype=torch.int64)[:S] 5422 idx_list = [idx, -idx - 1] 5423 for idx, acc in product(idx_list, (True, False)): 5424 yield SampleInput(input=make_arg((S, S)), 5425 args=(idx.clone(), 5426 make_arg((S,)), 5427 acc)) 5428 5429 # Scalar cases 5430 scalar_sizes = [(), (1,)] 5431 tgt_gen = (make_arg(size) for size in scalar_sizes) 5432 idx_gen = (make_idx(size, high=1) for size in scalar_sizes) 5433 src_gen = (make_arg(size) for size in scalar_sizes) 5434 for tgt, idx, src, acc in product(tgt_gen, idx_gen, src_gen, (True, False)): 5435 yield SampleInput(input=tgt.clone().requires_grad_(requires_grad), 5436 args=(idx.clone(), 5437 src.clone().requires_grad_(requires_grad), 5438 acc)) 5439 5440 # Empty cases 5441 tgt_sizes = [(0,), (), (1,), (3, 2)] 5442 tgt_gen = (make_arg(size) for size in tgt_sizes) 5443 idx = make_idx((0,), high=1) 5444 src = make_arg((0,)) 5445 for tgt, acc in product(tgt_gen, (True, False)): 5446 yield SampleInput(input=tgt.clone().requires_grad_(requires_grad), 5447 args=(idx.clone(), 5448 src.clone().requires_grad_(requires_grad), 5449 acc)) 5450 5451def sample_inputs_take(op_info, device, dtype, requires_grad, **kwargs): 5452 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 5453 make_idx = partial(make_tensor, low=0, dtype=torch.int64, device=device, requires_grad=False) 5454 5455 S = 3 5456 5457 # Generic inputs: take S elements out of S * S 5458 index = make_idx((S,), high=(S * S)) 5459 for idx in (index, -index - 1): 5460 yield SampleInput(input=make_arg((S, S)), args=(idx,)) 5461 5462 # Scalar cases 5463 scalar_sizes = [(), (1,)] 5464 src_gen = (make_arg(size) for size in scalar_sizes) 5465 idx_gen = (make_idx(size, high=1) for size in scalar_sizes) 5466 for src, idx in product(src_gen, idx_gen): 5467 yield SampleInput(input=src.clone().requires_grad_(requires_grad), 5468 args=(idx.clone(),)) 5469 5470 # Empty cases 5471 src_sizes = [(0,), (), (1,), (3, 2)] 5472 src_gen = (make_arg(size) for size in src_sizes) 5473 5474 idx = make_idx((0,), high=1) 5475 for src in src_gen: 5476 yield SampleInput(input=src.clone().requires_grad_(requires_grad), 5477 args=(idx.clone(),)) 5478 5479def sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs): 5480 make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) 5481 yield SampleInput(make_arg((4, 3, 2, 1)), [0, 1, 2, 3], [3, 2, 1, 0]) 5482 yield SampleInput(make_arg((4, 3, 2, 1)), [0, -1, -2, -3], [-3, -2, -1, -0]) 5483 5484def reference_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs): 5485 yield from sample_movedim_moveaxis(op_info, device, dtype, requires_grad, **kwargs) 5486 5487 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5488 5489 # shape, source, destination 5490 args = ( 5491 # empty inputs 5492 ((), (), ()), 5493 # int inputs, negative 5494 ((3, 5, 7, 2), -2, 1), 5495 # swap bounds 5496 ((3, 5, 7, 2), (-1, 0), (0, -1)), 5497 # non-sequential, negative 5498 ((2, 3, 4, 5, 6), (3, -3, 4), (1, 0, -1)), 5499 # idempotence, negative 5500 ((2, 3, 4, 5, 6), (-3, 4, 3, 1), (-3, 4, 3, 1)), 5501 # reverse, sequential, positive 5502 ((6, 2, 3, 5, 4), (4, 3, 2, 1, 0), (0, 1, 2, 3, 4)), 5503 # reverse, non-sequential 5504 ((6, 2, 3, 5, 4), (-3, -2, -4, -5, -1), (2, 1, 3, 4, 0)), 5505 # reverse, sequential, negative 5506 ((6, 2, 3, 5, 4), (4, -2, 2, -4, -5), (-5, 1, 2, -2, -1)), 5507 ) 5508 5509 for shape, source, destination in args: 5510 yield SampleInput(make_arg(shape), args=(source, destination)) 5511 5512def error_movedim_moveaxis(op_info, device, **kwargs): 5513 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 5514 5515 # source length < destination length 5516 yield ErrorInput( 5517 SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3), (1, 0, -1))), 5518 error_regex=(r"movedim: Invalid source or destination dims: source " 5519 r"\(\[3, -3\] dims\) should contain the same number of " 5520 r"dims as destination \(\[1, 0, -1\] dims\)"), 5521 ) 5522 5523 # source length > destination length 5524 yield ErrorInput( 5525 SampleInput(make_arg(2, 3, 4, 5, 6), args=((3, -3, 4), (1, 0))), 5526 error_regex=(r"movedim: Invalid source or destination dims: source " 5527 r"\(\[3, -3, 4\] dims\) should contain the same number of " 5528 r"dims as destination \(\[1, 0\] dims\)"), 5529 ) 5530 5531 # repeated source dim, with negative indices 5532 yield ErrorInput( 5533 SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 4, -5), (1, 0, 2))), 5534 error_regex=r"movedim: repeated dim in `source` \(\[0, 4, -5\]\)", 5535 ) 5536 5537 # repeated destination dim, with negative indices 5538 yield ErrorInput( 5539 SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, 2), (0, 4, -5))), 5540 error_regex=r"movedim: repeated dim in `destination` \(\[0, 4, -5\]\)", 5541 ) 5542 5543 # repeated dim (both), with negative indices 5544 yield ErrorInput( 5545 SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 0, -4), (0, 4, -5))), 5546 error_regex=r"movedim: repeated dim in `source` \(\[1, 0, -4\]\)", 5547 ) 5548 5549 # out of bounds source inputs, with negative indices 5550 yield ErrorInput( 5551 SampleInput(make_arg(2, 3, 4, 5, 6), args=((0, 1, -6), (1, 4, 2))), 5552 error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", 5553 error_type=IndexError, 5554 ) 5555 5556 # out of bounds destination inputs, with negative indices 5557 yield ErrorInput( 5558 SampleInput(make_arg(2, 3, 4, 5, 6), args=((1, 4, 2), (0, 1, -6))), 5559 error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", 5560 error_type=IndexError, 5561 ) 5562 5563 # out of bounds source input, int 5564 yield ErrorInput( 5565 SampleInput(make_arg(2, 3, 4, 5, 6), args=(-6, 1)), 5566 error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", 5567 error_type=IndexError, 5568 ) 5569 5570 # out of bounds destination input, int 5571 yield ErrorInput( 5572 SampleInput(make_arg(2, 3, 4, 5, 6), args=(3, -6)), 5573 error_regex=r"Dimension out of range \(expected to be in range of \[-5, 4\], but got -6\)", 5574 error_type=IndexError, 5575 ) 5576 5577def sample_repeat_tile(op_info, device, dtype, requires_grad, **kwargs): 5578 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 5579 rep_dims = ((), (0, ), (1, ), (0, 2), (1, 1), (2, 3), (2, 3, 2), (0, 2, 3), (2, 1, 1, 1),) 5580 shapes = ((), (0,), (2,), (3, 0), (3, 2), (3, 0, 1)) 5581 5582 if requires_grad: 5583 # Tests for variant_consistency_jit, grad, gradgrad 5584 # are slower. Use smaller bags of `rep_dims` and `shapes` 5585 # in this case. 5586 rep_dims = ((), (0, ), (0, 2), (1, 1), (2, 3), (1, 3, 2), (3, 1, 1)) # type: ignore[assignment] 5587 shapes = ((), (0,), (2,), (3, 2)) # type: ignore[assignment] 5588 5589 is_repeat_op = op_info.name in ['repeat', '_refs.repeat'] 5590 for rep_dim, shape in product(rep_dims, shapes): 5591 # `torch.repeat` errors for `len(rep_dims) < t.dim()`, 5592 # so we filter such combinations. 5593 if is_repeat_op and len(rep_dim) < len(shape): 5594 continue 5595 yield SampleInput(make_arg(shape), rep_dim) 5596 5597 5598def sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): 5599 shapes_and_args = ( 5600 ((S, S, S), 1, 2, 2), 5601 ((S, S, S), -1, 2, 2), 5602 ((S, S, S), 1, 0, 0), 5603 ((S, S, S), -1, 0, 0), 5604 ((S, S, S), 2, 1, 2), 5605 ) 5606 5607 for shape, dim, start, length in shapes_and_args: 5608 tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, 5609 requires_grad=requires_grad) 5610 yield SampleInput(tensor, dim, start, length) 5611 # narrow also accepts the start argument being a Tensor 5612 if is_narrow: 5613 yield SampleInput(tensor, dim, torch.tensor(start), length) 5614 5615def reference_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): 5616 yield from sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, is_narrow=is_narrow, **kwargs) 5617 5618 shapes_and_args = ( 5619 # 1-dim 5620 ((M,), 0, 0, 0), # 0 elems from the left 5621 ((M,), -1, -1, 0), # 0 elems from the right 5622 ((M,), 0, 5, 3), # 3 elems from the left 5623 ((M,), 0, -5, 2), # 2 elems from the right 5624 ((M,), -1, 0, M), # M elems from the left 5625 ((M,), 0, -M, M), # M elems from the right 5626 5627 # 2-dim 5628 ((M, S), 1, 0, 0), # dim 1, 0 elems from the left 5629 ((S, M), -2, -1, 0), # dim 0, 0 elems from the right 5630 ((L, S), 1, 2, 3), # dim 1, 3 elems from the left 5631 ((L, S), -1, 3, 2), # dim 1, 2 elems from the left 5632 ((M, L), 0, 0, M), # dim 0, M elems from the left 5633 ((M, L), -1, -L, L), # dim 1, L elems from the right 5634 5635 # 3-dim 5636 ((L, M, S), 2, 0, 0), # dim 2, 0 elems from the left 5637 ((M, S, L), -1, -1, 0), # dim 2, 0 elems from the right 5638 ((S, L, M), 2, 0, M), # dim 2, M elems from the left 5639 ((L, S, M), -1, -M, M), # dim 2, M elems from the right 5640 ((S, L, M), 1, 0, 0), # dim 1, 0 elems from the left 5641 ((S, L, M), 0, 2, 1), # dim 0, 1 elem from the left 5642 ((M, S, M), -1, -5, 4), # dim 2, 4 elems from the right 5643 ) 5644 5645 for shape, dim, start, length in shapes_and_args: 5646 tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, 5647 requires_grad=requires_grad) 5648 yield SampleInput(tensor, dim, start, length) 5649 # narrow also accepts the start argument being a Tensor 5650 if is_narrow: 5651 yield SampleInput(tensor, dim, torch.tensor(start), length) 5652 5653def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref): 5654 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 5655 5656 # 0-dim 5657 yield ErrorInput(SampleInput(make_arg(()), 0, 0, 1), 5658 error_type=RuntimeError, 5659 error_regex=r"narrow\(\) cannot be applied to a 0-dim tensor\.") 5660 5661 # out of bounds dim 5662 if not is_narrow and not is_ref and torch.device(device).type == 'cpu': 5663 # narrow_copy_dense_cpu_out 5664 yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), 5665 error_type=RuntimeError, 5666 error_regex=r"Expected dim < static_cast<int64_t>\(self_sizes.size\(\)\) to be true, but got false\.") 5667 else: 5668 yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), 5669 error_type=IndexError, 5670 error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got 3\)") 5671 # out of bounds dim (negative) 5672 yield ErrorInput(SampleInput(make_arg((L, S, M)), -4, 0, 0), 5673 error_type=IndexError, 5674 error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)") 5675 5676 # out of bounds start 5677 yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0), 5678 error_type=IndexError, 5679 error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got 11\)") 5680 # out of bounds start (negative) 5681 yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0), 5682 error_type=IndexError, 5683 error_regex=r"start out of range \(expected to be in range of \[-10, 10\], but got -11\)") 5684 5685 # out of bounds length 5686 yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1), 5687 error_type=RuntimeError, 5688 error_regex=r"start \(0\) \+ length \(11\) exceeds dimension size \(10\)\.") 5689 # out of bounds length (negative) 5690 if not is_narrow and not is_ref and torch.device(device).type == 'cpu': 5691 # narrow_copy_dense_cpu_out 5692 yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), 5693 error_type=RuntimeError, 5694 error_regex=r"start \(0\) \+ length \(-1\) exceeds dimension size \(10\)\.") 5695 else: 5696 yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), 5697 error_type=RuntimeError, 5698 error_regex=r"narrow\(\): length must be non-negative\.") 5699 5700 # Test Tensor overload that was added for XLA. Start must be an 0-dim 5701 # integral Tensor. narrow_copy doesn't have this overload. 5702 # https://github.com/pytorch/pytorch/issues/31558 5703 if is_narrow: 5704 # *1-dim* integral Tensor 5705 yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, make_arg(S, dtype=torch.int), 2), 5706 error_type=RuntimeError, 5707 error_regex=r"start must be an 0-dim integral Tensor\.") 5708 5709 # 0-dim *bool* Tensor (bools are not allowed) 5710 yield ErrorInput(SampleInput(make_arg((L, M, S)), -3, make_arg((), dtype=torch.bool), 3), 5711 error_type=RuntimeError, 5712 error_regex=r"start must be an 0-dim integral Tensor\.") 5713 5714 5715def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs): 5716 y_shape_x_shape_and_kwargs = [ 5717 ((2, 3), (2, 3), {}), 5718 ((2, 3), (2, 3), {'dim': 1}), 5719 ((6,), (6,), {}), 5720 ((6,), None, {}), 5721 # When 'trapezoid' is called with an empty input, it does not produce an output with requires_grad 5722 # See Issue #{61619} 5723 # ((6,0), (6,0), {}), 5724 ((2, 3), (1, 3), {}), 5725 ((3, 3), (3, 3), {}), 5726 ((3, 3), (3, 3), {'dim': -2}), 5727 ((5,), None, {'dx': 2.0}), 5728 ((2, 2), None, {'dx': 3.0}) 5729 ] 5730 make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, 5731 requires_grad=requires_grad) 5732 for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs: 5733 y_tensor = make_arg(y_shape) 5734 if x_shape is not None: 5735 x_tensor = make_arg(x_shape) 5736 yield SampleInput(y_tensor, x_tensor, **kwarg) 5737 else: 5738 yield SampleInput(y_tensor, **kwarg) 5739 5740def sample_cumulative_trapezoid(op_info, device, dtype, requires_grad, **kwargs): 5741 5742 y_shape_x_shape_and_kwargs = [ 5743 ((2, 3), (2, 3), {}), 5744 ((2, 3), (2, 3), {'dim': 1}), 5745 ((6,), (6,), {}), 5746 ((6,), None, {}), 5747 # When 'cumulative_trapezoid' is called with an empty input, it does not produce an output with requires_grad 5748 # See Issue #{61619} 5749 # ((6,0), (6,0), {}), 5750 ((2, 3), (1, 3), {}), 5751 ((3, 3), (3, 3), {}), 5752 ((3, 3), (3, 3), {'dim': -2}), 5753 ((5,), None, {'dx': 2.0}), 5754 ((2, 2), None, {'dx': 3.0}) 5755 ] 5756 make_arg = partial(make_tensor, device=device, dtype=dtype, 5757 requires_grad=requires_grad, low=None, high=None) 5758 for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs: 5759 y_tensor = make_arg(y_shape) 5760 if x_shape is not None: 5761 x_tensor = make_arg(x_shape) 5762 yield SampleInput(y_tensor, x_tensor, **kwarg) 5763 else: 5764 yield SampleInput(y_tensor, **kwarg) 5765 5766def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs): 5767 shapes_and_axes = [ 5768 ((3, 4, 5), 0), 5769 ((3, 4, 5), 1), 5770 ((3, 4, 5), 3), 5771 ((3, 4, 5), -1), 5772 ((3, 4, 5), -3), 5773 ((), 0), 5774 ((), -1), 5775 ((1,), 0), 5776 ((1,), -1), 5777 ] 5778 5779 for shape, axis in shapes_and_axes: 5780 tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, 5781 requires_grad=requires_grad) 5782 yield SampleInput(tensor, axis) 5783 5784 5785def sample_inputs_nn_unfold(op_info, device, dtype, requires_grad, **kwargs): 5786 shapes = ((0, 1, 5, 5), (2, 3, 5, 5)) 5787 kernel_sizes = (2, (2, 2), (2, 3)) 5788 dilations = (1, 2, (1, 2)) 5789 paddings = (0, 1, (1, 2)) 5790 strides = (1, 2, (1, 2)) 5791 5792 cases = product(shapes, kernel_sizes, dilations, paddings, strides) 5793 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5794 for shape, kernel_size, dilation, padding, stride in cases: 5795 tensor = make_arg(shape) 5796 yield SampleInput(tensor, kernel_size, dilation, padding, stride) 5797 5798 # With default args 5799 yield SampleInput(make_arg((1, 1, 5, 5)), (3, 3)) 5800 5801 5802def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs): 5803 shapes_and_args = ( 5804 ((S, 1, S, 1), ()), 5805 ((1, 1, 1, 1), ()), 5806 ((1, 1, 1, 1), (0,)), 5807 ((S, 1, S, 1), (1,)), 5808 ((S, 1, S, 1), (-1,)), 5809 ((S, 1, S, 1), (2,)), 5810 ((S, 1, S, 1), (-2,)), 5811 ((), (0, )), 5812 ) 5813 5814 for shape, args in shapes_and_args: 5815 tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, 5816 requires_grad=requires_grad) 5817 5818 yield SampleInput(tensor, args=args) 5819 5820 5821def sample_inputs_squeeze_multiple(op_info, device, dtype, requires_grad, **kwargs): 5822 shapes_and_args = ( 5823 ((1, 1, 1, 1), ()), 5824 ((S, 1, S, 1), (1,)), 5825 ((S, 1, S, 1), (-1,)), 5826 ((S, 1, S, 1), (1, 3)), 5827 ((S, 1, S, 1), (1, 2,)), 5828 ((), (0,)), 5829 ) 5830 5831 for shape, dims in shapes_and_args: 5832 tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, 5833 requires_grad=requires_grad) 5834 5835 yield SampleInput(tensor, dims) 5836 5837 5838def _squeeze_ref(x, axis=None): 5839 # NumPy doesn't allow squeezing scalars 5840 if x.ndim == 0: 5841 return x 5842 5843 if isinstance(axis, Sequence): 5844 # Numpy doesn't allow specifying non-singular dimensions 5845 axis = tuple(a for a in axis if x.shape[a] == 1) 5846 5847 if isinstance(axis, int) and x.shape[axis] != 1: 5848 return x 5849 5850 return np.squeeze(x, axis) 5851 5852def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs): 5853 assert mode in ('constant', 'reflect', 'replicate', 'circular') 5854 if mode in ['reflect', 'replicate']: 5855 cases: tuple = ( # ignore 5856 ((1, 3), (1, 2)), 5857 ((1, 3), (0, 1)), 5858 ((0, 3, 3), (1, 2)), 5859 ((0, 3, 3), (0, 1)), 5860 ((1, 3, 3), (1, 2)), 5861 ((1, 3, 3), (0, 1)), 5862 ((1, 3, 3), (0, 2, 0, 1)), 5863 ((0, 3, 3, 3), (0, 2, 0, 1)), 5864 ((3, 3, 5, 5), (0, 2, 0, 1)), 5865 ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)), 5866 ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), 5867 ((1, 3, 4, 4), (-1, 1, -2, 1)), 5868 ) 5869 elif mode == 'constant': 5870 cases = ( 5871 ((1, 3), (1, 2)), 5872 ((1, 3), (0, 1)), 5873 ((1, 3), (0, 2, 0, 1)), 5874 ((0, 3, 3), (1, 2)), 5875 ((0, 3, 3), (0, 1)), 5876 ((0, 3, 3), (0, 2, 0, 1)), 5877 ((0, 3, 3), (1, 1, 1, 1, 1, 1)), 5878 ((1, 3, 3), (1, 2)), 5879 ((1, 3, 3), (0, 1)), 5880 ((1, 3, 3), (0, 2, 0, 1)), 5881 ((1, 3, 3), (1, 1, 1, 1, 1, 1)), 5882 ((0, 3, 3, 3), (1, 2)), 5883 ((0, 3, 3, 3), (0, 1)), 5884 ((0, 3, 3, 3), (0, 2, 0, 1)), 5885 ((0, 3, 3, 3), (1, 1, 1, 1, 1, 1)), 5886 ((3, 3, 5, 5), (1, 2)), 5887 ((3, 3, 5, 5), (0, 1)), 5888 ((3, 3, 5, 5), (0, 2, 0, 1)), 5889 ((3, 3, 5, 5), (1, 1, 1, 1, 1, 1)), 5890 ((1, 3, 3, 3, 3), (1, 2)), 5891 ((1, 3, 3, 3, 3), (0, 1)), 5892 ((1, 3, 3, 3, 3), (0, 2, 0, 1)), 5893 ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), 5894 ((1, 3, 4, 4), (-1, 1, -2, 1)), 5895 ) 5896 else: # mode == 'circular' 5897 if dtype == torch.bool: 5898 # test_dtypes fails on ASAN with for the case ab 5899 # runtime error: load of value 190, which is not a valid value for type 'bool' 5900 # Reference: https://github.com/pytorch/pytorch/pull/62814#issuecomment-894156562 5901 # Reference Issue: https://github.com/pytorch/pytorch/issues/63034 5902 cases = ( 5903 ((2, 3, 3), (1, 2)), 5904 ((1, 3, 3), (1, 2)), 5905 ) 5906 else: 5907 cases = ( 5908 ((0, 3, 3), (1, 2)), 5909 ((0, 3, 3), (0, 1)), 5910 ((1, 3, 3), (1, 2)), 5911 ((1, 3, 3), (0, 1)), 5912 ((0, 3, 3, 3), (0, 2, 0, 1)), 5913 ((3, 3, 5, 5), (0, 2, 0, 1)), 5914 ((1, 3, 3, 3, 3), (1, 1, 1, 1, 1, 1)), 5915 ((1, 3, 4, 4), (-1, 1, -2, 1)), 5916 ) 5917 5918 make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5919 5920 if mode == 'constant': 5921 # Default args 5922 yield SampleInput(make_inp((1, 3, 3)), args=((2, 2),)) 5923 5924 if mode in ['reflect', 'replicate', 'circular']: 5925 for shape, pad in cases: 5926 yield SampleInput(make_inp(shape), args=(pad, mode)) 5927 else: # mode == 'constant' 5928 for pad_value in (1., 2.): 5929 for shape, pad in cases: 5930 yield SampleInput(make_inp(shape), args=(pad, mode, pad_value)) 5931 5932def sample_inputs_nn_pad_replicate_negative(op_info, device, dtype, requires_grad, **kwargs): 5933 cases: tuple = ( 5934 ((5, 3, 4, 4), (-4, 5, 0, 0)), 5935 ((6, 2, 4, 4), (0, 0, 2, -4)), 5936 ((5, 6, 4, 4), (5, -4, -4, 3)), 5937 ((4, 2, 5, 5), (-2, -1, 4, 6)), 5938 ((2, 6, 5, 5), (8, -1, -1, -3)), 5939 ((8, 1, 5, 5), (-2, -1, -1, -3)), 5940 ) 5941 make_inp = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5942 5943 for shape, pad in cases: 5944 yield SampleInput(make_inp(shape), args=(pad, 'replicate')) 5945 5946def sample_inputs_constant_pad_nd(op_info, device, dtype, *args, **kwargs): 5947 # Inherit sample inputs from nn.pad, but transform them to fit 5948 # constant_pad_nd's interface 5949 nn_samples = sample_inputs_nn_pad(op_info, device, dtype, *args, 5950 mode='constant', **kwargs) 5951 5952 # NOTE: primTorch is more strict about the type of the fill value argument 5953 # So we must cast it to the correct dtype 5954 from torch._prims_common import dtype_to_type 5955 scalar_type = dtype_to_type(dtype) 5956 5957 def drop_mode_argument(input, pad, mode=None, value=None): 5958 if value is None: 5959 return SampleInput(input, args=(pad,)) 5960 else: 5961 return SampleInput(input, args=(pad, scalar_type(value))) 5962 5963 for sample in nn_samples: 5964 yield drop_mode_argument(sample.input, *sample.args, **sample.kwargs) 5965 5966def sample_inputs_repeat_interleave(op_info, device, dtype, requires_grad, **kwargs): 5967 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 5968 5969 yield SampleInput(make_input(()), repeats=2) 5970 yield SampleInput(make_input((2, 3, 4)), repeats=2) 5971 yield SampleInput(make_input((2, 3, 4)), repeats=2, dim=1) 5972 yield SampleInput(make_input((2, 3, 4)), repeats=torch.arange(3, device=device), dim=1) 5973 5974 5975def sample_inputs_stft(op_info, device, dtype, requires_grad, **kwargs): 5976 def mt(shape, **kwargs): 5977 return make_tensor(shape, device=device, dtype=dtype, 5978 requires_grad=requires_grad, **kwargs) 5979 5980 yield SampleInput(mt(100), n_fft=10, return_complex=True) 5981 yield SampleInput(mt(100), n_fft=10, return_complex=False) 5982 if dtype.is_complex: 5983 yield SampleInput(mt(100), n_fft=10) 5984 5985 for center in [False, True]: 5986 yield SampleInput(mt(10), n_fft=7, center=center, return_complex=True) 5987 yield SampleInput(mt((10, 100)), n_fft=16, hop_length=4, 5988 center=center, return_complex=True) 5989 5990 window = mt(16, low=.5, high=2.0) 5991 yield SampleInput( 5992 mt((2, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center)) 5993 yield SampleInput( 5994 mt((3, 100)), kwargs=dict(n_fft=16, window=window, return_complex=True, center=center)) 5995 if not dtype.is_complex: 5996 yield SampleInput( 5997 mt((10, 100)), n_fft=16, window=window, onesided=False, 5998 return_complex=True) 5999 6000 6001def sample_inputs_istft(op_info, device, dtype, requires_grad, **kwargs): 6002 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6003 6004 def mt(shape, **kwargs): 6005 real_shape = shape if dtype.is_complex else shape + (2,) 6006 return make_arg(real_shape, **kwargs) 6007 6008 yield SampleInput(mt((10, 2)), kwargs=dict(n_fft=10)) 6009 yield SampleInput(mt((6, 3)), kwargs=dict(n_fft=6, onesided=False)) 6010 yield SampleInput(mt((6, 4)), kwargs=dict(n_fft=10, onesided=True)) 6011 6012 for center in [False, True]: 6013 yield SampleInput(mt((10, 10, 6)), kwargs=dict(n_fft=10, center=center)) 6014 yield SampleInput(mt((1, 9, 10)), kwargs=dict(n_fft=16, hop_length=4, center=center)) 6015 6016 window = make_arg(10, low=.5, high=2.0) 6017 yield SampleInput(mt((10, 10, 6)), kwargs=dict( 6018 n_fft=10, window=window, center=center, return_complex=dtype.is_complex)) 6019 yield SampleInput(mt((10, 10, 10)), kwargs=dict( 6020 n_fft=10, window=window[:8], win_length=8, center=center, return_complex=True)) 6021 6022 real_window = window if not dtype.is_complex else window.real 6023 yield SampleInput(mt((10, 5, 6)), kwargs=dict(n_fft=8, window=real_window[:8], center=center)) 6024 6025def sample_inputs_ormqr(op_info, device, dtype, requires_grad, **kwargs): 6026 # create a helper function wrapping `make_tensor` 6027 make_input = partial(make_tensor, dtype=dtype, device=device, low=-1, high=1) 6028 6029 batches = [(), (0, ), (2, ), (2, 1)] 6030 ns = [5, 2, 0] 6031 tf = [True, False] 6032 for batch, (m, n), left, transpose in product(batches, product(ns, ns), tf, tf): 6033 input = make_input((*batch, m, n)) 6034 reflectors, tau = torch.geqrf(input) 6035 reflectors.requires_grad_(requires_grad) 6036 tau.requires_grad_(requires_grad) 6037 other_matrix_shape = (m, n) if left else (n, m) 6038 other = make_input((*batch, *other_matrix_shape), requires_grad=requires_grad) 6039 yield SampleInput(reflectors, tau, other, left=left, transpose=transpose) 6040 6041 6042def sample_inputs_cholesky_solve(op_info, device, dtype, requires_grad=False, **kwargs): 6043 cholesky_inverse_samples = sample_inputs_linalg_cholesky_inverse( 6044 op_info, device, dtype, requires_grad=False 6045 ) 6046 6047 for sample in cholesky_inverse_samples: 6048 psd_matrix = sample.input 6049 sample.input = make_tensor(psd_matrix.shape, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None) 6050 sample.args = (psd_matrix.requires_grad_(requires_grad),) 6051 yield sample 6052 6053 6054def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs): 6055 make_arg = partial(make_fullrank_matrices_with_distinct_singular_values, 6056 dtype=dtype, device=device, requires_grad=requires_grad) 6057 6058 # not needed once OpInfo tests support Iterables 6059 batch_shapes = ((), (3,), (3, 3)) 6060 for batch_shape, get_infos, size_delta in product(batch_shapes, (True, False), (-2, -1, 0, +1, +2)): 6061 shape = batch_shape + (S + size_delta, S) 6062 input = make_arg(*shape) 6063 yield SampleInput(input, args=(True, get_infos)) 6064 6065 6066def sample_inputs_lu_unpack(op_info, device, dtype, requires_grad=False, **kwargs): 6067 def out_fn(output): 6068 return output[1], output[2] 6069 6070 for lu_sample in sample_inputs_linalg_lu(op_info, device, dtype, requires_grad, **kwargs): 6071 lu_data, pivots = torch.linalg.lu_factor(lu_sample.input) 6072 lu_data.requires_grad_(requires_grad) 6073 yield SampleInput(lu_data, pivots).with_metadata(output_process_fn_grad=out_fn) 6074 6075 6076def sample_inputs_roll(op_info, device, dtype, requires_grad=False, **kwargs): 6077 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6078 6079 args = ((0, 0), (1, 2), (0, 2), (2, 0), (-1, 0), (10000, 1), (2,), ((1, 2, -1), (0, 1, 2))) 6080 6081 for arg in args: 6082 yield SampleInput(make_arg((0, 0, 0)), args=arg) 6083 yield SampleInput(make_arg((S, S, S)), args=arg) 6084 6085 # Scalar tensor 6086 yield SampleInput(make_arg(()), args=(10, )) 6087 6088def error_inputs_roll(op_info, device, **kwargs): 6089 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 6090 err_msg1 = "`shifts` required" 6091 s1 = SampleInput(make_arg((S,)), ()) 6092 yield ErrorInput(s1, error_regex=err_msg1) 6093 6094 err_msg2 = ("shifts and dimensions must align") 6095 s2 = SampleInput(make_arg((S, S)), (2, 1), 0) 6096 yield ErrorInput(s2, error_regex=err_msg2) 6097 6098 err_msg3 = ("out of range") 6099 s3 = SampleInput(make_arg((S, )), 0, 2) 6100 yield ErrorInput(s3, error_regex=err_msg3, error_type=IndexError) 6101 6102 err_msg4 = ("Dimension specified as 0") 6103 s4 = SampleInput(make_arg(()), 0, 0) 6104 yield ErrorInput(s4, error_regex=err_msg4, error_type=IndexError) 6105 6106def sample_inputs_rot90(op_info, device, dtype, requires_grad=False, **kwargs): 6107 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6108 6109 args = itertools.product(range(-5, 6), [(0, 1), (1, 2), (1, -1)]) 6110 6111 yield SampleInput(make_arg((S, S, S))) 6112 for arg in args: 6113 yield SampleInput(make_arg((S, S, S)), args=arg) 6114 6115 6116def error_inputs_rot90(op_info, device, **kwargs): 6117 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 6118 err_msg1 = "expected total rotation dims" 6119 s1 = SampleInput(make_arg((S, S)), dims=(0,)) 6120 yield ErrorInput(s1, error_regex=err_msg1) 6121 6122 err_msg2 = "expected total dims >= 2" 6123 s2 = SampleInput(make_arg((S,))) 6124 yield ErrorInput(s2, error_regex=err_msg2) 6125 6126 err_msg3 = "expected rotation dims to be different" 6127 s3 = SampleInput(make_arg((S, S)), dims=(1, 1)) 6128 yield ErrorInput(s3, error_regex=err_msg3) 6129 6130 6131def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs): 6132 tensor_nd = partial(make_tensor, (S, S, S), device=device, dtype=dtype, 6133 requires_grad=requires_grad) 6134 tensor_1d = partial(make_tensor, (S,), device=device, dtype=dtype, 6135 requires_grad=requires_grad) 6136 6137 yield SampleInput(tensor_nd()) 6138 yield SampleInput(tensor_nd(), dim=1) 6139 yield SampleInput(tensor_nd(), dim=1, unbiased=True, keepdim=True) 6140 yield SampleInput(tensor_1d(), dim=0, unbiased=True, keepdim=True) 6141 yield SampleInput(tensor_1d(), dim=0, unbiased=False, keepdim=False) 6142 6143 yield SampleInput(tensor_nd(), dim=(1,), correction=1.3) 6144 yield SampleInput(tensor_nd(), dim=(1,), correction=S // 2) 6145 yield SampleInput(tensor_nd(), dim=None, correction=0, keepdim=True) 6146 yield SampleInput(tensor_nd(), dim=None, correction=None) 6147 yield SampleInput(tensor_nd(), correction=0, keepdim=True) 6148 yield SampleInput(make_tensor(3, 4, 5, device=device, dtype=dtype, requires_grad=requires_grad), dim=-3) 6149 6150 6151def sample_inputs_std_var_unbiased(op_info, device, dtype, requires_grad, **kwargs): 6152 make_arg = partial(make_tensor, device=device, dtype=dtype, 6153 requires_grad=requires_grad) 6154 6155 # Test var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) 6156 yield SampleInput(make_arg((S, S)), True) 6157 yield SampleInput(make_arg((S,)), False) 6158 6159 6160def _generate_correlation_inputs(device, dtype, requires_grad, **kwargs): 6161 shapes = [(2,), (1, 2), (3, 2), (2, 3)] 6162 for shape in shapes: 6163 yield make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad) 6164 6165 6166def sample_inputs_corrcoef(op_info, device, dtype, requires_grad, **kwargs): 6167 return (SampleInput(t) for t in _generate_correlation_inputs(device, dtype, requires_grad)) 6168 6169 6170def sample_inputs_cov(op_info, device, dtype, requires_grad, **kwargs): 6171 for t in _generate_correlation_inputs(device, dtype, requires_grad): 6172 yield SampleInput(t) 6173 num_observations = t.numel() if t.ndimension() < 2 else t.size(1) 6174 fweights = make_tensor((num_observations,), dtype=torch.int, device=device, low=1, high=10) 6175 aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=0, high=1, requires_grad=requires_grad) 6176 for correction, fw, aw in product(range(num_observations), [None, fweights], [None, aweights]): 6177 yield SampleInput(t.clone().requires_grad_(requires_grad), 6178 correction=correction, fweights=fw, aweights=aw) 6179 6180 6181def error_inputs_cov(op_info, device, **kwargs): 6182 a = torch.rand(S, device=device) 6183 yield ErrorInput( 6184 SampleInput(torch.rand(S, S, S, device=device)), 6185 error_regex="expected input to have two or fewer dimensions") 6186 yield ErrorInput( 6187 SampleInput(a, fweights=torch.rand(S, S, device=device)), 6188 error_regex="expected fweights to have one or fewer dimensions") 6189 yield ErrorInput( 6190 SampleInput(a, aweights=torch.rand(S, S, device=device)), 6191 error_regex="expected aweights to have one or fewer dimensions") 6192 yield ErrorInput( 6193 SampleInput(a, fweights=torch.rand(S, device=device)), 6194 error_regex="expected fweights to have integral dtype") 6195 yield ErrorInput( 6196 SampleInput(a, aweights=torch.tensor([1, 1], device=device)), 6197 error_regex="expected aweights to have floating point dtype") 6198 yield ErrorInput( 6199 SampleInput(a, fweights=torch.tensor([1], device=device)), 6200 error_regex="expected fweights to have the same numel") 6201 yield ErrorInput( 6202 SampleInput(a, aweights=torch.rand(1, device=device)), 6203 error_regex="expected aweights to have the same numel") 6204 yield ErrorInput( 6205 SampleInput(a, fweights=torch.tensor([-1, -2, -3, -4 , -5], device=device)), 6206 error_regex="fweights cannot be negative") 6207 yield ErrorInput( 6208 SampleInput(a, aweights=torch.tensor([-1., -2., -3., -4., -5.], device=device)), 6209 error_regex="aweights cannot be negative") 6210 6211 6212def sample_inputs_permute(op_info, device, dtype, requires_grad, **kwargs): 6213 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6214 6215 cases = [((1, 2, 3, 4), (0, 2, 3, 1)), 6216 ((1, 2, 3, 4), (0, -2, -1, 1)), 6217 ((), ()), 6218 ((1, 2, 3, 4), (2, 1, 3, 0))] 6219 6220 for shape, args in cases: 6221 yield SampleInput(make_arg(shape), args=(args,)) 6222 6223def reference_inputs_permute(op, device, dtype, requires_grad, **kwargs): 6224 yield from sample_inputs_permute(op, device, dtype, requires_grad, **kwargs) 6225 6226 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6227 6228 cases = ( 6229 ((), ()), 6230 ((1,), (0,)), 6231 ((2, 2), (1, 0)), 6232 ((2, 2), (0, 1)), 6233 ((2, 0, 1), (0, 2, 1)), 6234 ((3, 4, 2), (2, 1, 0)), 6235 ((3, 4, 2), (1, 0, 2)), 6236 ((3, 4, 2), (0, 1, 2)), 6237 ) 6238 6239 # Adds tricky permutations and permutations with noncontiguity 6240 for shape, permutation in cases: 6241 for p in itertools.permutations(permutation): 6242 a = make_arg(shape).permute(p) 6243 yield SampleInput(a, args=(permutation,)) 6244 6245 a = make_arg(shape, noncontiguous=True).permute(p) 6246 yield SampleInput(a, args=(permutation,)) 6247 6248def error_inputs_softshrink(op, device, **kwargs): 6249 yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"lambd": -0.5}), 6250 error_regex="lambda must be greater or equal to 0, but found to be -0.5") 6251 6252def sample_inputs_softshrink(op_info, device, dtype, requires_grad=False, **kwargs): 6253 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6254 6255 # The additional sample is to check additional values of lambd beyond the default 6256 # value (what is already checked by sample_inputs_elementwise_unary) 6257 for lbda in (0., 0.5): 6258 yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda}) 6259 6260 yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) 6261 6262def sample_inputs_hardshrink(op_info, device, dtype, requires_grad=False, **kwargs): 6263 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6264 6265 # The additional sample is to check additional values of lambd beyond the default 6266 # value (what is already checked by sample_inputs_elementwise_unary) 6267 # Note that unlike softshrink, lambd is allowed to be negative for hardshrink 6268 for lbda in (-0.5, 0., 0.5): 6269 yield SampleInput(make_arg(S, S), kwargs={"lambd": lbda}) 6270 6271 yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) 6272 6273 6274def sample_inputs_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs): 6275 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6276 6277 # The additional sample is to check additional values of min_val and max_val beyond the default 6278 # value (what is already checked by sample_inputs_elementwise_unary) 6279 for max_val, min_val in ((0.5, -0.5), (0., 0.)): 6280 yield SampleInput(make_arg(S, S), kwargs={"min_val": min_val, "max_val": max_val}) 6281 6282 yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) 6283 6284def error_inputs_hardtanh(op_info, device, **kwargs): 6285 # Tests that hardtanh errors out when passed min_val > max_val. 6286 yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device), kwargs={"min_val": 0.5, "max_val": -0.5}), 6287 error_type=ValueError, error_regex="min_val cannot be greater than max_val") 6288 6289def sample_inputs_einsum(op_info, device, dtype, requires_grad=False, **kwargs): 6290 def c(t): 6291 return t.clone().requires_grad_(requires_grad) 6292 6293 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6294 x = make_arg((3,)) 6295 y = make_arg((4,)) 6296 A = make_arg((2, 3,)) 6297 B = make_arg((1, 3,)) 6298 C = make_arg((1, 2, 3,)) 6299 D = make_arg((1, 3, 4,)) 6300 E = make_arg((4, 4,)) 6301 H = make_arg((3, 3,)) 6302 I = make_arg((1, 3, 1,)) 6303 6304 # Vector operations 6305 yield SampleInput([c(x)], 'i->') # sum 6306 yield SampleInput([c(x), c(y)], 'i,j->ij') # outer 6307 6308 # Matrix operations 6309 yield SampleInput([c(A)], "ij->i") # col sum 6310 yield SampleInput([c(A), c(B)], "ij,kj->ik") # matmul 6311 yield SampleInput([c(A), c(E)], "ij,Ab->ijAb") # matrix outer product 6312 6313 # Tensor operations 6314 yield SampleInput([c(C), c(D)], "aij,ajk->aik") # batch matmul 6315 yield SampleInput([c(D), c(E)], "aij,jk->aik") # tensor matrix contraction 6316 yield SampleInput([c(C), c(B)], "ijk,ik->j") # non contiguous 6317 6318 # Test diagonals 6319 yield SampleInput([c(I)], 'iji->j') # non-contiguous trace 6320 6321 # Test ellipsis 6322 yield SampleInput([c(H)], "i...->...") 6323 yield SampleInput([c(C), c(x)], '...ik, ...j -> ij') 6324 6325 6326def sample_inputs_flip(op_info, device, dtype, requires_grad, **kwargs): 6327 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 6328 sizes = ((S, M, S), (S, 0, M)) 6329 all_dims = ((0, 1, 2), (0,), (0, 2), (-1,), ()) 6330 6331 for size, dims in product(sizes, all_dims): 6332 yield SampleInput(make_arg(size), kwargs={"dims": dims}) 6333 6334def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad, **kwargs): 6335 shapes = [ 6336 (S, M, S), 6337 (S, 0, M), 6338 ] 6339 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 6340 return (SampleInput(make_arg(shape, low=None, high=None)) for shape in shapes) 6341 6342def error_inputs_fliplr(op, device, **kwargs): 6343 yield ErrorInput(SampleInput(make_tensor((1,), dtype=torch.float, device=device)), 6344 error_regex="Input must be >= 2-d.") 6345 6346def error_inputs_flipud(op, device, **kwargs): 6347 yield ErrorInput(SampleInput(make_tensor((), dtype=torch.float, device=device)), 6348 error_regex="Input must be >= 1-d.") 6349 6350def sample_inputs_clamp(op_info, device, dtype, requires_grad, **kwargs): 6351 make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) 6352 make_integral_arg = partial(make_tensor, dtype=torch.int32, device=device, low=None, high=None, requires_grad=False) 6353 shape = (S, M, S) 6354 6355 yield SampleInput(make_arg(shape), args=(make_arg(shape), make_arg(shape))) 6356 yield SampleInput(make_arg(shape), args=(make_arg(shape[1:]), make_arg(shape[1:]))) 6357 yield SampleInput(make_arg(shape), args=(make_arg((S, 1, S)),)) 6358 yield SampleInput(make_arg(shape), args=(None, make_arg(shape))) 6359 yield SampleInput(make_arg(shape), args=(make_arg(shape), None)) 6360 # test type promotion 6361 yield SampleInput(make_arg(shape), args=(make_integral_arg(shape), None)) 6362 yield SampleInput(make_arg(shape), args=(make_arg(shape), make_integral_arg(shape))) 6363 6364def reference_inputs_elementwise_ternary(op, device, dtype, requires_grad, *, sample_inputs_func, supports_scalars=False, **kwargs): 6365 yield from sample_inputs_func(op, device, dtype, requires_grad, **kwargs) 6366 6367 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6368 make_scalar_tensor = partial(make_tensor, (), device='cpu', dtype=dtype, requires_grad=requires_grad) 6369 supported_dtypes = op.supported_dtypes(device) 6370 6371 # broadcasting and oncontiguous cases 6372 cases = ( 6373 ((4, 4), (4, 4), (4, 4)), 6374 ((4, 4), (1, 4, 4), (4, 4)), 6375 ((4, 4), (1, 4, 4), (4, 1, 4)), 6376 ((4, 4, 1), (1, 4, 4), (4, 4)), 6377 ((4, 1), (1, 4, 4), (1, 4)), 6378 ((4, 4), (), (4, 4)), 6379 ((4, 4), (), ()), 6380 ((), (4, 4), (1, 4, 4)), 6381 ) 6382 6383 for a, b, c in cases: 6384 yield SampleInput(make_arg(a), args=(make_arg(b), make_arg(c))) 6385 yield SampleInput(make_arg(a, noncontiguous=True), 6386 args=(make_arg(b).transpose(0, -1), make_arg(c, noncontiguous=True).transpose(0, -1))) 6387 6388 # scalar cases 6389 if supports_scalars: 6390 cases = [ 6391 ((), 1, 2,), 6392 ((), 1., 2), 6393 ((4, 4), 1., 2,), 6394 ((3, 4), make_scalar_tensor(), make_scalar_tensor()), 6395 ] 6396 6397 if torch.complex64 in supported_dtypes: 6398 cases.extend([ 6399 ((3, 1, 4), complex(1, 2), 3.), 6400 ]) 6401 6402 for a, b, c in cases: 6403 yield SampleInput(make_arg(a), args=(b, c)) 6404 6405 # type promotion cases 6406 # int x float 6407 if torch.float in supported_dtypes and torch.long in supported_dtypes: 6408 a = make_arg((), dtype=torch.long) 6409 b = make_arg((1, 4), dtype=torch.float) 6410 c = make_arg((3, 4)) 6411 6412 cases = ( 6413 (a, b, c), 6414 (c, a, b), 6415 ) 6416 6417 for a, b, c in cases: 6418 yield SampleInput(a, args=(b, c)) 6419 6420 # NaN propagation 6421 if dtype.is_floating_point or dtype.is_complex: 6422 nan = float('nan') if dtype.is_floating_point else complex(float('nan'), float('nan')) 6423 6424 a = make_arg((12,)) 6425 a[4] = nan 6426 a[7] = nan 6427 b = make_arg((12,)) 6428 b[1] = nan 6429 b[7] = nan 6430 c = make_arg((12,)) 6431 c[9] = nan 6432 6433 yield SampleInput(a, args=(b, c)) 6434 6435 6436def _clamp_min_numpy(a, min=None): 6437 return np.maximum(a, min) 6438 6439 6440def _clamp_max_numpy(a, max=None): 6441 return np.minimum(a, max) 6442 6443 6444def _clamp_numpy(a, min=None, max=None): 6445 if min is None: 6446 return np.minimum(a, max) 6447 if max is None: 6448 return np.maximum(a, min) 6449 6450 return np.minimum(max, np.maximum(a, min)) 6451 6452 6453def sample_inputs_cumprod(op_info, device, dtype, requires_grad, **kwargs): 6454 def make_arg(shape): 6455 # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck 6456 return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad) 6457 6458 def prod_zeros(dim_select): 6459 assert len(dim_select) == 2 6460 result = make_arg(3 * (S,)) 6461 result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_() 6462 result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_() 6463 result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_() 6464 return result 6465 6466 for dim in range(3): 6467 yield SampleInput(make_arg((S, S, S)), args=(dim,)) 6468 # Scalar tensors and empty tensor 6469 for size in [(), (1,), (0,)]: 6470 yield SampleInput(make_arg(size), args=(0,)) 6471 6472 yield SampleInput(prod_zeros([0, 1]), args=(1,)) 6473 yield SampleInput(prod_zeros([0, 2]), args=(1,)) 6474 yield SampleInput(prod_zeros([1, 2]), args=(1,)) 6475 6476 # test dtype kwarg 6477 yield SampleInput(prod_zeros([1, 2]), args=(1,), kwargs={'dtype': dtype}) 6478 6479def sample_inputs_view_as_complex(op_info, device, dtype, requires_grad, **kwargs): 6480 yield SampleInput(make_tensor((S, 2), dtype=dtype, device=device, requires_grad=requires_grad)) 6481 6482def sample_inputs_view_as_real(op_info, device, dtype, requires_grad, **kwargs): 6483 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6484 sizes = ((S, S), ()) 6485 return (SampleInput(make_arg(size)) for size in sizes) 6486 6487def error_inputs_complex(op_info, device, is_ref=False, **kwargs): 6488 make_arg = partial(make_tensor, dtype=torch.float32, device=device) 6489 6490 if is_ref: 6491 error_float = "Expected both inputs to be Half, Float or Double tensors but got torch.float32 and torch.int32" 6492 error_dtype = "Expected object of scalar type torch.float32 but got scalar type torch.float64 for second argument" 6493 error_out = "Expected out tensor to have dtype torch.complex128 but got torch.complex64 instead" 6494 else: 6495 error_float = "Expected both inputs to be Half, Float or Double tensors but got Float and Int" 6496 error_dtype = "Expected object of scalar type Float but got scalar type Double for second argument" 6497 error_out = "Expected object of scalar type ComplexDouble but got scalar type ComplexFloat for argument 'out'" 6498 6499 yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.int)), 6500 error_type=RuntimeError, error_regex=error_float) 6501 6502 yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.float64)), 6503 error_type=RuntimeError, error_regex=error_dtype) 6504 6505 yield ErrorInput(SampleInput(make_arg(M, S, dtype=torch.float64), make_arg(M, S, dtype=torch.float64), 6506 out=make_arg(M, S, dtype=torch.complex64)), 6507 error_type=RuntimeError, error_regex=error_out) 6508 6509def sample_inputs_logaddexp(op_info, device, dtype, requires_grad, **kwargs): 6510 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6511 shape = (S, S) 6512 yield SampleInput(make_arg(shape), make_arg(shape)) 6513 6514def sample_inputs_prod(op_info, device, dtype, requires_grad, **kwargs): 6515 def make_arg(shape): 6516 # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck 6517 return make_tensor(shape, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad) 6518 6519 def prod_single_zero(): 6520 result = make_arg(2 * (S,)) 6521 result[0, 1] = 0 6522 return result 6523 6524 for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad): 6525 # only Tensor, ignore other inputs 6526 yield SampleInput(sample.input.clone().requires_grad_(requires_grad)) 6527 yield sample 6528 6529 # Generates samples with keepdim = True 6530 for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad): 6531 sample.kwargs['keepdim'] = True 6532 yield sample 6533 6534 yield SampleInput(prod_single_zero()) 6535 yield SampleInput(make_arg((3, 3, 3)), args=(1,)) 6536 yield SampleInput(make_arg((3, 3, 3)), args=(1,), kwargs={'keepdim': True}) 6537 6538 yield SampleInput(make_arg((3, 0)), args=(1,)) 6539 yield SampleInput(make_arg((3, 0)), args=(1,), kwargs={'keepdim': True}) 6540 yield SampleInput(torch.tensor([2., 3, 0, 0], dtype=dtype, device=device, requires_grad=requires_grad)) 6541 6542 # test zero scalar tensor 6543 zero = make_arg(()) 6544 zero.zero_() 6545 yield SampleInput(zero.clone().requires_grad_(requires_grad)) 6546 yield SampleInput(zero.clone().requires_grad_(requires_grad), args=(0,)) 6547 yield SampleInput(zero.clone().requires_grad_(requires_grad), 6548 args=(0,), 6549 kwargs={'keepdim': True}) 6550 6551def error_inputs_neg(op_info, device, **kwargs): 6552 si = SampleInput(torch.tensor((False, True), device=device)) 6553 msg = ("Negation, the `\\-` operator, on a bool tensor is not supported." 6554 " If you are trying to invert a mask, use the `\\~` or" 6555 " `logical_not\\(\\)` operator instead.") 6556 yield ErrorInput(si, error_regex=msg) 6557 6558def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs): 6559 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) 6560 yield SampleInput(make_arg(M)) 6561 6562 tensors = ( 6563 make_arg((M, M)), 6564 make_arg((3, 5)), 6565 make_arg((5, 3)), 6566 ) 6567 6568 args = ((), (2,), (-2,), (1,), (2,)) 6569 6570 for tensor, arg in product(tensors, args): 6571 yield SampleInput(tensor.clone().requires_grad_(requires_grad), *arg) 6572 6573def reference_inputs_diagonal_diag_embed(op_info, device, dtype, requires_grad, **kwargs): 6574 yield from sample_inputs_diagonal_diag_embed( 6575 op_info, device, dtype, requires_grad, **kwargs) 6576 6577 make_arg = partial( 6578 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6579 6580 shapes1d = ((0,), (1,)) 6581 shapes2d = ((L, M),) 6582 shapes3d = ((L, M, S),) 6583 6584 kwargs1d = {} 6585 6586 kwargs2d = ( 6587 # dim1 > dim2 is allowed 6588 dict(dim1=1, dim2=0), 6589 # negative dims are allowed 6590 dict(dim1=-2, dim2=-1), 6591 # one dim negative and the other nonnegative is allowed 6592 dict(dim1=-1, dim2=0), 6593 # out of bounds offset should return an empty tensor in diagonal and 6594 # offset the diagonal in diag_embed 6595 dict(offset=100), 6596 ) 6597 6598 kwargs3d = kwargs2d + ( 6599 # make sure we can use non-sequential dims 6600 dict(offset=-1, dim1=0, dim2=2), 6601 ) 6602 6603 samples1d = product(shapes1d, kwargs1d) 6604 samples2d = product(shapes2d, kwargs2d) 6605 samples3d = product(shapes3d, kwargs3d) 6606 6607 for shape, kwargs in chain(samples1d, samples2d, samples3d): 6608 if 'diagonal' in op_info.name: 6609 # these are error inputs for diagonal 6610 if shape in ((0,), (1,)): 6611 continue 6612 yield SampleInput(input=make_arg(shape), kwargs=kwargs) 6613 6614 6615def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs): 6616 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 6617 6618 # Shapes for 2D Tensors 6619 shapes_2d = ((M, M), (3, 5), (5, 3)) 6620 6621 # Shapes for 3D Tensors 6622 shapes_3d = ((M, M, M),) 6623 6624 args_2d = ((), (2,), (-2,), (1,)) 6625 args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1)) 6626 6627 for input_shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)): 6628 input_ = make_arg(input_shape) 6629 # We can programmatically figure out the right shape for src: 6630 # It should be the same size as input.diagonal(other_args...) 6631 if not isinstance(arg, tuple): 6632 arg_tuple = (arg,) 6633 else: 6634 arg_tuple = arg 6635 src_shape = input_.diagonal(*arg_tuple).size() 6636 src = make_arg(src_shape) 6637 yield SampleInput(input_, args=(src, *arg_tuple)) 6638 6639 6640def sample_inputs_to_sparse(op_info, device, dtype, requires_grad, **kwargs): 6641 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6642 6643 yield SampleInput(make_arg((S, S))).with_metadata(output_process_fn_grad=lambda x: x.to_dense()) 6644 yield SampleInput(make_arg((S, S)), 1).with_metadata(output_process_fn_grad=lambda x: x.to_dense()) 6645 6646def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs): 6647 batch_size, num_classes = shape = (2, 3) 6648 reductions = ("mean", "sum", "none") 6649 6650 input_shape_and_kwargs: List[Tuple[Tuple[int, ...], Dict[str, Any]]] = [ 6651 (shape, {}), 6652 ((*shape, 1), {}), 6653 ((*shape, 1, 2), {}), 6654 ((*shape, 1, 2, 3), {}), 6655 *[(shape, dict(reduction=reduction)) for reduction in reductions], 6656 *[ 6657 ( 6658 shape, 6659 dict( 6660 weight=make_tensor((num_classes,), device=device, dtype=dtype), 6661 reduction=reduction, 6662 ), 6663 ) 6664 for reduction in reductions 6665 ], 6666 (shape, dict(ignore_index=1)), 6667 ] 6668 6669 for (input_shape, kwargs), probabilities_target in itertools.product(input_shape_and_kwargs, (False, True)): 6670 input = make_tensor(input_shape, device=device, dtype=dtype, requires_grad=requires_grad) 6671 6672 if probabilities_target: 6673 # ignore_index is not supported for probabilities target 6674 if "ignore_index" in kwargs: 6675 continue 6676 6677 target = make_tensor( 6678 input_shape, 6679 low=0, 6680 high=1, 6681 device=device, 6682 dtype=dtype, 6683 requires_grad=requires_grad, 6684 ) 6685 else: 6686 target = make_tensor( 6687 (batch_size, *input_shape[2:]), 6688 low=0, 6689 high=num_classes, 6690 device=device, 6691 dtype=torch.long, 6692 ) 6693 6694 if "ignore_index" in kwargs and torch.all(target == kwargs["ignore_index"]): 6695 # make sure at least one item in target is not ignored 6696 target[0] = random.sample(sorted(set(range(num_classes)) - {kwargs["ignore_index"]}), 1)[0] 6697 6698 yield SampleInput(input, target, **kwargs) 6699 6700 6701def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs): 6702 low, high = op_info.domain 6703 6704 # Note: Operator is very sensitive at points near the 6705 # start and end of domain and leads to NaN for float16 6706 # if domain_eps is 1e-5. 6707 if dtype.is_floating_point or dtype.is_complex: 6708 domain_eps = op_info._domain_eps if dtype != torch.float16 else 3e-2 6709 6710 low = low + domain_eps 6711 high = high - domain_eps 6712 6713 make_arg = partial(make_tensor, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) 6714 6715 yield SampleInput(make_arg((S, S, S))) 6716 yield SampleInput(make_arg((S, S, S)), 0.2) 6717 yield SampleInput(make_arg(())) 6718 yield SampleInput(make_arg(()), 0.2) 6719 6720def sample_inputs_isin(op_info, device, dtype, requires_grad, **kwargs): 6721 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6722 # isin has two paths based on the size of elements and test_elements. 6723 # if elements.numel() < 10 * pow(test_elements.numel(), 0.145): 6724 yield SampleInput(make_arg((L,)), args=(make_arg((S,)),)) 6725 # else: 6726 yield SampleInput(make_arg((S,)), args=(make_arg((L,)),)) 6727 6728def sample_inputs_masked_scatter(op_info, device, dtype, requires_grad, **kwargs): 6729 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6730 6731 yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg((S, S)))) 6732 yield SampleInput(make_arg((S, S)), args=(torch.randn((S,), device=device) > 0, make_arg((S, S)))) 6733 yield SampleInput(make_arg((S, S)), args=(bernoulli_scalar().to(device), make_arg((S, S)))) 6734 yield SampleInput(make_arg((S,)), 6735 args=(torch.randn(S, S, device=device) > 0, make_arg((S, S))), 6736 broadcasts_input=True) 6737 6738def error_inputs_masked_scatter(op_info, device, **kwargs): 6739 make_arg = partial(make_tensor, device=device, dtype=torch.float) 6740 for mask_dtype in [torch.float, torch.uint8]: 6741 yield ErrorInput(SampleInput(make_arg(1, 3), args=(torch.ones(1, 3, device=device, dtype=mask_dtype), 6742 make_arg(3, 4))), 6743 error_regex=r"masked_scatter_ only supports boolean masks") 6744 6745def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs): 6746 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6747 6748 yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, 10)) 6749 yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, make_arg(()))) 6750 yield SampleInput(make_arg((S, S)), args=(torch.randn(S, device=device) > 0, 10)) 6751 yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, 10)) 6752 yield SampleInput(make_arg(()), args=(torch.randn((), device=device) > 0, make_arg(()))) 6753 yield SampleInput(make_arg((S, S)), args=(torch.randn((), device=device) > 0, 10)) 6754 6755 yield SampleInput(make_arg((S,)), 6756 args=(torch.randn(S, S, device=device) > 0, make_arg(())), 6757 broadcasts_input=True) 6758 yield SampleInput(make_arg((S,)), 6759 args=(torch.randn(S, S, device=device) > 0, 10), 6760 broadcasts_input=True) 6761 6762 if torch.device(device).type == 'cuda': 6763 # `self` and `mask` on CUDA but `value` is a CPU scalar tensor. 6764 yield SampleInput(make_arg((S, S)), 6765 args=(torch.randn(S, S, device=device) > 0, 6766 make_tensor((), device="cpu", dtype=dtype))) 6767 6768def error_inputs_masked_fill(op_info, device, **kwargs): 6769 make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) 6770 # `value` is not a 0-D tensor. 6771 yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, make_arg((1,)))), 6772 error_regex="only supports a 0-dimensional value tensor, but got tensor with 1 dimension") 6773 # downcasting complex value (scalar overload) 6774 yield ErrorInput(SampleInput(make_arg((2, 2)), args=(make_arg(()) > 0, 1j)), 6775 error_regex=r"value cannot be converted to type .* without overflow") 6776 # downcasting complex value (tensor overload) 6777 yield ErrorInput(SampleInput(torch.ones(2, dtype=torch.long, device=device), 6778 args=(make_arg(()) > 0, torch.tensor(1j, device=device))), 6779 error_regex=r"value cannot be converted to type .* without overflow") 6780 6781 if torch.device(device).type == 'cuda': 6782 # `self` and `mask` on CPU but `value` is a CUDA scalar tensor. 6783 yield ErrorInput(SampleInput(torch.randn((S, S), device='cpu'), 6784 args=(torch.randn(S, S, device='cpu') > 0, 6785 torch.randn((), device='cuda'))), 6786 error_regex=r"to be on same device") 6787 6788 6789def sample_inputs_masked_select(op_info, device, dtype, requires_grad, **kwargs): 6790 make_arg = partial( 6791 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) 6792 6793 yield SampleInput(make_arg((M, M)), torch.randn(M, M, device=device) > 0) 6794 6795 yield SampleInput(make_arg((M, M)), torch.randn((M,), device=device) > 0) 6796 yield SampleInput(make_arg((M,)), torch.randn((M, M), device=device) > 0) 6797 6798 yield SampleInput(make_arg((M, 1, M)), torch.randn((M, M), device=device) > 0) 6799 6800 yield SampleInput(make_arg(()), torch.tensor(1, device=device, dtype=torch.bool)) 6801 6802 yield SampleInput(make_arg((M, M)), torch.tensor(1, device=device, dtype=torch.bool)) 6803 6804 yield SampleInput(make_arg(()), torch.randn((M, M), device=device) > 0) 6805 6806def sample_inputs_matrix_exp(op_info, device, dtype, requires_grad, **kwargs): 6807 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6808 yield SampleInput(make_arg((S, S))) 6809 yield SampleInput(make_arg((S, S, S))) 6810 6811def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False, **kwargs): 6812 make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, 6813 high=None, requires_grad=requires_grad) 6814 test_cases = (((L,), (L,)), 6815 ((S, M), (M,)), 6816 ((M,), (M, S)), 6817 ((S, M), (M, S)), 6818 ((S, 0), (0, M)), 6819 ((S, S, M), (M,)), 6820 ((S, S, M), (M, S)), 6821 ((S, S, 0), (0, S)), 6822 ((M,), (S, M, S)), 6823 ((S, M), (S, M, S)), 6824 ((0, 0), (S, 0, 0)), 6825 ((S, S, M, M), (S, S, M, S)), 6826 ((S, S, M, M), (M,)), 6827 ((M,), (S, S, M, S)), 6828 ((S, S, S), (1, S, S)) 6829 ) 6830 for lhs_shape, rhs_shape in test_cases: 6831 lhs = make_arg(lhs_shape) 6832 rhs = make_arg(rhs_shape) 6833 if not is_rmatmul: 6834 yield SampleInput(lhs, rhs) 6835 else: 6836 yield SampleInput(rhs, lhs) 6837 6838 6839def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.dtype, 6840 requires_grad: bool, 6841 *, variant: str, **kwargs) -> List[SampleInput]: 6842 if variant == 'variadic': 6843 def make_inputs( 6844 tensors: List[torch.Tensor]) -> Tuple[Union[torch.Tensor, 6845 List[torch.Tensor]], 6846 Tuple[torch.Tensor, ...]]: 6847 return tensors 6848 elif variant == 'list': 6849 def make_inputs( 6850 tensors: List[torch.Tensor]) -> Tuple[Union[torch.Tensor, 6851 List[torch.Tensor]], 6852 Tuple[torch.Tensor, ...]]: 6853 return [tensors] 6854 else: 6855 raise ValueError( 6856 'Unsupported variant, must be one of {"variadic", "list"}. ' 6857 f'Got "{variant}".') 6858 6859 SCALAR = torch.Size([]) 6860 VECTOR = torch.Size([3]) 6861 test_cases: List[List[torch.Size]] = [ 6862 [SCALAR], 6863 [VECTOR], 6864 [VECTOR, SCALAR], 6865 [VECTOR, SCALAR, VECTOR], 6866 [VECTOR, SCALAR, VECTOR, SCALAR], 6867 ] 6868 6869 for shapes, indexing in itertools.product(test_cases, {'xy', 'ij'}): 6870 args = make_inputs( 6871 [make_tensor(shape, dtype=dtype, device=device, requires_grad=requires_grad) 6872 for shape in shapes]) 6873 yield SampleInput(*args, indexing=indexing) 6874 6875 6876def sample_inputs_mvlgamma(op_info, device, dtype, requires_grad, **kwargs): 6877 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6878 tensor_shapes = ((S, S), ()) 6879 ns = (1, 2, 3, 4, 5) 6880 6881 # Since the accepted lower bound for input 6882 # to mvlgamma depends on `p` argument, 6883 # the following function computes the lower bound 6884 # which we pass to `make_tensor`. 6885 def compute_min_val(p): 6886 return (p - 1.) / 2 6887 6888 for shape, n in product(tensor_shapes, ns): 6889 min_val = compute_min_val(n) 6890 if not dtype.is_floating_point: 6891 # Round-up minimum value for integral dtypes 6892 min_val += 1 6893 else: 6894 min_val += 2 * torch.finfo(dtype).eps 6895 yield SampleInput(make_arg(shape, low=min_val), args=(n,)) 6896 6897 6898# Since `mvlgamma` has multiple entries, 6899# there are multiple common skips for the additional 6900# entries. Following function is a helper to that end. 6901def skips_mvlgamma(skip_redundant=False): 6902 skips = ( 6903 # outside domain values are hard error for mvlgamma op. 6904 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_float_domains'), 6905 DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 6906 'test_reference_numerics_extremal'), 6907 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 6908 'test_reference_numerics_large', 6909 dtypes=(torch.float16, torch.int8)), 6910 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 6911 'test_reference_numerics_small', 6912 dtypes=(torch.int8,)), 6913 ) 6914 if skip_redundant: 6915 # Redundant tests 6916 skips = skips + ( # type: ignore[assignment] 6917 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), 6918 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), 6919 DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), 6920 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), 6921 ) 6922 return skips 6923 6924 6925# To test reference numerics against multiple values of argument `p`, 6926# we make multiple OpInfo entries with each entry corresponding to different value of p. 6927# We run the op tests from test_ops.py only for `p=1` to avoid redundancy in testing. 6928def make_mvlgamma_opinfo(variant_test_name, domain, skips, sample_kwargs): 6929 return UnaryUfuncInfo('mvlgamma', 6930 ref=reference_mvlgamma if TEST_SCIPY else None, 6931 aliases=('special.multigammaln',), 6932 variant_test_name=variant_test_name, 6933 domain=domain, 6934 decorators=(precisionOverride({torch.float16: 5e-2}),), 6935 dtypes=all_types_and(torch.half, torch.bfloat16), 6936 dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), 6937 sample_inputs_func=sample_inputs_mvlgamma, 6938 supports_forward_ad=True, 6939 supports_fwgrad_bwgrad=True, 6940 promotes_int_to_float=True, 6941 skips=skips, 6942 sample_kwargs=sample_kwargs) 6943 6944 6945def sample_inputs_cumulative_ops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs): 6946 def _make_tensor_helper(shape, low=None, high=None): 6947 return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) 6948 6949 yield SampleInput(_make_tensor_helper((S, S, S)), 0) 6950 yield SampleInput(_make_tensor_helper((S, S, S)), 1) 6951 yield SampleInput(_make_tensor_helper(()), 0) 6952 6953 if supports_dtype_kwargs: 6954 # NOTE: if `dtype` is not same as input, then inplace variants fail with 6955 # `provided dtype must match the dtype of self tensor in cumsum` 6956 yield SampleInput(_make_tensor_helper((S, S, S)), 1, dtype=dtype) 6957 6958 6959def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs): 6960 test_cases = ( 6961 ((), (0, 1, 1)), 6962 ((S, S, S, S), (0, 3, 1)), 6963 ((S, S, S, S), (1, 3, 1)), 6964 ((S, S, S, S), (2, 3, 1)), 6965 ((S, S, S, S), (3, 3, 1)), 6966 ((S, S, S, S), (0, 3, 2)), 6967 ((S, S, S, S), (1, 3, 2)), 6968 ((S, S, S, S), (2, 3, 2)), 6969 ((S, S, S, S), (3, 3, 2)), 6970 ((S, S, S, S), (0, 4, 1)), 6971 ((S, S, S, S), (1, 4, 1)), 6972 ((S, S, S, S), (2, 4, 1)), 6973 ((S, S, S, S), (3, 4, 1)), 6974 ((M,), (0, 3, 1)), 6975 ((M,), (0, 3, 2)), 6976 ((M,), (0, 3, 3)), 6977 ((1000,), (0, 3, 11)), 6978 ((1000,), (0, 2, 27)), 6979 ((10, 10), (0, 1, 2)), 6980 ((10, 10), (1, 2, 3)), 6981 ((10, 10), (1, 2, 2)), 6982 ((S, S, S), (2, 3, 2)), 6983 ) 6984 6985 for shape, arguments in test_cases: 6986 yield SampleInput(make_tensor(shape, dtype=dtype, device=device, 6987 low=None, high=None, 6988 requires_grad=requires_grad), 6989 *arguments) 6990 6991def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=False, **kwargs): 6992 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 6993 6994 if list_args: 6995 cases = ( 6996 ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), 6997 ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), 2),), 6998 ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), -2),) 6999 ) 7000 else: 7001 cases = ( # type: ignore[assignment] 7002 ((S, S, S), (2,)), 7003 ((S, S, S), (S, 1)), 7004 ) 7005 7006 for shape, args in cases: 7007 yield SampleInput(make_arg(shape), args=args) 7008 7009 7010def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs): 7011 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 7012 7013 cases = (((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), 7014 ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3), 0]),)), 7015 ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), 2)), 7016 ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), -2)), 7017 ) 7018 7019 for shape, args in cases: 7020 yield SampleInput(make_arg(shape), args=args) 7021 7022 7023def sample_inputs_msort(op_info, device, dtype, requires_grad, **kwargs): 7024 def apply_grad(t): 7025 if dtype in floating_types_and(torch.float16, torch.bfloat16): 7026 t.requires_grad_(requires_grad) 7027 7028 def large_1d_unique(dtype, device): 7029 res = torch.randperm(L * L * L, dtype=torch.int64, device=device) 7030 res = res.to(dtype) 7031 apply_grad(res) 7032 return res 7033 7034 # Test case for large tensor. 7035 yield SampleInput(large_1d_unique(dtype, device)) 7036 7037 yield SampleInput(make_tensor((S, M, S), dtype=dtype, device=device, 7038 low=None, high=None, 7039 requires_grad=requires_grad)) 7040 7041def sample_inputs_lerp(op_info, device, dtype, requires_grad, **kwargs): 7042 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7043 7044 # no broadcast 7045 yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4) 7046 # broadcast rhs 7047 yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4) 7048 # scalar tensor 7049 yield SampleInput(make_arg(()), make_arg(()), 0.4) 7050 # broadcast rhs scalar-tensor 7051 yield SampleInput(make_arg((S, S)), make_arg(()), 0.4) 7052 # broadcast rhs with weight tensor 7053 yield SampleInput(make_arg((S, S)), make_arg((S,)), make_arg((S, S))) 7054 # broadcast rhs and weight tensor 7055 yield SampleInput(make_arg((S, S)), make_arg((S, 1)), make_arg((S,))) 7056 # broadcast lhs 7057 yield SampleInput(make_arg((S,)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) 7058 # scalar broadcast_lhs 7059 yield SampleInput(make_arg(()), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) 7060 # broadcast all 7061 yield SampleInput(make_arg((S, 1)), make_arg((S, S)), 0.4).with_metadata(broadcasts_input=True) 7062 # tensor broadcast all 7063 yield SampleInput(make_arg((S, 1)), make_arg((S, S)), make_arg((S, 1))).with_metadata( 7064 broadcasts_input=True) 7065 # no broadcast with weight tensor 7066 yield SampleInput(make_arg((S, S)), make_arg((S, S)), make_arg((S, S))) 7067 # broadcast lhs with weight tensor 7068 yield SampleInput(make_arg((S,)), make_arg((S, S)), make_arg((S, S))).with_metadata( 7069 broadcasts_input=True) 7070 # broadcast lhs and weight tensor 7071 yield SampleInput(make_arg((S,)), make_arg((S, S, S)), make_arg((S, S))).with_metadata( 7072 broadcasts_input=True) 7073 # broadcast lhs and weight tensor variant 7074 yield SampleInput(make_arg((S, S)), make_arg((S, S, S)), make_arg((S,))).with_metadata( 7075 broadcasts_input=True) 7076 7077 if dtype.is_complex: 7078 # no broadcast 7079 yield SampleInput(make_arg((S, S)), make_arg((S, S)), 0.4j) 7080 yield SampleInput(make_arg((S, S)), make_arg((S, S)), 1.2 + 0.1j) 7081 # broadcast rhs 7082 yield SampleInput(make_arg((S, S)), make_arg((S,)), 0.4j) 7083 yield SampleInput(make_arg((S, S)), make_arg((S, S)), 5.4 + 9j) 7084 # scalar tensor 7085 yield SampleInput(make_arg(()), make_arg(()), 0.4j) 7086 yield SampleInput(make_arg(()), make_arg(()), 6.1 + 0.004j) 7087 # broadcast rhs scalar-tensor 7088 yield SampleInput(make_arg((S, S)), make_arg(()), 0.4j) 7089 yield SampleInput(make_arg((S, S)), make_arg(()), 1 + 2j) 7090 7091def sample_inputs_tensordot(self, device, dtype, requires_grad, **kwargs): 7092 cases = ( 7093 ((2, 2, 2), (2, 2, 2), (2)), 7094 ((2, 2, 1), (2, 1, 2), ([0, 1], [2, 0])), 7095 ) 7096 for first_shape, second_shape, dims in cases: 7097 yield SampleInput(make_tensor(first_shape, dtype=dtype, device=device, 7098 requires_grad=requires_grad), 7099 make_tensor(second_shape, dtype=dtype, device=device, 7100 requires_grad=requires_grad), 7101 dims=dims) 7102 7103def sample_inputs_kron(op_info, device, dtype, requires_grad, **kwargs): 7104 make_arg = partial( 7105 make_tensor, dtype=dtype, device=device, requires_grad=requires_grad, low=None, high=None) 7106 test_cases = ( 7107 ((S, S), (M, L)), 7108 ) 7109 7110 for input_shape, other_shape in test_cases: 7111 input = make_arg(input_shape) 7112 other = make_arg(other_shape) 7113 yield SampleInput(input, other) 7114 7115def sample_inputs_inner(self, device, dtype, requires_grad, **kwargs): 7116 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7117 yield SampleInput(make_arg(S), make_arg(S)) 7118 yield SampleInput(make_arg(), make_arg(S, S)) 7119 7120def sample_inputs_scatter(op_info, device, dtype, requires_grad, **kwargs): 7121 def _tensor(shape, dtype=dtype, low=None, high=None): 7122 return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) 7123 7124 def _gather(shape, index_dim, max_indices): 7125 return gather_variable(shape, index_dim, max_indices, device=device) 7126 7127 zero = torch.tensor(0, dtype=torch.long, device=device) 7128 test_cases = ( 7129 (_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))), 7130 (_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))), 7131 (_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))), 7132 (_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))), 7133 (_tensor((M, S)), (1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))), 7134 (_tensor((M, S)), (-1, _gather((M, S // 2), 0, S), _tensor((M, S // 2)))), 7135 (_tensor(()), (0, zero.clone().detach(), _tensor(()))), 7136 (_tensor(()), (0, zero.clone().detach(), 2.5)), 7137 ) 7138 7139 for tensor, args in test_cases: 7140 yield SampleInput(tensor, *args) 7141 7142 if not requires_grad: 7143 yield SampleInput(tensor.clone().detach(), *args, reduce='add') 7144 7145 if dtype.is_floating_point: 7146 yield SampleInput(tensor.clone().detach(), *args, reduce='multiply') 7147 7148def sample_inputs_scatter_add(op_info, device, dtype, requires_grad, **kwargs): 7149 def _tensor(shape, dtype=dtype, low=None, high=None): 7150 return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) 7151 7152 def _gather(shape, index_dim, max_indices): 7153 return gather_variable(shape, index_dim, max_indices, device=device) 7154 7155 zero = torch.tensor(0, dtype=torch.long, device=device) 7156 yield SampleInput(_tensor((M, S)), 0, _gather((S, S), 1, M), _tensor((S, S))) 7157 yield SampleInput(_tensor((M, S)), 1, _gather((S, S), 0, S), _tensor((S, S))) 7158 yield SampleInput(_tensor((M, S)), -1, _gather((S, S), 0, S), _tensor((S, S))) 7159 yield SampleInput(_tensor((M, S)), 0, _gather((M, S // 2), 1, M), _tensor((M, S // 2))) 7160 yield SampleInput(_tensor((M, S)), 1, _gather((M, S // 2), 0, S), _tensor((M, S // 2))) 7161 yield SampleInput(_tensor((M, S)), -1, _gather((M, S // 2), 0, S), _tensor((M, S // 2))) 7162 yield SampleInput(_tensor(()), 0, zero.clone().detach(), _tensor(())) 7163 7164def sample_inputs_scatter_reduce(op_info, device, dtype, requires_grad, **kwargs): 7165 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7166 gather = partial(gather_variable, device=device) 7167 7168 zero = torch.tensor(0, dtype=torch.long, device=device) 7169 test_cases = ( 7170 ((M, S), 0, gather((S, S), 1, M), (S, S)), 7171 ((M, S), 1, gather((S, S), 0, S), (S, S)), 7172 ((M, S), -1, gather((S, S), 0, S), (S, S)), 7173 ((M, S), 0, gather((M, S // 2), 1, M), (M, S // 2)), 7174 ((M, S), 1, gather((M, S // 2), 0, S), (M, S // 2)), 7175 ((M, S), -1, gather((M, S // 2), 0, S), (M, S // 2)), 7176 ((), 0, zero.clone().detach(), ()), 7177 ) 7178 7179 reduce = op_info.variant_test_name 7180 for (inp_shape, dim, index, src_shape), include_self in product(test_cases, [False, True, False]): 7181 yield SampleInput(make_arg(inp_shape), 7182 args=(dim, index, make_arg(src_shape), reduce), 7183 kwargs={'include_self': include_self}) 7184 7185 7186 # Sample inputs to test edge cases for backward 7187 # Check that gradients are propagated correctly for prod when zeros in self/src are reduced 7188 if requires_grad and reduce == 'prod': 7189 # This sample tests gradients for the following cases 7190 # (a) 1 zero reduced (from src (self[0, 1], self[1, 1]), from self (self[0, 0], self[2, 0])) 7191 # (b) 2 zeros reduced (1 from src and 1 from self (self[1, 0]) 7192 # (c) no zeros reduced (self([2, 1])) 7193 # (d) 2 zeros reduced (both from src) is tested in test/test_autograd.py 7194 # test_scatter_index_reduce_prod_gradgrad_error as this case is not supported for gradgrad 7195 input = torch.tensor([[0, 13], [0, 17], [0, 19]], dtype=dtype, device=device, requires_grad=requires_grad) 7196 src = torch.tensor([[0, 1, 2, 3], [0, 4, 0, 1], [2, 3, 5, 6]], dtype=dtype, device=device, requires_grad=requires_grad) 7197 idx = torch.tensor([[1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.long, device=device) 7198 7199 yield SampleInput(input, 7200 args=(1, idx, src, reduce), 7201 kwargs={'include_self': True}) 7202 7203def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode='lengths', **kwargs): 7204 def _tensor(shape, dtype=dtype, low=None, high=None): 7205 return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) 7206 7207 zero = torch.tensor(0, dtype=torch.long, device=device) 7208 test_cases = ( 7209 # inp_shape, dim, lengths, unsafe 7210 ((S,), 0, [0, 1, 2, 2], False), 7211 ((S,), 0, [0, 1, 2, 2], True), 7212 ((S,), 0, [2, 0, 3, 0], False), 7213 ((S, S), 0, [0, 1, 2, 2], False), 7214 # test when lengths do not sum to dim size 7215 ((M, S, S), 0, [1, 2, 0, 6, 0], True), 7216 # test for higher dimensions 7217 ((S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False), 7218 ((S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False), 7219 ((S, S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False), 7220 ((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False), 7221 ) 7222 7223 reductions = ["max", "mean", "min", "sum", "prod"] 7224 for args, reduce, initial in product(test_cases, reductions, [1, 2]): 7225 inp_shape, dim, lengths, unsafe = args 7226 lengths_t = torch.tensor(lengths, dtype=torch.long, device=device) 7227 sample_input_kwargs = {'axis': dim, 'unsafe': unsafe, 'initial': initial} 7228 if mode == 'lengths': 7229 sample_input_kwargs['lengths'] = lengths_t 7230 elif mode == 'offsets': 7231 zeros_shape = list(lengths_t.shape) 7232 zeros_shape[dim] = 1 7233 offsets_t = torch.cat((lengths_t.new_zeros(zeros_shape), lengths_t), dim).cumsum_(dim) 7234 sample_input_kwargs['offsets'] = offsets_t 7235 else: 7236 raise RuntimeError(f"mode most be one of 'offsets' or 'lengths' got '{mode}'.") 7237 yield SampleInput(_tensor(inp_shape), 7238 args=(reduce,), 7239 kwargs=sample_input_kwargs) 7240 7241 7242def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs): 7243 make_arg = partial(make_tensor, dtype=dtype, device=device, 7244 low=None, high=None, requires_grad=requires_grad) 7245 yield SampleInput(make_arg((S, S, S))) 7246 yield SampleInput(make_arg(())) 7247 yield SampleInput(make_arg((S, S, S), noncontiguous=True)) 7248 7249def sample_inputs_unravel_index(op_info, device, dtype, requires_grad, **kwargs): 7250 make_arg = partial(make_tensor, dtype=dtype, device=device, 7251 low=None, high=None, requires_grad=requires_grad) 7252 yield SampleInput( 7253 torch.tensor( 7254 [[3, 8, 13], [0, 5, 10]], 7255 device=device, 7256 dtype=dtype), 7257 (4, 5)) 7258 yield SampleInput( 7259 torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype), 7260 (4, 2**30)) 7261 yield SampleInput( 7262 torch.tensor([[3, 8, 13], [0, 5, 10]], device=device, dtype=dtype), 7263 (2**30, 4)) 7264 yield SampleInput( 7265 torch.tensor(2, device=device, dtype=dtype), 7266 (2, 2)) 7267 max_val = 2**(8 * dtype.itemsize - (1 if dtype.is_signed else 0)) - 1 7268 yield SampleInput( 7269 torch.tensor(max_val - 1, device=device, dtype=dtype), 7270 (1, max_val)) 7271 yield SampleInput( 7272 torch.tensor([22, 41, 37], device=device, dtype=dtype), 7273 (7, 6)) 7274 yield SampleInput( 7275 torch.tensor(min(1621, max_val), device=device, dtype=dtype), 7276 (6, 7, 8, 9)) 7277 yield SampleInput( 7278 torch.tensor([], device=device, dtype=dtype), 7279 (10, 3, 5)) 7280 yield SampleInput( 7281 torch.tensor( 7282 [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0]], 7283 device=device, 7284 dtype=dtype), 7285 (5, 8)) 7286 yield SampleInput( 7287 torch.tensor( 7288 [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0], [1, 3, 1, 0, 9, 5]], 7289 device=device, 7290 dtype=dtype), 7291 (5, 8, 10)) 7292 yield SampleInput( 7293 torch.tensor(0, device=device, dtype=dtype), 7294 ()) 7295 7296 a = np.array([[2, 4, 5, 6], [7, 8, 1, 15]]) 7297 b = np.array([[3, 2, 7, 6], [10, 12, 8, 9]]) 7298 _, i1, i2 = np.intersect1d(a, b, assume_unique=True, return_indices=True) 7299 yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape) 7300 yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape) 7301 7302 a = np.array([[2, 4, 5, 6, 6], [4, 7, 8, 7, 2]]) 7303 b = np.array([[3, 2, 7, 7], [10, 12, 8, 7]]) 7304 _, i1, i2 = np.intersect1d(a, b, return_indices=True) 7305 yield SampleInput(torch.tensor(i1, device=device, dtype=dtype), a.shape) 7306 yield SampleInput(torch.tensor(i2, device=device, dtype=dtype), b.shape) 7307 7308 7309def sample_inputs_tril_triu(op_info, device, dtype, requires_grad, **kwargs): 7310 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7311 cases = (((M, M), ()), 7312 ((M, M), (2,),), 7313 ((M, S), ()), 7314 ((M, S), (-1,)), 7315 ((M, M), (2,),), 7316 ((S, M, S), ()), 7317 ((S, M, S), (2,)), 7318 ((3, 3, S, S), ()),) 7319 7320 for shape, args in cases: 7321 yield SampleInput(make_arg(shape), args=args) 7322 7323def error_inputs_tril_triu(opinfo, device, **kwargs): 7324 make_arg = partial(make_tensor, device=device, dtype=torch.float32) 7325 7326 # error inputs for input.ndim <= 2 7327 yield ErrorInput(SampleInput(make_arg((4,))), error_regex="input tensor must have at least 2 dimensions") 7328 7329def sample_inputs_trilu_indices(op_info, device, dtype, requires_grad, **kwargs): 7330 # (row, col, offset) 7331 args_list = ((0, 0), 7332 (20, 0), 7333 (0, 20), 7334 (20, 21, 0), 7335 (20, 21, 7), 7336 (20, 21, -7), 7337 # Large test cases below are deliberately commented out to speed up CI 7338 # tests and to avoid OOM error. When modifying implementations of 7339 # tril_indices and triu_indices, please enable these tests and make sure 7340 # they pass. 7341 # (2, 68435455, 3), 7342 # (5000, 5000), 7343 # (5000, 5000, 1234), 7344 # (5000, 5000, -1233), 7345 ) 7346 for args in args_list: 7347 yield SampleInput(args[0], args=args[1:], kwargs={"dtype": dtype, "device": device}) 7348 7349def sample_inputs_clone_contiguous(op_info, device, dtype, requires_grad, **kwargs): 7350 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7351 7352 yield SampleInput(make_arg((S, M, S))) 7353 yield SampleInput(make_arg(())) 7354 7355def reference_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs): 7356 # NOTE: the default memory format for clone is torch.preserve_format, for contiguous it's torch.contiguous_format 7357 # This exploits that default to test torch.preserve_format for clone, without causing an error when testing contiguous 7358 yield from sample_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs) 7359 7360 shapes = ( 7361 (3, 5, 6), 7362 (1, 1, 3, 5, 6), 7363 (1, 1, 3, 5, 6, 1, 1), 7364 (1, 0, 3, 5, 0, 2), 7365 (1, 0, 3, 5, 0, 0, 1, 1, 2), 7366 (), 7367 ) 7368 7369 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7370 for shape in shapes: 7371 yield SampleInput(make_arg(shape)) 7372 yield SampleInput(make_arg(shape).transpose(0, -1)) 7373 yield SampleInput(make_arg(shape, noncontiguous=True)) 7374 yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1)) 7375 7376 yield SampleInput(make_arg(shape), kwargs={'memory_format': torch.contiguous_format}) 7377 yield SampleInput(make_arg(shape).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format}) 7378 yield SampleInput(make_arg(shape, noncontiguous=True), kwargs={'memory_format': torch.contiguous_format}) 7379 yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), kwargs={'memory_format': torch.contiguous_format}) 7380 7381 # shape, strides, offset 7382 strided_cases = ( 7383 ((5, 6, 2), (1, 1, 7), 2), 7384 ((5, 5, 4), (1, 1, 7), 2), 7385 ((5, 5, 2), (4, 5, 7), 3), 7386 ((5, 5, 2), (5, 5, 7), 3), 7387 ((5, 5, 2), (5, 5, 5), 3), 7388 ((9, 5, 2), (0, 1, 7), 3), 7389 ) 7390 7391 for shape, strides, offset in strided_cases: 7392 yield SampleInput(make_arg(500,).as_strided(shape, strides, offset)) 7393 yield SampleInput(make_arg(500,).as_strided(shape, strides, offset), kwargs={'memory_format': torch.contiguous_format}) 7394 7395 # channels last 2D 7396 yield SampleInput(make_arg((2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last}) 7397 a = make_arg((2, 2, 2, 2)).permute(0, 3, 1, 2) 7398 yield SampleInput(a, kwargs={'memory_format': torch.channels_last}) 7399 7400 # channels last 3D 7401 yield SampleInput(make_arg((2, 2, 2, 2, 2)), kwargs={'memory_format': torch.channels_last_3d}) 7402 a = make_arg((2, 2, 2, 2, 2)).permute(0, 4, 1, 2, 3) 7403 yield SampleInput(a, kwargs={'memory_format': torch.channels_last_3d}) 7404 7405 7406def sample_inputs_sum_to_size(op_info, device, dtype, requires_grad, **kwargs): 7407 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7408 7409 # list of tuples (shape, shape) defining the shapes of the input and output tensors 7410 sample_shapes = [ 7411 ((), ()), 7412 ((S,), (1,)), 7413 ((S, S), (1, 1)), 7414 ((S, S), (1, S)), 7415 ((S, S), (S, S)), 7416 ((S, S, S), (S, 1, S)), 7417 ] 7418 7419 for input_shape, output_shape in sample_shapes: 7420 yield SampleInput(make_arg(input_shape), args=(output_shape,)) 7421 if output_shape == (): 7422 continue 7423 yield SampleInput(make_arg(input_shape), args=(list(output_shape),)) 7424 yield SampleInput(make_arg(input_shape), args=(*output_shape,)) 7425 7426 7427def error_inputs_sum_to_size(op_info, device, **kwargs): 7428 shape = (M, S, M) 7429 err_msg = "is not expandable to size" 7430 si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M, M)) 7431 yield ErrorInput(si, error_regex=err_msg) 7432 7433 shape = (M + 1, S, S, M) 7434 err_msg = "is not expandable to size" 7435 si = SampleInput(make_tensor(shape, device=device, dtype=torch.float32), args=(M + 1, 1)) 7436 yield ErrorInput(si, error_regex=err_msg) 7437 7438 7439def sample_inputs_resize_ops(op_info, device, dtype, requires_grad, **kwargs): 7440 make_arg = partial(make_tensor, dtype=dtype, device=device) 7441 cases = (((S, S, S), (S * S, S)), 7442 ((), ()), 7443 ((), (1, 1, 1)), 7444 ) 7445 7446 for shape, args_or_shape in cases: 7447 # Update `args` based on operator 7448 if op_info.name == 'resize_': 7449 # resize_ takes shape/tuple of ints, 7450 args = (args_or_shape, ) 7451 elif op_info.name == 'resize_as_': 7452 # resize_as_ takes another tensor 7453 args = (make_arg(shape, requires_grad=False), ) # type:ignore[assignment] 7454 else: 7455 raise ValueError("sample_inputs_resize_ops is being used with incorrect operator") 7456 7457 yield SampleInput(make_arg(shape, requires_grad=requires_grad), args=args) 7458 7459def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs): 7460 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7461 7462 cases = ( 7463 # a, b, is_tensor_supported 7464 ((S, S, S), (S * S, S), True), 7465 ((S * S, S), (S, S, S), True), 7466 ((S * S, S), (S, -1, S), False), # neg index 7467 ((S * S * 2, S), (S, -1), False), # neg index 7468 ((S,), (S,), True), 7469 ((), (), False), # empty 7470 ((), (1,), True), 7471 ) 7472 7473 for a, b, is_tensor_supported in cases: 7474 # skip unsupported cases 7475 if kwargs.get("tensor_arg") and not is_tensor_supported: 7476 continue 7477 7478 # convert to tensor 7479 if kwargs.get("tensor_arg"): 7480 b = make_arg(b, requires_grad=False) 7481 7482 yield SampleInput(make_arg(a), args=(b,)) 7483 7484def reference_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs): 7485 yield from sample_inputs_view_reshape(op, device, dtype, requires_grad, **kwargs) 7486 7487 cases = ( 7488 # a, b, is_tensor_supported 7489 ((125,), (25, 5), True), 7490 ((25, 25), (1, 5, 5, 1, 5, 1, 5, 1), True), 7491 ((16, 32), (2, 4, 1, 4, 4, 1, 4), True), 7492 ((16, 12), (12, 16), True), 7493 ((1, 16, 12), (12, 16), True), 7494 ((1, 5, 1, 5), (25, 1), True), 7495 ((2, 4, 2), (4, 4), True), 7496 ((1, 4), (1, 1, 2, 1, 2), True), 7497 ((3, 5, 7), (7, 5, 3), True), 7498 ((1,), (), False), # empty 7499 ((5, 0, 2, 3), (5, 0, 2, 3), True), 7500 ((2, 1, 0, 3, 1), (5, 0), True), 7501 ((1,), (), False), # empty 7502 ((4, 5, 6), (4, 5, 6, 1, 1, 1), True), 7503 ((), (1, 1, 1, 1), False), # empty 7504 ) 7505 7506 irreversible_cases = ( 7507 ((), (-1,), False), # neg index, empty 7508 ((4, 7, 9, 1, 1), (1, 4, 3, -1, 1), False), # neg index 7509 ) 7510 7511 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7512 for a, b, is_tensor_supported in cases: 7513 # skip unsupported cases 7514 if kwargs.get("tensor_arg") and not is_tensor_supported: 7515 continue 7516 7517 if kwargs.get("tensor_arg"): 7518 # convert to tensor 7519 yield SampleInput(make_arg(a), args=(make_arg(b, requires_grad=False),)) 7520 yield SampleInput(make_arg(b), args=(make_arg(a, requires_grad=False),)) 7521 else: 7522 yield SampleInput(make_arg(a), args=(b,)) 7523 yield SampleInput(make_arg(b), args=(a,)) 7524 7525 for a, b, is_tensor_supported in irreversible_cases: 7526 # skip unsupported cases 7527 if kwargs.get("tensor_arg") and not is_tensor_supported: 7528 continue 7529 7530 # convert to tensor 7531 if kwargs.get("tensor_arg"): 7532 b = make_arg(b, requires_grad=False) 7533 7534 yield SampleInput(make_arg(a), args=(b,)) 7535 7536def error_inputs_view_reshape(op, device, **kwargs): 7537 7538 cases = ( 7539 # a, b, is_tensor_supported 7540 # Reshape to different numel 7541 ((2,), (), False), # empty 7542 ((1, 3, 0), (), False), # empty 7543 ((4, 3), (4, 2), True), 7544 ((1, 3, 5), (5, 2, 2), True), 7545 # No valid inference 7546 ((1, 3, 5), (5, -1, 2), False), # neg index 7547 # Two inferred shapes 7548 ((1, 3, 5), (5, -1, -1), False), # neg index 7549 ((1), (0, -1), False), # neg index 7550 ((0, 5), (0, -1), False), # neg index 7551 ) 7552 7553 make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False) 7554 for a, b, is_tensor_supported in cases: 7555 # skip unsupported cases 7556 if kwargs.get("tensor_arg") and not is_tensor_supported: 7557 continue 7558 7559 if b == (5, -1, -1): 7560 error_regex = "only one dimension can be inferred" 7561 elif a == (0, 5): 7562 error_regex = (r"cannot reshape tensor of 0 elements into shape " 7563 r"\[0, -1\] because the unspecified dimension size " 7564 r"-1 can be any value and is ambiguous") 7565 else: 7566 # to avoid having issues with a regex 7567 shape = ', '.join(map(str, b)) 7568 size = a if type(a) is int else functools.reduce(operator.mul, a, 1) 7569 error_regex = rf"shape '\[{shape}\]' is invalid for input of size {size}" 7570 7571 # convert to tensor 7572 if kwargs.get("tensor_arg"): 7573 b = make_arg(b, requires_grad=False) 7574 7575 yield ErrorInput(SampleInput(make_arg(a), args=(b,)), error_type=Exception, 7576 error_regex=error_regex) 7577 7578 7579def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs): 7580 input_list = [] 7581 shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),) 7582 make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7583 for shape in shapes: 7584 yield SampleInput(make_tensor_partial(shape)) 7585 yield SampleInput([make_tensor_partial(shape) for shape in shapes]) 7586 7587def sample_inputs_column_stack(op_info, device, dtype, requires_grad, **kwargs): 7588 cases: Tuple[tuple, tuple] = ( # type: ignore[assignment] 7589 ((S, 2, 1), (S, 3, 1)), 7590 ((S), (S, 5)), ((), (1, S)) 7591 ) 7592 make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7593 for shape1, shape2 in cases: 7594 yield SampleInput([make_tensor_partial(shape1), make_tensor_partial(shape2)]) 7595 7596def sample_inputs_flatten(op_info, device, dtype, requires_grad, **kwargs): 7597 shapes = ((S, S, S), (S, S), (S, ), (),) 7598 make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7599 for shape in shapes: 7600 yield SampleInput(make_tensor_partial(shape)) 7601 if len(shape) > 1: 7602 yield SampleInput(make_tensor_partial(shape), start_dim=1, end_dim=-1) 7603 7604def reference_inputs_flatten(op, device, dtype, requires_grad, **kwargs): 7605 yield from sample_inputs_flatten(op, device, dtype, requires_grad, **kwargs) 7606 7607 # shape x start_dim x end_dim 7608 cases = ( 7609 ((5, 4, 0, 1, 3, 7), 1, 3), 7610 ((5, 4, 0, 1, 3, 7), 4, 5), 7611 ((5, 4, 1, 1, 3, 7), 2, 3), 7612 ((), 0, -1), 7613 ((1,), 0, -1), 7614 ((3, 7, 5), 1, 2), 7615 ((4, 5), 1, 1), 7616 ((1, 5, 5, 1, 5, 1, 5, 1), 0, 2), 7617 ((1, 5, 5, 1, 5, 1, 5, 1), 3, -1), 7618 ((1, 5, 5, 1, 5, 7, 5, 1), -2, -1), 7619 ((2, 4, 2), 0, 1), 7620 ((4, 2, 2), 1, 2), 7621 ((0, 3, 4, 5), 1, 3), 7622 ) 7623 7624 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7625 for shape, start, end in cases: 7626 yield SampleInput(make_arg(shape), args=(start, end,)) 7627 yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1), args=(start, end,)) 7628 yield SampleInput(make_arg(shape).transpose(0, -1), args=(start, end,)) 7629 7630def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs): 7631 # in_shape, dim, sizes 7632 args = (((8,), 0, (8,)), 7633 ((8,), 0, (4, 2)), 7634 ((8,), -1, (2, 2, 2)), 7635 ((8,), -1, (-1, 2)), 7636 ((3, 6, 2), 1, (2, 3)), 7637 ((3, 6, 2), -2, (2, 3)), 7638 ((3, 6, 2), -2, (-1, 3)), 7639 ((3, 2, 12), 2, (3, 2, 2)), 7640 ((4, 0), 0, (2, 2)), 7641 ((4, 0), 1, (2, 0, 0, 0)), 7642 ) 7643 make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7644 for in_shape, dim, sizes in args: 7645 yield SampleInput(make_tensor_partial(in_shape), args=(dim, sizes)) 7646 7647 7648def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs): 7649 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7650 7651 cases = (((S, S, S), (1, 2)), 7652 ((S, S, S), (-1, 2)), 7653 ((S, S, S), (-1, -1)), 7654 ((S, S, S), (1, -1)), 7655 ((S,), (0, 2)) 7656 ) 7657 7658 for shape, args in cases: 7659 yield SampleInput(make_arg(shape), args=args) 7660 7661 7662def sample_inputs_select_scatter(op_info, device, dtype, requires_grad, **kwargs): 7663 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7664 7665 cases = (((S, S, S), (S, S), (1, 2)), 7666 ((S, S, S), (S, S), (-1, 2)), 7667 ((S, S, S), (S, S), (-1, -1)), 7668 ((S, S, S), (S, S), (1, -1)), 7669 ((S,), (), (0, 2)) 7670 ) 7671 7672 for input_shape, src_shape, args in cases: 7673 input_ = make_arg(input_shape) 7674 src = make_arg(src_shape) 7675 yield SampleInput(input_, args=(src, *args)) 7676 7677 7678def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs): 7679 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7680 7681 cases = (((L, L, L), (L, L, L,), (0, 0, L, 1)), 7682 ((L, L, L), (L // 2, L, L,), (0, L // 2, L, 1)), 7683 ((L, L, L), (L // 4, L, L,), (0, L // 2, L, 2)), 7684 ((L, L, L), (L, L, L,), (1, 0, L, 1)), 7685 ((L, L, L), (L, L // 2, L,), (1, L // 2, L, 1)), 7686 ((L, L, L), (L, L // 4, L,), (1, L // 2, L, 2)), 7687 ((L, L, L), (L, L, L,), (2, 0, L, 1)), 7688 ((L, L, L), (L, L, L // 2,), (2, L // 2, L, 1)), 7689 ((L, L, L), (L, L, L // 4,), (2, L // 2, L, 2)), 7690 ) 7691 7692 for input_shape, src_shape, args in cases: 7693 input_ = make_arg(input_shape) 7694 src = make_arg(src_shape) 7695 yield SampleInput(input_, args=(src, *args)) 7696 7697def sample_inputs_expand(op_info, device, dtype, requires_grad, **kwargs): 7698 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7699 7700 cases = (((S, 1, 1), (S, S, S)), 7701 ((S, 1, S), (S, S, S)), 7702 ((S, 1, S), (-1, S, -1)), 7703 ((S, 1, S), (-1, S, S)), 7704 ((S, 1), (S, S, S)), 7705 ((1,), (S, S, S)), 7706 ((1, S), (1, 1, S)), 7707 ((), ()), 7708 ((), (1, 3, 2)), 7709 ) 7710 7711 for case in cases: 7712 shape, args = case 7713 yield SampleInput(make_arg(shape), args=(args,)) 7714 7715def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): 7716 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7717 7718 shapes = ((), 7719 (2, 3)) 7720 memory_format_options = [None, torch.contiguous_format] 7721 7722 for shape, memory_format in itertools.product(shapes, memory_format_options): 7723 yield SampleInput(make_arg(shape), 7724 kwargs={'memory_format': memory_format} if memory_format else {}) 7725 yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last}) 7726 7727def sample_inputs_byte(op_info, device, dtype, requires_grad, **kwargs): 7728 make_arg = partial(make_tensor, dtype=dtype, device=device, low=0, high=255, requires_grad=requires_grad) 7729 7730 shapes = ((), 7731 (2, 3)) 7732 memory_format_options = [None, torch.contiguous_format] 7733 7734 for shape, memory_format in itertools.product(shapes, memory_format_options): 7735 yield SampleInput(make_arg(shape), 7736 kwargs={'memory_format': memory_format} if memory_format else {}) 7737 yield SampleInput(make_arg((2, 3, 2, 3)), kwargs={'memory_format': torch.channels_last}) 7738 7739def sample_inputs_expand_as(op_info, device, dtype, requires_grad, **kwargs): 7740 make_arg = partial(make_tensor, dtype=dtype, device=device) 7741 7742 cases = (((S, 1, 1), (S, S, S)), 7743 ((), ()), 7744 ((), (1, 1)), 7745 ) 7746 7747 for shape, shape_other in cases: 7748 yield SampleInput(make_arg(shape, requires_grad=requires_grad), 7749 args=(make_arg(shape_other, requires_grad=False),)) 7750 7751 7752def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs): 7753 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7754 7755 def make_bool_mask(shape): 7756 # Make sure atleast one element is nonzero, 7757 # except for empty tensor 7758 mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) 7759 7760 if mask_t.numel() == 0: 7761 return mask_t 7762 elif mask_t.numel() == 1: 7763 mask_t.fill_(True) 7764 return mask_t 7765 7766 if mask_t.sum() == 0: 7767 def random_index(shape): 7768 return tuple(random.randrange(0, max_idx) for max_idx in shape) 7769 7770 mask_t[random_index(mask_t.shape)] = True 7771 return mask_t 7772 7773 return mask_t 7774 7775 cases = (((M, M), (M, M), (M, M), False), 7776 ((M, 1, M), (M, M), (M, M, 1), True), 7777 ((), (), (), False), 7778 ((M, 1, M), (), (M, M, 1), True), 7779 ((), (M, M), (), True), 7780 ((), (2), (1, 1), True), 7781 ) 7782 7783 for shape, mask_shape, other_shape, broadcasts_input in cases: 7784 yield SampleInput(make_arg(shape), 7785 args=(make_bool_mask(mask_shape), make_arg(other_shape)), 7786 broadcasts_input=broadcasts_input) 7787 7788# TODO: add reference inputs for where(condition) signature 7789def reference_inputs_where(op, device, dtype, requires_grad, **kwargs): 7790 yield from sample_inputs_where(op, device, dtype, requires_grad, **kwargs) 7791 7792 make_cond = partial(make_tensor, dtype=torch.bool, device=device, requires_grad=requires_grad) 7793 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7794 7795 # noncontiguous 7796 c = make_cond((10, 3), noncontiguous=True) 7797 a = make_arg((10, 1), noncontiguous=True) 7798 b = make_arg((3, 10, 3)).transpose(0, -1) 7799 7800 # NOTE that the OpInfo for where takes samples of the form a, cond, b 7801 yield SampleInput(a, args=(c, b)) 7802 7803 # type promoting 7804 other_dtype = torch.double if dtype is not torch.double else torch.long 7805 c = make_cond((10, 3), noncontiguous=True) 7806 a = make_arg((10, 1), dtype=torch.long) 7807 b = make_arg((10, 1)) 7808 7809 yield SampleInput(a, args=(c, b)) 7810 7811 # two python scalars 7812 c = make_cond((10, 3), noncontiguous=True) 7813 a = make_arg((1,)).item() 7814 b = make_arg((1,)).item() 7815 7816 yield SampleInput(a, args=(c, b)) 7817 7818 # NaN propagation 7819 if dtype.is_floating_point or dtype.is_complex: 7820 if dtype.is_floating_point: 7821 nan = float('nan') 7822 else: 7823 # dtype.is_complex 7824 nan = complex(float('nan'), float('nan')) 7825 c = make_cond((1, 10, 3)) 7826 a = make_arg((10, 3), noncontiguous=True) 7827 a[2, 1] = nan 7828 b = make_arg((1, 3)) 7829 b[0, 2] = nan 7830 7831 yield SampleInput(a, args=(c, b)) 7832 7833 # Python scalars type promotion 7834 for scalar in (0, 0.0, 2j, False): 7835 yield SampleInput(scalar, args=(c, b)) 7836 yield SampleInput(a, args=(c, scalar)) 7837 7838 7839def error_inputs_where(op_info, device, **kwargs): 7840 shape = (S,) 7841 err_msg = "Expected all tensors to be on the same device" 7842 for devices in product(('cpu', device), repeat=3): 7843 if len(set(devices)) == 2: 7844 si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32), 7845 args=(make_tensor(shape, dtype=torch.bool, device=devices[1]), 7846 make_tensor(shape, device=devices[2], dtype=torch.float32))) 7847 yield ErrorInput(si, error_regex=err_msg) 7848 7849def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs): 7850 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7851 7852 sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) 7853 7854 inputs = [] 7855 for shape in sizes: 7856 # construct input without any non-zero elements 7857 zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) 7858 inputs.append(zeros) 7859 7860 # construct input with mixed zero and non-zero elements 7861 mixed = make_arg(shape).requires_grad_(False) 7862 mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) 7863 mixed[mask_t] = 0 7864 inputs.append(mixed) 7865 7866 for input_t, as_tuple in product(inputs, [False, True]): 7867 yield SampleInput(input_t.clone().requires_grad_(requires_grad), 7868 kwargs=dict(as_tuple=as_tuple)) 7869 7870def sample_inputs_nonzero_static(op_info, device, dtype, requires_grad, **kwargs): 7871 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7872 7873 sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) 7874 7875 inputs = [] 7876 for shape in sizes: 7877 # construct input without any non-zero elements 7878 zeros = torch.zeros(shape, dtype=dtype, device=device, requires_grad=requires_grad) 7879 inputs.append(zeros) 7880 7881 # construct input with mixed zero and non-zero elements 7882 mixed = make_arg(shape).requires_grad_(False) 7883 mask_t = make_tensor(shape, dtype=torch.bool, device=device, requires_grad=False) 7884 mixed[mask_t] = 0 7885 inputs.append(mixed) 7886 7887 nonzero_sizes = [0, 1, XS, S, M] 7888 7889 for input_t, nonzero_size in product(inputs, nonzero_sizes): 7890 yield SampleInput(input_t.clone().requires_grad_(requires_grad), 7891 kwargs=dict(size=nonzero_size)) 7892 7893def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): 7894 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7895 7896 cases = (((S, S, S), (2,)), 7897 ((S, S, S), (S, 1)), 7898 ((S, S, S), (S, -1))) 7899 7900 for case in cases: 7901 shape, args = case 7902 yield SampleInput(make_arg(shape), args=args) 7903 7904def reference_inputs_chunk(op, device, dtype, requires_grad, **kwargs): 7905 yield from sample_inputs_chunk(op, device, dtype, requires_grad, **kwargs) 7906 7907 make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) 7908 7909 # shape x chunks x dim 7910 cases = ( 7911 ((13, 9, 11), 17, -1), 7912 ((13, 9, 11), 11, -1), 7913 ((13,), 12, -1), 7914 ((15,), 12, -1), 7915 ((15,), 7, 0), 7916 ((15,), 9, 0), 7917 ((3, 7), 9, 1), 7918 ((3, 7), 9, 0), 7919 ((3, 7), 2, 0), 7920 ((3, 7), 3, 0), 7921 ((3, 7), 1, 0), 7922 ((3, 7), 1, 1), 7923 ((4, 4), 2, 0), 7924 ) 7925 7926 for shape, chunks, dim in cases: 7927 yield SampleInput(make_arg(shape), args=(chunks, dim)) 7928 7929def sample_inputs_kthvalue(op_info, device, dtype, requires_grad, **kwargs): 7930 def _tensor(shape, dtype=dtype, low=None, high=None): 7931 return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) 7932 7933 test_cases = [ 7934 ((S, S, S), (2,)), 7935 ((S, S, S), (2, 1,)), 7936 ((S, S, S), (2, -1,)), 7937 ((S, S, S), (2, 1, True,)), 7938 ((S, S, S), (2, -1, True,)), 7939 ((S,), (2, 0,)), 7940 ((S,), (2, 0, True,)), 7941 ((), (1,)), 7942 ((), (1, 0,)), 7943 ((), (1, 0, True)), 7944 ] 7945 7946 yield from (SampleInput(_tensor(tensor), *args) for tensor, args in test_cases) 7947 7948def error_inputs_kthvalue(op_info, device, **kwargs): 7949 # tests overlapping output fails 7950 t = make_tensor(10, dtype=torch.float32, device=device) 7951 indices = torch.empty((), device=device, dtype=torch.long) 7952 yield ErrorInput(SampleInput(t, 5, out=(t, indices)), 7953 error_regex="unsupported operation") 7954 7955 k_out_of_range_err = "selected number k out of range for dimension" 7956 yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3, 0), 7957 error_regex=k_out_of_range_err) 7958 yield ErrorInput(SampleInput(torch.randn(2, 2, device=device), 3), 7959 error_regex=k_out_of_range_err) 7960 yield ErrorInput(SampleInput(torch.tensor(2, device=device), 3), 7961 error_regex=k_out_of_range_err) 7962 7963def sample_inputs_dropout(op_info, device, dtype, requires_grad, *, 7964 train=None, valid_input_dim=None, **kwargs): 7965 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 7966 7967 if valid_input_dim: 7968 cases = ((S,) * i for i in valid_input_dim) 7969 else: 7970 cases = ((S, S), (S,), ()) 7971 p_vals = [0.0, 0.5, 1.0] 7972 # This is to handle special case for feature_alpha_dropout which has different 7973 # supported dtypes depending on `train` parameter 7974 training_vals = [train] if train is not None else [True, False] 7975 7976 for case, p, training in product(cases, p_vals, training_vals): 7977 yield SampleInput(make_arg(case), p=p, training=training) 7978 yield SampleInput(make_arg(case)) 7979 7980def sample_inputs_dropout_backward(op_info, device, dtype, requires_grad, **kwargs): 7981 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 7982 make_mask = partial(make_tensor, device=device, dtype=torch.bool, requires_grad=False) 7983 7984 cases = ((S, S, S, S), (S,), ()) 7985 scale_vals = [0.0, 1.0, 2.0] 7986 7987 for case, scale in product(cases, scale_vals): 7988 yield SampleInput(make_arg(case), make_mask(case), scale) 7989 7990def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs): 7991 def make_input(shape): 7992 return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) 7993 7994 def make_long_input(shape, *, low, high, noncontiguous=False): 7995 return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high, 7996 noncontiguous=noncontiguous) 7997 7998 def make_per_sample_weight(flag, idx): 7999 # a tensor of float / double weights, or None 8000 # to indicate all weights should be taken to be 1 8001 if flag: 8002 return make_input(idx.shape) 8003 return None 8004 8005 offsets = torch.tensor([0, 3], device=device, dtype=torch.long) 8006 for generate_per_sample_weight in (True, False): 8007 for mode in ('sum', 'mean', 'max'): 8008 # per_sample_weights is only supported for mode='sum' (got mode='****') 8009 if generate_per_sample_weight and mode in ('mean', 'max'): 8010 continue 8011 8012 # 1-D index tensor 8013 idx = make_long_input((S,), low=0, high=M) 8014 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8015 yield SampleInput(make_input((M, S)), args=(idx,), 8016 kwargs={'offsets': offsets, 'mode': mode, 8017 'per_sample_weights': per_sample_weights}) 8018 8019 idx = make_long_input((S,), low=0, high=M, noncontiguous=True) 8020 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8021 yield SampleInput(make_input((M, S)), args=(idx,), 8022 kwargs={'offsets': offsets, 'mode': mode, 8023 'per_sample_weights': per_sample_weights}) 8024 8025 # bag with zero length 8026 idx = make_long_input((S,), low=0, high=M, noncontiguous=True) 8027 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8028 yield SampleInput(make_input((M, S)), args=(idx,), 8029 kwargs={'offsets': torch.tensor([0, 0, 3], device=device, dtype=torch.long), 8030 'mode': mode, 8031 'per_sample_weights': per_sample_weights}) 8032 8033 # 2-D index tensor 8034 idx = make_long_input((S, S), low=0, high=M) 8035 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8036 yield SampleInput(make_input((M, S)), args=(idx,), 8037 kwargs={'mode': mode, 'per_sample_weights': per_sample_weights}) 8038 8039 idx = make_long_input((S, S), low=0, high=M, noncontiguous=True) 8040 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8041 yield SampleInput(make_input((M, S)), args=(idx,), 8042 kwargs={'mode': mode, 'per_sample_weights': per_sample_weights}) 8043 8044 # The gradient vector at `padding_idx` is not updated. 8045 # Negative padding_idx 8046 idx = make_long_input((6,), low=0, high=S) 8047 idx[0] = 4 8048 idx[4] = 4 8049 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8050 yield SampleInput(make_input((S, S)), args=(idx,), 8051 kwargs={'padding_idx': -1, 'offsets': offsets, 8052 'mode': mode, 'per_sample_weights': per_sample_weights},) 8053 8054 idx = make_long_input((3, 3), low=0, high=S) 8055 # Positive padding_idx 8056 idx[0, 0] = 2 8057 idx[1, 1] = 2 8058 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8059 yield SampleInput(make_input((S, S)), args=(idx,), 8060 kwargs={'padding_idx': 2, 'mode': mode, 8061 'per_sample_weights': per_sample_weights},) 8062 8063 idx = make_long_input((6, ), low=0, high=S) 8064 weights = make_input((S, S)) 8065 offsets_ = torch.tensor([0, 3, 6], device=device, dtype=torch.long) 8066 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8067 yield SampleInput(weights, args=(idx,), 8068 kwargs={'mode': mode, 'offsets': offsets_, 'include_last_offset': True},) 8069 8070 if not requires_grad: 8071 # Following inputs return different gradient from the numerical gradient. 8072 # This is expected and relevant tests are present in `test_nn.py`. 8073 8074 # Due to inplace renorming of weight, the numerical gradient doesn't match the 8075 # analytical gradient. 8076 idx = make_long_input((2, 2), low=0, high=S) 8077 weights = make_input((S, S)) * 2 8078 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8079 yield SampleInput(weights, args=(idx,), 8080 kwargs={'max_norm': 1., 'mode': mode, 8081 'per_sample_weights': per_sample_weights},) 8082 8083 idx = make_long_input((6, ), low=0, high=S) 8084 weights = make_input((S, S)) * 2 8085 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8086 yield SampleInput(weights, args=(idx,), 8087 kwargs={'max_norm': 1., 'norm_type': 1.0, 8088 'mode': mode, 'offsets': offsets, 8089 'per_sample_weights': per_sample_weights},) 8090 8091 if mode != 'max': 8092 # Scale the gradient based on the inverse frequency of a particular index. 8093 # Note : smax mode does not support sparse weights 8094 idx = make_long_input((2, 2), low=0, high=S) 8095 idx[0, 0] = 1 8096 idx[0, 1] = 1 8097 weights = make_input((S, S)) 8098 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8099 yield SampleInput(weights, args=(idx,), 8100 kwargs={'scale_grad_by_freq': True, 'mode': mode, 8101 'per_sample_weights': per_sample_weights},) 8102 8103 # gradcheck not implemented for sparse tensors. 8104 # Note : max mode does not support sparse weights 8105 idx = make_long_input((6, ), low=0, high=S) 8106 weights = make_input((S, S)) 8107 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8108 yield SampleInput(weights, args=(idx,), 8109 kwargs={'sparse': True, 'offsets': offsets, 8110 'mode': mode, 'per_sample_weights': per_sample_weights}) 8111 8112 idx = make_long_input((6, ), low=0, high=S) 8113 idx[0] = 1 # freq more than 1 8114 idx[1] = 1 # freq more than 1 8115 idx[3] = 0 # padding_idx 8116 weights = make_input((S, S)) * 2 8117 per_sample_weights = make_per_sample_weight(generate_per_sample_weight, idx) 8118 yield SampleInput(weights, args=(idx,), 8119 kwargs={'sparse': True, 'scale_grad_by_freq': True, 'padding_idx': 0, 8120 'max_norm': 1., 'offsets': offsets, 8121 'mode': mode, 'per_sample_weights': per_sample_weights}) 8122 8123 8124def sample_inputs_embedding(op_info, device, dtype, requires_grad, **kwargs): 8125 def make_input(shape): 8126 return make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) 8127 8128 def make_long_input(shape, *, low, high): 8129 return make_tensor(shape, device=device, dtype=torch.long, low=low, high=high) 8130 8131 # 0-D index tensor 8132 idx = make_long_input((), low=0, high=M) 8133 yield SampleInput(make_input((M, S)), args=(idx,),) 8134 8135 # 1-D index tensor 8136 idx = make_long_input((S,), low=0, high=M) 8137 yield SampleInput(make_input((M, S)), args=(idx,),) 8138 8139 # 2-D index tensor 8140 idx = make_long_input((S, S), low=0, high=M) 8141 yield SampleInput(make_input((M, S)), args=(idx,),) 8142 8143 if not requires_grad: 8144 # Following inputs return different gradient from the numerical gradient. 8145 # This is expected and relevant tests are present in `test_nn.py`. 8146 8147 # The gradient vector at `padding_idx` is not updated. 8148 idx = make_long_input((2, 2), low=0, high=S) 8149 idx[0, 0] = 2 8150 idx[1, 1] = 2 8151 yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': 2},) 8152 8153 idx = make_long_input((2, 2), low=0, high=S) 8154 idx[0, 0] = 4 8155 idx[1, 1] = 4 8156 yield SampleInput(make_input((S, S)), args=(idx,), kwargs={'padding_idx': -1},) 8157 8158 # Due to inplace renorming of weight, the numerical gradient doesn't match the 8159 # analytical gradient. 8160 idx = make_long_input((2, 2), low=0, high=S) 8161 weights = make_input((S, S)) * 2 8162 yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1.},) 8163 8164 idx = make_long_input((2, 2), low=0, high=S) 8165 weights = make_input((S, S)) * 2 8166 yield SampleInput(weights, args=(idx,), kwargs={'max_norm': 1., 'norm_type': 1.0},) 8167 8168 # Scale the gradient based on the inverse frequency of a particular index. 8169 idx = make_long_input((2, 2), low=0, high=S) 8170 idx[0, 0] = 1 8171 idx[0, 1] = 1 8172 weights = make_input((S, S)) 8173 yield SampleInput(weights, args=(idx,), kwargs={'scale_grad_by_freq': True},) 8174 8175 # gradcheck not implemented for sparse tensors. 8176 idx = make_long_input((2, 2), low=0, high=S) 8177 weights = make_input((S, S)) 8178 yield SampleInput(weights, args=(idx,), kwargs={'sparse': True}) 8179 8180 idx = make_long_input((3, 3), low=0, high=S) 8181 idx[0, 0] = 1 # freq more than 1 8182 idx[0, 1] = 1 # freq more than 1 8183 idx[1, 0] = 0 # padding_idx 8184 weights = make_input((S, S)) * 2 8185 yield SampleInput(weights, args=(idx,), 8186 kwargs={'sparse': True, 'scale_grad_by_freq': True, 8187 'padding_idx': 0, 'max_norm': 1.}) 8188 8189 8190def sample_inputs_one_hot(op_info, device, dtype, requires_grad, **kwargs): 8191 def make_input(shape, *, low, high): 8192 return make_tensor(shape, device=device, dtype=dtype, low=low, high=high, requires_grad=requires_grad) 8193 8194 shapes = ((), (S,), (L, M, S)) 8195 num_classess = (-1, 10) 8196 8197 return ( 8198 SampleInput( 8199 make_input( 8200 shape, 8201 low=0, 8202 high=10 if num_classes == -1 else num_classes // 2, 8203 ), 8204 kwargs=dict(num_classes=num_classes), 8205 ) 8206 for shape, num_classes in itertools.product(shapes, num_classess) 8207 ) 8208 8209 8210def sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs): 8211 rhs_requires_grad = kwargs.get('rhs_requires_grad', requires_grad) 8212 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8213 8214 # Although most losses also support the reduce and size_average combination instead of reduce, the former is 8215 # deprecated since 0.4.1 and thus is not tested 8216 shapes_and_kwargs = ( 8217 ((), None), 8218 ((S,), dict(reduction="mean")), 8219 ((S,), dict(reduction="sum")), 8220 ((S,), dict(reduction="none")), 8221 ((S, S), None), 8222 ((S, S, S), None), 8223 ) 8224 8225 for shape, kwargs in shapes_and_kwargs: 8226 yield SampleInput(_make_tensor(shape), 8227 args=(_make_tensor(shape, requires_grad=rhs_requires_grad),), 8228 kwargs=kwargs) 8229 8230def sample_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs): 8231 # We get better tests if we change the range of the values to something like [-2,2] 8232 # because for grid (second tensor argument) the "useful" range is [-1,1] and this way 8233 # you get a better combination of out-of-range and in-range test cases 8234 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, 8235 low=-2, high=2) 8236 8237 batch_size = 2 8238 num_channels = 3 8239 modes = ("bilinear", "nearest") 8240 align_cornerss = (False, True) 8241 padding_modes = ("zeros", "border", "reflection") 8242 8243 for dim in (2, 3): 8244 8245 modes_ = (*modes, "bicubic") if dim == 2 else modes 8246 8247 for mode, padding_mode, align_corners in itertools.product(modes_, padding_modes, align_cornerss): 8248 yield SampleInput( 8249 _make_tensor((batch_size, num_channels, *[S] * dim)), 8250 _make_tensor((batch_size, *[S] * dim, dim)), 8251 mode=mode, 8252 padding_mode=padding_mode, 8253 align_corners=align_corners, 8254 ) 8255 8256def reference_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs): 8257 8258 batch_size = 2 8259 num_channels = 3 8260 height = 345 8261 width = 456 8262 modes = ("bilinear", "nearest", "bicubic") 8263 align_cornerss = (False, True) 8264 padding_modes = ('zeros', 'border', 'reflection') 8265 8266 # Create an affine transformation matrix 8267 a = torch.deg2rad(torch.tensor(45.0)) 8268 ca, sa = torch.cos(a), torch.sin(a) # rotation angles 8269 s1, s2 = 1.23, 1.34 # scales 8270 8271 theta = torch.tensor([[ 8272 [ca / s1, sa, 0.0], 8273 [-sa, ca / s2, 0.0], 8274 ]], dtype=dtype, device=device) 8275 theta = theta.expand(batch_size, 2, 3).contiguous() 8276 8277 x = torch.arange(batch_size * num_channels * height * width, device=device) 8278 x = x.reshape(batch_size, num_channels, height, width).to(torch.uint8) 8279 x = x.to(dtype=dtype) 8280 x.requires_grad_(requires_grad) 8281 8282 for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss): 8283 grid = torch.nn.functional.affine_grid( 8284 theta, size=(batch_size, num_channels, height, width), align_corners=align_corners 8285 ) 8286 yield SampleInput( 8287 x, 8288 grid, 8289 mode, 8290 padding_mode, 8291 align_corners, 8292 ) 8293 8294def sample_inputs_grid_sampler_2d(op_info, device, dtype, requires_grad, **kwargs): 8295 # We get better tests if we change the range of the values to something like [-2,2] 8296 # because for grid (second tensor argument) the "useful" range is [-1,1] and this way 8297 # you get a better combination of out-of-range and in-range test cases 8298 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, 8299 low=-2, high=2) 8300 8301 batch_size = 2 8302 num_channels = 3 8303 modes = (0, 1, 2) 8304 align_cornerss = (False, True) 8305 padding_modes = (0, 1, 2) 8306 8307 for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss): 8308 yield SampleInput( 8309 _make_tensor((batch_size, num_channels, S, L)), 8310 _make_tensor((batch_size, M + 3, M, 2)), 8311 mode, 8312 padding_mode, 8313 align_corners, 8314 ) 8315 8316def sample_inputs_cosine_embedding_loss(op_info, device, dtype, requires_grad, **kwargs): 8317 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8318 8319 def make_target(shape): 8320 shape = () if len(shape) == 1 else (shape[0], ) 8321 t = torch.randint(0, 2, shape, device=device, dtype=torch.long) 8322 # Label with -1 or 1 8323 t = t * 2 - 1 8324 target = t.to(dtype=dtype).detach_().requires_grad_(requires_grad) 8325 return target 8326 8327 shapes = ((S, S), (S,)) 8328 reductions = ('none', 'mean', 'sum') 8329 for s, r in product(shapes, reductions): 8330 yield SampleInput( 8331 make_input(s), 8332 args=(make_input(s), make_target(s)), 8333 kwargs=dict(reduction=r, margin=random.uniform(-1, 1)) 8334 ) 8335 8336def sample_inputs_ctc_loss(op_info, device, dtype, requires_grad, **kwargs): 8337 input_length = 50 8338 batch = 16 8339 num_char = 20 8340 target_length = 30 8341 8342 def make_log_probs(s): 8343 t = make_tensor(s, device=device, dtype=dtype) 8344 log_probs = t.log_softmax(2).to(device=device, dtype=dtype).detach().requires_grad_(requires_grad=requires_grad) 8345 return log_probs 8346 8347 reductions = ('none', 'mean', 'sum') 8348 zero_inf = (True, False) 8349 lengths_type = (list, torch.Tensor) 8350 for r, z, lt in product(reductions, zero_inf, lengths_type): 8351 log_probs = make_log_probs((input_length, batch, num_char)) 8352 targets = torch.randint(1, num_char, (batch, target_length), dtype=torch.long, device=device) 8353 input_lengths = torch.full((batch, ), input_length, dtype=torch.long, device=device) 8354 target_lengths = torch.randint(10, target_length, (batch, ), dtype=torch.long, device=device) 8355 8356 # Dont generate int[] types if reduction = "Mean" since this results in non composite compliant calls 8357 # to ctc_loss.IntList since a tensor needs to be created from the target lengths. 8358 # Creating such a tensor requires the use of pointers to copy data from int[] -> torch.Tensor 8359 # e.g. via std::copy. Similarly symbolic/real tracing with fx will also not work 8360 if lt is list and r in ["none", "sum"]: 8361 input_lengths = input_lengths.tolist() 8362 target_lengths = target_lengths.tolist() 8363 8364 yield SampleInput(log_probs, args=(targets, input_lengths, target_lengths,), kwargs=dict(reduction=r, zero_infinity=z)) 8365 8366def sample_inputs_nll_loss(op_info, device, dtype, requires_grad, **kwargs): 8367 shape = (2, 3) 8368 num_classes = shape[1] 8369 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8370 # FIXME: Derivative wrt. weight not implemented 8371 make_weight = partial(make_tensor, num_classes, device=device, dtype=dtype, requires_grad=False) 8372 8373 def make_target(shape, zeros=False): 8374 s = (shape[0], *shape[2:]) if len(shape) > 1 else () 8375 if zeros: 8376 return torch.zeros(s, device=device, dtype=torch.long) 8377 else: 8378 return make_tensor(s, 8379 low=0, 8380 high=shape[1] if len(shape) > 1 else shape[0], 8381 device=device, 8382 dtype=torch.long) 8383 8384 8385 def gen_shape_kwargs(): 8386 # Batched, non-batched and 2d 8387 shapes = (shape, (num_classes,), shape + (2, 2)) 8388 reductions = ('none', 'mean', 'sum') 8389 for reduction, s in product(reductions, shapes): 8390 yield make_input(s), make_target(s), dict(reduction=reduction) 8391 yield make_input(s), make_target(s), dict(weight=make_weight(), reduction=reduction) 8392 yield make_input(s), make_target(s), dict(weight=make_weight(low=0), reduction=reduction) 8393 yield make_input(s), make_target(s), dict(weight=make_weight(high=0), reduction=reduction) 8394 t = make_target(s) 8395 ignore = num_classes // 2 8396 # If "mean", nll returns NaN, so it's not differentiable at those points 8397 if t.eq(ignore).all() and reduction == "mean": 8398 t.fill_(0) 8399 yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction) 8400 yield make_input(s), t, dict(ignore_index=num_classes // 2, reduction=reduction, weight=make_weight()) 8401 # Test ignoring all the targets 8402 # If "mean", nll returns NaN, so it's not differentiable at those points 8403 if reduction != "mean": 8404 yield make_input(s), make_target(s, zeros=True), dict(ignore_index=0, reduction=reduction) 8405 8406 for input, target, kwargs in gen_shape_kwargs(): 8407 yield SampleInput(input, args=(target,), kwargs=kwargs) 8408 8409 target = torch.tensor([-1, 2], device=device, dtype=torch.long) 8410 yield SampleInput(make_input(shape), args=(target,), kwargs={'ignore_index': -1}) 8411 8412 8413def sample_inputs_binary_cross_entropy_with_logits( 8414 op_info, device, dtype, requires_grad, **kwargs 8415): 8416 make = partial(make_tensor, device=device, dtype=dtype) 8417 make_prob = partial(make, low=0, high=1) 8418 reductions = ("mean", "sum", "none") 8419 8420 def make_weight_shape_kwargs(): 8421 kwargs = [] 8422 for shape in ((1,), (1, S), (S), (S, S)): 8423 kwargs.extend([((S, S), dict(reduction=reduction, weight=make(shape))) for reduction in reductions]) 8424 return kwargs 8425 8426 shapes_and_kwargs = [ 8427 *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))], 8428 *[((S, S), dict(reduction=reduction)) for reduction in reductions], 8429 *make_weight_shape_kwargs(), 8430 *[((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions], 8431 *[((S, S), dict(reduction=reduction, weight=make((S, S)), pos_weight=make((S,), low=0))) for reduction in reductions], 8432 ] 8433 8434 for shape, kwargs in shapes_and_kwargs: 8435 yield SampleInput( 8436 make(shape, requires_grad=requires_grad), 8437 args=(make_prob(shape, requires_grad=requires_grad),), 8438 kwargs=kwargs, 8439 ) 8440 8441def sample_inputs_argwhere(op_info, device, dtype, requires_grad, **kwargs): 8442 yield SampleInput(torch.tensor([1, 0, 2, 0], dtype=dtype, device=device, requires_grad=requires_grad)) 8443 mask = torch.tensor([[0, 1, 0, 1, 0], 8444 [1, 1, 1, 1, 0], 8445 [0, 0, 0, 1, 0], 8446 [1, 0, 1, 1, 0], 8447 [1, 0, 0, 1, 0]], dtype=torch.bool, device=device) 8448 t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad) 8449 t[mask] = 0 8450 yield SampleInput(t) 8451 8452 t = make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad, noncontiguous=True) 8453 t[mask] = 0 8454 yield SampleInput(t) 8455 8456 t = make_tensor((S, 0), dtype=dtype, device=device, requires_grad=requires_grad) 8457 yield SampleInput(t) 8458 8459 yield SampleInput(torch.zeros((S,), dtype=dtype, device=device, requires_grad=requires_grad)) 8460 yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) 8461 8462def _generate_sample_shape_reduction(): 8463 shapes = ((S,), (S, S), (S, S, S)) 8464 reductions = ('none', 'mean', 'sum') 8465 yield from product(shapes, reductions) 8466 8467def sample_inputs_gaussian_nll_loss(op_info, device, dtype, requires_grad, **kwargs): 8468 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8469 # Set low slightly above 0 so gradcheck doesn't accidentally dip below 0 8470 make_var = partial(make_tensor, low=0.1, device=device, dtype=dtype, requires_grad=requires_grad) 8471 8472 def gen_shape(shape): 8473 yield shape 8474 # Broadcast 8475 yield (*shape[:-1], 1) 8476 yield shape[:-1] 8477 8478 def gen_shape_kwargs(): 8479 for s, r in _generate_sample_shape_reduction(): 8480 for t_s, v_s in product(gen_shape(s), gen_shape(s)): 8481 yield _make_tensor(s), _make_tensor(t_s), make_var(v_s), dict(reduction=r) 8482 yield ( 8483 _make_tensor(s), _make_tensor(t_s), make_var(v_s), 8484 dict(full=True, reduction=r) 8485 ) 8486 yield ( 8487 _make_tensor(s), _make_tensor(t_s), make_var(v_s), 8488 dict(eps=random.uniform(1e-6, 1e-3), reduction=r) 8489 ) 8490 yield ( 8491 _make_tensor(s), _make_tensor(t_s), make_var(v_s), 8492 dict(full=True, eps=random.uniform(1e-6, 1e-3), reduction=r) 8493 ) 8494 8495 for input, target, var, kwargs in gen_shape_kwargs(): 8496 yield SampleInput(input, args=(target, var, ), kwargs=kwargs) 8497 8498def error_inputs_gaussian_nll_loss(op_info, device, **kwargs): 8499 _make = partial(make_tensor, device=device, dtype=torch.float32) 8500 8501 # invalid reduction value 8502 yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 3), low=0), reduction="abc"), 8503 error_type=ValueError, error_regex="abc is not valid") 8504 8505 # var is of incorrect shape 8506 yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 3), _make((10, 2, 2), low=0)), 8507 error_type=ValueError, error_regex="var is of incorrect size") 8508 8509 # target is of incorrect shape 8510 yield ErrorInput(SampleInput(_make(10, 2, 3), _make(10, 2, 2), _make((10, 2, 3), low=0)), 8511 error_type=RuntimeError, 8512 error_regex=(r"The size of tensor a \(3\) must match the size of tensor b \(2\) " 8513 r"at non-singleton dimension 2")) 8514 8515def _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): 8516 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8517 8518 for s, r in _generate_sample_shape_reduction(): 8519 yield _make_tensor(s), _make_tensor(s), dict(reduction=r) 8520 8521def sample_inputs_hinge_embedding_loss(op_info, device, dtype, requires_grad, **kwargs): 8522 for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): 8523 # target should contain either 1 or -1 as per docs 8524 mask = torch.rand_like(target) > 0.5 8525 target[mask] = 1 8526 target[~mask] = -1 8527 d['margin'] = random.uniform(-9, 9) 8528 yield SampleInput(input, args=(target, ), kwargs=d) 8529 8530 # scalar input and target. 8531 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8532 yield SampleInput(_make_tensor(()), args=(_make_tensor(()), )) 8533 8534def error_inputs_hinge_embedding_loss(op, device, **kwargs): 8535 make_input = partial(make_tensor, device=device, dtype=torch.float32) 8536 # invalid reduction value 8537 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), 8538 error_type=ValueError, error_regex='is not a valid value') 8539 8540def reference_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs): 8541 yield from sample_inputs_hinge_embedding_loss(op, device, dtype, requires_grad, **kwargs) 8542 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8543 8544 for reduction in ('sum', 'mean', 'none'): 8545 if dtype.is_floating_point: # only supports ints and floats 8546 # NaN propagation 8547 inp = make_input((10, )) 8548 inp[2] = float('nan') 8549 target = make_input((10, )) 8550 # target should contain either 1 or -1 as per docs 8551 mask = torch.rand_like(target) > 0.5 8552 target[mask] = -1 8553 target[~mask] = 1 8554 yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) 8555 8556 # Inf Handling 8557 inp = make_input((10, )) 8558 inp[4] = float('inf') 8559 target = make_input((10, )) 8560 mask = torch.rand_like(target) > 0.5 8561 target[mask] = -1 8562 target[~mask] = 1 8563 yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) 8564 8565 # Broadcasting 8566 inp = make_input((5, 5)) 8567 target = make_input((1, 5)) 8568 mask = torch.rand_like(target) > 0.5 8569 target[mask] = -1 8570 target[~mask] = 1 8571 yield SampleInput(inp, args=(target,), kwargs={'reduction': reduction}) 8572 8573def sample_inputs_huber_loss(op_info, device, dtype, requires_grad, **kwargs): 8574 for input, target, d in _generate_sample_inputs_nn_loss(op_info, device, dtype, requires_grad, **kwargs): 8575 d['delta'] = random.uniform(1e-3, 9) 8576 yield SampleInput(input, args=(target, ), kwargs=d) 8577 8578def error_inputs_huber_loss(op, device, **kwargs): 8579 make_input = partial(make_tensor, device=device, dtype=torch.float32) 8580 # invalid reduction value 8581 err = 'is not a valid value for reduction' 8582 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}), 8583 error_type=ValueError, error_regex=err) 8584 # delta <= 0 8585 for delta in (0, -1): 8586 err = 'huber_loss does not support non-positive values for delta.' 8587 yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'delta': delta}), 8588 error_type=RuntimeError, error_regex=err) 8589 8590def sample_inputs_poisson_nll_loss(op_info, device, dtype, requires_grad, **kwargs): 8591 _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8592 8593 def gen_shape_kwargs(): 8594 for s, r in _generate_sample_shape_reduction(): 8595 for li in (True, False): 8596 for f in (True, False): 8597 i1 = _make_tensor(s) 8598 i2 = _make_tensor(s) 8599 # For Poisson NLL Loss, 8600 # target is assumed to be from 8601 # Poisson Distribution which 8602 # always has positive samples 8603 t1 = _make_tensor(s, low=0) 8604 t2 = _make_tensor(s, low=0) 8605 8606 if not li: 8607 i1.abs_() 8608 i2.abs_() 8609 t1.abs_() 8610 t2.abs_() 8611 8612 yield ( 8613 i1, t1, 8614 dict(log_input=li, full=f, reduction=r) 8615 ) 8616 yield ( 8617 i2, t2, 8618 dict(log_input=li, full=f, 8619 eps=random.uniform(1e-8, 1e-3), 8620 reduction=r) 8621 ) 8622 8623 for input, target, kwargs in gen_shape_kwargs(): 8624 yield SampleInput(input, args=(target, ), kwargs=kwargs) 8625 8626 # test INT_TO_FLOAT promotion 8627 if dtype.is_complex: 8628 for d in (torch.bool, torch.int64): 8629 yield SampleInput(_make_tensor(dtype=dtype), args=(_make_tensor(dtype=d),)) 8630 yield SampleInput(_make_tensor(dtype=d), args=(_make_tensor(dtype=dtype),)) 8631 8632def error_inputs_poisson_nll_loss(op_info, device, **kwargs): 8633 make = partial(make_tensor, device=device, dtype=torch.float32) 8634 8635 # invalid reduction value 8636 yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), 8637 kwargs={'reduction': 'abc'}), 8638 error_type=ValueError, 8639 error_regex='abc is not a valid value for reduction') 8640 # invalid input shapes 8641 yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), 8642 error_regex=(r'(Attempting to broadcast a dimension of length|' 8643 r'The size of tensor a \(5\) must match the ' 8644 r'size of tensor b \(4\) at non-singleton ' 8645 r'dimension 1)')) 8646 8647def error_inputs_soft_margin_loss(op_info, device, **kwargs): 8648 make = partial(make_tensor, device=device, dtype=torch.float32) 8649 8650 # invalid reduction value 8651 yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), 8652 kwargs={'reduction': 'abc'}), 8653 error_type=ValueError, 8654 error_regex='abc is not a valid value for reduction') 8655 # invalid input shapes 8656 yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), 8657 error_regex=(r'(Attempting to broadcast a dimension of length|' 8658 r'The size of tensor a \(4\) must match the ' 8659 r'size of tensor b \(5\) at non-singleton ' 8660 r'dimension 1)')) 8661 8662def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, with_distance=False, **kwargs): 8663 make = partial(make_tensor, (S, M), device=device, dtype=dtype, requires_grad=requires_grad) 8664 8665 kwargss = ( 8666 *[dict(margin=margin) for margin in (1e-6, 1.0, 10.0)], 8667 dict(swap=True), 8668 *[dict(reduction=reduction) for reduction in ("mean", "sum", "none")], 8669 ) 8670 8671 for kwargs in kwargss: 8672 input = make() 8673 args = (make(), make()) 8674 if with_distance: 8675 kwargs["distance_function"] = torch.nn.PairwiseDistance() 8676 yield SampleInput(input, args=args, kwargs=kwargs) 8677 8678def error_inputs_triplet_margin_loss(op_info, device, **kwargs): 8679 make_input = partial(make_tensor, device=device, dtype=torch.float32) 8680 8681 samples = ( 8682 # input, args, kwargs, error_type, error_regex 8683 # invalid reduction 8684 (make_input(3, 4), (make_input(3, 4), make_input(3, 4)), 8685 dict(reduction="abc"), 8686 ValueError, "abc is not a valid value for reduction"), 8687 8688 # invalid margin 8689 (make_input(3, 4), (make_input(3, 4), make_input(3, 4)), 8690 dict(margin=-1.0), 8691 ValueError, "margin must be greater than 0, got -1.0"), 8692 8693 # shape mismatch 8694 (make_input(3, 5), (make_input(3, 4), make_input(3, 4)), 8695 {}, 8696 RuntimeError, 8697 (r'(Attempting to broadcast a dimension of length|' 8698 r"The size of tensor a \(5\) must match the size of tensor b \(4\) " 8699 r"at non-singleton dimension 1)")), 8700 (make_input(3, 4), (make_input(3, 5), make_input(3, 4)), 8701 {}, 8702 RuntimeError, 8703 (r'(Attempting to broadcast a dimension of length|' 8704 r"The size of tensor a \(4\) must match the size of tensor b \(5\) " 8705 r"at non-singleton dimension 1)")), 8706 (make_input(3, 4), (make_input(3, 4), make_input(3, 5)), 8707 {}, 8708 RuntimeError, 8709 (r'(Attempting to broadcast a dimension of length|' 8710 r"The size of tensor a \(4\) must match the size of tensor b \(5\) " 8711 r"at non-singleton dimension 1)")), 8712 8713 # different dimensions 8714 (make_input(3,), (make_input(3, 4), make_input(3, 4)), 8715 {}, 8716 RuntimeError, 8717 (r"The anchor, positive, and negative tensors are expected to have " 8718 r"the same number of dimensions, but got: anchor 1D, positive 2D, " 8719 r"and negative 2D inputs")), 8720 (make_input(3, 4), (make_input(3,), make_input(3, 4)), 8721 {}, 8722 RuntimeError, 8723 (r"The anchor, positive, and negative tensors are expected to have " 8724 r"the same number of dimensions, but got: anchor 2D, positive 1D, " 8725 r"and negative 2D inputs")), 8726 (make_input(3, 4), (make_input(3, 4), make_input(3,)), 8727 {}, 8728 RuntimeError, 8729 (r"The anchor, positive, and negative tensors are expected to have " 8730 r"the same number of dimensions, but got: anchor 2D, positive 2D, " 8731 r"and negative 1D inputs")), 8732 ) 8733 8734 for input, args, kwargs, error_type, error_regex in samples: 8735 yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs), 8736 error_type=error_type, error_regex=error_regex) 8737 8738def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs): 8739 make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad) 8740 make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad) 8741 make_scale = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) 8742 M, N, K = 15, 32, 16 8743 samples = [] 8744 # two e4m3 8745 mat1 = make_mat_e4m3((M, K)) 8746 mat2 = make_mat_e4m3((K, N)).t().contiguous().t() 8747 scale1 = make_scale((1,)) 8748 scale2 = make_scale((1,)) 8749 samples.append(SampleInput(mat1, mat2, scale1, scale2)) 8750 # mat1 e4m3 mat2 e5m2 8751 mat1 = make_mat_e4m3((M, K)) 8752 mat2 = make_mat_e5m2((K, N)).t().contiguous().t() 8753 scale1 = make_scale((1,)) 8754 scale2 = make_scale((1,)) 8755 samples.append(SampleInput(mat1, mat2, scale1, scale2)) 8756 # mat1 e5m2 mat2 e4m3 8757 mat1 = make_mat_e5m2((M, K)) 8758 mat2 = make_mat_e4m3((K, N)).t().contiguous().t() 8759 scale1 = make_scale((1,)) 8760 scale2 = make_scale((1,)) 8761 samples.append(SampleInput(mat1, mat2, scale1, scale2)) 8762 8763 yield from samples 8764 8765def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs): 8766 make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8767 batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8 8768 num_heads_q_gqa, num_heads_kv_gqa = 32, 8 8769 8770 dim_3_q_shape = (batch, seq_q, head_dim) 8771 dim_3_kv_shape = (batch, seq_kv, head_dim) 8772 dim_4_q_shape = (batch, num_heads, seq_q, head_dim) 8773 dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) 8774 8775 broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim)) 8776 8777 qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple] 8778 samples = [] 8779 gqa_options = [False] if TEST_WITH_ROCM else [True, False] # TODO: GQA support 8780 if TEST_WITH_ROCM and dtype == torch.float32: 8781 causal_options = [False] # FIXME: Large errors with causal+fp32 8782 else: 8783 causal_options = [True, False] 8784 for qkv_shape, is_causal, dropout_p, enable_gqa in product( 8785 qkv_shapes, causal_options, [0.0, 0.5], gqa_options): 8786 shape_q, shape_kv = qkv_shape 8787 samples.append(SampleInput( 8788 make(shape_q), 8789 make(shape_kv), 8790 make(shape_kv), 8791 is_causal=is_causal, 8792 dropout_p=dropout_p 8793 )) 8794 8795 # Add non standard shapes 8796 diff_v_head_dim = SampleInput( 8797 make((batch, num_heads, seq_q, head_dim)), 8798 make((batch, num_heads, seq_kv, head_dim)), 8799 make((batch, num_heads, seq_kv, head_dim + 8)), 8800 is_causal=is_causal, 8801 dropout_p=dropout_p 8802 ) 8803 8804 # Add an attn_mask 8805 samples.append( 8806 SampleInput( 8807 make((batch, num_heads, seq_q, head_dim)), 8808 make((batch, num_heads, seq_kv, head_dim)), 8809 make((batch, num_heads, seq_kv, head_dim)), 8810 attn_mask=make((seq_q, seq_kv)), 8811 is_causal=False, 8812 dropout_p=0.0) 8813 ) 8814 8815 if not TEST_WITH_ROCM: 8816 samples.append( 8817 SampleInput( 8818 make((batch, num_heads_q_gqa, seq_q, head_dim)), 8819 make((batch, num_heads_kv_gqa, seq_kv, head_dim)), 8820 make((batch, num_heads_kv_gqa, seq_kv, head_dim)), 8821 enable_gqa=True 8822 ) 8823 ) 8824 8825 yield from samples 8826 8827 8828def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_grad, **kwargs): 8829 make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8830 batch, num_heads, head_dim = 4, 4, 8 8831 seq_q = 11 8832 seq_kv = 32 8833 8834 dim_4_q_shape = (batch, num_heads, seq_q, head_dim) 8835 dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) 8836 8837 qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] 8838 samples = [] 8839 mask_types = [1, 2] # UpperLeft, LowerRight 8840 scales = [None, 1.0] 8841 8842 for qkv_shape, is_causal, dropout_p, mask_type, scale in product( 8843 qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales): 8844 shape_q, shape_kv = qkv_shape 8845 samples.append(SampleInput( 8846 make(shape_q).transpose(1, 2), 8847 make(shape_kv).transpose(1, 2), 8848 make(shape_kv).transpose(1, 2), 8849 bias=None, 8850 cu_seqlens_q=None, 8851 cu_seqlens_k=None, 8852 max_seqlen_q=None, 8853 max_seqlen_k=None, 8854 dropout_p=dropout_p, 8855 custom_mask_type=mask_type, 8856 compute_log_sumexp=requires_grad, 8857 scale=scale, 8858 seqlen_k=None 8859 )) 8860 8861 # Add non standard shapes 8862 diff_v_head_dim = SampleInput( 8863 make((batch, seq_q, num_heads, head_dim)), 8864 make((batch, seq_kv, num_heads, head_dim)), 8865 make((batch, seq_kv, num_heads, head_dim + 8)), 8866 bias=None, 8867 cu_seqlens_q=None, 8868 cu_seqlens_k=None, 8869 max_seqlen_q=None, 8870 max_seqlen_k=None, 8871 dropout_p=dropout_p, 8872 custom_mask_type=0, # No Mask 8873 compute_log_sumexp=requires_grad, 8874 scale=None, 8875 seqlen_k=None 8876 ) 8877 8878 # Add an attn_mask 8879 samples.append( 8880 SampleInput( 8881 make((batch, seq_q, num_heads, head_dim)), 8882 make((batch, seq_kv, num_heads, head_dim)), 8883 make((batch, seq_kv, num_heads, head_dim)), 8884 bias=make(batch, num_heads, seq_q, seq_kv), 8885 cu_seqlens_q=None, 8886 cu_seqlens_k=None, 8887 max_seqlen_q=None, 8888 max_seqlen_k=None, 8889 dropout_p=dropout_p, 8890 custom_mask_type=0, # No Mask 8891 compute_log_sumexp=requires_grad, 8892 scale=None, 8893 seqlen_k=None 8894 ) 8895 ) 8896 8897 # jagged (with query/keys offsets) 8898 cu_seqlens_k = torch.arange(-1, 32 * 2 + 1, 2, dtype=torch.int32, device=device) 8899 cu_seqlens_k[-1] = 62 8900 cu_seqlens_k[0] = 0 8901 samples.append( 8902 SampleInput( 8903 make((32, 2, 64)).view(-1, 8, 8).unsqueeze(0), 8904 make((64, 64)).view(-1, 8, 8).unsqueeze(0), 8905 make((64, 64)).view(-1, 8, 8).unsqueeze(0), 8906 bias=None, 8907 cu_seqlens_q=torch.arange(0, 32 * 2 + 2, 2, dtype=torch.int32, device=device), 8908 cu_seqlens_k=cu_seqlens_k, 8909 max_seqlen_q=2, 8910 max_seqlen_k=2, 8911 dropout_p=0.0, 8912 custom_mask_type=0, # No Mask 8913 compute_log_sumexp=requires_grad, 8914 scale=None, 8915 seqlen_k=None, 8916 ) 8917 ) 8918 8919 yield from samples 8920 8921def sample_inputs_flash_attention_forward(op_info, device, dtype, requires_grad, **kwargs): 8922 make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8923 batch, num_heads, head_dim = 4, 4, 8 8924 seq_q = 11 8925 seq_kv = 32 8926 8927 dim_4_q_shape = (batch, num_heads, seq_q, head_dim) 8928 dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) 8929 8930 qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] 8931 samples = [] 8932 scales = [None, 1.0] 8933 8934 for qkv_shape, is_causal, dropout_p, scale in product( 8935 qkv_shapes, [True, False], [0.0, 0.5], scales): 8936 shape_q, shape_kv = qkv_shape 8937 samples.append(SampleInput( 8938 make(shape_q).transpose(1, 2), 8939 make(shape_kv).transpose(1, 2), 8940 make(shape_kv).transpose(1, 2), 8941 cum_seq_q=None, 8942 cum_seq_k=None, 8943 max_q=seq_q, 8944 max_k=seq_kv, 8945 dropout_p=dropout_p, 8946 is_causal=is_causal, 8947 return_debug_mask=False, 8948 scale=scale, 8949 )) 8950 8951 yield from samples 8952 8953def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs): 8954 make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8955 8956 shape = (3,) 8957 batched_shape = (2, *shape) 8958 shapes_and_kwargs = [ 8959 (shape, None), 8960 (batched_shape, None), 8961 (shape, dict(keepdim=True)), 8962 (batched_shape, dict(keepdim=True)), 8963 (shape, dict(p=5.0)), 8964 (shape, dict(p=-1.0)), 8965 (shape, dict(eps=1.0)), 8966 ] 8967 8968 return ( 8969 SampleInput(make(shape), args=(make(shape),), kwargs=kwargs) for shape, kwargs in shapes_and_kwargs 8970 ) 8971 8972def sample_inputs_pixel_shuffle(op_info, device, dtype, requires_grad, **kwargs): 8973 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8974 yield from ( 8975 SampleInput(make_arg((1, 9, 2, 2)), upscale_factor=upscale_factor) 8976 for upscale_factor in (1, 3) 8977 ) 8978 yield from ( 8979 SampleInput(make_arg(shape), upscale_factor=1) 8980 for shape in [ 8981 (1, 0, 1, 1), 8982 (1, 1, 0, 1), 8983 (1, 1, 1, 0), 8984 ] 8985 ) 8986 8987def sample_inputs_pixel_unshuffle(op_info, device, dtype, requires_grad, **kwargs): 8988 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 8989 yield from ( 8990 SampleInput(make_arg((1, 1, 6, 6)), downscale_factor=downscale_factor) 8991 for downscale_factor in (1, 3) 8992 ) 8993 yield from ( 8994 SampleInput(make_arg(shape), downscale_factor=1) 8995 for shape in [ 8996 (1, 0, 1, 1), 8997 (1, 1, 0, 1), 8998 (1, 1, 1, 0), 8999 ] 9000 ) 9001 9002def sample_inputs_channel_shuffle(op_info, device, dtype, requires_grad, **kwargs): 9003 make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 9004 9005 shapes_groups = [ 9006 ((1, 4, 10, 10), 2), 9007 ((2, 6, 8, 8), 3), 9008 ((2, 8, 5, 5), 4), 9009 ] 9010 9011 yield from ( 9012 SampleInput(make_arg(shape), args=(groups,)) 9013 for shape, groups in shapes_groups 9014 ) 9015 9016def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs): 9017 make = partial(make_tensor, device=device, dtype=dtype) 9018 # Lower bounds must be greater than 'eps' defined in gradcheck.py::gradgradcheck() -> eps 9019 # otherwise perturbation calculation causes Tensor value to become negative triggering 9020 # a device-side hardware assertion 9021 make_prob = partial(make, low=1e-6, high=1) 9022 9023 reductions = ("mean", "sum", "none") 9024 9025 shapes_and_kwargs = [ 9026 *[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))], 9027 *[((S, S), dict(reduction=reduction)) for reduction in reductions], 9028 *[((S, S), dict(reduction=reduction, weight=make((S, S)))) for reduction in reductions], 9029 ] 9030 9031 if logits: 9032 shapes_and_kwargs.extend( 9033 [((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions] 9034 ) 9035 9036 for shape, kwargs in shapes_and_kwargs: 9037 yield SampleInput( 9038 (make if logits else make_prob)(shape, requires_grad=requires_grad), 9039 args=(make_prob(shape, requires_grad=requires_grad),), 9040 kwargs=kwargs, 9041 ) 9042 9043def sample_inputs_allclose(op_info, device, dtype, requires_grad, **kwargs): 9044 sample_shapes = [(), (S), (S, S, S)] 9045 atols = [1e-2, 1e-16] 9046 rtols = [1e-1, 0.5] 9047 eps = 1e-8 9048 for s, rtol, atol in product(sample_shapes, rtols, atols): 9049 # close sample 9050 t = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) 9051 close = (t + atol).detach().requires_grad_(requires_grad) 9052 yield SampleInput(t, close, rtol=rtol, atol=atol) 9053 9054 # random sample 9055 a = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) 9056 b = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) 9057 yield SampleInput(a, b, rtol=rtol, atol=atol) 9058 9059 9060def sample_inputs_l1_loss(op_info, device, dtype, requires_grad, **kwargs): 9061 yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs) 9062 9063 # test COMPLEX_TO_FLOAT promotion 9064 if dtype.is_complex: 9065 make = partial(make_tensor, (), device=device, requires_grad=requires_grad) 9066 yield SampleInput(make(dtype=dtype), args=(make(dtype=torch.double),)) 9067 yield SampleInput(make(dtype=torch.double), args=(make(dtype=dtype),)) 9068 9069def error_inputs_l1_loss(op_info, device, **kwargs): 9070 make = partial(make_tensor, device=device, dtype=torch.float32) 9071 9072 # invalid reduction value 9073 yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),), 9074 kwargs={'reduction': 'abc'}), 9075 error_type=ValueError, 9076 error_regex='abc is not a valid value for reduction') 9077 # invalid input shapes 9078 yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)), 9079 error_regex=(r'(Attempting to broadcast a dimension of length|' 9080 r'The size of tensor a \(4\) must match the ' 9081 r'size of tensor b \(5\) at non-singleton ' 9082 r'dimension 1)') 9083 ) 9084 9085def sample_inputs_smooth_l1_loss(op_info, device, dtype, requires_grad, **kwargs): 9086 yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs) 9087 9088 make = partial(make_tensor, (S, S), device=device, dtype=dtype, requires_grad=requires_grad) 9089 9090 # This test case always triggers the smooth condition, since absolute difference of input and target 9091 # is smaller than beta 9092 yield SampleInput(make(low=0, high=2), args=(make(low=-2, high=0),), kwargs=dict(beta=5)) 9093 yield SampleInput(make(), args=(make(),), kwargs=dict(beta=0)) 9094 9095def sample_inputs_kl_div(op_info, device, dtype, requires_grad, **kwargs): 9096 # kl_div works with inputs in [0, 1] (aka the pdf of a probability measure) 9097 # Then log [0, 1] = (-inf, 0], so this is the log space 9098 make_arg = partial(make_tensor, low=0., device=device, dtype=dtype, requires_grad=requires_grad) 9099 9100 def make_log(shape): 9101 out = torch.nn.functional.log_softmax(make_arg(shape), -1) 9102 out.requires_grad_(requires_grad) 9103 return out 9104 9105 def make_prob(shape): 9106 out = torch.nn.functional.softmax(make_arg(shape), -1) 9107 out.requires_grad_(requires_grad) 9108 return out 9109 9110 shapes = ((2,), (2, 3)) 9111 reductions = ("none", "mean", "batchmean", "sum") 9112 for shape, reduction, log_target in product(shapes, reductions, (True, False)): 9113 input = make_log(shape) 9114 target = make_log(shape) if log_target else make_prob(shape) 9115 yield SampleInput(input, args=(target,), kwargs=dict(reduction=reduction, log_target=log_target)) 9116 9117def sample_inputs_pdist(op_info, device, dtype, requires_grad, **kwargs): 9118 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 9119 9120 yield from (SampleInput(make_input((n, m))) for n, m in itertools.product((1, S), repeat=2)) 9121 yield from (SampleInput(make_input((S, S)), kwargs=dict(p=p)) for p in (0.0, 1.0, 2.0, 10.0, float("inf"))) 9122 9123def reference_pdist(input, p=2): 9124 pdist = scipy.spatial.distance.pdist 9125 if p == 0: 9126 output = pdist(input, "hamming") * input.shape[1] 9127 elif p == float("inf"): 9128 output = pdist(input, lambda x, y: np.abs(x - y).max()) 9129 else: 9130 output = pdist(input, "minkowski", p=p) 9131 return output.astype(input.dtype) 9132 9133def sample_inputs_diagflat(op_info, device, dtype, requires_grad, **kwargs): 9134 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 9135 9136 yield SampleInput(make_input(())) 9137 yield SampleInput(make_input((2,))) 9138 yield SampleInput(make_input((2, 2))) 9139 yield SampleInput(make_input((2,)), offset=1) 9140 yield SampleInput(make_input((2,)), offset=-1) 9141 9142def sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs): 9143 unpool_name_to_pool_method_dict = { 9144 'nn.functional.max_unpool1d': torch.nn.functional.max_pool1d, 9145 'nn.functional.max_unpool2d': torch.nn.functional.max_pool2d, 9146 'nn.functional.max_unpool3d': torch.nn.functional.max_pool3d 9147 } 9148 9149 unpool_name_to_dim = { 9150 'nn.functional.max_unpool1d': 1, 9151 'nn.functional.max_unpool2d': 2, 9152 'nn.functional.max_unpool3d': 3 9153 } 9154 9155 unpool_to_pool_name_dict = {k: f'nn.functional.{v.__name__}' for k, v in unpool_name_to_pool_method_dict.items()} 9156 9157 pool_dim = unpool_name_to_dim[op_info.name] 9158 pool_method = unpool_name_to_pool_method_dict[op_info.name] 9159 9160 pool_op_info = copy.copy(op_info) 9161 pool_op_info.name = unpool_to_pool_name_dict[op_info.name] 9162 9163 for sample in sample_inputs_max_pool(pool_op_info, device, dtype, requires_grad, **kwargs): 9164 # shapes (C, ...) do not work as of now, 9165 # see https://github.com/pytorch/pytorch/issues/68337 9166 # TODO: remove once the issue is resolved 9167 if sample.input.dim() != pool_dim + 2: 9168 continue 9169 9170 # No dilation > 1 for max_unpool, 9171 # see https://github.com/pytorch/pytorch/issues/68420 9172 if sample.kwargs['dilation'] != 1: 9173 continue 9174 9175 # Can't unpool without indices 9176 if sample.kwargs['return_indices']: 9177 pool, indices = pool_method(sample.input, **sample.kwargs) 9178 # arg has to be a leaf 9179 arg = pool.detach().requires_grad_(requires_grad) 9180 sample_kwargs = { 9181 'kernel_size': sample.kwargs['kernel_size'], 9182 'stride': sample.kwargs['stride'], 9183 'padding': sample.kwargs['padding'], 9184 # output_size could be None but we specify it explicitly 9185 # to compensate for the information lose in pool due 9186 # to the floor/ceil operation used to compute the shapes 9187 'output_size': sample.input.size() 9188 } 9189 9190 yield SampleInput(arg, args=(indices,), kwargs=sample_kwargs) 9191 9192def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwargs): 9193 for sample in sample_inputs_max_unpool(op_info, device, dtype, requires_grad, **kwargs): 9194 indices = sample.args[0] 9195 # The samples for max_unpool are generated with max_pool. 9196 # It could be that a single element from the max_pool's 9197 # input is mapped to several locations in its output. 9198 # This situation leads to failed gradchecks because 9199 # the finite difference algorithm perturbs the elements 9200 # of the output one by one, and not in classes of 9201 # equivalences determined by whether two elements 9202 # in the output are coming from the same location in the 9203 # input (simply put, they have the same corresponding index). 9204 # So, there are two ways to resolve this issue: 9205 # 1. Extract a perturbation for one element and apply it all 9206 # the elements from the same equivalence class, or 9207 # 2. Make sure that the equivalence classes are all singletons, 9208 # i.e. the index tensor has to be comprised of only unique 9209 # indices. 9210 # Here we go with the solution 2, the easiest of all. 9211 if indices.unique().numel() == indices.numel(): 9212 yield sample 9213 9214def sample_inputs_multi_head_attention_forward(opinfo, device, dtype, requires_grad, **kwargs): 9215 make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 9216 9217 if requires_grad: 9218 # backward tests would take too long to complete, causing the job timeout. 9219 bsz = 2 9220 is_batcheds = (True,) 9221 use_separate_proj_weights = (False,) 9222 emb_sizes = (2,) 9223 src_lens = (XS,) 9224 tgt_lens = (XS,) 9225 heads = (2,) 9226 dropouts = (0.5,) 9227 mask_types = ("2d",) 9228 else: 9229 bsz = 2 9230 is_batcheds = (False, True) 9231 use_separate_proj_weights = (False, True) 9232 emb_sizes = (2, 4) 9233 src_lens = (XS,) 9234 tgt_lens = (XS, S) 9235 heads = (1, 2) 9236 dropouts = (0.0, 0.5) 9237 mask_types = (None, "2d", "3d") 9238 9239 for is_batched, use_separate_proj_weight, mask_type, emb_size, src_len, tgt_len, num_heads, dropout_p in itertools.product( 9240 is_batcheds, use_separate_proj_weights, mask_types, emb_sizes, src_lens, tgt_lens, heads, dropouts 9241 ): 9242 attn_mask = None 9243 if mask_type == "2d": 9244 attn_mask = make_input(src_len, tgt_len) 9245 elif mask_type == "3d": 9246 attn_mask = make_input((bsz if is_batched else 1) * num_heads, src_len, tgt_len) 9247 9248 if is_batched: 9249 q = make_input(src_len, bsz, emb_size) 9250 k = make_input(tgt_len, bsz, emb_size) 9251 v = make_input(tgt_len, bsz, emb_size) 9252 else: 9253 q = make_input(src_len, emb_size) 9254 k = make_input(tgt_len, emb_size) 9255 v = make_input(tgt_len, emb_size) 9256 if use_separate_proj_weight: 9257 in_proj_weight = None 9258 q_proj_weight = make_input(emb_size, emb_size) 9259 k_proj_weight = make_input(emb_size, emb_size) 9260 v_proj_weight = make_input(emb_size, emb_size) 9261 else: 9262 in_proj_weight = make_input(emb_size * 3, emb_size) 9263 q_proj_weight = None 9264 k_proj_weight = None 9265 v_proj_weight = None 9266 9267 bias_k = make_input(emb_size) 9268 bias_v = make_input(emb_size) 9269 in_proj_bias = make_input(emb_size * 3) 9270 out_proj_weight = make_input(emb_size, emb_size) 9271 out_proj_bias = make_input(emb_size) 9272 sample_args = ( 9273 k, v, emb_size, num_heads, in_proj_weight, 9274 in_proj_bias, bias_k, bias_v, False, 9275 dropout_p, out_proj_weight, out_proj_bias 9276 ) 9277 sample_kwargs = { 9278 "q_proj_weight" : q_proj_weight, 9279 "k_proj_weight" : k_proj_weight, 9280 "v_proj_weight" : v_proj_weight, 9281 "attn_mask" : attn_mask, 9282 "training" : True if dropout_p > 0.0 else False, 9283 "use_separate_proj_weight" : use_separate_proj_weight 9284 } 9285 9286 yield SampleInput(q, args=sample_args, kwargs=sample_kwargs) 9287 9288 9289# Includes some values such that N * N won't be a multiple of 4, 9290# which should ensure we test the vectorized and non-vectorized 9291# kernel code paths. 9292NUM_SIZE0_TENSORS = 10000 9293foreach_num_tensors = [20, 23] if not TEST_WITH_SLOW else [23, 30, 300] 9294_foreach_inputs_default_kwargs = {"noncontiguous": False, "same_size": False, "low": None, "high": None} 9295 9296 9297class ForeachRightmostArgType(enum.Enum): 9298 TensorList = enum.auto() 9299 ScalarList = enum.auto() 9300 Scalar = enum.auto() 9301 Tensor = enum.auto() 9302 9303 9304class ForeachSampleInput(SampleInput): 9305 # For TensorList <op> Scalar/Tensor, we compute the reference 9306 # by converting it into TensorList <op> ScalarList/TensorList and 9307 # then converting into multiple Tensor <op> Scalar/Tensor. 9308 # ref_args contains the args converted to TensorList <op> ScalarList/TensorList 9309 ref_args: Any 9310 disable_fastpath: bool 9311 9312 def __init__(self, *args, disable_fastpath=False, ref_args=None, **kwargs): 9313 super().__init__(*args, **kwargs) 9314 self.ref_args = ref_args or self.args 9315 self.disable_fastpath = disable_fastpath 9316 9317 9318class foreach_inputs_sample_func: 9319 def __init__( 9320 self, 9321 arity: int, 9322 rightmost_supports_scalar: bool, 9323 rightmost_supports_scalarlist: bool, 9324 rightmost_supports_tensor: bool = False, 9325 ) -> None: 9326 self.arity = arity 9327 self._set_rightmost_arg_types( 9328 rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor, 9329 ) 9330 self._intersperse_empty = (True, False) 9331 9332 def _set_rightmost_arg_types( 9333 self, 9334 rightmost_supports_scalar: bool, 9335 rightmost_supports_scalarlist: bool, 9336 rightmost_supports_tensor: bool, 9337 ) -> None: 9338 self._rightmost_arg_types = [ForeachRightmostArgType.TensorList] 9339 if self.arity > 1: 9340 if rightmost_supports_scalar: 9341 self._rightmost_arg_types.append(ForeachRightmostArgType.Scalar) 9342 if rightmost_supports_scalarlist: 9343 self._rightmost_arg_types.append(ForeachRightmostArgType.ScalarList) 9344 if rightmost_supports_tensor: 9345 self._rightmost_arg_types.append(ForeachRightmostArgType.Tensor) 9346 9347 def _sample_rightmost_arg( 9348 self, 9349 opinfo, 9350 rightmost_arg_type, 9351 device, 9352 dtype, 9353 num_tensors, 9354 allow_higher_dtype_scalars, 9355 **_foreach_inputs_kwargs, 9356 ): 9357 if rightmost_arg_type == ForeachRightmostArgType.TensorList: 9358 return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)] 9359 if rightmost_arg_type == ForeachRightmostArgType.Tensor: 9360 return [make_tensor( 9361 (), device=device, dtype=dtype, 9362 noncontiguous=_foreach_inputs_kwargs["noncontiguous"], 9363 requires_grad=_foreach_inputs_kwargs.get("requires_grad", False), 9364 )] 9365 should_use_simpler_scalars = opinfo.name == "_foreach_pow" and dtype in (torch.float16, torch.bfloat16) 9366 9367 def sample_float(): 9368 s = random.random() 9369 if should_use_simpler_scalars: 9370 return 1.0 if s > 0.5 else 2.0 9371 else: 9372 return 1.0 - s 9373 9374 high = 2 if should_use_simpler_scalars else 9 9375 if rightmost_arg_type == ForeachRightmostArgType.ScalarList: 9376 scalarlist_list = [] 9377 scalarlist_list.append([random.randint(0, high) + 1 for _ in range(num_tensors)]) 9378 9379 if allow_higher_dtype_scalars or dtype.is_floating_point: 9380 scalarlist_list.append([sample_float() for _ in range(num_tensors)]) 9381 if allow_higher_dtype_scalars or dtype.is_complex: 9382 scalarlist_list.append([complex(sample_float(), sample_float()) for _ in range(num_tensors)]) 9383 scalarlist_list.append([1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)]) 9384 scalarlist_list.append([True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)]) 9385 return scalarlist_list 9386 if rightmost_arg_type == ForeachRightmostArgType.Scalar: 9387 scalars = [] 9388 scalars.append(random.randint(1, high + 1)) 9389 if allow_higher_dtype_scalars or dtype.is_floating_point: 9390 scalars.append(sample_float()) 9391 if allow_higher_dtype_scalars or dtype.is_complex: 9392 scalars.append(complex(sample_float(), sample_float())) 9393 scalars.append(True) 9394 return scalars 9395 raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") 9396 9397 def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): 9398 if self.arity == 1: 9399 if "foreach_abs" in opinfo.name and dtype in complex_types(): 9400 return True 9401 # unary 9402 if opinfo.ref in (torch.abs, torch.neg): 9403 return False 9404 if opinfo.ref_inplace in (torch.Tensor.zero_,): 9405 return False 9406 return dtype in integral_types_and(torch.bool) 9407 if self.arity < 2 or rightmost_arg_type == ForeachRightmostArgType.Tensor: 9408 return None 9409 if "foreach_pow" in opinfo.name and dtype in integral_types_and(torch.bool): 9410 return True 9411 if any( 9412 foreach_name in opinfo.name 9413 for foreach_name in ("foreach_clamp_max", "foreach_clamp_min", "foreach_maximum", "foreach_minimum") 9414 ) and dtype in integral_types_and(torch.bool): 9415 return True 9416 if rightmost_arg_type == ForeachRightmostArgType.TensorList: 9417 disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool) 9418 if "foreach_add" in opinfo.name and dtype == torch.bool: 9419 disable_fastpath = True 9420 return disable_fastpath 9421 elif rightmost_arg_type == ForeachRightmostArgType.Scalar: 9422 disable_fastpath = "foreach_div" in opinfo.name and dtype in integral_types_and(torch.bool) 9423 if isinstance(rightmost_arg, bool): 9424 disable_fastpath |= dtype == torch.bool 9425 if opinfo.ref in (torch.add, torch.mul): 9426 disable_fastpath = False 9427 elif isinstance(rightmost_arg, int): 9428 disable_fastpath |= dtype == torch.bool 9429 elif isinstance(rightmost_arg, float): 9430 disable_fastpath |= dtype in integral_types_and(torch.bool) 9431 elif isinstance(rightmost_arg, complex): 9432 disable_fastpath |= dtype not in complex_types() 9433 else: 9434 raise AssertionError(f"Invalid scalar of type {rightmost_arg_type} - {rightmost_arg}") 9435 return disable_fastpath 9436 elif rightmost_arg_type == ForeachRightmostArgType.ScalarList: 9437 disable_fastpath = opinfo.ref == torch.div and dtype in integral_types_and(torch.bool) 9438 elmt_t = type(rightmost_arg[0]) 9439 has_same_type = all(isinstance(v, elmt_t) for v in rightmost_arg) 9440 if not has_same_type: 9441 return dtype not in complex_types() 9442 if isinstance(rightmost_arg[0], bool): 9443 if ("foreach_add" in opinfo.name or "foreach_mul" in opinfo.name) and dtype == torch.bool: 9444 disable_fastpath = False 9445 elif isinstance(rightmost_arg[0], int): 9446 disable_fastpath |= dtype == torch.bool 9447 elif isinstance(rightmost_arg[0], float): 9448 disable_fastpath |= dtype in integral_types_and(torch.bool) 9449 elif isinstance(rightmost_arg[0], complex): 9450 disable_fastpath |= dtype not in complex_types() 9451 else: 9452 raise AssertionError(f"Invalid scalarlist of {rightmost_arg}") 9453 return disable_fastpath 9454 else: 9455 raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") 9456 9457 def _sample_kwargs(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): 9458 kwargs = {} 9459 if rightmost_arg_type == ForeachRightmostArgType.TensorList and opinfo.supports_alpha_param: 9460 if dtype in integral_types_and(torch.bool): 9461 kwargs["alpha"] = 3 9462 elif dtype.is_complex: 9463 kwargs["alpha"] = complex(3, 3) 9464 else: 9465 kwargs["alpha"] = 3.14 9466 if self.arity > 1: 9467 kwargs["disable_fastpath"] = self._should_disable_fastpath(opinfo, rightmost_arg, rightmost_arg_type, dtype) 9468 return kwargs 9469 9470 def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): 9471 assert "num_input_tensors" not in kwargs 9472 _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} 9473 _foreach_inputs_kwargs["requires_grad"] = requires_grad 9474 allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) 9475 for rightmost_arg_type in self._rightmost_arg_types: 9476 zero_size_foreach_inputs_kwargs = copy.deepcopy(_foreach_inputs_kwargs) 9477 zero_size_foreach_inputs_kwargs["zero_size"] = True 9478 input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs) 9479 if self.arity > 1: 9480 args = [ 9481 sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs) 9482 for _ in range(self.arity - 2) 9483 ] 9484 args.append( 9485 self._sample_rightmost_arg( 9486 opinfo, 9487 ForeachRightmostArgType.TensorList, 9488 device, 9489 dtype, 9490 NUM_SIZE0_TENSORS, 9491 allow_higher_dtype_scalars=allow_higher_dtype_scalars, 9492 **zero_size_foreach_inputs_kwargs, 9493 )[0]) 9494 kwargs = self._sample_kwargs( 9495 opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype) 9496 else: 9497 args = [] 9498 kwargs = {} 9499 if opinfo.ref in (torch.abs, torch.neg): 9500 kwargs["disable_fastpath"] = False 9501 else: 9502 kwargs["disable_fastpath"] = dtype in integral_types_and(torch.bool) 9503 yield ForeachSampleInput(input, *args, **kwargs) 9504 9505 def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): 9506 num_input_tensors_specified = "num_input_tensors" in kwargs 9507 num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors 9508 assert isinstance(num_input_tensors, list) 9509 _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} 9510 _foreach_inputs_kwargs["requires_grad"] = requires_grad 9511 _foreach_inputs_kwargs["zero_size"] = False 9512 allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) 9513 9514 # add empty tensor interspersion to test fully fixing #100701 9515 for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product( 9516 num_input_tensors, self._rightmost_arg_types, self._intersperse_empty): 9517 if intersperse_empty_tensors and (num_tensors != max(num_input_tensors) or str(device) == 'cpu'): 9518 # generate interspersed empty tensors for only 1 N on non-cpu device to lessen redundancy 9519 continue 9520 _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors 9521 input = sample_inputs_foreach( 9522 None, device, dtype, num_tensors, **_foreach_inputs_kwargs) 9523 args = [] 9524 if self.arity > 1: 9525 args = [ 9526 sample_inputs_foreach( 9527 None, device, dtype, num_tensors, **_foreach_inputs_kwargs) 9528 for _ in range(self.arity - 2) 9529 ] 9530 rightmost_arg_list = self._sample_rightmost_arg( 9531 opinfo, rightmost_arg_type, device, dtype, num_tensors, allow_higher_dtype_scalars, 9532 **_foreach_inputs_kwargs) 9533 for rightmost_arg in rightmost_arg_list: 9534 args.append(rightmost_arg) 9535 kwargs = self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype) 9536 ref_args = args 9537 if rightmost_arg_type in (ForeachRightmostArgType.Scalar, ForeachRightmostArgType.Tensor): 9538 ref_args = args[:-1] + [[args[-1] for _ in range(num_tensors)]] 9539 sample = ForeachSampleInput(input, *args, ref_args=ref_args, **kwargs) 9540 yield sample 9541 args.pop() 9542 else: 9543 yield ForeachSampleInput( 9544 input, 9545 *args, 9546 disable_fastpath=self._should_disable_fastpath(opinfo, None, None, dtype), 9547 ) 9548 9549 9550class foreach_max_sample_func(foreach_inputs_sample_func): 9551 def __init__( 9552 self, 9553 arity: int, 9554 rightmost_supports_scalar: bool, 9555 rightmost_supports_scalarlist: bool, 9556 rightmost_supports_tensor: bool = False, 9557 ) -> None: 9558 super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist, rightmost_supports_tensor) 9559 self._intersperse_empty = (False,) 9560 9561 def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): 9562 return [] 9563 9564 def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): 9565 return False 9566 9567 9568class foreach_norm_sample_func(foreach_inputs_sample_func): 9569 def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): 9570 assert "num_input_tensors" not in kwargs 9571 _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} 9572 _foreach_inputs_kwargs["requires_grad"] = requires_grad 9573 for ord in (0, 1, 2, -1, -2, float('inf'), float('-inf')): 9574 input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) 9575 disable_fastpath = True 9576 if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): 9577 disable_fastpath = False 9578 yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath) 9579 9580 def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): 9581 num_input_tensors = kwargs.pop("num_input_tensors", foreach_num_tensors) 9582 assert isinstance(num_input_tensors, list) 9583 _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} 9584 _foreach_inputs_kwargs["requires_grad"] = requires_grad 9585 _allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) 9586 9587 for num_tensors, ord, out_dtype in product( 9588 num_input_tensors, 9589 (0, 1, 2, -1, -2, float('inf'), float('-inf')), 9590 (None,) + (torch.complex128,) if dtype in complex_types() else (torch.float64,), 9591 ): 9592 input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) 9593 disable_fastpath = True 9594 if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): 9595 disable_fastpath = False 9596 yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath, dtype=out_dtype) 9597 9598 # Also test nan propagation with a single tensor, but skip autograd testing 9599 if not requires_grad: 9600 nan_inputs = [ 9601 [float('nan')], 9602 [float('nan'), 1.0], 9603 [1.0, float('nan')], 9604 [1.0, 2.0, 3.0, float('nan'), float('nan'), 7.0, float('nan'), float('nan'), -1.5, 6.0], 9605 [7.0, 3.0, float('nan'), float('nan'), -1.5, 6.0], 9606 [3.0, float('nan'), float('nan'), -1.5, 6.0], 9607 ] 9608 for input in nan_inputs: 9609 x = torch.tensor(input, device=device) 9610 disable_fastpath = True 9611 if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): 9612 disable_fastpath = False 9613 yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath) 9614 9615 9616class foreach_pointwise_sample_func(foreach_inputs_sample_func): 9617 9618 def __init__( 9619 self, 9620 arity: int = 3, 9621 rightmost_supports_scalar: bool = False, 9622 rightmost_supports_scalarlist: bool = False, 9623 ): 9624 super().__init__(arity, rightmost_supports_scalar, rightmost_supports_scalarlist) 9625 9626 def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): 9627 return dtype in integral_types_and(torch.bool) and opinfo.ref in (torch.addcmul,) 9628 9629 def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, **kwargs): 9630 assert "num_input_tensors" not in kwargs 9631 _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} 9632 _foreach_inputs_kwargs["requires_grad"] = requires_grad 9633 # zero_size tensor 9634 input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) 9635 args = [ 9636 sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) 9637 for _ in range(2) 9638 ] 9639 if "scalars" in kwargs: 9640 del kwargs["scalars"] 9641 kwargs.update(self._sample_kwargs(opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype)) 9642 yield ForeachSampleInput(input, *args, **kwargs) 9643 9644 def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): 9645 num_input_tensors_specified = "num_input_tensors" in kwargs 9646 num_input_tensors = kwargs.pop("num_input_tensors") if num_input_tensors_specified else foreach_num_tensors 9647 assert isinstance(num_input_tensors, list) 9648 _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} 9649 _foreach_inputs_kwargs["requires_grad"] = requires_grad 9650 allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) 9651 9652 for num_tensors, rightmost_arg_type in itertools.product(num_input_tensors, self._rightmost_arg_types): 9653 input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) 9654 args = [ 9655 sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) 9656 for _ in range(2 - int(rightmost_arg_type == ForeachRightmostArgType.TensorList)) 9657 ] 9658 rightmost_arg_list = self._sample_rightmost_arg( 9659 opinfo, 9660 rightmost_arg_type, 9661 device, 9662 dtype, 9663 num_tensors, 9664 zero_size=False, 9665 allow_higher_dtype_scalars=allow_higher_dtype_scalars, 9666 **_foreach_inputs_kwargs, 9667 ) 9668 for rightmost_arg in rightmost_arg_list: 9669 kwargs = {} 9670 if rightmost_arg_type == ForeachRightmostArgType.TensorList: 9671 args.append(rightmost_arg) 9672 elif rightmost_arg_type in [ForeachRightmostArgType.Tensor, ForeachRightmostArgType.ScalarList]: 9673 kwargs["scalars"] = rightmost_arg 9674 else: 9675 kwargs["value"] = rightmost_arg 9676 kwargs.update(self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype)) 9677 assert len(args) == 2, f"{len(args)=}" 9678 sample = ForeachSampleInput(input, *args, **kwargs) 9679 yield sample 9680 if rightmost_arg_type == ForeachRightmostArgType.TensorList: 9681 args.pop() 9682 9683 9684foreach_unary_op_db: List[OpInfo] = [ 9685 ForeachFuncInfo( 9686 'exp', 9687 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 9688 backward_requires_result=True, 9689 supports_autograd=True, 9690 supports_inplace_autograd=True, 9691 supports_forward_ad=True, 9692 decorators=( 9693 DecorateInfo( 9694 unittest.expectedFailure, 9695 "TestMeta", 9696 "test_dispatch_meta_inplace", 9697 dtypes=integral_types_and(torch.bool,), 9698 ), 9699 DecorateInfo( 9700 unittest.expectedFailure, 9701 "TestMeta", 9702 "test_dispatch_symbolic_meta_inplace", 9703 dtypes=integral_types_and(torch.bool,), 9704 ), 9705 DecorateInfo( 9706 unittest.expectedFailure, 9707 "TestMeta", 9708 "test_meta_inplace", 9709 dtypes=integral_types_and(torch.bool,), 9710 ), 9711 ), 9712 ), 9713 ForeachFuncInfo( 9714 'acos', 9715 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 9716 supports_autograd=True, 9717 supports_inplace_autograd=True, 9718 supports_forward_ad=True, 9719 decorators=( 9720 DecorateInfo( 9721 unittest.expectedFailure, 9722 "TestMeta", 9723 "test_dispatch_meta_inplace", 9724 dtypes=integral_types_and(torch.bool,), 9725 ), 9726 DecorateInfo( 9727 unittest.expectedFailure, 9728 "TestMeta", 9729 "test_dispatch_symbolic_meta_inplace", 9730 dtypes=integral_types_and(torch.bool,), 9731 ), 9732 DecorateInfo( 9733 unittest.expectedFailure, 9734 "TestMeta", 9735 "test_meta_inplace", 9736 dtypes=integral_types_and(torch.bool,), 9737 ), 9738 ), 9739 ), 9740 ForeachFuncInfo( 9741 'asin', 9742 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 9743 supports_autograd=True, 9744 supports_inplace_autograd=True, 9745 supports_forward_ad=True, 9746 decorators=( 9747 DecorateInfo( 9748 unittest.expectedFailure, 9749 "TestMeta", 9750 "test_dispatch_meta_inplace", 9751 dtypes=integral_types_and(torch.bool,), 9752 ), 9753 DecorateInfo( 9754 unittest.expectedFailure, 9755 "TestMeta", 9756 "test_dispatch_symbolic_meta_inplace", 9757 dtypes=integral_types_and(torch.bool,), 9758 ), 9759 DecorateInfo( 9760 unittest.expectedFailure, 9761 "TestMeta", 9762 "test_meta_inplace", 9763 dtypes=integral_types_and(torch.bool,), 9764 ), 9765 ), 9766 ), 9767 ForeachFuncInfo( 9768 'atan', 9769 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 9770 supports_autograd=True, 9771 supports_inplace_autograd=True, 9772 supports_forward_ad=True, 9773 decorators=( 9774 DecorateInfo( 9775 unittest.expectedFailure, 9776 "TestMeta", 9777 "test_dispatch_meta_inplace", 9778 dtypes=integral_types_and(torch.bool,), 9779 ), 9780 DecorateInfo( 9781 unittest.expectedFailure, 9782 "TestMeta", 9783 "test_dispatch_symbolic_meta_inplace", 9784 dtypes=integral_types_and(torch.bool,), 9785 ), 9786 DecorateInfo( 9787 unittest.expectedFailure, 9788 "TestMeta", 9789 "test_meta_inplace", 9790 dtypes=integral_types_and(torch.bool,), 9791 ), 9792 ), 9793 ), 9794 ForeachFuncInfo( 9795 'cos', 9796 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 9797 supports_autograd=True, 9798 supports_inplace_autograd=True, 9799 supports_forward_ad=True, 9800 decorators=( 9801 DecorateInfo( 9802 unittest.expectedFailure, 9803 "TestMeta", 9804 "test_dispatch_meta_inplace", 9805 dtypes=integral_types_and(torch.bool,), 9806 ), 9807 DecorateInfo( 9808 unittest.expectedFailure, 9809 "TestMeta", 9810 "test_dispatch_symbolic_meta_inplace", 9811 dtypes=integral_types_and(torch.bool,), 9812 ), 9813 DecorateInfo( 9814 unittest.expectedFailure, 9815 "TestMeta", 9816 "test_meta_inplace", 9817 dtypes=integral_types_and(torch.bool,), 9818 ), 9819 ), 9820 ), 9821 ForeachFuncInfo( 9822 'cosh', 9823 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 9824 supports_autograd=True, 9825 supports_inplace_autograd=True, 9826 supports_forward_ad=True, 9827 decorators=( 9828 DecorateInfo( 9829 unittest.expectedFailure, 9830 "TestMeta", 9831 "test_dispatch_meta_inplace", 9832 dtypes=integral_types_and(torch.bool,), 9833 ), 9834 DecorateInfo( 9835 unittest.expectedFailure, 9836 "TestMeta", 9837 "test_dispatch_symbolic_meta_inplace", 9838 dtypes=integral_types_and(torch.bool,), 9839 ), 9840 DecorateInfo( 9841 unittest.expectedFailure, 9842 "TestMeta", 9843 "test_meta_inplace", 9844 dtypes=integral_types_and(torch.bool,), 9845 ), 9846 ), 9847 ), 9848 ForeachFuncInfo( 9849 'log', 9850 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 9851 supports_autograd=True, 9852 supports_inplace_autograd=True, 9853 supports_forward_ad=True, 9854 decorators=( 9855 DecorateInfo( 9856 unittest.expectedFailure, 9857 "TestMeta", 9858 "test_dispatch_meta_inplace", 9859 dtypes=integral_types_and(torch.bool,), 9860 ), 9861 DecorateInfo( 9862 unittest.expectedFailure, 9863 "TestMeta", 9864 "test_dispatch_symbolic_meta_inplace", 9865 dtypes=integral_types_and(torch.bool,), 9866 ), 9867 DecorateInfo( 9868 unittest.expectedFailure, 9869 "TestMeta", 9870 "test_meta_inplace", 9871 dtypes=integral_types_and(torch.bool,), 9872 ), 9873 ), 9874 ), 9875 ForeachFuncInfo( 9876 'log10', 9877 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 9878 supports_autograd=True, 9879 supports_inplace_autograd=True, 9880 supports_forward_ad=True, 9881 decorators=( 9882 DecorateInfo( 9883 unittest.expectedFailure, 9884 "TestMeta", 9885 "test_dispatch_meta_inplace", 9886 dtypes=integral_types_and(torch.bool,), 9887 ), 9888 DecorateInfo( 9889 unittest.expectedFailure, 9890 "TestMeta", 9891 "test_dispatch_symbolic_meta_inplace", 9892 dtypes=integral_types_and(torch.bool,), 9893 ), 9894 DecorateInfo( 9895 unittest.expectedFailure, 9896 "TestMeta", 9897 "test_meta_inplace", 9898 dtypes=integral_types_and(torch.bool,), 9899 ), 9900 ), 9901 ), 9902 ForeachFuncInfo( 9903 'log2', 9904 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 9905 supports_autograd=True, 9906 supports_inplace_autograd=True, 9907 supports_forward_ad=True, 9908 decorators=( 9909 DecorateInfo( 9910 unittest.expectedFailure, 9911 "TestMeta", 9912 "test_dispatch_meta_inplace", 9913 dtypes=integral_types_and(torch.bool,), 9914 ), 9915 DecorateInfo( 9916 unittest.expectedFailure, 9917 "TestMeta", 9918 "test_dispatch_symbolic_meta_inplace", 9919 dtypes=integral_types_and(torch.bool,), 9920 ), 9921 DecorateInfo( 9922 unittest.expectedFailure, 9923 "TestMeta", 9924 "test_meta_inplace", 9925 dtypes=integral_types_and(torch.bool,), 9926 ), 9927 ), 9928 ), 9929 ForeachFuncInfo( 9930 'tan', 9931 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 9932 backward_requires_result=True, 9933 supports_autograd=True, 9934 supports_inplace_autograd=True, 9935 supports_forward_ad=True, 9936 decorators=( 9937 # due to https://github.com/pytorch/pytorch/pull/102427 enabling jiterator for complex 9938 DecorateInfo( 9939 unittest.expectedFailure, 9940 "TestMeta", 9941 "test_dispatch_meta_inplace", 9942 dtypes=integral_types_and(torch.bool,), 9943 ), 9944 DecorateInfo( 9945 unittest.expectedFailure, 9946 "TestMeta", 9947 "test_dispatch_symbolic_meta_inplace", 9948 dtypes=integral_types_and(torch.bool,), 9949 ), 9950 DecorateInfo( 9951 unittest.expectedFailure, 9952 "TestMeta", 9953 "test_meta_inplace", 9954 dtypes=integral_types_and(torch.bool,), 9955 ), 9956 DecorateInfo( 9957 toleranceOverride( 9958 { 9959 torch.complex64: tol(atol=3e-04, rtol=2e-05) 9960 } 9961 ), 9962 'TestForeach', 9963 'test_parity', 9964 device_type='cuda' 9965 ), 9966 ), 9967 ), 9968 ForeachFuncInfo( 9969 'tanh', 9970 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 9971 backward_requires_result=True, 9972 supports_autograd=True, 9973 supports_inplace_autograd=True, 9974 supports_forward_ad=True, 9975 decorators=( 9976 DecorateInfo( 9977 unittest.expectedFailure, 9978 "TestMeta", 9979 "test_dispatch_meta_inplace", 9980 dtypes=integral_types_and(torch.bool,), 9981 ), 9982 DecorateInfo( 9983 unittest.expectedFailure, 9984 "TestMeta", 9985 "test_dispatch_symbolic_meta_inplace", 9986 dtypes=integral_types_and(torch.bool,), 9987 ), 9988 DecorateInfo( 9989 unittest.expectedFailure, 9990 "TestMeta", 9991 "test_meta_inplace", 9992 dtypes=integral_types_and(torch.bool,), 9993 ), 9994 DecorateInfo( 9995 toleranceOverride( 9996 {torch.complex64: tol(atol=5e-03, rtol=1e-04)} 9997 ), 9998 'TestForeach', 9999 'test_parity', 10000 device_type='cuda' 10001 ), 10002 ), 10003 ), 10004 ForeachFuncInfo( 10005 'sin', 10006 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10007 supports_autograd=True, 10008 supports_inplace_autograd=True, 10009 supports_forward_ad=True, 10010 decorators=( 10011 DecorateInfo( 10012 unittest.expectedFailure, 10013 "TestMeta", 10014 "test_dispatch_meta_inplace", 10015 dtypes=integral_types_and(torch.bool,), 10016 ), 10017 DecorateInfo( 10018 unittest.expectedFailure, 10019 "TestMeta", 10020 "test_dispatch_symbolic_meta_inplace", 10021 dtypes=integral_types_and(torch.bool,), 10022 ), 10023 DecorateInfo( 10024 unittest.expectedFailure, 10025 "TestMeta", 10026 "test_meta_inplace", 10027 dtypes=integral_types_and(torch.bool,), 10028 ), 10029 ), 10030 ), 10031 ForeachFuncInfo( 10032 'sinh', 10033 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10034 supports_autograd=True, 10035 supports_inplace_autograd=True, 10036 supports_forward_ad=True, 10037 decorators=( 10038 DecorateInfo( 10039 unittest.expectedFailure, 10040 "TestMeta", 10041 "test_dispatch_meta_inplace", 10042 dtypes=integral_types_and(torch.bool), 10043 ), 10044 DecorateInfo( 10045 unittest.expectedFailure, 10046 "TestMeta", 10047 "test_dispatch_symbolic_meta_inplace", 10048 dtypes=integral_types_and(torch.bool), 10049 ), 10050 DecorateInfo( 10051 unittest.expectedFailure, 10052 "TestMeta", 10053 "test_meta_inplace", 10054 dtypes=integral_types_and(torch.bool), 10055 ), 10056 ), 10057 ), 10058 ForeachFuncInfo( 10059 'neg', 10060 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10061 supports_autograd=True, 10062 supports_inplace_autograd=True, 10063 supports_forward_ad=True, 10064 decorators=( 10065 DecorateInfo( 10066 unittest.expectedFailure, 10067 "TestMeta", 10068 "test_dispatch_meta_inplace", 10069 dtypes=(torch.bool,), 10070 ), 10071 DecorateInfo( 10072 unittest.expectedFailure, 10073 "TestMeta", 10074 "test_dispatch_meta_outplace", 10075 dtypes=(torch.bool,), 10076 ), 10077 DecorateInfo( 10078 unittest.expectedFailure, 10079 "TestMeta", 10080 "test_dispatch_symbolic_meta_inplace", 10081 dtypes=(torch.bool,), 10082 ), 10083 DecorateInfo( 10084 unittest.expectedFailure, 10085 "TestMeta", 10086 "test_dispatch_symbolic_meta_outplace", 10087 dtypes=(torch.bool,), 10088 ), 10089 DecorateInfo( 10090 unittest.expectedFailure, 10091 "TestMeta", 10092 "test_meta_inplace", 10093 dtypes=(torch.bool,), 10094 ), 10095 DecorateInfo( 10096 unittest.expectedFailure, 10097 "TestMeta", 10098 "test_meta_outplace", 10099 dtypes=(torch.bool,), 10100 ), 10101 DecorateInfo( 10102 unittest.expectedFailure, 10103 "TestForeach", 10104 "test_unary_op_tensors_on_different_devices", 10105 device_type="cuda", 10106 dtypes=(torch.bool,), 10107 ), 10108 ), 10109 ), 10110 ForeachFuncInfo( 10111 'sqrt', 10112 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10113 supports_autograd=True, 10114 supports_inplace_autograd=True, 10115 supports_forward_ad=True, 10116 backward_requires_result=True, 10117 decorators=( 10118 DecorateInfo( 10119 unittest.expectedFailure, 10120 "TestMeta", 10121 "test_dispatch_meta_inplace", 10122 dtypes=integral_types_and(torch.bool), 10123 ), 10124 DecorateInfo( 10125 unittest.expectedFailure, 10126 "TestMeta", 10127 "test_dispatch_symbolic_meta_inplace", 10128 dtypes=integral_types_and(torch.bool), 10129 ), 10130 DecorateInfo( 10131 unittest.expectedFailure, 10132 "TestMeta", 10133 "test_meta_inplace", 10134 dtypes=integral_types_and(torch.bool), 10135 ), 10136 ), 10137 ), 10138 ForeachFuncInfo( 10139 'ceil', 10140 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10141 supports_autograd=True, 10142 supports_inplace_autograd=True, 10143 supports_forward_ad=True, 10144 decorators=( 10145 DecorateInfo( 10146 unittest.expectedFailure, 10147 "TestMeta", 10148 "test_dispatch_meta_inplace", 10149 dtypes=complex_types_and(torch.bool), 10150 ), 10151 DecorateInfo( 10152 unittest.expectedFailure, 10153 "TestMeta", 10154 "test_dispatch_meta_outplace", 10155 dtypes=complex_types_and(torch.bool), 10156 ), 10157 DecorateInfo( 10158 unittest.expectedFailure, 10159 "TestMeta", 10160 "test_dispatch_symbolic_meta_inplace", 10161 dtypes=complex_types_and(torch.bool), 10162 ), 10163 DecorateInfo( 10164 unittest.expectedFailure, 10165 "TestMeta", 10166 "test_dispatch_symbolic_meta_outplace", 10167 dtypes=complex_types_and(torch.bool), 10168 ), 10169 DecorateInfo( 10170 unittest.expectedFailure, 10171 "TestMeta", 10172 "test_meta_inplace", 10173 dtypes=complex_types_and(torch.bool), 10174 ), 10175 DecorateInfo( 10176 unittest.expectedFailure, 10177 "TestMeta", 10178 "test_meta_outplace", 10179 dtypes=complex_types_and(torch.bool), 10180 ), 10181 DecorateInfo( 10182 unittest.expectedFailure, 10183 "TestForeach", 10184 "test_autodiff", 10185 device_type="cuda", 10186 dtypes=(torch.complex128,), 10187 ), 10188 ), 10189 ), 10190 ForeachFuncInfo( 10191 'erf', 10192 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10193 supports_autograd=True, 10194 supports_inplace_autograd=True, 10195 supports_forward_ad=True, 10196 decorators=( 10197 DecorateInfo( 10198 unittest.expectedFailure, 10199 "TestMeta", 10200 "test_dispatch_meta_inplace", 10201 dtypes=integral_types_and(torch.bool) + complex_types(), 10202 ), 10203 DecorateInfo( 10204 unittest.expectedFailure, 10205 "TestMeta", 10206 "test_dispatch_meta_outplace", 10207 dtypes=complex_types(), 10208 ), 10209 DecorateInfo( 10210 unittest.expectedFailure, 10211 "TestMeta", 10212 "test_dispatch_symbolic_meta_inplace", 10213 dtypes=integral_types_and(torch.bool) + complex_types(), 10214 ), 10215 DecorateInfo( 10216 unittest.expectedFailure, 10217 "TestMeta", 10218 "test_dispatch_symbolic_meta_outplace", 10219 dtypes=complex_types(), 10220 ), 10221 DecorateInfo( 10222 unittest.expectedFailure, 10223 "TestMeta", 10224 "test_meta_inplace", 10225 dtypes=integral_types_and(torch.bool) + complex_types(), 10226 ), 10227 DecorateInfo( 10228 unittest.expectedFailure, 10229 "TestMeta", 10230 "test_meta_outplace", 10231 dtypes=complex_types(), 10232 ), 10233 DecorateInfo( 10234 unittest.expectedFailure, 10235 "TestForeach", 10236 "test_autodiff", 10237 device_type="cuda", 10238 dtypes=(torch.complex128,), 10239 ), 10240 ), 10241 ), 10242 ForeachFuncInfo( 10243 'erfc', 10244 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10245 supports_autograd=True, 10246 supports_inplace_autograd=True, 10247 supports_forward_ad=True, 10248 decorators=( 10249 DecorateInfo( 10250 unittest.expectedFailure, 10251 "TestMeta", 10252 "test_dispatch_meta_inplace", 10253 dtypes=integral_types_and(torch.bool) + complex_types(), 10254 ), 10255 DecorateInfo( 10256 unittest.expectedFailure, 10257 "TestMeta", 10258 "test_dispatch_meta_outplace", 10259 dtypes=complex_types(), 10260 ), 10261 DecorateInfo( 10262 unittest.expectedFailure, 10263 "TestMeta", 10264 "test_dispatch_symbolic_meta_inplace", 10265 dtypes=integral_types_and(torch.bool) + complex_types(), 10266 ), 10267 DecorateInfo( 10268 unittest.expectedFailure, 10269 "TestMeta", 10270 "test_dispatch_symbolic_meta_outplace", 10271 dtypes=complex_types(), 10272 ), 10273 DecorateInfo( 10274 unittest.expectedFailure, 10275 "TestMeta", 10276 "test_meta_inplace", 10277 dtypes=integral_types_and(torch.bool) + complex_types(), 10278 ), 10279 DecorateInfo( 10280 unittest.expectedFailure, 10281 "TestMeta", 10282 "test_meta_outplace", 10283 dtypes=complex_types(), 10284 ), 10285 DecorateInfo( 10286 unittest.expectedFailure, 10287 "TestForeach", 10288 "test_autodiff", 10289 device_type="cuda", 10290 dtypes=(torch.complex128,), 10291 ), 10292 ), 10293 ), 10294 ForeachFuncInfo( 10295 'expm1', 10296 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10297 supports_autograd=True, 10298 supports_inplace_autograd=True, 10299 supports_forward_ad=True, 10300 backward_requires_result=True, 10301 decorators=( 10302 DecorateInfo( 10303 unittest.expectedFailure, 10304 "TestMeta", 10305 "test_dispatch_meta_inplace", 10306 dtypes=integral_types_and(torch.bool), 10307 ), 10308 DecorateInfo( 10309 unittest.expectedFailure, 10310 "TestMeta", 10311 "test_dispatch_symbolic_meta_inplace", 10312 dtypes=integral_types_and(torch.bool), 10313 ), 10314 DecorateInfo( 10315 unittest.expectedFailure, 10316 "TestMeta", 10317 "test_meta_inplace", 10318 dtypes=integral_types_and(torch.bool), 10319 ), 10320 ), 10321 ), 10322 ForeachFuncInfo( 10323 'floor', 10324 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10325 supports_autograd=True, 10326 supports_inplace_autograd=True, 10327 supports_forward_ad=True, 10328 decorators=( 10329 DecorateInfo( 10330 unittest.expectedFailure, 10331 "TestMeta", 10332 "test_dispatch_meta_inplace", 10333 dtypes=complex_types_and(torch.bool), 10334 ), 10335 DecorateInfo( 10336 unittest.expectedFailure, 10337 "TestMeta", 10338 "test_dispatch_meta_outplace", 10339 dtypes=complex_types_and(torch.bool), 10340 ), 10341 DecorateInfo( 10342 unittest.expectedFailure, 10343 "TestMeta", 10344 "test_dispatch_symbolic_meta_inplace", 10345 dtypes=complex_types_and(torch.bool), 10346 ), 10347 DecorateInfo( 10348 unittest.expectedFailure, 10349 "TestMeta", 10350 "test_dispatch_symbolic_meta_outplace", 10351 dtypes=complex_types_and(torch.bool), 10352 ), 10353 DecorateInfo( 10354 unittest.expectedFailure, 10355 "TestMeta", 10356 "test_meta_inplace", 10357 dtypes=complex_types_and(torch.bool), 10358 ), 10359 DecorateInfo( 10360 unittest.expectedFailure, 10361 "TestMeta", 10362 "test_meta_outplace", 10363 dtypes=complex_types_and(torch.bool), 10364 ), 10365 DecorateInfo( 10366 unittest.expectedFailure, 10367 "TestForeach", 10368 "test_autodiff", 10369 device_type="cuda", 10370 dtypes=(torch.complex128,), 10371 ), 10372 ), 10373 ), 10374 ForeachFuncInfo( 10375 'log1p', 10376 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10377 supports_autograd=True, 10378 supports_inplace_autograd=True, 10379 supports_forward_ad=True, 10380 decorators=( 10381 DecorateInfo( 10382 unittest.expectedFailure, 10383 "TestMeta", 10384 "test_dispatch_meta_inplace", 10385 dtypes=integral_types_and(torch.bool), 10386 ), 10387 DecorateInfo( 10388 unittest.expectedFailure, 10389 "TestMeta", 10390 "test_dispatch_symbolic_meta_inplace", 10391 dtypes=integral_types_and(torch.bool), 10392 ), 10393 DecorateInfo( 10394 unittest.expectedFailure, 10395 "TestMeta", 10396 "test_meta_inplace", 10397 dtypes=integral_types_and(torch.bool), 10398 ), 10399 ), 10400 ), 10401 ForeachFuncInfo( 10402 'round', 10403 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10404 supports_autograd=True, 10405 supports_inplace_autograd=True, 10406 supports_forward_ad=True, 10407 decorators=( 10408 DecorateInfo( 10409 unittest.expectedFailure, 10410 "TestMeta", 10411 "test_dispatch_meta_inplace", 10412 dtypes=complex_types_and(torch.bool), 10413 ), 10414 DecorateInfo( 10415 unittest.expectedFailure, 10416 "TestMeta", 10417 "test_dispatch_meta_outplace", 10418 dtypes=complex_types_and(torch.bool), 10419 ), 10420 DecorateInfo( 10421 unittest.expectedFailure, 10422 "TestMeta", 10423 "test_dispatch_symbolic_meta_inplace", 10424 dtypes=complex_types_and(torch.bool), 10425 ), 10426 DecorateInfo( 10427 unittest.expectedFailure, 10428 "TestMeta", 10429 "test_dispatch_symbolic_meta_outplace", 10430 dtypes=complex_types_and(torch.bool), 10431 ), 10432 DecorateInfo( 10433 unittest.expectedFailure, 10434 "TestMeta", 10435 "test_meta_inplace", 10436 dtypes=complex_types_and(torch.bool), 10437 ), 10438 DecorateInfo( 10439 unittest.expectedFailure, 10440 "TestMeta", 10441 "test_meta_outplace", 10442 dtypes=complex_types_and(torch.bool), 10443 ), 10444 DecorateInfo( 10445 unittest.expectedFailure, 10446 "TestForeach", 10447 "test_autodiff", 10448 device_type="cuda", 10449 dtypes=(torch.complex128,), 10450 ), 10451 ), 10452 ), 10453 ForeachFuncInfo( 10454 'frac', 10455 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10456 supports_autograd=True, 10457 supports_inplace_autograd=True, 10458 supports_forward_ad=True, 10459 decorators=( 10460 DecorateInfo( 10461 unittest.expectedFailure, 10462 "TestMeta", 10463 "test_dispatch_meta_inplace", 10464 dtypes=integral_types_and(torch.bool) + complex_types(), 10465 ), 10466 DecorateInfo( 10467 unittest.expectedFailure, 10468 "TestMeta", 10469 "test_dispatch_meta_outplace", 10470 dtypes=integral_types_and(torch.bool) + complex_types(), 10471 ), 10472 DecorateInfo( 10473 unittest.expectedFailure, 10474 "TestMeta", 10475 "test_dispatch_symbolic_meta_inplace", 10476 dtypes=integral_types_and(torch.bool) + complex_types(), 10477 ), 10478 DecorateInfo( 10479 unittest.expectedFailure, 10480 "TestMeta", 10481 "test_dispatch_symbolic_meta_outplace", 10482 dtypes=integral_types_and(torch.bool) + complex_types(), 10483 ), 10484 DecorateInfo( 10485 unittest.expectedFailure, 10486 "TestMeta", 10487 "test_meta_inplace", 10488 dtypes=integral_types_and(torch.bool) + complex_types(), 10489 ), 10490 DecorateInfo( 10491 unittest.expectedFailure, 10492 "TestMeta", 10493 "test_meta_outplace", 10494 dtypes=integral_types_and(torch.bool) + complex_types(), 10495 ), 10496 DecorateInfo( 10497 unittest.expectedFailure, 10498 "TestForeach", 10499 "test_autodiff", 10500 device_type="cuda", 10501 dtypes=(torch.complex128,), 10502 ), 10503 ), 10504 ), 10505 ForeachFuncInfo( 10506 'reciprocal', 10507 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10508 supports_autograd=True, 10509 supports_inplace_autograd=True, 10510 supports_forward_ad=True, 10511 backward_requires_result=True, 10512 decorators=( 10513 DecorateInfo( 10514 unittest.expectedFailure, 10515 "TestMeta", 10516 "test_dispatch_meta_inplace", 10517 dtypes=integral_types_and(torch.bool), 10518 ), 10519 DecorateInfo( 10520 unittest.expectedFailure, 10521 "TestMeta", 10522 "test_dispatch_symbolic_meta_inplace", 10523 dtypes=integral_types_and(torch.bool), 10524 ), 10525 DecorateInfo( 10526 unittest.expectedFailure, 10527 "TestMeta", 10528 "test_meta_inplace", 10529 dtypes=integral_types_and(torch.bool), 10530 ), 10531 ), 10532 ), 10533 ForeachFuncInfo( 10534 'sigmoid', 10535 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10536 supports_autograd=True, 10537 supports_inplace_autograd=True, 10538 supports_forward_ad=True, 10539 backward_requires_result=True, 10540 decorators=( 10541 DecorateInfo( 10542 unittest.expectedFailure, 10543 "TestMeta", 10544 "test_dispatch_meta_inplace", 10545 dtypes=integral_types_and(torch.bool), 10546 ), 10547 DecorateInfo( 10548 unittest.expectedFailure, 10549 "TestMeta", 10550 "test_dispatch_symbolic_meta_inplace", 10551 dtypes=integral_types_and(torch.bool), 10552 ), 10553 DecorateInfo( 10554 unittest.expectedFailure, 10555 "TestMeta", 10556 "test_meta_inplace", 10557 dtypes=integral_types_and(torch.bool), 10558 ), 10559 ), 10560 ), 10561 ForeachFuncInfo( 10562 'trunc', 10563 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10564 supports_autograd=True, 10565 supports_inplace_autograd=True, 10566 supports_forward_ad=True, 10567 decorators=( 10568 DecorateInfo( 10569 unittest.expectedFailure, 10570 "TestMeta", 10571 "test_dispatch_meta_inplace", 10572 dtypes=complex_types_and(torch.bool), 10573 ), 10574 DecorateInfo( 10575 unittest.expectedFailure, 10576 "TestMeta", 10577 "test_dispatch_meta_outplace", 10578 dtypes=complex_types_and(torch.bool), 10579 ), 10580 DecorateInfo( 10581 unittest.expectedFailure, 10582 "TestMeta", 10583 "test_dispatch_symbolic_meta_inplace", 10584 dtypes=complex_types_and(torch.bool), 10585 ), 10586 DecorateInfo( 10587 unittest.expectedFailure, 10588 "TestMeta", 10589 "test_dispatch_symbolic_meta_outplace", 10590 dtypes=complex_types_and(torch.bool), 10591 ), 10592 DecorateInfo( 10593 unittest.expectedFailure, 10594 "TestMeta", 10595 "test_meta_inplace", 10596 dtypes=complex_types_and(torch.bool), 10597 ), 10598 DecorateInfo( 10599 unittest.expectedFailure, 10600 "TestMeta", 10601 "test_meta_outplace", 10602 dtypes=complex_types_and(torch.bool), 10603 ), 10604 DecorateInfo( 10605 unittest.expectedFailure, 10606 "TestForeach", 10607 "test_autodiff", 10608 device_type="cuda", 10609 dtypes=(torch.complex128,), 10610 ), 10611 ), 10612 ), 10613 ForeachFuncInfo( 10614 'abs', 10615 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10616 supports_autograd=True, 10617 supports_inplace_autograd=True, 10618 supports_forward_ad=True, 10619 supports_fwgrad_bwgrad=True, 10620 decorators=( 10621 DecorateInfo( 10622 unittest.expectedFailure, 10623 "TestMeta", 10624 "test_dispatch_symbolic_meta_inplace", 10625 dtypes=complex_types(), 10626 ), 10627 DecorateInfo( 10628 unittest.expectedFailure, 10629 "TestMeta", 10630 "test_dispatch_meta_inplace", 10631 dtypes=complex_types(), 10632 ), 10633 DecorateInfo( 10634 unittest.expectedFailure, 10635 "TestMeta", 10636 "test_dispatch_meta_outplace", 10637 device_type="cpu", 10638 dtypes=(torch.bool,), 10639 ), 10640 DecorateInfo( 10641 unittest.expectedFailure, 10642 "TestMeta", 10643 "test_dispatch_symbolic_meta_inplace", 10644 device_type="cpu", 10645 dtypes=(torch.bool,), 10646 ), 10647 DecorateInfo( 10648 unittest.expectedFailure, 10649 "TestMeta", 10650 "test_dispatch_symbolic_meta_outplace", 10651 device_type="cpu", 10652 dtypes=(torch.bool,), 10653 ), 10654 DecorateInfo( 10655 unittest.expectedFailure, 10656 "TestMeta", 10657 "test_meta_inplace", 10658 device_type="cpu", 10659 dtypes=(torch.bool,), 10660 ), 10661 DecorateInfo( 10662 unittest.expectedFailure, 10663 "TestMeta", 10664 "test_meta_outplace", 10665 device_type="cpu", 10666 dtypes=(torch.bool,), 10667 ), 10668 DecorateInfo( 10669 unittest.expectedFailure, 10670 "TestMeta", 10671 "test_dispatch_meta_inplace", 10672 device_type="cpu", 10673 dtypes=(torch.bool,), 10674 ), 10675 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=complex_types()), 10676 ), 10677 ), 10678 ForeachFuncInfo( 10679 'zero', 10680 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10681 supports_autograd=True, 10682 supports_inplace_autograd=True, 10683 supports_forward_ad=True, 10684 supports_out=False, 10685 ), 10686 ForeachFuncInfo( 10687 'sign', 10688 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10689 supports_autograd=True, 10690 supports_inplace_autograd=True, 10691 supports_forward_ad=True, 10692 decorators=( 10693 DecorateInfo( 10694 unittest.expectedFailure, 10695 "TestMeta", 10696 "test_dispatch_meta_inplace", 10697 dtypes=complex_types(), 10698 ), 10699 DecorateInfo( 10700 unittest.expectedFailure, 10701 "TestMeta", 10702 "test_dispatch_meta_outplace", 10703 dtypes=complex_types(), 10704 ), 10705 DecorateInfo( 10706 unittest.expectedFailure, 10707 "TestMeta", 10708 "test_dispatch_symbolic_meta_inplace", 10709 dtypes=complex_types(), 10710 ), 10711 DecorateInfo( 10712 unittest.expectedFailure, 10713 "TestMeta", 10714 "test_dispatch_symbolic_meta_outplace", 10715 dtypes=complex_types(), 10716 ), 10717 DecorateInfo( 10718 unittest.expectedFailure, 10719 "TestMeta", 10720 "test_meta_inplace", 10721 dtypes=complex_types(), 10722 ), 10723 DecorateInfo( 10724 unittest.expectedFailure, 10725 "TestMeta", 10726 "test_meta_outplace", 10727 dtypes=complex_types(), 10728 ), 10729 DecorateInfo( 10730 unittest.expectedFailure, 10731 "TestForeach", 10732 "test_autodiff", 10733 device_type="cuda", 10734 dtypes=(torch.complex128,), 10735 ), 10736 ), 10737 ), 10738 ForeachFuncInfo( 10739 'lgamma', 10740 sample_inputs_func=foreach_inputs_sample_func(1, False, False), 10741 supports_autograd=True, 10742 supports_inplace_autograd=True, 10743 supports_forward_ad=True, 10744 decorators=( 10745 DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", 10746 "test_dispatch_symbolic_meta_inplace", dtypes=integral_types_and(torch.bool)), 10747 # DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", 10748 # "test_dispatch_meta_inplace", dtypes=integral_types_and(torch.bool)), 10749 DecorateInfo(unittest.skip("In-place lgamma not supported for integral tensors"), "TestMeta", 10750 "test_meta_inplace", dtypes=integral_types_and(torch.bool)), 10751 DecorateInfo( 10752 unittest.expectedFailure, 10753 "TestMeta", 10754 "test_dispatch_meta_inplace", 10755 dtypes=complex_types() + integral_types_and(torch.bool), 10756 ), 10757 DecorateInfo( 10758 unittest.expectedFailure, 10759 "TestMeta", 10760 "test_dispatch_meta_outplace", 10761 dtypes=complex_types(), 10762 ), 10763 DecorateInfo( 10764 unittest.expectedFailure, 10765 "TestMeta", 10766 "test_dispatch_symbolic_meta_inplace", 10767 dtypes=complex_types() + integral_types_and(torch.bool), 10768 ), 10769 DecorateInfo( 10770 unittest.expectedFailure, 10771 "TestMeta", 10772 "test_dispatch_symbolic_meta_outplace", 10773 dtypes=complex_types(), 10774 ), 10775 DecorateInfo( 10776 unittest.expectedFailure, 10777 "TestMeta", 10778 "test_meta_inplace", 10779 dtypes=complex_types() + integral_types_and(torch.bool), 10780 ), 10781 DecorateInfo( 10782 unittest.expectedFailure, 10783 "TestMeta", 10784 "test_meta_outplace", 10785 dtypes=complex_types(), 10786 ), 10787 DecorateInfo( 10788 unittest.expectedFailure, 10789 "TestForeach", 10790 "test_autodiff", 10791 device_type="cuda", 10792 dtypes=(torch.complex128,), 10793 ), 10794 ), 10795 ), 10796] 10797 10798foreach_binary_op_db: List[OpInfo] = [ 10799 ForeachFuncInfo( 10800 "add", 10801 sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), 10802 supports_alpha_param=True, 10803 supports_autograd=True, 10804 supports_inplace_autograd=True, 10805 supports_forward_ad=True, 10806 decorators=( 10807 # These tests fail with aten._local_scalar_dense not being implemented. 10808 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), 10809 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", 10810 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)), 10811 # Samples have complex types and inplace only works if the dtype is complex. 10812 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", 10813 dtypes=(torch.bool,)), 10814 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", 10815 dtypes=(torch.bool,)), 10816 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", 10817 dtypes=integral_types() + complex_types_and(torch.bool, torch.bfloat16, torch.float16, torch.float64)), 10818 ), 10819 ), 10820 ForeachFuncInfo( 10821 "sub", 10822 sample_inputs_func=foreach_inputs_sample_func(2, True, True), 10823 supports_alpha_param=True, 10824 supports_autograd=True, 10825 supports_inplace_autograd=True, 10826 supports_forward_ad=True, 10827 decorators=( 10828 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), 10829 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), 10830 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), 10831 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), 10832 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), 10833 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), 10834 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), 10835 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), 10836 ), 10837 ), 10838 ForeachFuncInfo( 10839 "mul", 10840 sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), 10841 supports_autograd=True, 10842 supports_inplace_autograd=True, 10843 supports_forward_ad=True, 10844 decorators=( 10845 # Samples have complex types and inplace only works if the dtype is complex. 10846 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", 10847 dtypes=(torch.bool,)), 10848 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", 10849 dtypes=(torch.bool,)), 10850 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), 10851 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", 10852 dtypes=(torch.bool,)), 10853 ), 10854 ), 10855 ForeachFuncInfo( 10856 "div", 10857 sample_inputs_func=foreach_inputs_sample_func(2, True, True, True), 10858 supports_autograd=True, 10859 supports_inplace_autograd=True, 10860 supports_forward_ad=True, 10861 decorators=( 10862 # Samples have complex types and inplace only works if the dtype is complex. 10863 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", 10864 dtypes=integral_types_and(torch.bool)), 10865 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", 10866 dtypes=integral_types_and(torch.bool)), 10867 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", 10868 dtypes=integral_types_and(torch.bool)), 10869 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", 10870 dtypes=integral_types_and(torch.bool)), 10871 ), 10872 ), 10873 ForeachFuncInfo( 10874 "clamp_min", 10875 sample_inputs_func=foreach_inputs_sample_func(2, True, True), 10876 supports_autograd=True, 10877 supports_inplace_autograd=True, 10878 supports_forward_ad=True, 10879 decorators=( 10880 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", 10881 dtypes=complex_types_and(torch.bool)), 10882 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", 10883 dtypes=complex_types_and(torch.bool)), 10884 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", 10885 dtypes=complex_types_and(torch.bool)), 10886 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", 10887 dtypes=complex_types_and(torch.bool)), 10888 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", 10889 dtypes=complex_types_and(torch.bool)), 10890 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", 10891 dtypes=complex_types_and(torch.bool)), 10892 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", 10893 dtypes=complex_types_and(torch.bool)), 10894 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", 10895 dtypes=complex_types_and(torch.bool)), 10896 DecorateInfo( 10897 unittest.expectedFailure, 10898 "TestForeach", 10899 "test_autodiff", 10900 device_type="cuda", 10901 dtypes=(torch.complex128,), 10902 ), 10903 DecorateInfo( 10904 unittest.expectedFailure, 10905 "TestForeach", 10906 "test_binary_op_scalar_with_overlapping_tensors", 10907 dtypes=complex_types(), 10908 ), 10909 ), 10910 ), 10911 ForeachFuncInfo( 10912 "clamp_max", 10913 sample_inputs_func=foreach_inputs_sample_func(2, True, True), 10914 supports_autograd=True, 10915 supports_inplace_autograd=True, 10916 supports_forward_ad=True, 10917 decorators=( 10918 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", 10919 dtypes=complex_types_and(torch.bool)), 10920 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", 10921 dtypes=complex_types_and(torch.bool)), 10922 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", 10923 dtypes=complex_types_and(torch.bool)), 10924 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", 10925 dtypes=complex_types_and(torch.bool)), 10926 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", 10927 dtypes=complex_types_and(torch.bool)), 10928 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", 10929 dtypes=complex_types_and(torch.bool)), 10930 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", 10931 dtypes=complex_types_and(torch.bool)), 10932 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", 10933 dtypes=complex_types_and(torch.bool)), 10934 DecorateInfo( 10935 unittest.expectedFailure, 10936 "TestForeach", 10937 "test_autodiff", 10938 device_type="cuda", 10939 dtypes=(torch.complex128,), 10940 ), 10941 DecorateInfo( 10942 unittest.expectedFailure, 10943 "TestForeach", 10944 "test_binary_op_scalar_with_overlapping_tensors", 10945 dtypes=complex_types(), 10946 ), 10947 ), 10948 ), 10949 # note(crcrpar): forward ad not implemented. 10950 ForeachFuncInfo( 10951 "minimum", 10952 sample_inputs_func=foreach_inputs_sample_func(2, True, True), 10953 supports_autograd=True, 10954 supports_inplace_autograd=False, 10955 supports_forward_ad=False, 10956 decorators=( 10957 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", 10958 dtypes=complex_types_and(torch.bool)), 10959 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", 10960 dtypes=complex_types_and(torch.bool)), 10961 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", 10962 dtypes=complex_types_and(torch.bool)), 10963 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", 10964 dtypes=complex_types_and(torch.bool)), 10965 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", 10966 dtypes=complex_types_and(torch.bool)), 10967 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", 10968 dtypes=complex_types_and(torch.bool)), 10969 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", 10970 dtypes=complex_types_and(torch.bool)), 10971 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", 10972 dtypes=complex_types_and(torch.bool)), 10973 DecorateInfo( 10974 unittest.expectedFailure, 10975 "TestForeach", 10976 "test_autodiff", 10977 device_type="cuda", 10978 dtypes=(torch.complex128,), 10979 ), 10980 DecorateInfo( 10981 unittest.expectedFailure, 10982 "TestForeach", 10983 "test_binary_op_scalar_with_overlapping_tensors", 10984 dtypes=complex_types(), 10985 ), 10986 ), 10987 ), 10988 # note(crcrpar): forward ad not implemented. 10989 ForeachFuncInfo( 10990 "maximum", 10991 sample_inputs_func=foreach_inputs_sample_func(2, True, True), 10992 supports_autograd=True, 10993 supports_forward_ad=False, 10994 supports_inplace_autograd=False, 10995 decorators=( 10996 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", 10997 dtypes=complex_types_and(torch.bool)), 10998 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", 10999 dtypes=complex_types_and(torch.bool)), 11000 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", 11001 dtypes=complex_types_and(torch.bool)), 11002 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", 11003 dtypes=complex_types_and(torch.bool)), 11004 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", 11005 dtypes=complex_types_and(torch.bool)), 11006 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", 11007 dtypes=complex_types_and(torch.bool)), 11008 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", 11009 dtypes=complex_types_and(torch.bool)), 11010 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", 11011 dtypes=complex_types_and(torch.bool)), 11012 DecorateInfo( 11013 unittest.expectedFailure, 11014 "TestForeach", 11015 "test_autodiff", 11016 device_type="cuda", 11017 dtypes=(torch.complex128,), 11018 ), 11019 DecorateInfo( 11020 unittest.expectedFailure, 11021 "TestForeach", 11022 "test_binary_op_scalar_with_overlapping_tensors", 11023 dtypes=complex_types(), 11024 ), 11025 ), 11026 ), 11027 ForeachFuncInfo( 11028 "pow", 11029 supports_alpha_param=False, 11030 supports_scalar_self_arg=True, 11031 sample_inputs_func=foreach_inputs_sample_func(2, True, True), 11032 supports_autograd=True, 11033 supports_inplace_autograd=True, 11034 supports_forward_ad=True, 11035 decorators=( 11036 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), 11037 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", 11038 dtypes=(torch.bool,)), 11039 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), 11040 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), 11041 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", 11042 dtypes=(torch.bool,)), 11043 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", 11044 dtypes=(torch.bool,)), 11045 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", 11046 dtypes=(torch.bool,)), 11047 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", 11048 dtypes=(torch.bool,),), 11049 DecorateInfo(unittest.skip("flaky"), "TestForeach", "test_parity", device_type="cpu", dtypes=(torch.complex64,)), 11050 DecorateInfo( 11051 unittest.skip("failed starting on ROCm 6.2"), 11052 "TestForeach", 11053 "test_parity", 11054 device_type="cuda", 11055 dtypes=(torch.complex64,), 11056 active_if=TEST_WITH_ROCM), 11057 DecorateInfo( 11058 unittest.expectedFailure, 11059 "TestForeach", 11060 "test_binary_op_with_scalar_self_support", 11061 device_type="cuda", 11062 dtypes=(torch.bool,), 11063 active_if=lambda kwargs: kwargs["is_fastpath"], 11064 ), 11065 ), 11066 backward_requires_result=True, 11067 ), 11068 ForeachFuncInfo( 11069 "copy", 11070 sample_inputs_func=foreach_inputs_sample_func(2, False, False), 11071 supports_out=False, 11072 supports_forward_ad=False, 11073 supports_autograd=False, 11074 supports_inplace_autograd=False, 11075 ) 11076] 11077 11078foreach_pointwise_op_db: List[ForeachFuncInfo] = [ 11079 ForeachFuncInfo( 11080 "addcmul", 11081 sample_inputs_func=foreach_pointwise_sample_func(4, True, True), 11082 supports_autograd=True, 11083 supports_inplace_autograd=True, 11084 supports_forward_ad=True, 11085 decorators=( 11086 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bool,)), 11087 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", 11088 dtypes=(torch.bool,)), 11089 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), 11090 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", 11091 dtypes=(torch.bool,)), 11092 # # Samples have complex types and inplace only works if the dtype is complex. 11093 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), 11094 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", 11095 dtypes=(torch.bool,)), 11096 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), 11097 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", 11098 dtypes=integral_types() + complex_types_and(torch.bool)), 11099 ), 11100 ), 11101 ForeachFuncInfo( 11102 "addcdiv", 11103 sample_inputs_func=foreach_pointwise_sample_func(4, True, True), 11104 supports_autograd=True, 11105 supports_inplace_autograd=True, 11106 supports_forward_ad=True, 11107 decorators=( 11108 # Samples have complex types and inplace only works if the dtype is complex. 11109 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", 11110 dtypes=integral_types_and(torch.bool)), 11111 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", 11112 dtypes=integral_types_and(torch.bool)), 11113 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", 11114 dtypes=integral_types_and(torch.bool)), 11115 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", 11116 dtypes=integral_types() + complex_types_and(torch.bool)), 11117 # fails with div_cpu is not implemented with ComplexHalf 11118 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", 11119 dtypes=integral_types_and(torch.bool)), 11120 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", 11121 dtypes=integral_types_and(torch.bool)), 11122 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", 11123 dtypes=integral_types_and(torch.bool)), 11124 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", 11125 dtypes=integral_types() + complex_types_and(torch.bool)), 11126 ), 11127 ), 11128] 11129 11130foreach_reduce_op_db: List[ForeachFuncInfo] = [ 11131 ForeachFuncInfo( 11132 "max", 11133 sample_inputs_func=foreach_max_sample_func(1, False, False), 11134 supports_autograd=True, 11135 supports_inplace_autograd=True, 11136 supports_forward_ad=True, 11137 decorators=( 11138 # no complex support for ordering ops like max 11139 DecorateInfo( 11140 unittest.expectedFailure, 11141 "TestForeach", 11142 "test_autodiff", 11143 dtypes=(torch.complex128, torch.complex64), 11144 ), 11145 DecorateInfo( 11146 unittest.expectedFailure, 11147 "TestForeach", 11148 "test_foreach_reduce_large_input", 11149 dtypes=(torch.complex128, torch.complex64), 11150 ), 11151 DecorateInfo( 11152 unittest.expectedFailure, 11153 "TestMeta", 11154 "test_dispatch_symbolic_meta_outplace", 11155 dtypes=(torch.complex128, torch.complex64), 11156 ), 11157 DecorateInfo( 11158 unittest.expectedFailure, 11159 "TestMeta", 11160 "test_meta_outplace", 11161 dtypes=(torch.complex128, torch.complex64), 11162 ), 11163 DecorateInfo( 11164 unittest.expectedFailure, 11165 "TestMeta", 11166 "test_dispatch_meta_outplace", 11167 dtypes=(torch.complex128, torch.complex64), 11168 ), 11169 ), 11170 ), 11171 ForeachFuncInfo( 11172 "norm", 11173 sample_inputs_func=foreach_norm_sample_func(1, False, False), 11174 supports_autograd=True, 11175 supports_inplace_autograd=True, 11176 supports_forward_ad=True, 11177 decorators=( 11178 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), 11179 DecorateInfo( 11180 unittest.expectedFailure, 11181 "TestMeta", 11182 "test_dispatch_meta_outplace", 11183 dtypes=integral_types_and(torch.bool), 11184 ), 11185 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), 11186 DecorateInfo( 11187 unittest.expectedFailure, 11188 "TestMeta", 11189 "test_dispatch_symbolic_meta_outplace", 11190 dtypes=integral_types_and(torch.bool), 11191 ), 11192 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), 11193 DecorateInfo( 11194 unittest.expectedFailure, 11195 "TestMeta", 11196 "test_meta_outplace", 11197 dtypes=integral_types_and(torch.bool), 11198 ), 11199 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), 11200 DecorateInfo( 11201 unittest.expectedFailure, 11202 "TestForeach", 11203 "test_foreach_reduce_large_input", 11204 device_type="cuda", 11205 dtypes=integral_types_and(torch.bool), 11206 ), 11207 ), 11208 ), 11209] 11210 11211foreach_other_op_db: List[ForeachFuncInfo] = [ 11212 ForeachFuncInfo( 11213 "lerp", 11214 sample_inputs_func=foreach_inputs_sample_func(3, True, False), 11215 supports_autograd=True, 11216 supports_inplace_autograd=True, 11217 supports_forward_ad=True, 11218 decorators=( 11219 DecorateInfo( 11220 unittest.expectedFailure, 11221 "TestMeta", 11222 "test_dispatch_meta_inplace", 11223 dtypes=integral_types_and(torch.bool), 11224 ), 11225 DecorateInfo( 11226 unittest.expectedFailure, 11227 "TestMeta", 11228 "test_dispatch_meta_outplace", 11229 dtypes=integral_types_and(torch.bool), 11230 ), 11231 DecorateInfo( 11232 unittest.expectedFailure, 11233 "TestMeta", 11234 "test_dispatch_symbolic_meta_outplace", 11235 dtypes=integral_types_and(torch.bool), 11236 ), 11237 DecorateInfo( 11238 unittest.expectedFailure, 11239 "TestMeta", 11240 "test_dispatch_symbolic_meta_inplace", 11241 dtypes=integral_types_and(torch.bool), 11242 ), 11243 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=integral_types_and(torch.bool)), 11244 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=integral_types_and(torch.bool)), 11245 ), 11246 ), 11247] 11248 11249def reference_sign(x): 11250 if x.dtype == np.bool_: 11251 # `np.sign` doesn't support `bool`. 11252 # >>> np.sign(True) 11253 # ufunc 'sign' did not contain a loop 11254 # with signature matching types dtype('bool') -> dtype('bool') 11255 return np.sign(x, dtype=np.uint8).astype(np.bool_) 11256 return np.sign(x) 11257 11258 11259def reference_sgn(x): 11260 # NumPy doesn't have an equivalent to `torch.sgn` when the dtype is complex. 11261 # For complex inputs, `np.sign` returns sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j. 11262 # while `torch.sgn` returns, 0 if abs(input) == 0 else input/abs(input) 11263 if x.dtype not in [np.complex64, np.complex128]: 11264 return reference_sign(x) 11265 11266 out = (x / np.abs(x)) 11267 if out.ndim == 0: 11268 # Handle x == 0 case 11269 if (x == 0): 11270 # Can't assign to np.complex object 11271 # So make a new one. 11272 return np.array(complex(0, 0), dtype=x.dtype) 11273 return out 11274 11275 # Handle x == 0 case 11276 mask = (x == 0) 11277 out[mask] = complex(0, 0) 11278 return out 11279 11280 11281def reference_sigmoid(x): 11282 # 'scipy.special.expit' not supported for the input types 11283 if x.dtype in [np.complex64, np.complex128]: 11284 return (1 / (1 + np.exp(-x))) 11285 return scipy.special.expit(x) 11286 11287 11288def reference_logsigmoid(x): 11289 return np.where( 11290 x < 0, 11291 x - np.log1p(np.exp(x)), 11292 -np.log1p(np.exp(-x))) 11293 11294 11295def reference_hardsigmoid(x): 11296 intermediate = x / 6 + 0.5 11297 y = np.clip(intermediate, 0, None) 11298 return np.where(y > 1, 1, y).astype(x.dtype) 11299 11300 11301def reference_lgamma(x): 11302 # scipy.special.gammaln returns `-inf` when input is `-inf`. 11303 # While Pytorch, C and C++, all return `inf` when input is `-inf`. 11304 # Reference: 11305 # https://en.cppreference.com/w/cpp/numeric/math/lgamma 11306 # https://en.cppreference.com/w/c/numeric/math/lgamma 11307 11308 # To handle the above discrepancy, 11309 # we replace -inf with inf so values 11310 # that were originally -inf map to inf as expected 11311 if x.dtype.kind == 'f': 11312 x = np.where(x == float('-inf'), np.array(float('inf'), dtype=x.dtype), x) 11313 11314 out = scipy.special.gammaln(x) 11315 11316 if x.dtype == np.float16: 11317 # `scipy.special.gammaln` returns output of float32 when input is float16, 11318 # while `torch.lgamma` preserves `float16`. But due to smaller range of float16, 11319 # Pytorch version outputs `inf` while SciPy returns finite values. 11320 out = out.astype(np.float16) 11321 11322 return out 11323 11324 11325def reference_mvlgamma(x, d): 11326 if x.dtype == np.float16: 11327 return scipy.special.multigammaln(x, d).astype(np.float16) 11328 11329 return scipy.special.multigammaln(x, d) 11330 11331def reference_softplus(input, beta=1, threshold=20): 11332 non_linear = input * beta <= threshold 11333 output = input.copy() 11334 output[non_linear] = np.log(1 + np.exp(beta * input[non_linear])) / beta 11335 return output 11336 11337def reference_gelu(X, *, approximate='none'): 11338 def _gelu_ref(X): 11339 return X * stats.norm.cdf(X) 11340 11341 def _tanh_gelu_ref(X): 11342 M_SQRT_2_PI = math.sqrt(2 / math.pi) 11343 Z = M_SQRT_2_PI * (X + 0.044715 * np.power(X, 3.0)) 11344 return 0.5 * X * (1.0 + np.tanh(Z)) 11345 11346 if approximate == 'tanh': 11347 return _tanh_gelu_ref(X) 11348 else: 11349 return _gelu_ref(X) 11350 11351 11352def reference_one_hot(a: np.ndarray, num_classes: int = -1) -> np.ndarray: 11353 if num_classes == -1: 11354 num_classes = int(np.amax(a) + 1) 11355 11356 idcs = a.reshape(-1) + np.arange(0, a.size, dtype=np.int64) * num_classes 11357 one_hot = np.zeros((a.size, num_classes), dtype=a.dtype) 11358 np.put(one_hot, idcs, 1) 11359 return one_hot.reshape(*a.shape, -1) 11360 11361 11362def reference_mse_loss(input, target, reduction="mean"): 11363 se = (input - target) ** 2 11364 if reduction == "mean": 11365 return np.mean(se) 11366 elif reduction == "sum": 11367 return np.sum(se) 11368 else: # reduction == "none" 11369 return se 11370 11371 11372def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5): 11373 return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0] 11374 11375 11376def reference_native_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight, bias, eps): 11377 feature_size = np.prod(normalized_shape) 11378 inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] 11379 mean = inp_view.mean(axis=-1, keepdims=True) 11380 var = inp_view.var(axis=-1, ddof=0, keepdims=True) 11381 Y = (inp_view - mean) / np.sqrt(var + eps) 11382 if weight is None and bias is not None: 11383 Y = Y + bias.reshape(-1) 11384 elif weight is not None and bias is None: 11385 Y = Y * weight.reshape(-1) 11386 elif weight is not None and bias is not None: 11387 Y = Y * weight.reshape(-1) + bias.reshape(-1) 11388 axis = inp.ndim - len(normalized_shape) 11389 stat_shape = inp.shape[:axis] + (1,) * len(normalized_shape) 11390 return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape) 11391 11392 11393def reference_rms_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, eps=None): 11394 if eps is None: 11395 eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps 11396 feature_size = np.prod(normalized_shape) 11397 inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] 11398 rms = np.sqrt((inp_view**2).mean(axis=-1, keepdims=True) + eps) 11399 Y = inp_view / rms 11400 if weight is not None: 11401 Y = Y * weight.reshape(-1) 11402 return Y.reshape(*inp.shape) 11403 11404 11405def reference_group_norm(inp: np.ndarray, num_groups: int, weight=None, bias=None, eps=1e-5): 11406 inp_view = inp 11407 if np.prod(inp.shape) != 0: 11408 inp_view = inp.reshape((inp.shape[0], num_groups, -1)) 11409 mean = inp_view.mean(axis=-1, keepdims=True) 11410 var = inp_view.var(axis=-1, ddof=0, keepdims=True) 11411 Y = (inp_view - mean) / np.sqrt(var + eps) 11412 Y = Y.reshape(inp.shape) 11413 if weight is not None: 11414 # weight is a vector of length equal to the channel 11415 if len(Y.shape) > 2: 11416 weight = np.expand_dims(weight, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) 11417 Y = Y * weight 11418 if bias is not None: 11419 # bias is a vector of length equal to the channel 11420 if len(Y.shape) > 2: 11421 bias = np.expand_dims(bias, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) 11422 Y = Y + bias 11423 return Y 11424 11425 11426# using a custom reference function since numpy only has a string side arg (instead of right and side) and doesn't 11427# have an out_int32 arg. Additionally, numpy doesn't support searchsorted with ND arrays, so this splits those into 11428# stacked 1D cases 11429def reference_searchsorted(sorted_sequence, boundary, out_int32=False, right=False, side='left', sorter=None): 11430 side = 'right' if (right or side == 'right') else 'left' 11431 if len(sorted_sequence.shape) == 1 : 11432 ret = np.searchsorted(sorted_sequence, boundary, side=side, sorter=sorter) 11433 return ret.astype(np.int32) if out_int32 else ret 11434 elif sorted_sequence.shape[0] == 0: 11435 if sorter is not None: 11436 sorter = sorter.flatten() 11437 ret = np.searchsorted(sorted_sequence.flatten(), boundary.flatten(), side=side, sorter=sorter) 11438 ret = ret.astype(np.int32) if out_int32 else ret 11439 return ret.reshape(boundary.shape) 11440 else: 11441 # numpy searchsorted only supports 1D inputs so we split up ND inputs 11442 orig_shape = boundary.shape 11443 num_splits = np.prod(sorted_sequence.shape[:-1]) 11444 splits = range(0, num_splits) 11445 sorted_sequence, boundary = sorted_sequence.reshape(num_splits, -1), boundary.reshape(num_splits, -1) 11446 if sorter is not None: 11447 sorter = sorter.reshape(num_splits, -1) 11448 11449 split_sequence = [sorted_sequence[i] for i in splits] 11450 split_boundary = [boundary[i] for i in splits] 11451 split_sorter = [sorter[i] if (sorter is not None) else None for i in splits] 11452 11453 split_ret = [np.searchsorted(s_seq, b, side=side, sorter=s_sort) 11454 for (s_seq, b, s_sort) in zip(split_sequence, split_boundary, split_sorter)] 11455 split_ret = [i.astype(np.int32) for i in split_ret] if out_int32 else split_ret 11456 return np.stack(split_ret).reshape(orig_shape) 11457 11458def loss_reference_reduction_wrapper(fn): 11459 def wrapper(input, target, *, size_average=None, reduce=None, reduction="mean", **other_kwargs): 11460 if size_average is not None or reduce is not None: 11461 raise RuntimeError( 11462 "The keyword arguments 'size_average' and 'reduce' are deprecated and not supported by this wrapper" 11463 ) 11464 output = fn(input, target, **other_kwargs) 11465 if reduction == "mean": 11466 return np.mean(output) 11467 elif reduction == "sum": 11468 return np.sum(output) 11469 else: # reduction == "none" 11470 return output 11471 11472 return wrapper 11473 11474@loss_reference_reduction_wrapper 11475def reference_smooth_l1_loss(input, target, beta=1.0): 11476 diff = input - target 11477 abs_diff = np.abs(diff) 11478 above_threshold = abs_diff >= beta 11479 11480 loss = np.empty_like(input) 11481 loss[above_threshold] = abs_diff[above_threshold] - 0.5 * beta 11482 loss[~above_threshold] = diff[~above_threshold] ** 2 / (2 * beta) 11483 11484 return loss 11485 11486def reference_std_var(f): 11487 """Forwards unbiased/correction kwargs as NumPy's equivalent ddof""" 11488 g = reference_reduction_numpy(f) 11489 11490 @wraps(g) 11491 def wrapper(x: np.ndarray, *args, **kwargs): 11492 assert not ('unbiased' in kwargs and 'correction' in kwargs) 11493 11494 if 'unbiased' in kwargs: 11495 kwargs['ddof'] = int(kwargs.pop('unbiased')) 11496 elif 'correction' in kwargs: 11497 kwargs['ddof'] = kwargs.pop('correction') 11498 11499 return g(x, *args, **kwargs) 11500 11501 return wrapper 11502 11503def generate_std_var_kwargs(t: torch.Tensor, **kwargs): 11504 """Generates unbiased/correction kwargs for std/var operators""" 11505 yield ((), {'unbiased': True}) 11506 yield ((), {'unbiased': False}) 11507 11508 # Currently, calling std with correction is only enabled when 11509 # both dim and keepdim are provided. 11510 if 'dim' in kwargs and 'keepdim' in kwargs: 11511 yield ((), {'correction': 0}) 11512 yield ((), {'correction': 1}) 11513 11514 numel = torch.tensor(t.shape)[kwargs.get('dim')].prod() 11515 yield ((), {'correction': numel // 2}) 11516 11517def error_inputs_mean(op_info, device, is_ref=False, **kwargs): 11518 if is_ref: 11519 err_msg1 = (r"mean\(\): could not infer output dtype. " 11520 r"Input dtype must be either a floating point or complex dtype. " 11521 r"Got: torch.int64") 11522 else: 11523 err_msg1 = (r"mean\(\): could not infer output dtype. " 11524 r"Input dtype must be either a floating point or complex dtype. " 11525 r"Got: Long") 11526 yield ErrorInput( 11527 SampleInput(make_tensor((3, 4, 5), dtype=torch.int64, device=device), []), 11528 error_regex=err_msg1, 11529 ) 11530 11531 if is_ref: 11532 err_msg2 = (r"mean\(\): could not infer output dtype. " 11533 r"Optional dtype must be either a floating point or complex dtype. " 11534 r"Got: torch.int64") 11535 else: 11536 err_msg2 = (r"mean\(\): could not infer output dtype. " 11537 r"Optional dtype must be either a floating point or complex dtype. " 11538 r"Got: Long") 11539 yield ErrorInput( 11540 SampleInput( 11541 make_tensor((3, 4, 5), dtype=torch.float32, device=device), 11542 [], 11543 dtype=torch.int64), 11544 error_regex=err_msg2 11545 ) 11546 11547 if is_ref: 11548 err_msg3 = "Expected out tensor to have dtype torch.float64, but got torch.float32 instead" 11549 else: 11550 err_msg3 = "Expected out tensor to have dtype double, but got float instead" 11551 yield ErrorInput( 11552 SampleInput( 11553 make_tensor((3, 4, 5), dtype=torch.int64, device=device), 11554 [], 11555 dtype=torch.float64, 11556 out=make_tensor([], dtype=torch.float32, device=device), 11557 ), 11558 error_regex=err_msg3 11559 ) 11560 11561# numpy implementation of torch.flatten 11562# unfortunately there's no np.flatten. we figure out the desired shape and call np.reshape 11563def reference_flatten(input, start_dim=0, end_dim=-1): 11564 in_shape = input.shape 11565 in_rank = len(in_shape) 11566 for d in start_dim, end_dim: 11567 if not ((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank): 11568 raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank - 1}], but got {d}") 11569 end_dim = end_dim if end_dim >= 0 else in_rank + end_dim 11570 start_dim = start_dim if start_dim >= 0 else in_rank + start_dim 11571 if in_rank == 0: 11572 end_dim = start_dim 11573 if end_dim < start_dim: 11574 raise RuntimeError("flatten() has invalid args: start_dim cannot come after end_dim") 11575 flatten_bit_dim = functools.reduce(operator.mul, in_shape[start_dim:end_dim + 1], 1) 11576 out_shape = in_shape[:start_dim] + (flatten_bit_dim,) + in_shape[end_dim + 1:] 11577 return np.reshape(input, out_shape) 11578 11579 11580def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): 11581 yield SampleInput(make_tensor((S,), dtype=dtype, device=device, requires_grad=requires_grad)) 11582 yield SampleInput(make_tensor((), dtype=dtype, device=device, requires_grad=requires_grad)) 11583 11584 11585# Operator database (sorted alphabetically) 11586op_db: List[OpInfo] = [ 11587 UnaryUfuncInfo('abs', 11588 aliases=('absolute', ), 11589 ref=np.abs, 11590 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), 11591 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 11592 skips=( 11593 DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients', 11594 'test_inplace_grad', dtypes=(torch.cdouble,)), 11595 DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestBwdGradients', 11596 'test_inplace_gradgrad', dtypes=(torch.cdouble,)), 11597 DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), 'TestFwdGradients', 11598 'test_inplace_forward_mode_AD', dtypes=(torch.cdouble,)), 11599 DecorateInfo(unittest.skip("In-place abs not supported for complex tensors"), "TestSparseUnaryUfuncs", 11600 "test_inplace", dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), 11601 # Reference: https://github.com/pytorch/pytorch/issues/49224 11602 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', 11603 dtypes=[torch.int8], active_if=TEST_WITH_ASAN), 11604 # TODO: Fix test_out_arg_all_dtypes as torch.empty_like(expected_output) where expected_output=op(input) 11605 # We can break the logic of the loop over all possible types but it is OK. 11606 # https://github.com/pytorch/pytorch/blob/master/test/test_unary_ufuncs.py#L440-L449 11607 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes', 11608 dtypes=[torch.cfloat, torch.cdouble]), 11609 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', 11610 dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), 11611 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace', 11612 dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), 11613 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace', 11614 dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), 11615 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides', 11616 dtypes=(torch.cdouble, torch.cfloat, torch.chalf)), 11617 ), 11618 supports_fwgrad_bwgrad=True, 11619 assert_autodiffed=True, 11620 supports_sparse=True, 11621 supports_sparse_csr=True, 11622 supports_sparse_csc=True, 11623 supports_sparse_bsr=True, 11624 supports_sparse_bsc=True, 11625 supports_forward_ad=True), 11626 # NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952) 11627 UnaryUfuncInfo('acos', 11628 aliases=('arccos', ), 11629 ref=np.arccos, 11630 domain=(-1, 1), 11631 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 11632 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 11633 assert_autodiffed=True, 11634 supports_forward_ad=True, 11635 supports_fwgrad_bwgrad=True, 11636 promotes_int_to_float=True, 11637 decorators=(precisionOverride({torch.float16: 1e-2, 11638 torch.bfloat16: 1e-1, 11639 torch.complex64: 1e-2}),), 11640 skips=( 11641 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', 11642 device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), 11643 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 11644 device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), 11645 # Failing with wrong imaginary sign on at least some Windows jobs 11646 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', 11647 device_type='cuda', dtypes=[torch.cdouble], 11648 active_if=IS_WINDOWS), 11649 # Failing with wrong imaginary sign on at least some Windows jobs 11650 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 11651 device_type='cuda', dtypes=[torch.cdouble], 11652 active_if=IS_WINDOWS), 11653 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 11654 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 11655 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 11656 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 11657 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad', 11658 dtypes=[torch.cdouble], active_if=IS_WINDOWS), 11659 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_method_grad', 11660 dtypes=[torch.cdouble], active_if=IS_WINDOWS), 11661 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_inplace_grad', 11662 dtypes=[torch.cdouble], active_if=IS_WINDOWS), 11663 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', 11664 dtypes=[torch.cdouble], active_if=IS_WINDOWS), 11665 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_inplace_forward_mode_AD', 11666 dtypes=[torch.cdouble], active_if=IS_WINDOWS),)), 11667 # NOTE: the derivative for inplace acosh is not implemented 11668 UnaryUfuncInfo('acosh', 11669 aliases=('arccosh', ), 11670 ref=np.arccosh, 11671 domain=(1, None), 11672 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 11673 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 11674 decorators=(precisionOverride({torch.bfloat16: 5e-2}),), 11675 supports_inplace_autograd=False, 11676 supports_forward_ad=True, 11677 supports_fwgrad_bwgrad=True, 11678 promotes_int_to_float=True, 11679 skips=( 11680 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', 11681 device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), 11682 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 11683 device_type='cuda', dtypes=[torch.cdouble], active_if=IS_WINDOWS), 11684 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 11685 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 11686 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 11687 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 11688 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 11689 device_type='cuda', dtypes=[torch.cdouble], 11690 active_if=IS_WINDOWS), 11691 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 11692 device_type='cuda', dtypes=[torch.cdouble], 11693 active_if=IS_WINDOWS), 11694 # Failing with wrong imaginary sign on at least some Windows jobs 11695 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', 11696 device_type='cuda', dtypes=[torch.cdouble], 11697 active_if=IS_WINDOWS), 11698 ), 11699 # acosh is not defined at x < 1 (real) 11700 reference_numerics_filter=NumericsFilter( 11701 condition=lambda x: (x < 1 if not x.is_complex() else torch.zeros_like(x, dtype=torch.bool)), 11702 safe_val=2)), 11703 BinaryUfuncInfo('add', 11704 # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate 11705 ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \ 11706 else np.add(input, np.multiply(alpha, other)), 11707 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, 11708 torch.float16, torch.chalf), 11709 assert_autodiffed=True, 11710 sample_inputs_func=sample_inputs_add_sub, 11711 supports_fwgrad_bwgrad=True, 11712 supports_forward_ad=True, 11713 supports_two_python_scalars=True, 11714 decorators=( 11715 DecorateInfo( 11716 toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), 11717 'TestBinaryUfuncs', 'test_reference_numerics'), 11718 ), 11719 skips=( 11720 # boolean alpha not handled properly 11721 DecorateInfo(unittest.expectedFailure, 11722 'TestNNCOpInfo', 11723 'test_nnc_correctness', 11724 dtypes=(torch.bool,)), 11725 DecorateInfo(unittest.skip("Skipped!"), 11726 'TestCommon', 11727 'test_numpy_refs', 11728 dtypes=(torch.complex128,)), 11729 DecorateInfo(unittest.skip("Skipped!"), 11730 'TestBinaryUfuncs', 11731 'test_reference_numerics_extremal_values', 11732 dtypes=(torch.complex64, torch.complex128)), 11733 )), 11734 OpInfo('item', 11735 op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.item, inp, *args, **kwargs), 11736 ref=np.ndarray.item, 11737 method_variant=None, 11738 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf, torch.bool), 11739 supports_out=False, 11740 supports_autograd=False, 11741 error_inputs_func=error_inputs_item, 11742 sample_inputs_func=sample_inputs_item, 11743 skips=( 11744 # Error testing item function variant 11745 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', 11746 dtypes=(torch.float32, torch.complex64)), 11747 # FX failed to normalize op - add the op to the op_skip list. 11748 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 11749 # RuntimeError: Composite compliance check failed with the above error. 11750 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), 11751 # Booleans mismatch: AssertionError: False is not true 11752 DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'), 11753 # Booleans mismatch: AssertionError: False is not true 11754 DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'), 11755 )), 11756 OpInfo('arange', 11757 dtypes=all_types_and(torch.bfloat16, torch.float16), 11758 supports_out=True, 11759 supports_autograd=False, 11760 is_factory_function=True, 11761 error_inputs_func=error_inputs_arange, 11762 sample_inputs_func=sample_inputs_arange, 11763 skips=( 11764 # https://github.com/pytorch/pytorch/issues/81774 11765 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 11766 11767 # Tests that assume input is a tensor or sequence of tensors 11768 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 11769 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 11770 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 11771 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 11772 11773 # Lazy tensor failures 11774 DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), 11775 DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'), 11776 DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), 11777 11778 # Exception raised from analyzeImpl at ../torch/csrc/jit/ir/alias_analysis.cpp:608 11779 # We don't have an op for aten::arange but it isn't a special case. 11780 # Argument types: bool, bool, bool, int, int, Device, boo 11781 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), 11782 11783 # Captured graph does not contain aten::arange (succeeds on complex!) 11784 # g: graph(): 11785 # %25 : Long(1, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value={1}]() 11786 # return (%25) 11787 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 11788 11789 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 11790 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 11791 )), 11792 OpInfo('cauchy', 11793 op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.cauchy_, inp, *args, **kwargs), 11794 inplace_variant=torch.Tensor.cauchy_, 11795 dtypes=floating_types_and(torch.float16, torch.bfloat16), 11796 supports_out=False, 11797 supports_autograd=False, 11798 allow_cow_input_materialize_forward=[0], 11799 sample_inputs_func=sample_inputs_cauchy, 11800 error_inputs_func=error_inputs_cauchy, 11801 skips=( 11802 # Tests that assume input tensor has a meaningful effect on output tensor 11803 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 11804 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 11805 11806 # AssertionError: JIT Test does not execute any logic 11807 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 11808 11809 # AssertionError: Tensor-likes are not close! 11810 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 11811 11812 # FX failed to normalize op - add the op to the op_skip list. 11813 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 11814 11815 # vmap: calling random operator not supported 11816 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), 11817 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), 11818 11819 DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), 11820 11821 DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), 11822 )), 11823 OpInfo('exponential', 11824 op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.exponential_, inp, *args, **kwargs), 11825 inplace_variant=torch.Tensor.exponential_, 11826 dtypes=floating_types_and(torch.float16, torch.bfloat16), 11827 supports_out=False, 11828 supports_autograd=False, 11829 allow_cow_input_materialize_forward=[0], 11830 sample_inputs_func=sample_inputs_exponential, 11831 error_inputs_func=error_inputs_exponential, 11832 skips=( 11833 # Tests that assume input tensor has a meaningful effect on output tensor 11834 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 11835 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 11836 11837 # AssertionError: JIT Test does not execute any logic 11838 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 11839 11840 # AssertionError: Tensor-likes are not close! 11841 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 11842 11843 # FX failed to normalize op - add the op to the op_skip list. 11844 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 11845 11846 # vmap: calling random operator not supported 11847 DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), 11848 DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), 11849 11850 DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), 11851 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 11852 )), 11853 OpInfo('geometric', 11854 op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.geometric_, inp, *args, **kwargs), 11855 inplace_variant=torch.Tensor.geometric_, 11856 dtypes=floating_types_and(torch.float16, torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8), 11857 supports_out=False, 11858 supports_autograd=False, 11859 allow_cow_input_materialize_forward=[0], 11860 sample_inputs_func=sample_inputs_geometric, 11861 error_inputs_func=error_inputs_geometric, 11862 skips=( 11863 # Tests that assume input tensor has a meaningful effect on output tensor 11864 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 11865 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 11866 11867 # AssertionError: JIT Test does not execute any logic 11868 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 11869 11870 # AssertionError: Tensor-likes are not close! 11871 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 11872 11873 # FX failed to normalize op - add the op to the op_skip list. 11874 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 11875 11876 # vmap: calling random operator not supported 11877 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), 11878 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), 11879 11880 DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), 11881 )), 11882 OpInfo('log_normal', 11883 op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.log_normal_, inp, *args, **kwargs), 11884 inplace_variant=torch.Tensor.log_normal_, 11885 dtypes=floating_types_and(torch.float16, torch.bfloat16), 11886 supports_out=False, 11887 supports_autograd=False, 11888 allow_cow_input_materialize_forward=[0], 11889 sample_inputs_func=sample_inputs_log_normal, 11890 error_inputs_func=error_inputs_log_normal, 11891 skips=( 11892 # Tests that assume input tensor has a meaningful effect on output tensor 11893 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 11894 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 11895 11896 # AssertionError: JIT Test does not execute any logic 11897 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 11898 11899 # AssertionError: Tensor-likes are not close! 11900 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 11901 # FX failed to normalize op - add the op to the op_skip list. 11902 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 11903 11904 # vmap: calling random operator not supported 11905 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), 11906 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), 11907 DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), 11908 )), 11909 OpInfo('normal', 11910 variant_test_name='in_place', 11911 op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.normal_, inp, *args, **kwargs), 11912 inplace_variant=torch.Tensor.normal_, 11913 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 11914 supports_out=False, 11915 supports_autograd=False, 11916 allow_cow_input_materialize_forward=[0], 11917 sample_inputs_func=sample_inputs_normal, 11918 error_inputs_func=error_inputs_normal, 11919 skips=( 11920 # Tests that assume input is a tensor or sequence of tensors 11921 DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), 11922 11923 # Tests that assume input tensor has a meaningful effect on output tensor 11924 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 11925 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 11926 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 11927 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 11928 DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), 11929 # AssertionError: JIT Test does not execute any logic 11930 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 11931 # AssertionError: Tensor-likes are not close! 11932 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 11933 # FX failed to normalize op - add the op to the op_skip list. 11934 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 11935 # vmap: calling random operator not supported 11936 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), 11937 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), 11938 )), 11939 OpInfo('uniform', 11940 op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.uniform_, inp, *args, **kwargs), 11941 method_variant=None, 11942 inplace_variant=torch.Tensor.uniform_, 11943 dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), 11944 supports_out=False, 11945 supports_autograd=False, 11946 is_factory_function=False, 11947 allow_cow_input_materialize_forward=[0], 11948 sample_inputs_func=sample_inputs_uniform, 11949 error_inputs_func=error_inputs_uniform, 11950 skips=( 11951 # FX failed to normalize op - add the op to the op_skip list. 11952 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 11953 # Tests that assume input tensor has a meaningful effect on output tensor 11954 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 11955 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 11956 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 11957 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 11958 # AssertionError: JIT Test does not execute any logic 11959 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 11960 # aten.uniform was not decomposed 11961 DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), 11962 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 11963 )), 11964 BinaryUfuncInfo('clamp_max', 11965 ref=_clamp_max_numpy, 11966 dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), 11967 supports_forward_ad=True, 11968 supports_rhs_python_scalar=False, 11969 supports_fwgrad_bwgrad=True, 11970 rhs_make_tensor_kwargs=dict(exclude_zero=False), 11971 skips=( 11972 # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' 11973 DecorateInfo(unittest.expectedFailure, 11974 'TestBinaryUfuncs', 11975 'test_type_promotion', 11976 device_type='cuda'), 11977 # dispatch to lazy test failed 11978 DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), 11979 # test error disabled since rhs non-tensor python scalar is supported 11980 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), 11981 )), 11982 BinaryUfuncInfo('clamp_min', 11983 ref=_clamp_min_numpy, 11984 dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), 11985 supports_forward_ad=True, 11986 supports_rhs_python_scalar=False, 11987 supports_fwgrad_bwgrad=True, 11988 rhs_make_tensor_kwargs=dict(exclude_zero=False), 11989 skips=( 11990 # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' 11991 DecorateInfo(unittest.expectedFailure, 11992 'TestBinaryUfuncs', 11993 'test_type_promotion', 11994 device_type='cuda'), 11995 # dispatch to lazy test failed 11996 DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_dispatched_to_lazy'), 11997 # test error disabled since rhs non-tensor python scalar is supported 11998 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors'), 11999 )), 12000 BinaryUfuncInfo('mul', 12001 aliases=('multiply',), 12002 dtypes=all_types_and_complex_and(torch.chalf, torch.float16, torch.bfloat16, torch.bool), 12003 assert_autodiffed=True, 12004 supports_forward_ad=True, 12005 supports_fwgrad_bwgrad=True, 12006 supports_two_python_scalars=True, 12007 error_inputs_sparse_func=error_inputs_sparse_mul, 12008 sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_coo), 12009 sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csr), 12010 sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_csc), 12011 sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsr), 12012 sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_mul, layout=torch.sparse_bsc)), 12013 BinaryUfuncInfo('sub', 12014 # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate 12015 ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)), 12016 aliases=('subtract',), 12017 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf), 12018 assert_autodiffed=True, 12019 supports_forward_ad=True, 12020 supports_fwgrad_bwgrad=True, 12021 sample_inputs_func=sample_inputs_add_sub, 12022 supports_two_python_scalars=True, 12023 decorators=( 12024 DecorateInfo( 12025 toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0), 12026 torch.bfloat16: tol(atol=1e-5, rtol=5e-3), 12027 torch.complex32: tol(atol=1e-5, rtol=1e-3)}), 12028 'TestBinaryUfuncs', 'test_reference_numerics'), 12029 DecorateInfo( 12030 toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), 12031 'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'), 12032 DecorateInfo( 12033 toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), 12034 'TestDecomp', 'test_comprehensive', device_type='cpu'), 12035 DecorateInfo( 12036 toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), 12037 'TestDecomp', 'test_quick', device_type='cpu'), 12038 ), 12039 skips=( 12040 DecorateInfo(unittest.skip("Skipped!"), 12041 'TestBinaryUfuncs', 12042 'test_reference_numerics', 12043 dtypes=(torch.uint8,)), 12044 DecorateInfo(unittest.skip("Skipped!"), 12045 'TestBinaryUfuncs', 12046 'test_reference_numerics_small_values', 12047 dtypes=(torch.uint8,)), 12048 )), 12049 OpInfo('addmm', 12050 # This addmm OpInfo is for when alpha and beta are not both equal to 1. 12051 # alpha=beta=1 is tested in the following opinfo, because that special case will 12052 # trigger addmm being decomposed by a jit pass. 12053 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 12054 dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16), 12055 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), 12056 assert_autodiffed=True, 12057 supports_forward_ad=True, 12058 supports_fwgrad_bwgrad=True, 12059 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 12060 sample_inputs_func=sample_inputs_addmm, 12061 skips=( 12062 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 12063 DecorateInfo( 12064 unittest.skip("Skipped!"), 12065 'TestSchemaCheckModeOpInfo', 12066 'test_schema_correctness', 12067 dtypes=(torch.complex64, torch.complex128)), 12068 )), 12069 OpInfo('addmm', 12070 # When alpha=beta=1 as compile-time constants, JIT will decompose addmm into mm and add. 12071 variant_test_name='decomposed', 12072 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 12073 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), 12074 assert_autodiffed=True, 12075 supports_forward_ad=True, 12076 supports_fwgrad_bwgrad=True, 12077 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 12078 autodiff_nonfusible_nodes=['aten::add', 'aten::mm'], 12079 sample_inputs_func=partial(sample_inputs_addmm, alpha=1, beta=1), 12080 skips=( 12081 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 12082 DecorateInfo( 12083 unittest.skip("Skipped!"), 12084 'TestSchemaCheckModeOpInfo', 12085 'test_schema_correctness', 12086 dtypes=(torch.complex64, torch.complex128)), 12087 # https://github.com/pytorch/pytorch/issues/71784 12088 DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', 12089 device_type='cpu', dtypes=(torch.float16,)), 12090 )), 12091 OpInfo('addmv', 12092 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), 12093 dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, 12094 torch.bfloat16), 12095 supports_forward_ad=True, 12096 supports_fwgrad_bwgrad=True, 12097 decorators=[ 12098 DecorateInfo( 12099 toleranceOverride({torch.half: tol(atol=1e-5, rtol=3e-3)}), 12100 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), 12101 ], 12102 sample_inputs_func=sample_inputs_addmv), 12103 OpInfo('addbmm', 12104 ref=lambda M, batch1, batch2, beta=1, alpha=1: np.add(np.multiply(np.asarray(beta, dtype=M.dtype), M), 12105 np.multiply(np.asarray(alpha, dtype=batch1.dtype), 12106 np.sum(np.matmul(batch1, batch2), axis=0))), 12107 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), 12108 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, 12109 *[torch.bfloat16] 12110 if SM53OrLater or TEST_WITH_ROCM else []), 12111 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 12112 gradcheck_fast_mode=True, 12113 supports_forward_ad=True, 12114 supports_fwgrad_bwgrad=True, 12115 decorators=[ 12116 DecorateInfo( 12117 toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=1.3e-05), 12118 torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), 12119 'TestCommon', 'test_numpy_refs'), 12120 # MPS has slightly worse precision. Is this acceptable? 12121 DecorateInfo( 12122 toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-04), 12123 torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), 12124 'TestCommon', 'test_numpy_ref_mps'), 12125 DecorateInfo( 12126 toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}), 12127 'TestConsistency', 12128 'test_output_match', 12129 ), 12130 DecorateInfo( 12131 toleranceOverride({torch.float32: tol(atol=1.5e-05, rtol=1e-05)}), 12132 'TestCommon', 'test_out'), 12133 DecorateInfo( 12134 toleranceOverride({torch.half: tol(atol=6e-3, rtol=1e-2)}), 12135 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), 12136 ], 12137 skips=( 12138 # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 12139 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), 12140 # addbmm does not correctly warn when resizing out= inputs 12141 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 12142 # https://github.com/pytorch/pytorch/issues/55907 12143 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), 12144 ), 12145 sample_inputs_func=sample_inputs_addbmm), 12146 OpInfo('baddbmm', 12147 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), 12148 dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, 12149 torch.bfloat16), 12150 backward_dtypesIfCUDA=floating_types_and(torch.float16, 12151 *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else [], 12152 torch.complex64, torch.complex128), 12153 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 12154 gradcheck_fast_mode=True, 12155 supports_forward_ad=True, 12156 supports_fwgrad_bwgrad=True, 12157 decorators=[ 12158 DecorateInfo( 12159 toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), 12160 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), 12161 DecorateInfo( 12162 toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), 12163 'TestMathBits', 'test_conj_view', device_type='cuda'), 12164 ], 12165 sample_inputs_func=sample_inputs_baddbmm, 12166 skips=( 12167 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 12168 DecorateInfo( 12169 unittest.skip("Skipped!"), 12170 'TestSchemaCheckModeOpInfo', 12171 'test_schema_correctness', 12172 dtypes=(torch.complex64, torch.complex128)), 12173 )), 12174 OpInfo('dot', 12175 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 12176 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), 12177 assert_autodiffed=True, 12178 sample_inputs_func=sample_inputs_dot_vdot, 12179 error_inputs_func=error_inputs_dot_vdot, 12180 supports_forward_ad=True, 12181 supports_fwgrad_bwgrad=True, 12182 skips=( 12183 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 12184 DecorateInfo( 12185 unittest.skip("Skipped!"), 12186 'TestSchemaCheckModeOpInfo', 12187 'test_schema_correctness', 12188 dtypes=(torch.complex64, torch.complex128)), 12189 )), 12190 OpInfo('vdot', 12191 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 12192 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), 12193 sample_inputs_func=sample_inputs_dot_vdot, 12194 error_inputs_func=error_inputs_dot_vdot, 12195 supports_forward_ad=True, 12196 supports_fwgrad_bwgrad=True, 12197 skips=( 12198 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 12199 DecorateInfo( 12200 unittest.skip("Skipped!"), 12201 'TestSchemaCheckModeOpInfo', 12202 'test_schema_correctness', 12203 dtypes=(torch.complex64, torch.complex128)), 12204 )), 12205 OpInfo('bmm', 12206 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 12207 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, 12208 *[torch.bfloat16] 12209 if SM53OrLater or TEST_WITH_ROCM else []), 12210 assert_autodiffed=True, 12211 assert_jit_shape_analysis=True, 12212 supports_forward_ad=True, 12213 supports_fwgrad_bwgrad=True, 12214 skips=( 12215 # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 12216 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), 12217 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}), 12218 "TestCommon", "test_out") 12219 ), 12220 sample_inputs_func=sample_inputs_bmm), 12221 OpInfo('mv', 12222 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 12223 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), 12224 assert_autodiffed=True, 12225 supports_forward_ad=True, 12226 supports_fwgrad_bwgrad=True, 12227 sample_inputs_func=sample_inputs_mv), 12228 OpInfo('addr', 12229 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), 12230 # Reference: https://github.com/pytorch/pytorch/issues/50747 12231 supports_forward_ad=True, 12232 supports_fwgrad_bwgrad=True, 12233 skips=( 12234 # Reference: https://github.com/pytorch/pytorch/issues/50747 12235 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', 12236 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)), 12237 ), 12238 sample_inputs_func=sample_inputs_addr, 12239 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), 12240 OpInfo('addcmul', 12241 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 12242 assert_autodiffed=True, 12243 supports_forward_ad=True, 12244 supports_fwgrad_bwgrad=True, 12245 skips=( 12246 # TODO: update sample inputs with for_inplace_variant kwarg to support this test 12247 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 12248 ), 12249 sample_inputs_func=sample_inputs_addcmul_addcdiv, 12250 reference_inputs_func=partial( 12251 reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)), 12252 OpInfo('addcdiv', 12253 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 12254 supports_forward_ad=True, 12255 supports_fwgrad_bwgrad=True, 12256 skips=( 12257 # TODO: update sample inputs with for_inplace_variant kwarg to support this test 12258 DecorateInfo(unittest.expectedFailure, 12259 'TestCommon', 12260 'test_variant_consistency_eager'), 12261 ), 12262 sample_inputs_func=sample_inputs_addcmul_addcdiv, 12263 reference_inputs_func=partial( 12264 reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)), 12265 UnaryUfuncInfo('asin', 12266 aliases=('arcsin', ), 12267 ref=np.arcsin, 12268 domain=(-1, 1), 12269 supports_sparse=True, 12270 supports_sparse_csr=True, 12271 supports_sparse_csc=True, 12272 supports_sparse_bsr=True, 12273 supports_sparse_bsc=True, 12274 supports_forward_ad=True, 12275 supports_fwgrad_bwgrad=True, 12276 promotes_int_to_float=True, 12277 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 12278 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 12279 assert_autodiffed=True, 12280 decorators=[ 12281 DecorateInfo( 12282 toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}), 12283 'TestUnaryUfuncs', device_type='cuda' 12284 ), 12285 DecorateInfo( 12286 toleranceOverride({torch.float32: tol(atol=8e-5, rtol=4e-5)}), 12287 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' 12288 ), 12289 precisionOverride({torch.bfloat16: 1e-2}), 12290 ], 12291 skips=( 12292 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12293 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12294 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12295 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12296 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12297 device_type='cuda', dtypes=[torch.cdouble], 12298 active_if=IS_WINDOWS), 12299 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12300 device_type='cuda', dtypes=[torch.cdouble], 12301 active_if=IS_WINDOWS), 12302 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 12303 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 12304 )), 12305 # NOTE: derivative for inplace asinh is not implemented 12306 UnaryUfuncInfo('asinh', 12307 aliases=('arcsinh', ), 12308 ref=np.arcsinh, 12309 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 12310 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 12311 decorators=(precisionOverride({torch.bfloat16: 5e-2}),), 12312 supports_inplace_autograd=False, 12313 supports_forward_ad=True, 12314 supports_fwgrad_bwgrad=True, 12315 supports_sparse=True, 12316 supports_sparse_csr=True, 12317 supports_sparse_csc=True, 12318 supports_sparse_bsr=True, 12319 supports_sparse_bsc=True, 12320 promotes_int_to_float=True, 12321 skips=( 12322 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12323 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12324 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12325 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12326 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', 12327 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12328 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', 12329 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12330 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12331 device_type='cuda', dtypes=[torch.cdouble], 12332 active_if=IS_WINDOWS), 12333 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12334 device_type='cuda', dtypes=[torch.cdouble], 12335 active_if=IS_WINDOWS), 12336 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 12337 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 12338 )), 12339 UnaryUfuncInfo('atan', 12340 aliases=('arctan', ), 12341 ref=np.arctan, 12342 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 12343 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 12344 assert_autodiffed=True, 12345 supports_forward_ad=True, 12346 supports_fwgrad_bwgrad=True, 12347 supports_sparse=True, 12348 supports_sparse_csr=True, 12349 supports_sparse_csc=True, 12350 supports_sparse_bsr=True, 12351 supports_sparse_bsc=True, 12352 promotes_int_to_float=True, 12353 decorators=(precisionOverride({torch.bfloat16: 1e-2}),), 12354 skips=( 12355 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12356 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12357 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12358 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12359 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', 12360 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12361 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12362 device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], 12363 active_if=IS_WINDOWS), 12364 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12365 device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], 12366 active_if=IS_WINDOWS), 12367 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 12368 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 12369 )), 12370 BinaryUfuncInfo('atan2', 12371 aliases=('arctan2',), 12372 dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), 12373 supports_forward_ad=True, 12374 supports_fwgrad_bwgrad=True, 12375 promotes_int_to_float=True, 12376 supports_rhs_python_scalar=False, 12377 skips=( 12378 # Incorrectly attempts to use a scalar for the second argument 12379 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), 12380 )), 12381 UnaryUfuncInfo('atanh', 12382 aliases=('arctanh', ), 12383 ref=np.arctanh, 12384 domain=(-1, 1), 12385 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 12386 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 12387 decorators=[ 12388 precisionOverride({torch.bfloat16: 1e-2}), 12389 DecorateInfo( 12390 toleranceOverride({torch.float32: tol(atol=9e-3, rtol=8e-5)}), 12391 "TestInductorOpInfo", 12392 "test_comprehensive", 12393 device_type="cuda" 12394 ), 12395 ], 12396 supports_inplace_autograd=False, 12397 supports_forward_ad=True, 12398 supports_fwgrad_bwgrad=True, 12399 supports_sparse=True, 12400 supports_sparse_csr=True, 12401 supports_sparse_csc=True, 12402 supports_sparse_bsr=True, 12403 supports_sparse_bsc=True, 12404 promotes_int_to_float=True, 12405 skips=( 12406 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', 12407 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12408 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12409 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12410 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12411 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 12412 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12413 device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], 12414 active_if=IS_WINDOWS), 12415 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12416 device_type='cuda', dtypes=[torch.cfloat], 12417 active_if=IS_WINDOWS), 12418 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 12419 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 12420 )), 12421 OpInfo('allclose', 12422 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 12423 ref=np.allclose, 12424 supports_autograd=False, 12425 supports_forward_ad=False, 12426 sample_inputs_func=sample_inputs_allclose, 12427 skips=( 12428 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 12429 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), 12430 DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), 12431 ), 12432 supports_out=False), 12433 OpInfo('broadcast_to', 12434 ref=np.broadcast_to, 12435 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 12436 supports_out=False, 12437 supports_forward_ad=True, 12438 supports_fwgrad_bwgrad=True, 12439 # See https://github.com/pytorch/pytorch/pull/78358 12440 check_batched_forward_grad=False, 12441 sample_inputs_func=sample_inputs_broadcast_to), 12442 OpInfo('broadcast_shapes', 12443 op=torch.broadcast_shapes, 12444 ref=np.broadcast_shapes if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None, 12445 dtypes=_dispatch_dtypes((torch.float32,)), 12446 supports_out=False, 12447 supports_gradgrad=False, 12448 assert_autodiffed=False, 12449 supports_autograd=False, 12450 supports_scripting=False, 12451 sample_inputs_func=sample_inputs_broadcast_shapes, 12452 skips=( 12453 # https://github.com/pytorch/pytorch/issues/64997 12454 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 12455 # skip dtype tests since broadcast_shape is not device dependent. 12456 # having dtypes limited to torch.float32 would cause test_dtypes to report unexpected success 12457 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), 12458 # skip these tests since we have non tensor input 12459 DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), 12460 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), 12461 DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), 12462 )), 12463 OpInfo('broadcast_tensors', 12464 ref=np.broadcast_arrays, 12465 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 12466 sample_inputs_func=sample_inputs_broadcast_tensors, 12467 reference_inputs_func=reference_inputs_broadcast_tensors, 12468 supports_out=False, 12469 supports_forward_ad=True, 12470 supports_fwgrad_bwgrad=True, 12471 # See https://github.com/pytorch/pytorch/pull/78358 12472 check_batched_forward_grad=False, 12473 skips=( 12474 # https://github.com/pytorch/pytorch/issues/64997 12475 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 12476 # JIT does not support variadic tensors. 12477 # RuntimeError: input->type()->kind() == TypeKind::OptionalType 12478 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, 12479 # please report a bug to PyTorch. 12480 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), 12481 )), 12482 OpInfo('block_diag', 12483 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 12484 supports_out=False, 12485 supports_forward_ad=True, 12486 supports_fwgrad_bwgrad=True, 12487 # Default batching rule in core doesn't work for ops with TensorList args 12488 check_batched_forward_grad=False, 12489 skips=( 12490 # https://github.com/pytorch/pytorch/issues/64997 12491 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 12492 # JIT does not support variadic tensors. 12493 # RuntimeError: input->type()->kind() == TypeKind::OptionalType 12494 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, 12495 # please report a bug to PyTorch. 12496 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), 12497 ), 12498 sample_inputs_func=sample_inputs_block_diag), 12499 UnaryUfuncInfo('bitwise_not', 12500 ref=np.bitwise_not, 12501 dtypes=integral_types_and(torch.bool), 12502 operator_variant=operator.invert, 12503 supports_autograd=False), 12504 BinaryUfuncInfo('bitwise_left_shift', 12505 op=torch.bitwise_left_shift, 12506 dtypes=integral_types(), 12507 dtypesIfCUDA=integral_types(), 12508 operator_variant=operator.lshift, 12509 inplace_operator_variant=operator.ilshift, 12510 supports_autograd=False, 12511 supports_one_python_scalar=True, 12512 rhs_make_tensor_kwargs=dict(low=0), 12513 skips=( 12514 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), 12515 # https://github.com/pytorch/pytorch/issues/70904 12516 DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), 12517 )), 12518 BinaryUfuncInfo('bitwise_right_shift', 12519 op=torch.bitwise_right_shift, 12520 dtypes=integral_types(), 12521 dtypesIfCUDA=integral_types(), 12522 operator_variant=operator.rshift, 12523 inplace_operator_variant=operator.irshift, 12524 supports_autograd=False, 12525 supports_one_python_scalar=True, 12526 rhs_make_tensor_kwargs=dict(low=0), 12527 skips=( 12528 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), 12529 # https://github.com/pytorch/pytorch/issues/70904 12530 DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), 12531 )), 12532 OpInfo('combinations', 12533 op=torch.combinations, 12534 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 12535 supports_forward_ad=True, 12536 supports_fwgrad_bwgrad=True, 12537 # See https://github.com/pytorch/pytorch/pull/78358 12538 check_batched_forward_grad=False, 12539 supports_out=False, 12540 sample_inputs_func=sample_inputs_combinations), 12541 OpInfo('cartesian_prod', 12542 op=torch.cartesian_prod, 12543 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 12544 supports_out=False, 12545 supports_forward_ad=True, 12546 supports_fwgrad_bwgrad=True, 12547 # See https://github.com/pytorch/pytorch/pull/78358 12548 check_batched_forward_grad=False, 12549 sample_inputs_func=sample_inputs_cartesian_prod, 12550 skips=( 12551 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 12552 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 12553 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 12554 # RuntimeError: input->type()->kind() == TypeKind::OptionalType 12555 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 12556 DecorateInfo(unittest.expectedFailure, 12557 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 12558 )), 12559 OpInfo('cdist', 12560 dtypes=floating_types(), 12561 supports_out=False, 12562 supports_gradgrad=False, 12563 assert_autodiffed=False, 12564 sample_inputs_func=sample_inputs_cdist), 12565 UnaryUfuncInfo('ceil', 12566 ref=np.ceil, 12567 dtypes=all_types_and(torch.half, torch.bfloat16), 12568 supports_forward_ad=True, 12569 supports_fwgrad_bwgrad=True, 12570 skips=( 12571 DecorateInfo(unittest.expectedFailure, 12572 'TestNNCOpInfo', 12573 'test_nnc_correctness', 12574 dtypes=tuple(t for t in integral_types() if t != torch.uint8)), 12575 ), 12576 supports_sparse=True, 12577 supports_sparse_csr=True, 12578 supports_sparse_csc=True, 12579 supports_sparse_bsr=True, 12580 supports_sparse_bsc=True, 12581 assert_autodiffed=True), 12582 OpInfo('cholesky', 12583 dtypes=floating_and_complex_types(), 12584 sample_inputs_func=sample_inputs_linalg_cholesky, 12585 gradcheck_wrapper=gradcheck_wrapper_hermitian_input, 12586 decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],), 12587 OpInfo('cholesky_inverse', 12588 dtypes=floating_and_complex_types(), 12589 backward_dtypes=floating_and_complex_types(), 12590 # https://github.com/pytorch/pytorch/issues/80411 12591 gradcheck_fast_mode=True, 12592 supports_fwgrad_bwgrad=True, 12593 supports_forward_ad=True, 12594 check_batched_gradgrad=True, 12595 sample_inputs_func=sample_inputs_linalg_cholesky_inverse, 12596 gradcheck_wrapper=gradcheck_wrapper_triangular_input_real_positive_diagonal, 12597 decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], 12598 skips=( 12599 # Strides are not the same! Original strides were ((4, 2, 1),) and strides are now ((4, 1, 2),) 12600 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),)), 12601 OpInfo('cholesky_solve', 12602 op=torch.cholesky_solve, 12603 dtypes=floating_and_complex_types(), 12604 sample_inputs_func=sample_inputs_cholesky_solve, 12605 check_batched_gradgrad=False, 12606 supports_forward_ad=True, 12607 supports_fwgrad_bwgrad=True, 12608 gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs), 12609 decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), 12610 OpInfo('chunk', 12611 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 12612 sample_inputs_func=sample_inputs_chunk, 12613 reference_inputs_func=reference_inputs_chunk, 12614 supports_forward_ad=True, 12615 supports_fwgrad_bwgrad=True, 12616 supports_out=False), 12617 OpInfo('unsafe_chunk', 12618 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 12619 sample_inputs_func=sample_inputs_chunk, 12620 check_batched_forward_grad=False, 12621 reference_inputs_func=reference_inputs_chunk, 12622 supports_forward_ad=True, 12623 supports_fwgrad_bwgrad=True, 12624 supports_out=False), 12625 OpInfo('clone', 12626 ref=np.copy, 12627 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 12628 sample_inputs_func=sample_inputs_clone_contiguous, 12629 reference_inputs_func=reference_inputs_clone_contiguous, 12630 supports_forward_ad=True, 12631 supports_fwgrad_bwgrad=True, 12632 supports_out=False, 12633 skips=( 12634 # TypeError: _copy_dispatcher() got an unexpected keyword argument 'memory_format' 12635 # (NumPy reference needs to be extended with memory_format) 12636 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref'), 12637 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), 12638 ),), 12639 OpInfo('contiguous', 12640 op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs), 12641 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 12642 sample_inputs_func=sample_inputs_clone_contiguous, 12643 reference_inputs_func=reference_inputs_clone_contiguous, 12644 supports_forward_ad=True, 12645 supports_fwgrad_bwgrad=True, 12646 autodiff_fusible_nodes=['aten::contiguous'], 12647 assert_jit_shape_analysis=True, 12648 supports_out=False, 12649 skips=( 12650 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 12651 )), 12652 OpInfo('sum_to_size', 12653 op=lambda x, *args, **kwargs: x.sum_to_size(*args, **kwargs), 12654 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 12655 sample_inputs_func=sample_inputs_sum_to_size, 12656 error_inputs_func=error_inputs_sum_to_size, 12657 supports_forward_ad=True, 12658 supports_fwgrad_bwgrad=True, 12659 supports_out=False, 12660 skips=( 12661 # lambda impl 12662 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 12663 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float,)), 12664 )), 12665 OpInfo('clamp', 12666 aliases=('clip',), 12667 ref=_clamp_numpy, 12668 dtypes=all_types_and(torch.bfloat16, torch.half), 12669 sample_inputs_func=sample_inputs_clamp, 12670 reference_inputs_func=partial(reference_inputs_elementwise_ternary, sample_inputs_func=sample_inputs_clamp), 12671 assert_autodiffed=True, 12672 supports_forward_ad=True, 12673 supports_fwgrad_bwgrad=True, 12674 skips=( 12675 # NNC appear to not handle boolean clamp 12676 DecorateInfo(unittest.expectedFailure, 12677 'TestNNCOpInfo', 12678 'test_nnc_correctness', 12679 dtypes=(torch.bool,)), 12680 # MPS does not support float64, while numpy does internal computations in float64. 12681 # See https://github.com/pytorch/pytorch/blob/3c1cf03fde145bdbe1f5ffb81765d076c10b4c04/test/test_ops.py#L260-L264 12682 DecorateInfo(unittest.expectedFailure, 12683 'TestCommon', 12684 'test_numpy_ref_mps'), 12685 )), 12686 UnaryUfuncInfo('positive', 12687 ref=np.positive, 12688 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), 12689 supports_out=False, 12690 supports_forward_ad=True, 12691 supports_fwgrad_bwgrad=True, 12692 supports_sparse=True, 12693 supports_sparse_csr=True, 12694 supports_sparse_csc=True, 12695 supports_sparse_bsr=True, 12696 supports_sparse_bsc=True, 12697 ), 12698 UnaryUfuncInfo('conj', 12699 ref=np.conj, 12700 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, 12701 torch.half, torch.chalf), 12702 supports_sparse=True, 12703 supports_forward_ad=True, 12704 supports_fwgrad_bwgrad=True, 12705 # See https://github.com/pytorch/pytorch/pull/78358 12706 check_batched_forward_grad=False, 12707 supports_out=False), 12708 UnaryUfuncInfo('conj_physical', 12709 decomp_aten_name='_conj_physical', 12710 ref=np.conj, 12711 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, 12712 torch.half, torch.chalf), 12713 supports_forward_ad=True, 12714 supports_fwgrad_bwgrad=True, 12715 supports_sparse=True, 12716 supports_sparse_csr=True, 12717 supports_sparse_csc=True, 12718 supports_sparse_bsr=True, 12719 supports_sparse_bsc=True, 12720 skips=( 12721 # RuntimeError: inputSet && outputSet 12722 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":118, 12723 # please report a bug to PyTorch. 12724 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, )), 12725 DecorateInfo(unittest.skip("Skipped! conj_physical_ not implemented for sparse"), 12726 'TestSparseUnaryUfuncs', 'test_inplace'), 12727 )), 12728 OpInfo('resolve_conj', 12729 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 12730 sample_inputs_func=sample_inputs_view_as_real, 12731 supports_forward_ad=True, 12732 supports_fwgrad_bwgrad=True, 12733 supports_out=False, 12734 ), 12735 OpInfo('resolve_neg', 12736 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 12737 sample_inputs_func=sample_inputs_view_as_real, 12738 supports_forward_ad=True, 12739 supports_fwgrad_bwgrad=True, 12740 supports_out=False, 12741 ), 12742 OpInfo('view_as_real', 12743 dtypes=complex_types(), 12744 supports_forward_ad=True, 12745 supports_out=False, 12746 supports_fwgrad_bwgrad=True, 12747 sample_inputs_func=sample_inputs_view_as_real, 12748 test_conjugated_samples=False, 12749 ), 12750 OpInfo('view_as_complex', 12751 dtypes=floating_types_and(torch.half), 12752 supports_out=False, 12753 supports_forward_ad=True, 12754 supports_fwgrad_bwgrad=True, 12755 test_neg_view=False, 12756 sample_inputs_func=sample_inputs_view_as_complex, 12757 skips=( 12758 # RuntimeError: Tensor must have a last dimension with stride 1 12759 DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), 12760 # RuntimeError: "eq_cpu" not implemented for 'ComplexHalf' 12761 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.half,)), 12762 # RuntimeError: view size is not compatible with input tensor's size and stride 12763 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), 12764 )), 12765 BinaryUfuncInfo('complex', 12766 dtypes=floating_types_and(torch.half), 12767 supports_forward_ad=True, 12768 supports_fwgrad_bwgrad=True, 12769 supports_rhs_python_scalar=False, 12770 error_inputs_func=error_inputs_complex, 12771 skips=( 12772 # Tests don't account for complex's type promotion semantics 12773 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), 12774 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='mps'), 12775 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'),)), 12776 BinaryUfuncInfo('copysign', 12777 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 12778 promotes_int_to_float=True, 12779 # https://github.com/pytorch/pytorch/issues/80411 12780 gradcheck_fast_mode=True, 12781 supports_forward_ad=True, 12782 supports_fwgrad_bwgrad=True), 12783 OpInfo('corrcoef', 12784 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 12785 sample_inputs_func=sample_inputs_corrcoef, 12786 supports_forward_ad=True, 12787 supports_fwgrad_bwgrad=True, 12788 # See https://github.com/pytorch/pytorch/pull/78358 12789 check_batched_forward_grad=False, 12790 skips=( 12791 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 12792 DecorateInfo( 12793 unittest.skip("Skipped!"), 12794 'TestSchemaCheckModeOpInfo', 12795 'test_schema_correctness', 12796 dtypes=(torch.complex64, torch.complex128)), 12797 ), 12798 supports_out=False), 12799 UnaryUfuncInfo('cos', 12800 ref=np.cos, 12801 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 12802 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 12803 assert_autodiffed=True, 12804 handles_large_floats=False, 12805 supports_forward_ad=True, 12806 supports_fwgrad_bwgrad=True, 12807 promotes_int_to_float=True, 12808 decorators=(precisionOverride({torch.bfloat16: 1e-2}),), 12809 skips=( 12810 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12811 dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), 12812 # This fails on CUDA but passes on ROCm 12813 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12814 dtypes=(torch.cdouble,), device_type='cuda'), 12815 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12816 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), 12817 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12818 device_type='cpu', 12819 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), 12820 # AssertionError: Tensor-likes are not close! 12821 # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed) 12822 # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) 12823 DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', 12824 device_type='cuda', 12825 dtypes=(torch.chalf,), active_if=IS_WINDOWS), 12826 )), 12827 UnaryUfuncInfo('cosh', 12828 ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh), 12829 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 12830 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 12831 assert_autodiffed=True, 12832 supports_forward_ad=True, 12833 supports_fwgrad_bwgrad=True, 12834 promotes_int_to_float=True, 12835 skips=( 12836 # Reference: https://github.com/pytorch/pytorch/issues/48641 12837 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12838 device_type='cpu', dtypes=[torch.int8]), 12839 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12840 dtypes=[torch.cdouble]), 12841 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12842 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), 12843 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12844 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), 12845 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 12846 device_type='cpu', 12847 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), 12848 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 12849 device_type='cpu', 12850 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), 12851 # AssertionError: Tensor-likes are not close! 12852 # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed) 12853 # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed) 12854 DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', 12855 device_type='cuda', 12856 dtypes=(torch.chalf,), active_if=IS_WINDOWS), 12857 )), 12858 OpInfo('cov', 12859 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 12860 sample_inputs_func=sample_inputs_cov, 12861 error_inputs_func=error_inputs_cov, 12862 supports_out=False, 12863 supports_forward_ad=True, 12864 supports_fwgrad_bwgrad=True, 12865 skips=( 12866 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 12867 DecorateInfo( 12868 unittest.skip("Skipped!"), 12869 'TestSchemaCheckModeOpInfo', 12870 'test_schema_correctness', 12871 dtypes=(torch.complex64, torch.complex128)), 12872 # Float did not match double 12873 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), 12874 # Jacobian mismatch 12875 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), 12876 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), 12877 DecorateInfo(unittest.skip("Barely fails"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), 12878 # JIT test not working for tensor kwargs (https://github.com/pytorch/pytorch/issues/58507) 12879 # RuntimeError: 12880 # undefined value tensor: 12881 # File "<string>", line 3 12882 # def the_method(i0): 12883 # return torch.cov(i0, correction=0, fweights=None, aweights=tensor([0.0518, 0.4681], dtype=torch.float32, requires_grad=True)) # noqa: B950 12884 # ~~~~~~ <--- HERE 12885 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 12886 DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=1.4e-3)}), 12887 "TestInductorOpInfo", "test_comprehensive", device_type="cpu"), 12888 )), 12889 OpInfo('cross', 12890 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 12891 sample_inputs_func=sample_inputs_cross, 12892 supports_fwgrad_bwgrad=True, 12893 supports_out=True, 12894 supports_forward_ad=True), 12895 OpInfo('cumsum', 12896 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 12897 supports_forward_ad=True, 12898 supports_fwgrad_bwgrad=True, 12899 skips=( 12900 # cumsum does not handle correctly out= dtypes 12901 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 12902 ), 12903 sample_inputs_func=sample_inputs_cumulative_ops), 12904 OpInfo('cumprod', 12905 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 12906 supports_forward_ad=True, 12907 supports_fwgrad_bwgrad=True, 12908 skips=( 12909 # cumprod does not handle correctly out= dtypes 12910 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 12911 ), 12912 # gradgradcheck fails in fast_mode=True: #56275 12913 sample_inputs_func=sample_inputs_cumprod, 12914 gradcheck_fast_mode=False), 12915 OpInfo('cummax', 12916 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 12917 sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), 12918 supports_forward_ad=True, 12919 supports_fwgrad_bwgrad=True, 12920 skips=( 12921 ), 12922 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), 12923 OpInfo('cummin', 12924 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 12925 sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False), 12926 supports_forward_ad=True, 12927 supports_fwgrad_bwgrad=True, 12928 skips=( 12929 ), 12930 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), 12931 UnaryUfuncInfo('deg2rad', 12932 ref=np.radians, 12933 decorators=(precisionOverride({torch.bfloat16: 7e-1, 12934 torch.float16: 7e-1}),), 12935 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 12936 supports_forward_ad=True, 12937 supports_fwgrad_bwgrad=True, 12938 supports_sparse=True, 12939 supports_sparse_csr=True, 12940 supports_sparse_csc=True, 12941 supports_sparse_bsr=True, 12942 supports_sparse_bsc=True, 12943 promotes_int_to_float=True), 12944 OpInfo('diff', 12945 op=torch.diff, 12946 # np.diff has np._NoValue as default values for prepend and append, compare_with_reference breaks if prepend/append 12947 # are set as None when converting to numpy 12948 ref=lambda input, n=1, dim=-1, prepend=np._NoValue, append=np._NoValue: ( 12949 np.diff(input, n, dim, np._NoValue if prepend is None else prepend, np._NoValue if append is None else append) 12950 ), 12951 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 12952 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 12953 gradcheck_fast_mode=True, 12954 supports_forward_ad=True, 12955 supports_fwgrad_bwgrad=True, 12956 sample_inputs_func=sample_inputs_diff, 12957 error_inputs_func=error_inputs_diff, 12958 # See https://github.com/pytorch/pytorch/pull/78358 12959 check_batched_forward_grad=False, 12960 skips=( 12961 )), 12962 BinaryUfuncInfo('div', 12963 aliases=('divide',), 12964 variant_test_name='no_rounding_mode', 12965 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 12966 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 12967 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 12968 gradcheck_fast_mode=True, 12969 supports_forward_ad=True, 12970 promotes_int_to_float=True, 12971 supports_fwgrad_bwgrad=True, 12972 supports_two_python_scalars=True, 12973 assert_autodiffed=True, 12974 rhs_make_tensor_kwargs=dict(exclude_zero=True),), 12975 BinaryUfuncInfo('div', 12976 aliases=('divide',), 12977 variant_test_name='trunc_rounding', 12978 dtypes=all_types_and(torch.half, torch.bfloat16), 12979 sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="trunc")), 12980 # https://github.com/pytorch/pytorch/issues/80411 12981 gradcheck_fast_mode=True, 12982 supports_forward_ad=True, 12983 supports_fwgrad_bwgrad=True, 12984 supports_two_python_scalars=True, 12985 assert_autodiffed=True, 12986 rhs_make_tensor_kwargs=dict(exclude_zero=True), 12987 decorators=( 12988 # See https://github.com/pytorch/pytorch/issues/111126 12989 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), 12990 ), 12991 skips=( 12992 # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div 12993 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'), 12994 # FIXME: 12995 # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for 12996 # output 0 with respect to input 1, 12997 # numerical:tensor(-17746.9307, dtype=torch.float64) 12998 # analytical:tensor(0., dtype=torch.float64) 12999 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 13000 'test_fn_grad', device_type='cpu', 13001 dtypes=(torch.float64,)), 13002 )), 13003 BinaryUfuncInfo('div', 13004 aliases=('divide',), 13005 variant_test_name='floor_rounding', 13006 dtypes=all_types_and(torch.half, torch.bfloat16), 13007 sample_inputs_func=partial(sample_inputs_elementwise_binary, sample_kwargs=dict(rounding_mode="floor")), 13008 # https://github.com/pytorch/pytorch/issues/80411 13009 gradcheck_fast_mode=True, 13010 supports_forward_ad=True, 13011 supports_fwgrad_bwgrad=True, 13012 supports_two_python_scalars=True, 13013 assert_autodiffed=True, 13014 rhs_make_tensor_kwargs=dict(exclude_zero=True), 13015 decorators=( 13016 # See https://github.com/pytorch/pytorch/issues/111126 13017 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), 13018 ), 13019 skips=( 13020 # RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div 13021 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'), 13022 # FIXME: 13023 # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for 13024 # output 0 with respect to input 1, 13025 # numerical:tensor(-17746.9307, dtype=torch.float64) 13026 # analytical:tensor(0., dtype=torch.float64) 13027 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 13028 'test_fn_grad', 13029 dtypes=(torch.float64,), 13030 device_type='cpu'), 13031 )), 13032 BinaryUfuncInfo('true_divide', 13033 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13034 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 13035 supports_forward_ad=True, 13036 promotes_int_to_float=True, 13037 supports_fwgrad_bwgrad=True, 13038 supports_two_python_scalars=True, 13039 rhs_make_tensor_kwargs=dict(exclude_zero=True)), 13040 OpInfo('equal', 13041 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 13042 ref=lambda input, other: (input == other).all(), 13043 sample_inputs_func=sample_inputs_equal, 13044 supports_autograd=False, 13045 supports_tracing=False, 13046 skips=( 13047 )), 13048 UnaryUfuncInfo('exp', 13049 ref=np_unary_ufunc_integer_promotion_wrapper(np.exp), 13050 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13051 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 13052 skips=( 13053 # Reference: https://github.com/pytorch/pytorch/issues/48010 13054 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 13055 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 13056 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 13057 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 13058 ), 13059 assert_autodiffed=True, 13060 supports_forward_ad=True, 13061 supports_fwgrad_bwgrad=True, 13062 promotes_int_to_float=True), 13063 OpInfo('expand', 13064 op=lambda self, shape: self.expand(shape), 13065 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13066 sample_inputs_func=sample_inputs_expand, 13067 supports_forward_ad=True, 13068 supports_fwgrad_bwgrad=True, 13069 assert_jit_shape_analysis=True, 13070 supports_out=False, 13071 skips=( 13072 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 13073 )), 13074 OpInfo('expand_as', 13075 op=lambda self, other: self.expand_as(other), 13076 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13077 supports_forward_ad=True, 13078 supports_fwgrad_bwgrad=True, 13079 sample_inputs_func=sample_inputs_expand_as, 13080 supports_out=False, 13081 skips=( 13082 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),), 13083 ), 13084 OpInfo('expand_copy', 13085 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13086 sample_inputs_func=sample_inputs_expand, 13087 supports_forward_ad=True, 13088 supports_fwgrad_bwgrad=True, 13089 assert_jit_shape_analysis=True, 13090 supports_out=True, 13091 skips=( 13092 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 13093 )), 13094 OpInfo('diag', 13095 ref=np.diag, 13096 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), 13097 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 13098 supports_forward_ad=True, 13099 supports_fwgrad_bwgrad=True, 13100 check_batched_forward_grad=False, 13101 sample_inputs_func=sample_inputs_diag, 13102 error_inputs_func=error_inputs_diag), 13103 OpInfo('diag_embed', 13104 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 13105 supports_out=False, 13106 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 13107 gradcheck_fast_mode=True, 13108 supports_forward_ad=True, 13109 supports_fwgrad_bwgrad=True, 13110 sample_inputs_func=sample_inputs_diagonal_diag_embed, 13111 reference_inputs_func=reference_inputs_diagonal_diag_embed, 13112 error_inputs_func=error_inputs_diagonal_diag_embed), 13113 OpInfo('diagonal', 13114 aten_backward_name='diagonal_backward', 13115 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 13116 supports_out=False, 13117 supports_forward_ad=True, 13118 supports_fwgrad_bwgrad=True, 13119 sample_inputs_func=sample_inputs_diagonal_diag_embed, 13120 reference_inputs_func=reference_inputs_diagonal_diag_embed, 13121 error_inputs_func=error_inputs_diagonal_diag_embed), 13122 OpInfo('diagonal_copy', 13123 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 13124 supports_forward_ad=True, 13125 supports_fwgrad_bwgrad=True, 13126 sample_inputs_func=sample_inputs_diagonal_diag_embed, 13127 reference_inputs_func=reference_inputs_diagonal_diag_embed, 13128 error_inputs_func=error_inputs_diagonal_diag_embed), 13129 OpInfo('diagonal_scatter', 13130 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), 13131 supports_out=False, 13132 supports_forward_ad=True, 13133 supports_fwgrad_bwgrad=True, 13134 sample_inputs_func=sample_inputs_diagonal_scatter), 13135 OpInfo('alias_copy', 13136 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 13137 sample_inputs_func=sample_inputs_alias_copy, 13138 supports_forward_ad=True, 13139 supports_fwgrad_bwgrad=True, 13140 supports_out=True), 13141 BinaryUfuncInfo('eq', 13142 ref=np.equal, 13143 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 13144 always_returns_bool=True, 13145 supports_autograd=False, 13146 sample_inputs_func=sample_inputs_comparison_ops, 13147 skips=( 13148 )), 13149 BinaryUfuncInfo('fmax', 13150 op=torch.fmax, 13151 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 13152 supports_forward_ad=True, 13153 supports_fwgrad_bwgrad=True, 13154 supports_rhs_python_scalar=False, 13155 skips=( 13156 # RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' 13157 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), 13158 )), 13159 BinaryUfuncInfo('fmin', 13160 op=torch.fmin, 13161 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 13162 supports_forward_ad=True, 13163 supports_fwgrad_bwgrad=True, 13164 supports_rhs_python_scalar=False, 13165 skips=( 13166 # RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' 13167 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), 13168 )), 13169 BinaryUfuncInfo('fmod', 13170 ref=np.fmod, 13171 dtypes=all_types_and(torch.float16, torch.bfloat16), 13172 dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), 13173 # https://github.com/pytorch/pytorch/issues/80411 13174 gradcheck_fast_mode=True, 13175 supports_forward_ad=True, 13176 supports_fwgrad_bwgrad=True, 13177 assert_autodiffed=None, 13178 rhs_make_tensor_kwargs={'exclude_zero': True}, 13179 decorators=( 13180 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 13181 'test_contig_vs_every_other', 13182 dtypes=(torch.bfloat16,)), 13183 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 13184 'test_non_contig', 13185 dtypes=(torch.bfloat16,)), 13186 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 13187 'test_reference_numerics', 13188 dtypes=(torch.bfloat16,)), 13189 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 13190 'test_reference_numerics_small_values', 13191 dtypes=(torch.uint8,)), 13192 # FIXME: 13193 # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for 13194 # output 0 with respect to input 1, 13195 # numerical:tensor(101.6283, dtype=torch.float64) 13196 # analytical:tensor(-18.3575, dtype=torch.float64) 13197 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 13198 'test_fn_grad', 13199 dtypes=(torch.float64,), 13200 device_type='cpu'), 13201 )), 13202 BinaryUfuncInfo('remainder', 13203 ref=np.remainder, 13204 dtypes=all_types_and(torch.float16, torch.bfloat16), 13205 dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), 13206 # https://github.com/pytorch/pytorch/issues/80411 13207 gradcheck_fast_mode=True, 13208 supports_forward_ad=True, 13209 supports_fwgrad_bwgrad=True, 13210 assert_autodiffed=None, 13211 operator_variant=operator.mod, 13212 inplace_operator_variant=operator.imod, 13213 supports_one_python_scalar=True, 13214 rhs_make_tensor_kwargs={'exclude_zero': True}, 13215 decorators=( 13216 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 13217 'test_contig_vs_every_other', 13218 dtypes=(torch.bfloat16,)), 13219 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 13220 'test_non_contig', 13221 dtypes=(torch.bfloat16,)), 13222 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 13223 'test_reference_numerics', 13224 dtypes=(torch.bfloat16,)), 13225 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 13226 'test_reference_numerics_small_values', 13227 dtypes=(torch.uint8,)), 13228 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 13229 'test_nnc_correctness', 13230 dtypes=(torch.bfloat16,)), 13231 # Fails on XLA 13232 # False is not true : Tensors failed to compare as equal! 13233 # Attempted to compare equality of tensors with different dtypes 13234 DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)), 13235 # FIXME: 13236 # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for 13237 # output 0 with respect to input 1, 13238 # numerical:tensor(102.4676, dtype=torch.float64) 13239 # analytical:tensor(-17.5182, dtype=torch.float64) 13240 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 13241 'test_fn_grad', device_type='cpu', 13242 dtypes=(torch.float64,)), 13243 DecorateInfo( 13244 toleranceOverride({ 13245 torch.float16: tol(atol=5e-4, rtol=3e-3), 13246 }), 13247 "TestInductorOpInfo", 13248 "test_comprehensive", 13249 device_type="cuda" 13250 ), 13251 )), 13252 UnaryUfuncInfo('frac', 13253 ref=lambda x: np.modf(x)[0], 13254 dtypes=floating_types_and(torch.bfloat16, torch.float16), 13255 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 13256 assert_autodiffed=True, 13257 supports_forward_ad=True, 13258 supports_fwgrad_bwgrad=True, 13259 supports_sparse=True, 13260 supports_sparse_csr=True, 13261 supports_sparse_csc=True, 13262 supports_sparse_bsr=True, 13263 supports_sparse_bsc=True, 13264 skips=( 13265 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 13266 dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)), 13267 # 76047 13268 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', 13269 dtypes=(torch.bfloat16, torch.float32, torch.float64)), 13270 )), 13271 OpInfo('stft', 13272 decorators=[ 13273 skipCPUIfNoFFT, 13274 DecorateInfo(unittest.skip("Skipped! stft does not match the native function"), 13275 'TestJit', 'test_variant_consistency_jit'), 13276 ], 13277 dtypes=floating_and_complex_types(), 13278 sample_inputs_func=sample_inputs_stft, 13279 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 13280 gradcheck_fast_mode=True, 13281 supports_forward_ad=True, 13282 supports_fwgrad_bwgrad=True, 13283 check_batched_forward_grad=False, 13284 check_batched_grad=False, 13285 check_batched_gradgrad=False, 13286 supports_out=False, 13287 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 13288 ), 13289 OpInfo('istft', 13290 dtypes=complex_types(), 13291 sample_inputs_func=sample_inputs_istft, 13292 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 13293 gradcheck_fast_mode=True, 13294 supports_forward_ad=True, 13295 supports_fwgrad_bwgrad=True, 13296 check_batched_forward_grad=False, 13297 check_batched_grad=False, 13298 check_batched_gradgrad=False, 13299 supports_out=False, 13300 decorators=( 13301 DecorateInfo(unittest.skip("Skipped! istft does not match the native function"), 13302 'TestJit', 'test_variant_consistency_jit'), 13303 ), 13304 skips=( 13305 skipCPUIfNoFFT, 13306 # gradcheck fails on ROCm (gh-68429) 13307 # grad is computed improperly (probably for weights tensor) 13308 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), 13309 # Pre-existing condition (calls .item); needs to be fixed 13310 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), 13311 )), 13312 UnaryUfuncInfo('floor', 13313 ref=np.floor, 13314 dtypes=all_types_and(torch.half, torch.bfloat16), 13315 supports_forward_ad=True, 13316 supports_fwgrad_bwgrad=True, 13317 skips=( 13318 DecorateInfo(unittest.expectedFailure, 13319 'TestNNCOpInfo', 13320 'test_nnc_correctness', 13321 dtypes=tuple(t for t in integral_types() if t != torch.uint8)), 13322 ), 13323 supports_sparse=True, 13324 supports_sparse_csr=True, 13325 supports_sparse_csc=True, 13326 supports_sparse_bsr=True, 13327 supports_sparse_bsc=True, 13328 assert_autodiffed=True), 13329 OpInfo('flip', 13330 op=torch.flip, 13331 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13332 sample_inputs_func=sample_inputs_flip, 13333 supports_forward_ad=True, 13334 supports_fwgrad_bwgrad=True, 13335 supports_out=False), 13336 OpInfo('fliplr', 13337 op=torch.fliplr, 13338 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13339 sample_inputs_func=sample_inputs_fliplr_flipud, 13340 error_inputs_func=error_inputs_fliplr, 13341 supports_forward_ad=True, 13342 supports_fwgrad_bwgrad=True, 13343 supports_out=False), 13344 OpInfo('flipud', 13345 op=torch.flipud, 13346 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13347 sample_inputs_func=sample_inputs_fliplr_flipud, 13348 error_inputs_func=error_inputs_flipud, 13349 supports_forward_ad=True, 13350 supports_fwgrad_bwgrad=True, 13351 supports_out=False), 13352 OpInfo('sparse.sampled_addmm', 13353 dtypes=floating_and_complex_types(), 13354 supports_autograd=True, 13355 sample_inputs_func=sample_inputs_sparse_sampled_addmm, 13356 decorators=[ 13357 skipCUDAIf(not ((_get_torch_cuda_version() >= (11, 3)) 13358 or (_get_torch_rocm_version() >= (5, 2))), 13359 "cusparseSDDMM was added in 11.2.1"), 13360 skipCPUIfNoMklSparse, ], 13361 skips=( 13362 # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous 13363 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), 13364 # RuntimeError: Sparse CSR tensors do not have strides. 13365 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), 13366 DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), 13367 # RuntimeError: sampled_addmm: Expected result to have sparse csr layout, but got Strided 13368 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'), 13369 # RuntimeError: Sparse CSR tensors do not have strides 13370 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), 13371 # RuntimeError: Sparse CSR tensors do not have strides 13372 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), 13373 # RuntimeError: Sparse CSR tensors do not have strides 13374 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), 13375 # RuntimeError: Sparse CSR tensors do not have strides 13376 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), 13377 # RuntimeError: Sparse CSR tensors do not have strides 13378 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), 13379 # RuntimeError: Sparse CSR tensors do not have strides 13380 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), 13381 # RuntimeError: Sparse CSR tensors do not have strides 13382 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 13383 # RuntimeError: unsupported memory format option Preserve 13384 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 13385 # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype 13386 # RuntimeError: Sparse CSR tensors do not have strides 13387 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), 13388 # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... 13389 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), 13390 # RuntimeError: sparse_mask does not support automatic differentiation for outputs with complex dtype. 13391 # RuntimeError: Sparse CSR tensors do not have is_contiguous 13392 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), 13393 # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... 13394 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), 13395 # NotImplementedError: Could not run 'aten::sparse_sampled_addmm' with arguments from the 'SparseCsrMeta' backend. 13396 DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'), 13397 DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), 13398 DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'), 13399 DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), 13400 DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'), 13401 )), 13402 OpInfo('sparse.mm', 13403 dtypes=floating_types_and(torch.bfloat16, torch.float16), 13404 variant_test_name='reduce', 13405 supports_autograd=True, 13406 supports_out=False, 13407 supports_gradgrad=False, 13408 supports_forward_ad=False, 13409 sample_inputs_func=sample_inputs_sparse_mm_reduce, 13410 decorators=[onlyCPU], 13411 skips=( 13412 # NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous 13413 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), 13414 # RuntimeError: Sparse CSR tensors do not have strides. 13415 DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), 13416 # RuntimeError: Sparse CSR tensors do not have strides 13417 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), 13418 # RuntimeError: Sparse CSR tensors do not have strides 13419 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), 13420 # RuntimeError: Sparse CSR tensors do not have strides 13421 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), 13422 # RuntimeError: Sparse CSR tensors do not have strides 13423 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), 13424 # RuntimeError: Sparse CSR tensors do not have strides 13425 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), 13426 # RuntimeError: Sparse CSR tensors do not have strides 13427 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), 13428 # RuntimeError: Sparse CSR tensors do not have strides 13429 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 13430 # RuntimeError: unsupported memory format option Preserve 13431 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 13432 # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... 13433 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), 13434 # RuntimeError: Sparse CSR tensors do not have is_contiguou 13435 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), 13436 # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... 13437 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), 13438 # RuntimeError: Sparse CSR tensors do not have strides 13439 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), 13440 # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... 13441 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'), 13442 # NotImplementedError: Could not run 'aten::_sparse_mm_reduce_impl' with arguments from the 'SparseCsrMeta' backend 13443 DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'), 13444 DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), 13445 DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'), 13446 )), 13447 UnaryUfuncInfo('i0', 13448 ref=np_unary_ufunc_integer_promotion_wrapper( 13449 scipy.special.i0) if TEST_SCIPY else None, 13450 aliases=('special.i0',), 13451 decorators=(precisionOverride({torch.bfloat16: 3e-1, 13452 torch.float16: 5e-1}),), 13453 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 13454 backward_dtypes=floating_types(), 13455 supports_forward_ad=True, 13456 supports_fwgrad_bwgrad=True, 13457 promotes_int_to_float=True, 13458 sample_inputs_func=sample_inputs_i0_i1, 13459 skips=( 13460 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 13461 dtypes=(torch.int8,)), 13462 )), 13463 BinaryUfuncInfo('floor_divide', 13464 ref=_floor_divide_np, 13465 dtypes=all_types_and(torch.half, torch.bfloat16), 13466 supports_autograd=False, 13467 rhs_make_tensor_kwargs=dict(exclude_zero=True), 13468 supports_two_python_scalars=True, 13469 skips=( 13470 # AssertionError: Results of original model and exported/imported version of model differed 13471 DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), 13472 # bfloat16 floor_divide compared with a float32 reference works inconsistently 13473 DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 13474 dtypes=(torch.bfloat16,)), 13475 # int8 floor divide has different results for -128 // -1 vs. NumPy 13476 DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', 13477 dtypes=(torch.int8,)), 13478 # The following tests fails on some jobs 13479 DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', 13480 dtypes=(torch.float16,)), 13481 DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}), 13482 'TestBinaryUfuncs', 'test_reference_numerics'), 13483 )), 13484 UnaryUfuncInfo('frexp', 13485 op=torch.frexp, 13486 ref=np.frexp, 13487 dtypes=floating_types_and(torch.half, torch.bfloat16), 13488 decorators=[], 13489 supports_forward_ad=True, 13490 supports_fwgrad_bwgrad=True, 13491 skips=( 13492 # skips below tests as torch.frexp returns tuple-like (mantissa, exponent) as outputs, 13493 # while theses tests currently requires output to a single tensor. 13494 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'), 13495 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'), 13496 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'), 13497 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_non_contig_expand'), 13498 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency'), 13499 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), 13500 13501 # skips test_reference_numerics due to error in Windows CI. 13502 # The np.frexp returns exponent as np.intc dtype on Windows platform, 13503 # and np.intc does not have the correspond torch dtype 13504 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', 13505 active_if=IS_WINDOWS), 13506 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 13507 active_if=IS_WINDOWS), 13508 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 13509 active_if=IS_WINDOWS), 13510 )), 13511 UnaryUfuncInfo('log1p', 13512 ref=np.log1p, 13513 aliases=('special.log1p',), 13514 domain=(-1, None), 13515 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13516 decorators=(precisionOverride({torch.bfloat16: 1e-1}),), 13517 supports_forward_ad=True, 13518 supports_fwgrad_bwgrad=True, 13519 supports_sparse=True, 13520 supports_sparse_csr=True, 13521 supports_sparse_csc=True, 13522 supports_sparse_bsr=True, 13523 supports_sparse_bsc=True, 13524 assert_autodiffed=True, 13525 promotes_int_to_float=True), 13526 BinaryUfuncInfo('ge', 13527 ref=np.greater_equal, 13528 aliases=('greater_equal',), 13529 dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), 13530 always_returns_bool=True, 13531 supports_autograd=False, 13532 skips=( 13533 )), 13534 OpInfo('geqrf', 13535 dtypes=floating_and_complex_types(), 13536 sample_inputs_func=sample_inputs_linalg_qr_geqrf, 13537 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 13538 supports_autograd=False, 13539 skips=( 13540 # FIXME: geqrf can't forward with complex inputs that require grad 13541 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), 13542 # Strides are not the same! 13543 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 13544 )), 13545 BinaryUfuncInfo('gt', 13546 ref=np.greater, 13547 aliases=('greater',), 13548 dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), 13549 always_returns_bool=True, 13550 supports_autograd=False, 13551 skips=( 13552 )), 13553 UnaryUfuncInfo('imag', 13554 ref=np.imag, 13555 dtypes=complex_types_and(torch.chalf), 13556 supports_out=False, 13557 supports_forward_ad=True, 13558 supports_fwgrad_bwgrad=True, 13559 # See https://github.com/pytorch/pytorch/issues/66357 13560 # RuntimeError: view_as_real doesn't work on unresolved conjugated tensors. 13561 check_batched_forward_grad=False, 13562 skips=( 13563 # Skip since real and imag don't have out variants. 13564 DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), 13565 )), 13566 OpInfo('gradient', 13567 dtypes=floating_and_complex_types_and(torch.int8, torch.int16, 13568 torch.int32, torch.int64, 13569 torch.bfloat16, torch.half), 13570 supports_out=False, 13571 supports_forward_ad=True, 13572 supports_fwgrad_bwgrad=True, 13573 # See https://github.com/pytorch/pytorch/pull/78358 13574 check_batched_forward_grad=False, 13575 skips=( 13576 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 13577 # following tests give a runtime error with undefined value tensor 13578 # see discussion : https://github.com/pytorch/pytorch/issues/56660 13579 # RuntimeError: 13580 # Arguments for call are not valid. 13581 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950 13582 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), 13583 DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), 13584 ), 13585 supports_inplace_autograd=False, 13586 sample_inputs_func=sample_inputs_gradient, 13587 error_inputs_func=error_inputs_gradient), 13588 OpInfo('isin', 13589 dtypes=all_types(), 13590 dtypesIfCUDA=all_types_and(torch.half), 13591 supports_autograd=False, 13592 sample_inputs_func=sample_inputs_isin), 13593 OpInfo('kthvalue', 13594 dtypes=all_types_and(torch.bfloat16, torch.float16), 13595 supports_forward_ad=True, 13596 supports_fwgrad_bwgrad=True, 13597 sample_inputs_func=sample_inputs_kthvalue, 13598 error_inputs_func=error_inputs_kthvalue), 13599 BinaryUfuncInfo('le', 13600 ref=np.less_equal, 13601 aliases=('less_equal',), 13602 dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), 13603 always_returns_bool=True, 13604 supports_autograd=False, 13605 skips=( 13606 )), 13607 OpInfo('linspace', 13608 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), 13609 is_factory_function=True, 13610 supports_out=True, 13611 supports_autograd=False, 13612 error_inputs_func=error_inputs_linspace, 13613 sample_inputs_func=sample_inputs_linspace, 13614 skips=( 13615 # FX failed to normalize op - add the op to the op_skip list. 13616 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 13617 # Tests that assume input is a tensor or sequence of tensors 13618 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 13619 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 13620 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 13621 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 13622 13623 # Same failure as arange: cannot find linspace in captured graph 13624 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 13625 13626 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 13627 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 13628 # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API 13629 # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! 13630 # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. 13631 # CUDA driver allocated memory was 1254555648 and is now 1242955776. 13632 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', 13633 dtypes=(torch.cfloat,), device_type="cuda"), 13634 )), 13635 OpInfo('linspace', 13636 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), 13637 is_factory_function=True, 13638 supports_out=True, 13639 supports_autograd=False, 13640 error_inputs_func=error_inputs_linspace, 13641 sample_inputs_func=sample_inputs_linspace_tensor_overload, 13642 variant_test_name="tensor_overload", 13643 skips=( 13644 # FX failed to normalize op - add the op to the op_skip list. 13645 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 13646 # TypeError: 'int' object is not subscriptable 13647 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 13648 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 13649 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 13650 13651 # Same failure as arange: cannot find linspace in captured graph 13652 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 13653 13654 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 13655 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 13656 # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API 13657 # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! 13658 # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. 13659 # CUDA driver allocated memory was 1254555648 and is now 1242955776. 13660 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', 13661 dtypes=(torch.cfloat,), device_type="cuda"), 13662 )), 13663 OpInfo('logspace', 13664 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 13665 is_factory_function=True, 13666 supports_out=True, 13667 supports_autograd=False, 13668 error_inputs_func=error_inputs_linspace, 13669 sample_inputs_func=sample_inputs_logspace, 13670 skips=( 13671 # FX failed to normalize op - add the op to the op_skip list. 13672 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 13673 # Tests that assume input is a tensor or sequence of tensors 13674 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 13675 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 13676 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 13677 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 13678 # Same failure as arange: cannot find linspace in captured graph 13679 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 13680 13681 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 13682 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 13683 13684 # Off-by-one issue when casting floats to ints 13685 DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', 13686 dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), 13687 DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', 13688 dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), 13689 # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API 13690 # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! 13691 # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. 13692 # CUDA driver allocated memory was 1254555648 and is now 1242955776. 13693 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', 13694 dtypes=(torch.cfloat,), device_type="cuda"), 13695 )), 13696 OpInfo('logspace', 13697 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 13698 is_factory_function=True, 13699 supports_out=True, 13700 supports_autograd=False, 13701 error_inputs_func=error_inputs_linspace, 13702 sample_inputs_func=sample_inputs_logspace_tensor_overload, 13703 variant_test_name="tensor_overload", 13704 skips=( 13705 # FX failed to normalize op - add the op to the op_skip list. 13706 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 13707 # TypeError: 'int' object is not subscriptable 13708 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 13709 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 13710 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 13711 # Same failure as arange: cannot find linspace in captured graph 13712 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 13713 13714 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 13715 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 13716 13717 # Off-by-one issue when casting floats to ints 13718 DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', 13719 dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), 13720 DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', 13721 dtypes=(torch.int16, torch.int32, torch.int64), device_type="cuda"), 13722 # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API 13723 # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! 13724 # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. 13725 # CUDA driver allocated memory was 1254555648 and is now 1242955776. 13726 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', 13727 dtypes=(torch.cfloat,), device_type="cuda"), 13728 )), 13729 UnaryUfuncInfo('log', 13730 ref=np.log, 13731 domain=(0, None), 13732 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13733 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 13734 backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.chalf), 13735 assert_autodiffed=True, 13736 supports_forward_ad=True, 13737 supports_fwgrad_bwgrad=True, 13738 promotes_int_to_float=True, 13739 decorators=(precisionOverride({torch.bfloat16: 5e-2}),), 13740 skips=( 13741 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 13742 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 13743 active_if=IS_WINDOWS), 13744 ), 13745 # log(z)->-inf for |z|->0 13746 reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), 13747 UnaryUfuncInfo('log10', 13748 ref=np.log10, 13749 domain=(0, None), 13750 decorators=(precisionOverride({torch.bfloat16: 5e-2}),), 13751 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13752 assert_autodiffed=True, 13753 supports_forward_ad=True, 13754 supports_fwgrad_bwgrad=True, 13755 promotes_int_to_float=True, 13756 skips=( 13757 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 13758 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 13759 active_if=IS_WINDOWS), 13760 ), 13761 # log10(z)->-inf for |z|->0 13762 reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), 13763 UnaryUfuncInfo('log2', 13764 ref=np.log2, 13765 domain=(0, None), 13766 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13767 assert_autodiffed=True, 13768 supports_forward_ad=True, 13769 supports_fwgrad_bwgrad=True, 13770 promotes_int_to_float=True, 13771 decorators=(precisionOverride({torch.bfloat16: 1e-1}),), 13772 skips=( 13773 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 13774 dtypes=[torch.cfloat, torch.cdouble]), 13775 ), 13776 # log2(z)->-inf for |z|->0 13777 reference_numerics_filter=NumericsFilter(condition=lambda x: torch.abs(x) < 0.1, safe_val=1)), 13778 BinaryUfuncInfo('ldexp', 13779 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13780 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 13781 gradcheck_fast_mode=True, 13782 supports_forward_ad=True, 13783 supports_fwgrad_bwgrad=True, 13784 supports_inplace_autograd=False, 13785 promotes_int_to_float=True, 13786 supports_out=True, 13787 supports_rhs_python_scalar=False, 13788 skips=( 13789 # RuntimeError: mul(): functions with out=... arguments don't support 13790 # automatic differentiation, but one of the arguments requires grad 13791 # https://github.com/pytorch/pytorch/issues/68966 13792 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 13793 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 13794 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 13795 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 13796 ), 13797 decorators=[ 13798 DecorateInfo( 13799 toleranceOverride({ 13800 torch.complex64: tol(atol=1e-05, rtol=1e-05) 13801 }), 13802 'TestCommon', device_type='cpu', 13803 ), 13804 ], ), 13805 BinaryUfuncInfo('logaddexp', 13806 dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), 13807 dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), 13808 supports_forward_ad=True, 13809 supports_fwgrad_bwgrad=True, 13810 supports_rhs_python_scalar=False, 13811 skips=( 13812 # TODO: FIXME: RuntimeError: not implemented for 'ComplexFloat' 13813 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'), 13814 )), 13815 OpInfo('logaddexp2', 13816 dtypes=floating_types_and(torch.bfloat16, torch.half), 13817 supports_forward_ad=True, 13818 supports_fwgrad_bwgrad=True, 13819 sample_inputs_func=sample_inputs_logaddexp), 13820 UnaryUfuncInfo('logical_not', 13821 ref=np.logical_not, 13822 decorators=(precisionOverride({torch.bfloat16: 7e-1, 13823 torch.float16: 5e-1}),), 13824 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13825 supports_autograd=False, 13826 skips=( 13827 # The function variant always returns BoolTensor 13828 # while the inplace variant preserves the input dtype. 13829 # >>> t = torch.randn(3) 13830 # >>> torch.logical_not(t) 13831 # tensor([False, False, False]) 13832 # >>> torch.logical_not(t).dtype 13833 # torch.bool 13834 # >>> t.logical_not_().dtype 13835 # torch.float32 13836 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_variant_consistency', 13837 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)), 13838 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', 13839 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16)), 13840 )), 13841 BinaryUfuncInfo('lt', 13842 ref=np.less, 13843 aliases=('less',), 13844 dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), 13845 always_returns_bool=True, 13846 supports_autograd=False, 13847 skips=( 13848 )), 13849 OpInfo('lu_unpack', 13850 op=torch.lu_unpack, 13851 dtypes=floating_and_complex_types(), 13852 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 13853 gradcheck_fast_mode=True, 13854 supports_forward_ad=True, 13855 supports_fwgrad_bwgrad=True, 13856 skips=(skipCPUIfNoLapack,), 13857 sample_inputs_func=sample_inputs_lu_unpack), 13858 OpInfo('lu', 13859 op=torch.lu, 13860 dtypes=floating_and_complex_types(), 13861 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 13862 gradcheck_fast_mode=True, 13863 supports_forward_ad=True, 13864 supports_fwgrad_bwgrad=True, 13865 # https://github.com/pytorch/pytorch/issues/66357 13866 check_batched_forward_grad=False, 13867 sample_inputs_func=sample_inputs_lu, 13868 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 13869 skips=( 13870 # we skip jit tests because `lu` is a torch function 13871 # RuntimeError: 13872 # 'Tensor (inferred)' object has no attribute or method 'lu'.: 13873 # File "<string>", line 3 13874 # def the_method(i0): 13875 # return i0.lu(True, True) 13876 # ~~~~~ <--- HERE 13877 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 13878 # RuntimeError not raised: Expected RuntimeError when calling with input.device=cpu and out.device=cuda 13879 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 13880 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 13881 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 13882 )), 13883 OpInfo('lu_solve', 13884 op=torch.lu_solve, 13885 dtypes=floating_and_complex_types(), 13886 supports_forward_ad=True, 13887 # See https://github.com/pytorch/pytorch/issues/66357 13888 check_batched_forward_grad=False, 13889 supports_fwgrad_bwgrad=True, 13890 sample_inputs_func=sample_inputs_lu_solve, 13891 skips=( 13892 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', 13893 device_type='mps', dtypes=[torch.float32]), 13894 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', 13895 device_type='mps', dtypes=[torch.float32]), 13896 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', 13897 device_type='mps', dtypes=[torch.float32]), 13898 DecorateInfo(unittest.skip("Tests different backward paths"), 13899 "TestCommon", "test_floating_inputs_are_differentiable"),), 13900 decorators=[skipCPUIfNoLapack, skipCUDAIfNoMagmaAndNoCusolver]), 13901 OpInfo('masked_fill', 13902 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 13903 sample_inputs_func=sample_inputs_masked_fill, 13904 error_inputs_func=error_inputs_masked_fill, 13905 supports_forward_ad=True, 13906 supports_fwgrad_bwgrad=True, 13907 check_batched_forward_grad=False, 13908 supports_out=False), 13909 OpInfo('masked_scatter', 13910 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13911 sample_inputs_func=sample_inputs_masked_scatter, 13912 error_inputs_func=error_inputs_masked_scatter, 13913 supports_forward_ad=True, 13914 supports_fwgrad_bwgrad=True, 13915 # https://github.com/pytorch/pytorch/issues/66357 13916 check_batched_forward_grad=False, 13917 supports_out=False, 13918 skips=( 13919 )), 13920 OpInfo('masked_select', 13921 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 13922 supports_forward_ad=True, 13923 supports_fwgrad_bwgrad=True, 13924 sample_inputs_func=sample_inputs_masked_select, 13925 error_inputs_func=error_inputs_masked_select, 13926 skips=( 13927 # Compiler issue on ROCm. Might need to skip until ROCm5.5 13928 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', 13929 dtypes=[torch.bool], active_if=TEST_WITH_ROCM), 13930 )), 13931 OpInfo('matrix_exp', 13932 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 13933 aliases=('linalg.matrix_exp',), 13934 sample_inputs_func=sample_inputs_matrix_exp, 13935 # Needs to construct a 2nx2n matrix by copy_ ing into it 13936 check_batched_grad=False, 13937 check_batched_gradgrad=False, 13938 supports_forward_ad=True, 13939 supports_fwgrad_bwgrad=True, 13940 # https://github.com/pytorch/pytorch/issues/66357 13941 check_batched_forward_grad=False, 13942 skips=( 13943 # mexp does not support bf16 and fp16 13944 DecorateInfo(unittest.skip('Skipped!'), 'TestInductorOpInfo', 'test_comprehensive', 13945 dtypes=[torch.half], device_type="cpu"), 13946 ), 13947 supports_out=False, 13948 ), 13949 OpInfo('matmul', 13950 aliases=('linalg.matmul',), 13951 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 13952 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, 13953 *[torch.bfloat16] 13954 if SM53OrLater or TEST_WITH_ROCM else []), 13955 assert_autodiffed=True, 13956 assert_jit_shape_analysis=True, 13957 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 13958 gradcheck_fast_mode=True, 13959 supports_forward_ad=True, 13960 supports_fwgrad_bwgrad=True, 13961 check_batched_forward_grad=False, 13962 sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=False), 13963 decorators=[ 13964 # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 13965 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), 13966 # ROCm intermittently fails the test with standard atol/rtol 13967 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}), 13968 'TestCommon', 'test_noncontiguous_samples', device_type='cuda', 13969 active_if=TEST_WITH_ROCM), 13970 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=0)}), 13971 'TestCommon', 'test_out', device_type='cuda', 13972 active_if=TEST_WITH_ROCM), 13973 # mv for the sample with shapes (S, S, M, M), (M,) has some variance in the 13974 # backward on CPU 13975 DecorateInfo(toleranceOverride({torch.float32: tol(atol=0, rtol=1e-5)}), 13976 'TestCommon', 'test_noncontiguous_samples', 13977 device_type='cpu'), 13978 DecorateInfo( 13979 toleranceOverride({ 13980 torch.float32: tol(atol=1e-5, rtol=1e-5), 13981 torch.complex64: tol(atol=1e-5, rtol=1e-5), 13982 }), 13983 "TestDecomp", "test_comprehensive", device_type="cuda", 13984 ), 13985 ], 13986 skips=( 13987 # Strides are not the same! 13988 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 13989 # https://github.com/pytorch/pytorch/issues/67470 13990 DecorateInfo(unittest.skip("67470!"), 13991 'TestCommon', 'test_noncontiguous_samples', 13992 device_type='cpu', dtypes=(torch.long,)), 13993 # AssertionError: False is not true : Tensors failed to compare as equal! 13994 DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', 13995 device_type='xla', dtypes=(torch.long,)), 13996 # https://github.com/pytorch/pytorch/issues/71774 13997 DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', 13998 device_type='cpu', dtypes=(torch.long,)), 13999 )), 14000 OpInfo('max', 14001 variant_test_name='reduction_with_dim', 14002 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 14003 sample_inputs_func=sample_inputs_max_min_reduction_with_dim, 14004 supports_fwgrad_bwgrad=True, 14005 skips=( 14006 ), 14007 supports_forward_ad=True), 14008 OpInfo('max', 14009 variant_test_name='reduction_no_dim', 14010 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 14011 supports_out=True, 14012 supports_forward_ad=True, 14013 supports_fwgrad_bwgrad=True, 14014 sample_inputs_func=sample_inputs_max_min_reduction_no_dim, 14015 skips=( 14016 )), 14017 OpInfo('median', 14018 dtypes=all_types_and(torch.bfloat16, torch.float16), 14019 # TODO: some signatures of median do support out 14020 supports_out=False, 14021 supports_forward_ad=True, 14022 supports_fwgrad_bwgrad=True, 14023 error_inputs_func=error_inputs_median, 14024 sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), 14025 OpInfo('nanmedian', 14026 dtypes=all_types_and(torch.bfloat16, torch.float16), 14027 # TODO: some signatures of nanmedian do support out 14028 supports_out=False, 14029 supports_forward_ad=True, 14030 supports_fwgrad_bwgrad=True, 14031 sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), 14032 OpInfo('var_mean', 14033 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 14034 sample_inputs_func=sample_inputs_std_var, 14035 # TODO: some signatures of var_mean do support out 14036 supports_out=False, 14037 supports_forward_ad=True, 14038 check_batched_forward_grad=False, 14039 supports_fwgrad_bwgrad=True, 14040 decorators=( 14041 DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), 14042 "TestDecomp", "test_comprehensive", device_type="cuda"), 14043 DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), 14044 "TestInductorOpInfo", "test_comprehensive", device_type="cuda"), 14045 )), 14046 OpInfo('var_mean', 14047 variant_test_name='unbiased', 14048 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 14049 sample_inputs_func=sample_inputs_std_var_unbiased, 14050 # TODO: some signatures of var_mean do support out 14051 supports_out=False, 14052 supports_forward_ad=True, 14053 check_batched_forward_grad=False, 14054 supports_fwgrad_bwgrad=True, 14055 decorators=( 14056 DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), 14057 "TestDecomp", "test_comprehensive", device_type="cuda"), 14058 DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), 14059 "TestInductorOpInfo", "test_comprehensive", device_type="cuda"), 14060 )), 14061 OpInfo('std_mean', 14062 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 14063 sample_inputs_func=sample_inputs_std_var, 14064 # TODO: some signatures of std_mean do support out 14065 supports_out=False, 14066 supports_forward_ad=True, 14067 check_batched_forward_grad=False, 14068 supports_fwgrad_bwgrad=True, 14069 decorators=( 14070 DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}), 14071 "TestDecomp", "test_comprehensive", device_type="cuda"), 14072 )), 14073 OpInfo('std_mean', 14074 variant_test_name='unbiased', 14075 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 14076 sample_inputs_func=sample_inputs_std_var_unbiased, 14077 # TODO: some signatures of var_mean do support out 14078 supports_out=False, 14079 supports_forward_ad=True, 14080 check_batched_forward_grad=False, 14081 supports_fwgrad_bwgrad=True, 14082 decorators=( 14083 DecorateInfo( 14084 toleranceOverride({ 14085 torch.float16: tol(atol=4e-5, rtol=9e-3), 14086 torch.float64: tol(atol=2e-7, rtol=2e-7), 14087 }), 14088 "TestDecomp", 14089 "test_comprehensive", 14090 device_type="cuda" 14091 ), 14092 DecorateInfo( 14093 toleranceOverride({ 14094 torch.float16: tol(atol=4e-5, rtol=9e-3), 14095 torch.float64: tol(atol=2e-7, rtol=2e-7), 14096 }), 14097 "TestInductorOpInfo", 14098 "test_comprehensive", 14099 device_type="cuda" 14100 ), 14101 )), 14102 OpInfo('meshgrid', 14103 variant_test_name='variadic_tensors', 14104 ref=np.meshgrid, 14105 dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16), 14106 sample_inputs_func=partial(sample_inputs_meshgrid, variant='variadic'), 14107 skips=[ 14108 # JIT does not support variadic tensors. 14109 # RuntimeError: input->type()->kind() == TypeKind::OptionalType 14110 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, 14111 # please report a bug to PyTorch. 14112 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 14113 # meshgrid is defined in torch.functional to take a 14114 # variadic list of tensors. Variadic parameters are not 14115 # compatible with the normalize operator tests. 14116 DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 14117 # Skip operator schema test because this is a functional and not an operator 14118 DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), 14119 ], 14120 supports_out=False, 14121 supports_fwgrad_bwgrad=True, 14122 supports_forward_ad=True, 14123 # See https://github.com/pytorch/pytorch/pull/78358 14124 check_batched_forward_grad=False,), 14125 OpInfo('meshgrid', 14126 variant_test_name='list_of_tensors', 14127 # Unlike the variant above, we do not use np.meshgrid as a 14128 # ref since it does not officially support list of numpy 14129 # arrays. 14130 dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.float16), 14131 sample_inputs_func=partial(sample_inputs_meshgrid, variant='list'), 14132 skips=[ 14133 # meshgrid is defined in torch.functional to take a 14134 # variadic list of tensors. Variadic parameters are not 14135 # compatible with the normalize operator tests. 14136 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 14137 ], 14138 assert_autodiffed=True, 14139 supports_out=False, 14140 autodiff_nonfusible_nodes=[], 14141 supports_fwgrad_bwgrad=True, 14142 supports_forward_ad=True, 14143 # See https://github.com/pytorch/pytorch/pull/78358 14144 check_batched_forward_grad=False,), 14145 OpInfo('min', 14146 variant_test_name='reduction_with_dim', 14147 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 14148 sample_inputs_func=sample_inputs_max_min_reduction_with_dim, 14149 supports_fwgrad_bwgrad=True, 14150 supports_forward_ad=True, 14151 skips=( 14152 )), 14153 OpInfo('min', 14154 variant_test_name='reduction_no_dim', 14155 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 14156 supports_out=True, 14157 supports_forward_ad=True, 14158 supports_fwgrad_bwgrad=True, 14159 sample_inputs_func=sample_inputs_max_min_reduction_no_dim, 14160 skips=( 14161 )), 14162 OpInfo('quantile', 14163 dtypes=floating_types(), 14164 sample_inputs_func=sample_inputs_reduction_quantile, 14165 supports_forward_ad=True, 14166 supports_fwgrad_bwgrad=True, 14167 # See https://github.com/pytorch/pytorch/issues/66357 14168 # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which 14169 # does not have a batching rule in core 14170 check_batched_forward_grad=False), 14171 OpInfo('nanquantile', 14172 dtypes=floating_types(), 14173 sample_inputs_func=sample_inputs_reduction_quantile, 14174 supports_forward_ad=True, 14175 supports_fwgrad_bwgrad=True, 14176 # See https://github.com/pytorch/pytorch/issues/66357 14177 # Relies on copy_ to broadcast, but the forward AD path calls broadcast_to which 14178 # does not have a batching rule in core 14179 check_batched_forward_grad=False), 14180 BinaryUfuncInfo( 14181 'max', 14182 aliases=('maximum',), 14183 variant_test_name='binary', 14184 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 14185 supports_forward_ad=True, 14186 supports_fwgrad_bwgrad=True, 14187 assert_autodiffed=True, 14188 ref=np.maximum, 14189 supports_rhs_python_scalar=False, 14190 skips=( 14191 # Incorrectly attempts to use a scalar for the second argument 14192 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), 14193 # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' 14194 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'), 14195 )), 14196 BinaryUfuncInfo( 14197 'maximum', 14198 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 14199 supports_forward_ad=True, 14200 supports_fwgrad_bwgrad=True, 14201 ref=np.maximum, 14202 supports_rhs_python_scalar=False, 14203 skips=( 14204 # TODO: FIXME: RuntimeError: "max_elementwise_cuda" not implemented for 'ComplexFloat' 14205 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion', device_type='cuda'), 14206 )), 14207 BinaryUfuncInfo( 14208 'min', 14209 aliases=('minimum',), 14210 variant_test_name='binary', 14211 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 14212 supports_forward_ad=True, 14213 supports_fwgrad_bwgrad=True, 14214 assert_autodiffed=True, 14215 ref=np.minimum, 14216 supports_rhs_python_scalar=False, 14217 skips=( 14218 # Incorrectly attempts to use a scalar for the second argument 14219 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), 14220 # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' 14221 DecorateInfo(unittest.expectedFailure, 14222 'TestBinaryUfuncs', 14223 'test_type_promotion', 14224 device_type='cuda'), 14225 )), 14226 BinaryUfuncInfo( 14227 'minimum', 14228 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 14229 supports_forward_ad=True, 14230 supports_fwgrad_bwgrad=True, 14231 ref=np.minimum, 14232 supports_rhs_python_scalar=False, 14233 skips=( 14234 # TODO: FIXME: RuntimeError: "min_elementwise_cuda" not implemented for 'ComplexFloat' 14235 DecorateInfo(unittest.expectedFailure, 14236 'TestBinaryUfuncs', 14237 'test_type_promotion', 14238 device_type='cuda'), 14239 ), 14240 ), 14241 BinaryUfuncInfo('logical_and', 14242 ref=np.logical_and, 14243 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 14244 supports_autograd=False, 14245 always_returns_bool=True, 14246 supports_rhs_python_scalar=False), 14247 BinaryUfuncInfo('logical_or', 14248 ref=np.logical_or, 14249 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 14250 supports_autograd=False, 14251 always_returns_bool=True, 14252 supports_rhs_python_scalar=False), 14253 BinaryUfuncInfo('logical_xor', 14254 ref=np.logical_xor, 14255 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 14256 supports_autograd=False, 14257 always_returns_bool=True, 14258 supports_rhs_python_scalar=False, 14259 skips=( 14260 )), 14261 BinaryUfuncInfo('bitwise_and', 14262 ref=np.bitwise_and, 14263 dtypes=integral_types_and(torch.bool), 14264 operator_variant=operator.and_, 14265 inplace_operator_variant=operator.iand, 14266 supports_autograd=False, 14267 supports_one_python_scalar=True, 14268 skips=( 14269 # RuntimeError: "bitwise_and_cuda" not implemented for 'Half' 14270 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 14271 'test_type_promotion', device_type='cuda'), 14272 )), 14273 BinaryUfuncInfo('bitwise_or', 14274 ref=np.bitwise_or, 14275 dtypes=integral_types_and(torch.bool), 14276 operator_variant=operator.or_, 14277 inplace_operator_variant=operator.ior, 14278 supports_autograd=False, 14279 supports_one_python_scalar=True, 14280 skips=( 14281 # TODO: FIXME: RuntimeError: "bitwise_or_cuda" not implemented for 'Half' 14282 DecorateInfo(unittest.expectedFailure, 14283 'TestBinaryUfuncs', 14284 'test_type_promotion', 14285 device_type='cuda'), 14286 )), 14287 BinaryUfuncInfo('bitwise_xor', 14288 ref=np.bitwise_xor, 14289 dtypes=integral_types_and(torch.bool), 14290 operator_variant=operator.xor, 14291 inplace_operator_variant=operator.ixor, 14292 supports_autograd=False, 14293 supports_one_python_scalar=True, 14294 skips=( 14295 # TODO: FIXME: RuntimeError: "bitwise_xor_cuda" not implemented for 'Half' 14296 DecorateInfo(unittest.expectedFailure, 14297 'TestBinaryUfuncs', 14298 'test_type_promotion', 14299 device_type='cuda'), 14300 )), 14301 BinaryUfuncInfo('heaviside', 14302 ref=lambda a, b: ( 14303 # necessary because np.heaviside incorrectly returns float64 when passed args of dtype int64 14304 np.int64(np.heaviside(a, b)) if a.dtype == np.int64 and b.dtype == np.int64 else np.heaviside(a, b) 14305 ), 14306 dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), 14307 supports_autograd=False, 14308 supports_rhs_python_scalar=False, 14309 skips=( 14310 # RuntimeError: heaviside is not yet implemented for tensors with different dtypes. 14311 DecorateInfo(unittest.expectedFailure, 14312 'TestBinaryUfuncs', 14313 'test_type_promotion'), 14314 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), 14315 # PyTorch's heaviside does not appear to propagate NaNs 14316 DecorateInfo(unittest.skip("Skipped!"), 14317 'TestBinaryUfuncs', 14318 'test_reference_numerics_extremal_values'), 14319 )), 14320 BinaryUfuncInfo('lcm', 14321 ref=np.lcm, 14322 dtypes=integral_types_and(), 14323 supports_autograd=False, 14324 supports_rhs_python_scalar=False), 14325 BinaryUfuncInfo('gcd', 14326 ref=np.gcd, 14327 dtypes=integral_types_and(), 14328 supports_autograd=False, 14329 supports_rhs_python_scalar=False, 14330 skips=( 14331 DecorateInfo(unittest.expectedFailure, 14332 'TestBinaryUfuncs', 14333 'test_reference_numerics_small_values', 14334 dtypes=(torch.int8,)),)), 14335 BinaryUfuncInfo('isclose', 14336 ref=np.isclose, 14337 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 14338 sample_inputs_func=sample_inputs_isclose, 14339 error_inputs_func=error_inputs_isclose, 14340 supports_autograd=False, 14341 supports_out=False, 14342 supports_rhs_python_scalar=False, 14343 skips=( 14344 DecorateInfo(unittest.expectedFailure, 14345 'TestCommon', 14346 'test_numpy_refs', dtypes=(torch.complex128,)), 14347 # RuntimeError: Short did not match Int 14348 DecorateInfo(unittest.expectedFailure, 14349 'TestBinaryUfuncs', 14350 'test_type_promotion'), 14351 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), 14352 DecorateInfo(unittest.skip("Skipped!"), 14353 'TestBinaryUfuncs', 14354 'test_reference_numerics_extremal_values'), 14355 )), 14356 # `softmax` supports different dtypes based on whether `dtype` argument, 14357 # is passed or not. Hence two OpInfo entries, one with dtype and other without. 14358 # https://github.com/pytorch/pytorch/issues/68752 14359 OpInfo('softmax', 14360 aliases=('special.softmax', 'nn.functional.softmax',), 14361 aten_name='softmax', 14362 aten_backward_name='_softmax_backward_data', 14363 dtypes=floating_types_and(torch.half, torch.bfloat16), 14364 sample_inputs_func=sample_inputs_softmax_variant, 14365 assert_jit_shape_analysis=True, 14366 assert_autodiffed=True, 14367 supports_forward_ad=True, 14368 supports_fwgrad_bwgrad=True, 14369 supports_out=True), 14370 OpInfo('softmax', 14371 aliases=('special.softmax', 'nn.functional.softmax',), 14372 variant_test_name="with_dtype", 14373 aten_name='softmax', 14374 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 14375 sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), 14376 assert_autodiffed=True, 14377 supports_forward_ad=True, 14378 supports_fwgrad_bwgrad=True, 14379 supports_out=True), 14380 OpInfo( 14381 '_softmax_backward_data', 14382 op=torch.ops.aten._softmax_backward_data, 14383 aten_name='_softmax_backward_data', 14384 dtypes=floating_types_and(torch.bfloat16, torch.float16), 14385 sample_inputs_func=sample_inputs_softmax_backward_data, 14386 assert_autodiffed=True, 14387 supports_forward_ad=True, 14388 supports_fwgrad_bwgrad=True, 14389 supports_out=False, 14390 skips=( 14391 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cpu'), 14392 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 14393 ), 14394 ), 14395 # `softmin` supports different dtypes based on whether `dtype` argument, 14396 # is passed or not. Hence two OpInfo entries, one with dtype and other without. 14397 # https://github.com/pytorch/pytorch/issues/68752 14398 OpInfo('nn.functional.softmin', 14399 aten_name='softmin', 14400 dtypes=floating_types_and(torch.half, torch.bfloat16), 14401 sample_inputs_func=sample_inputs_softmax_variant, 14402 assert_jit_shape_analysis=False, 14403 assert_autodiffed=False, 14404 supports_forward_ad=True, 14405 supports_fwgrad_bwgrad=True, 14406 supports_out=False), 14407 OpInfo('nn.functional.softmin', 14408 variant_test_name="with_dtype", 14409 aten_name='softmin', 14410 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 14411 sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), 14412 assert_autodiffed=False, 14413 supports_forward_ad=True, 14414 supports_fwgrad_bwgrad=True, 14415 supports_out=False), 14416 OpInfo( 14417 "nn.functional.cross_entropy", 14418 dtypes=floating_types_and(torch.float16, torch.bfloat16), 14419 sample_inputs_func=sample_inputs_cross_entropy, 14420 supports_out=False, 14421 supports_forward_ad=True, 14422 supports_fwgrad_bwgrad=True, 14423 decorators=( 14424 DecorateInfo( 14425 toleranceOverride({torch.float32: tol(atol=3e-3, rtol=1e-3)}), 14426 "TestJit", 14427 "test_variant_consistency_jit", 14428 device_type="cpu", 14429 ), 14430 ), 14431 skips=( 14432 # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 1536 14433 # test_ops.TestJitCUDA.test_variant_consistency_jit_nn_functional_cross_entropy_cuda_float32 leaked 14434 # 1536 bytes CUDA memory on device 0 14435 DecorateInfo( 14436 unittest.expectedFailure, 14437 "TestJit", 14438 "test_variant_consistency_jit", 14439 device_type="cuda", 14440 ), 14441 DecorateInfo(unittest.skip("FP16 corss_entropy cases have not been enabled on MPS yet"), 14442 dtypes=(torch.half,), device_type="mps"), 14443 14444 ) 14445 ), 14446 OpInfo('nn.functional.normalize', 14447 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 14448 sample_inputs_func=sample_inputs_normalize, 14449 supports_forward_ad=True, 14450 supports_fwgrad_bwgrad=True), 14451 OpInfo('aminmax', 14452 ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)), 14453 dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), 14454 decorators=(onlyNativeDeviceTypes,), 14455 supports_autograd=False, 14456 sample_inputs_func=sample_inputs_aminmax, 14457 error_inputs_func=error_inputs_aminmax_amax_amin), 14458 OpInfo('as_strided', 14459 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 14460 supports_out=False, 14461 supports_forward_ad=True, 14462 supports_fwgrad_bwgrad=True, 14463 # vmap does not support inplace views 14464 check_inplace_batched_forward_grad=False, 14465 sample_inputs_func=sample_inputs_as_strided, 14466 skips=( 14467 # Note: This xfail is fine -- it's inherent to how as_strided works 14468 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), 14469 # AssertionError: False is not true : Scalars failed to compare as equal! 14470 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 14471 'TestCommon', 'test_variant_consistency_eager'), 14472 # Not close 14473 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 14474 'TestCommon', 'test_complex_half_reference_testing'), 14475 # Not close 14476 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), 14477 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), 14478 DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'), 14479 DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'), 14480 )), 14481 OpInfo('as_strided', 14482 variant_test_name='partial_views', 14483 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 14484 supports_out=False, 14485 supports_forward_ad=True, 14486 supports_fwgrad_bwgrad=True, 14487 # vmap does not support inplace views 14488 check_inplace_batched_forward_grad=False, 14489 sample_inputs_func=sample_inputs_as_strided_partial_views, 14490 skips=( 14491 # Note: This xfail is fine -- it's inherent to how as_strided works 14492 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), 14493 # RuntimeError: This operator is not Composite Compliant: the 14494 # storage_offset of the tensor was modified directly without 14495 # going through the PyTorch dispatcher. 14496 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), 14497 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), 14498 14499 # These fail because the test changes the input's in-memory layout 14500 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'), 14501 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 14502 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), 14503 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 14504 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', 14505 dtypes=(torch.complex64, torch.complex128)), 14506 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), 14507 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'), 14508 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_grad'), 14509 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_gradgrad'), 14510 DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo', 14511 'test_make_fx_symbolic_exhaustive_inplace'), 14512 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), 14513 # Fail but are also flaky 14514 DecorateInfo(unittest.skip("Test changes in memory layout"), 'TestMathBits'), 14515 DecorateInfo(unittest.skip("Modifies input strides and storage_offset"), 'TestCommon', 14516 'test_non_standard_bool_values'), 14517 # RuntimeError: setStorage: sizes [2, 2], strides [1, 2], storage offset 10, and itemsize 2 requiring a 14518 # storage size of 28 are out of bounds for storage of size 20 14519 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace'), 14520 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace'), 14521 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace'), 14522 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace_all_strides'), 14523 )), 14524 OpInfo('as_strided_copy', 14525 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 14526 supports_out=True, 14527 supports_forward_ad=True, 14528 supports_fwgrad_bwgrad=True, 14529 # vmap does not support inplace views 14530 check_inplace_batched_forward_grad=False, 14531 sample_inputs_func=sample_inputs_as_strided, 14532 skips=( 14533 # Note: This xfail is fine -- it's inherent to how as_strided works 14534 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), 14535 # AssertionError: False is not true : Scalars failed to compare as equal! 14536 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 14537 'TestCommon', 'test_variant_consistency_eager'), 14538 # Not close 14539 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 14540 'TestCommon', 'test_complex_half_reference_testing'), 14541 # Not close 14542 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), 14543 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), 14544 DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'), 14545 DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'), 14546 DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), 14547 )), 14548 OpInfo('as_strided_scatter', 14549 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 14550 supports_out=False, 14551 supports_forward_ad=True, 14552 supports_fwgrad_bwgrad=True, 14553 # vmap does not support inplace views 14554 check_inplace_batched_forward_grad=False, 14555 sample_inputs_func=sample_inputs_as_strided_scatter, 14556 error_inputs_func=error_inputs_as_strided_scatter, 14557 skips=( 14558 DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'), # noqa: B950 14559 DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'), # noqa: B950 14560 DecorateInfo(unittest.skip('Fails on cuda + rocm'), 'TestCommon', 'test_complex_half_reference_testing'), 14561 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), 14562 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), 14563 DecorateInfo(unittest.skip('Passes on complex128 and float64 only'), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), 14564 # AssertionError: Tensor-likes are not close! (new_empty_strided.default) 14565 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'),)), 14566 OpInfo('native_layer_norm', 14567 aten_name='native_layer_norm', 14568 ref=reference_native_layer_norm, 14569 dtypes=floating_types_and(torch.half, torch.bfloat16), 14570 supports_out=False, 14571 assert_jit_shape_analysis=True, 14572 supports_fwgrad_bwgrad=True, 14573 sample_inputs_func=sample_inputs_native_layer_norm, 14574 error_inputs_func=error_inputs_native_layer_norm, 14575 skips=( 14576 # IndexError: tuple index out of range 14577 DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients', 'test_forward_mode_AD'), 14578 # Tests fail when weight=None and bias is defined 14579 # https://github.com/pytorch/pytorch/issues/79705 14580 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), 14581 # JIT test also tries to compute double backward, which fails 14582 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 14583 DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), 14584 DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-03, rtol=5e-03)}), 14585 "TestDecomp", "test_comprehensive", device_type="cpu"), 14586 )), 14587 OpInfo('native_batch_norm', 14588 aten_name='native_batch_norm', 14589 dtypes=floating_types_and(torch.float16, torch.bfloat16), 14590 supports_forward_ad=True, 14591 supports_fwgrad_bwgrad=True, 14592 assert_jit_shape_analysis=True, 14593 allow_cow_input_materialize_forward=[3, 4], 14594 allow_cow_input_materialize_backward=[3, 4], 14595 sample_inputs_func=sample_inputs_native_batch_norm, 14596 skips=( 14597 # NotImplementedError: Could not run 14598 # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. 14599 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), 14600 # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] 14601 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), 14602 # Problem with _get_numerical_jacobian 14603 # IndexError: tuple index out of range 14604 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), 14605 # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED 14606 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 14607 # https://github.com/pytorch/pytorch/issues/85960 14608 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), 14609 # AssertionError: Booleans mismatch: True is not False 14610 DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'), 14611 DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'), 14612 DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), 14613 "TestCompositeCompliance", "test_forward_ad"), 14614 ) 14615 ), 14616 OpInfo('_native_batch_norm_legit', 14617 aten_name='_native_batch_norm_legit', 14618 dtypes=floating_types_and(torch.float16, torch.bfloat16), 14619 supports_forward_ad=True, 14620 supports_fwgrad_bwgrad=True, 14621 assert_jit_shape_analysis=True, 14622 allow_cow_input_materialize_forward=[3, 4], 14623 allow_cow_input_materialize_backward=[3, 4], 14624 sample_inputs_func=sample_inputs__native_batch_norm_legit, 14625 skips=( 14626 # NotImplementedError: Could not run 14627 # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. 14628 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), 14629 # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] 14630 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), 14631 # Problem with _get_numerical_jacobian 14632 # IndexError: tuple index out of range 14633 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), 14634 # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED 14635 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 14636 # https://github.com/pytorch/pytorch/issues/85960 14637 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), 14638 DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), 14639 "TestCompositeCompliance", "test_forward_ad"), 14640 ) 14641 ), 14642 OpInfo('_batch_norm_with_update', 14643 op=torch.ops.aten._batch_norm_with_update, 14644 aten_name='_batch_norm_with_update', 14645 dtypes=floating_types_and(torch.float16, torch.bfloat16), 14646 supports_forward_ad=True, 14647 supports_fwgrad_bwgrad=True, 14648 assert_jit_shape_analysis=True, 14649 allow_cow_input_materialize_forward=[3, 4], 14650 allow_cow_input_materialize_backward=[3, 4], 14651 sample_inputs_func=sample_inputs__batch_norm_with_update, 14652 skips=( 14653 # NotImplementedError: Could not run 14654 # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. 14655 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), 14656 # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] 14657 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), 14658 # Problem with _get_numerical_jacobian 14659 # IndexError: tuple index out of range 14660 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), 14661 # RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED 14662 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 14663 DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-5)}), 14664 "TestCompositeCompliance", "test_forward_ad"), 14665 # _batch_norm_with_update expects contiguous inputs for cudnn and miopen 14666 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type="cuda"), 14667 DecorateInfo(unittest.expectedFailure, 14668 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides', device_type="cuda"), 14669 # _batch_norm_with_update does not have python bindings 14670 DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 14671 # aten out variants do not accept out= kwarg, only python out variants 14672 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 14673 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 14674 ) 14675 ), 14676 OpInfo('nn.functional.cosine_similarity', 14677 aten_name="cosine_similarity", 14678 dtypes=floating_types_and(torch.half, torch.bfloat16), 14679 supports_out=False, 14680 supports_forward_ad=True, 14681 supports_fwgrad_bwgrad=True, 14682 decorators=[ 14683 DecorateInfo( 14684 toleranceOverride({torch.float16: tol(atol=1.3e-5, rtol=2e-2)}), 14685 "TestInductorOpInfo", 14686 "test_comprehensive", 14687 device_type="cuda" 14688 ), 14689 ], 14690 sample_inputs_func=sample_inputs_cosine_similarity), 14691 OpInfo('nn.functional.adaptive_avg_pool1d', 14692 dtypes=floating_types_and(torch.half, torch.bfloat16), 14693 supports_out=False, 14694 supports_forward_ad=True, 14695 supports_fwgrad_bwgrad=True, 14696 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14697 error_inputs_func=error_inputs_adaptive_avg_pool1d, 14698 sample_inputs_func=sample_inputs_adaptive_avg_pool1d), 14699 OpInfo('nn.functional.adaptive_avg_pool2d', 14700 dtypes=floating_types_and(torch.half, torch.bfloat16), 14701 decorators=( 14702 # RuntimeError: 14703 # adaptive_avg_pool2d(Tensor input, int[2] output_size) -> (Tensor): 14704 # Expected a value of type 'List[int]' for argument 'output_size' but 14705 # instead found type 'Tuple[NoneType, int]'. : 14706 # File "<string>", line 3 14707 # def the_method(i0): 14708 # return torch.nn.functional.adaptive_avg_pool2d(i0, (None, 7)) 14709 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE 14710 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 14711 ), 14712 supports_out=False, 14713 supports_forward_ad=True, 14714 supports_fwgrad_bwgrad=True, 14715 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14716 error_inputs_func=error_inputs_adaptive_avg_pool2d, 14717 sample_inputs_func=sample_inputs_adaptive_avg_pool2d), 14718 OpInfo('nn.functional.adaptive_avg_pool3d', 14719 dtypes=floating_types_and(torch.half, torch.bfloat16), 14720 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), 14721 decorators=( 14722 # RuntimeError: 14723 # adaptive_avg_pool3d(Tensor input, int[3] output_size) -> (Tensor): 14724 # Expected a value of type 'List[int]' for argument 'output_size' but 14725 # instead found type 'Tuple[NoneType, NoneType, NoneType]'. : 14726 # File "<string>", line 3 14727 # 14728 # def the_method(i0): 14729 # return torch.nn.functional.adaptive_avg_pool3d(i0, (None, None, None)) 14730 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE 14731 # 14732 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 14733 ), 14734 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 14735 gradcheck_fast_mode=True, 14736 supports_out=False, 14737 supports_forward_ad=True, 14738 supports_fwgrad_bwgrad=True, 14739 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14740 error_inputs_func=error_inputs_adaptive_avg_pool3d, 14741 sample_inputs_func=sample_inputs_adaptive_avg_pool3d), 14742 OpInfo('nn.functional.adaptive_max_pool1d', 14743 dtypes=floating_types_and(torch.half, torch.bfloat16), 14744 supports_out=False, 14745 supports_forward_ad=True, 14746 supports_fwgrad_bwgrad=True, 14747 # got: Batching rule not implemented for aten::flatten.using_ints 14748 check_batched_forward_grad=False, 14749 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14750 error_inputs_func=error_inputs_adaptive_max_pool1d, 14751 sample_inputs_func=sample_inputs_adaptive_max_pool1d), 14752 OpInfo('nn.functional.adaptive_max_pool2d', 14753 dtypes=floating_types_and(torch.half, torch.bfloat16), 14754 decorators=( 14755 # RuntimeError: 14756 # adaptive_max_pool2d(Tensor input, int[2] output_size) -> (Tensor): 14757 # Expected a value of type 'List[int]' for argument 'output_size' but 14758 # instead found type 'Tuple[NoneType, int]'. : 14759 # File "<string>", line 3 14760 # def the_method(i0): 14761 # return torch.nn.functional.adaptive_max_pool2d(i0, (None, 7)) 14762 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE 14763 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 14764 ), 14765 supports_out=False, 14766 supports_forward_ad=True, 14767 supports_fwgrad_bwgrad=True, 14768 # got: Batching rule not implemented for aten::flatten.using_ints 14769 check_batched_forward_grad=False, 14770 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14771 error_inputs_func=error_inputs_adaptive_max_pool2d, 14772 sample_inputs_func=sample_inputs_adaptive_max_pool2d), 14773 OpInfo('nn.functional.adaptive_max_pool3d', 14774 dtypes=floating_types_and(torch.bfloat16), 14775 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), 14776 decorators=( 14777 # RuntimeError: 14778 # adaptive_max_pool3d(Tensor input, int[3] output_size) -> (Tensor): 14779 # Expected a value of type 'List[int]' for argument 'output_size' but 14780 # instead found type 'Tuple[NoneType, NoneType, NoneType]'. : 14781 # File "<string>", line 3 14782 # 14783 # def the_method(i0): 14784 # return torch.nn.functional.adaptive_max_pool3d(i0, (None, None, None)) 14785 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE 14786 # 14787 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 14788 ), 14789 supports_out=False, 14790 supports_forward_ad=True, 14791 supports_fwgrad_bwgrad=True, 14792 # got: Batching rule not implemented for aten::flatten.using_ints 14793 check_batched_forward_grad=False, 14794 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14795 error_inputs_func=error_inputs_adaptive_max_pool3d, 14796 sample_inputs_func=sample_inputs_adaptive_max_pool3d), 14797 OpInfo('nn.functional.avg_pool1d', 14798 aten_name='avg_pool1d', 14799 supports_autograd=True, 14800 supports_out=False, 14801 supports_forward_ad=True, 14802 supports_fwgrad_bwgrad=True, 14803 dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), 14804 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 14805 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14806 error_inputs_func=error_inputs_avg_pool1d, 14807 sample_inputs_func=sample_inputs_avgpool1d), 14808 OpInfo('nn.functional.avg_pool3d', 14809 aten_name='avg_pool3d', 14810 supports_autograd=True, 14811 supports_forward_ad=True, 14812 supports_fwgrad_bwgrad=True, 14813 dtypes=floating_types_and(torch.int64), 14814 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 14815 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14816 error_inputs_func=error_inputs_avg_pool3d, 14817 sample_inputs_func=sample_inputs_avgpool3d, 14818 skips=( 14819 # AssertionError: Tensor-likes are not close! 14820 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'), 14821 )), 14822 OpInfo( 14823 "nn.functional.binary_cross_entropy_with_logits", 14824 aten_name="binary_cross_entropy_with_logits", 14825 supports_autograd=True, 14826 supports_forward_ad=True, 14827 supports_fwgrad_bwgrad=True, 14828 supports_out=False, 14829 dtypes=floating_types_and(torch.half, torch.bfloat16), 14830 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14831 sample_inputs_func=sample_inputs_binary_cross_entropy_with_logits, 14832 skips=( 14833 DecorateInfo( 14834 unittest.skip("Skipped!"), 14835 'TestJit', 14836 'test_variant_consistency_jit', 14837 dtypes=(torch.float32,) 14838 ), 14839 ), 14840 ), 14841 UnaryUfuncInfo( 14842 'nn.functional.relu', 14843 aten_name="relu", 14844 ref=lambda a: np.where(a <= 0, 0, a), 14845 supports_autograd=True, 14846 supports_sparse=True, 14847 supports_sparse_csr=True, 14848 supports_sparse_csc=True, 14849 supports_sparse_bsr=True, 14850 supports_sparse_bsc=True, 14851 dtypes=all_types_and(torch.half, torch.bfloat16), 14852 sample_inputs_func=sample_inputs_nn_activation_relu, 14853 supports_out=False, 14854 supports_fwgrad_bwgrad=True, 14855 supports_forward_ad=True), 14856 OpInfo('nn.functional.conv_transpose1d', 14857 # `ref` for this function is backward of 14858 # corresponding `conv*d` 14859 ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose1d), 14860 aten_name='conv_transpose1d', 14861 aliases=('conv_transpose1d',), 14862 dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), 14863 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, 14864 torch.bfloat16), 14865 sample_inputs_func=sample_inputs_conv_transpose1d, 14866 supports_forward_ad=True, 14867 supports_fwgrad_bwgrad=True, 14868 assert_jit_shape_analysis=True, 14869 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14870 decorators=( 14871 DecorateInfo( 14872 toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }), 14873 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), 14874 DecorateInfo( 14875 toleranceOverride({torch.chalf: tol(atol=5e-2, rtol=5e-2), }), 14876 'TestCommon', 'test_complex_half_reference_testing'), 14877 DecorateInfo( 14878 toleranceOverride({torch.float: tol(atol=1.5e-5, rtol=1.5e-5), }), 14879 'TestCommon', 'test_numpy_ref_mps'), 14880 DecorateInfo( 14881 toleranceOverride({torch.half: tol(atol=1e-3, rtol=5e-3), }), 14882 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), 14883 ), 14884 skips=( 14885 # Reason for Skip: https://github.com/pytorch/pytorch/pull/79694#issuecomment-1186949486 14886 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', 14887 dtypes=(torch.complex64,)), 14888 # RuntimeError: UNSUPPORTED DTYPE: complex 14889 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', 14890 dtypes=(torch.complex64, torch.complex128)), 14891 # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at 14892 # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. 14893 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', 14894 dtypes=(torch.float,)), 14895 # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long' 14896 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', 14897 dtypes=(torch.int64,)), 14898 ), 14899 supports_out=False,), 14900 OpInfo('nn.functional.conv_transpose2d', 14901 aten_name='conv_transpose2d', 14902 aliases=('conv_transpose2d',), 14903 # `ref` for this function is backward of 14904 # corresponding `conv*d` 14905 ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose2d), 14906 dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), 14907 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, 14908 torch.bfloat16), 14909 sample_inputs_func=sample_inputs_conv_transpose2d, 14910 # Runs very slowly on slow-gradcheck for complex. 14911 gradcheck_fast_mode=True, 14912 supports_forward_ad=True, 14913 supports_fwgrad_bwgrad=True, 14914 assert_jit_shape_analysis=True, 14915 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14916 decorators=[ 14917 DecorateInfo( 14918 toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), }), 14919 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), 14920 DecorateInfo( 14921 toleranceOverride({torch.float32: tol(atol=2e-05, rtol=5e-05), }), 14922 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), 14923 DecorateInfo( 14924 toleranceOverride({torch.chalf: tol(atol=8e-2, rtol=8e-2), }), 14925 'TestCommon', 'test_complex_half_reference_testing'), 14926 DecorateInfo( 14927 toleranceOverride({torch.half: tol(atol=1e-3, rtol=4e-3), }), 14928 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')], 14929 skips=( 14930 # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at 14931 # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. 14932 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 14933 # RuntimeError: UNSUPPORTED DTYPE: complex 14934 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', 14935 dtypes=(torch.complex64, torch.complex128)), 14936 # RuntimeError: "slow_conv2d_cpu_grad_input" not implemented for 'Long' 14937 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', 14938 dtypes=(torch.int64,)), 14939 # Reference: https://github.com/pytorch/pytorch/issues/86356 14940 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', 14941 dtypes=(torch.double, torch.cdouble)), 14942 DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), 14943 # AssertionError: None mismatch: torch.complex64 is not None 14944 DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', 'test_custom_rules', 14945 dtypes=(torch.complex64, torch.complex128)), 14946 ), 14947 supports_out=False,), 14948 OpInfo('nn.functional.conv_transpose3d', 14949 aten_name='conv_transpose3d', 14950 aliases=('conv_transpose3d',), 14951 # `ref` for this function is backward of 14952 # corresponding `conv*d` 14953 ref=partial(conv_transpose_ref, fn=torch.nn.functional.conv_transpose3d), 14954 dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), 14955 dtypesIfCUDA=floating_and_complex_types_and( 14956 torch.float16, torch.chalf, torch.bfloat16), 14957 sample_inputs_func=sample_inputs_conv_transpose3d, 14958 supports_forward_ad=True, 14959 supports_fwgrad_bwgrad=True, 14960 assert_jit_shape_analysis=True, 14961 # Runs very slowly on slow-gradcheck - alternatively reduce input sizes 14962 gradcheck_fast_mode=True, 14963 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 14964 decorators=[ 14965 DecorateInfo( 14966 toleranceOverride({torch.float16: tol(atol=5e-2, rtol=5e-2), }), 14967 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'), 14968 DecorateInfo( 14969 toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1.3e-06), 14970 torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}), 14971 'TestCommon', 'test_variant_consistency_eager', device_type='cuda'), 14972 DecorateInfo( 14973 toleranceOverride({torch.float32: tol(atol=2e-04, rtol=2e-04), }), 14974 'TestCompositeCompliance', 'test_operator', device_type='cuda'), 14975 DecorateInfo( 14976 toleranceOverride({torch.float32: tol(atol=1.3e-04, rtol=1.3e-06), 14977 torch.complex64: tol(atol=1.3e-04, rtol=1.3e-05)}), 14978 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), 14979 DecorateInfo( 14980 toleranceOverride({torch.float32: tol(atol=1e-04, rtol=2e-05), }), 14981 'TestCompositeCompliance', 'test_forward_ad', device_type='cuda', 14982 active_if=TEST_CUDNN), 14983 DecorateInfo( 14984 toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1e-4)}), 14985 "TestMathBits", "test_conj_view", device_type='cuda'), 14986 DecorateInfo( 14987 toleranceOverride({torch.chalf: tol(atol=9e-2, rtol=9e-2), }), 14988 'TestCommon', 'test_complex_half_reference_testing'), 14989 DecorateInfo( 14990 toleranceOverride({torch.half: tol(atol=9e-3, rtol=2e-1), }), 14991 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')], 14992 skips=( 14993 # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at 14994 # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":104, please report a bug to PyTorch. 14995 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 14996 # RuntimeError: "slow_conv3d_cpu_grad_input" not implemented for 'Long' 14997 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', 14998 dtypes=(torch.int64,)), 14999 # Reference: https://github.com/pytorch/pytorch/issues/86356 15000 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', 15001 dtypes=(torch.double, torch.cdouble)), 15002 DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), 15003 # RuntimeError: UNSUPPORTED DTYPE: complex 15004 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', 15005 dtypes=(torch.complex64, torch.complex128)), 15006 DecorateInfo(unittest.skip('Skipped for ROCm!'), 'TestCommon', 'test_complex_half_reference_testing', 15007 dtypes=[torch.complex32], active_if=TEST_WITH_ROCM), 15008 ), 15009 supports_out=False,), 15010 OpInfo('nn.functional.conv1d', 15011 aliases=('conv1d',), 15012 aten_name='conv1d', 15013 dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), 15014 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, 15015 torch.bfloat16), 15016 sample_inputs_func=sample_inputs_conv1d, 15017 error_inputs_func=error_inputs_conv1d, 15018 supports_forward_ad=True, 15019 supports_fwgrad_bwgrad=True, 15020 assert_jit_shape_analysis=True, 15021 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15022 decorators=( 15023 DecorateInfo( 15024 toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=5e-2)}), 15025 'TestCommon', 'test_complex_half_reference_testing' 15026 ), 15027 DecorateInfo( 15028 toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}), 15029 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda', 15030 ), 15031 ), 15032 skips=( 15033 # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at 15034 # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. 15035 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 15036 # Ref: https://github.com/pytorch/pytorch/issues/75309 15037 # AssertionError: None mismatch: torch.complex128 is not None 15038 DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', 15039 'test_custom_rules', dtypes=(torch.complex64, torch.complex128)), 15040 # Ref: https://github.com/pytorch/pytorch/issues/75309 15041 # RuntimeError: UNSUPPORTED DTYPE: complex 15042 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 15043 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), 15044 ), 15045 supports_expanded_weight=True, 15046 supports_out=False,), 15047 OpInfo('nn.functional.conv2d', 15048 aliases=('conv2d',), 15049 aten_name='conv2d', 15050 dtypes=floating_and_complex_types_and(torch.int64, torch.float16, torch.bfloat16), 15051 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, 15052 torch.bfloat16), 15053 sample_inputs_func=partial(sample_inputs_conv2d), 15054 error_inputs_func=error_inputs_conv2d, 15055 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15056 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 15057 gradcheck_fast_mode=True, 15058 supports_forward_ad=True, 15059 supports_fwgrad_bwgrad=True, 15060 assert_jit_shape_analysis=True, 15061 decorators=( 15062 DecorateInfo( 15063 toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}), 15064 'TestCommon', 'test_complex_half_reference_testing', 15065 ), 15066 ), 15067 skips=( 15068 # RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at 15069 # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. 15070 DecorateInfo(unittest.skip("Works on some configs!"), 'TestJit', 'test_variant_consistency_jit'), 15071 # Ref: https://github.com/pytorch/pytorch/issues/75309 15072 # AssertionError: None mismatch: torch.complex128 is not None 15073 DecorateInfo(unittest.expectedFailure, 'TestDtypeCustomRules', 15074 'test_custom_rules', dtypes=(torch.complex64, torch.complex128)), 15075 # RuntimeError: UNSUPPORTED DTYPE: complex 15076 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 15077 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), 15078 ), 15079 supports_expanded_weight=True, 15080 supports_out=False,), 15081 OpInfo('nn.functional.conv3d', 15082 aliases=('conv3d',), 15083 aten_name='conv3d', 15084 dtypes=floating_and_complex_types_and(torch.int64, torch.bfloat16, torch.float16), 15085 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.chalf, torch.bfloat16), 15086 sample_inputs_func=sample_inputs_conv3d, 15087 error_inputs_func=error_inputs_conv3d, 15088 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15089 gradcheck_fast_mode=True, 15090 supports_forward_ad=True, 15091 supports_fwgrad_bwgrad=True, 15092 decorators=( 15093 DecorateInfo( 15094 toleranceOverride({torch.chalf: tol(atol=6e-2, rtol=5e-2)}), 15095 'TestCommon', 'test_complex_half_reference_testing', 15096 ), 15097 # TF32 15098 DecorateInfo( 15099 toleranceOverride({torch.float32: tol(atol=5e-3, rtol=1e-3), 15100 torch.complex64: tol(atol=5e-3, rtol=1e-3)}), 15101 'TestCommon', 'test_noncontiguous_samples', 15102 ), 15103 DecorateInfo( 15104 toleranceOverride({torch.complex64: tol(atol=5e-5, rtol=5e-6)}), 15105 'TestMathBits', 'test_conj_view', 15106 ), 15107 DecorateInfo( 15108 toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-6)}), 15109 'TestOperators', 'test_vjpvmap', 15110 ), 15111 ), 15112 skips=( 15113 # RuntimeError: !lhs.isAliasOf(rhs) INTERNAL ASSERT FAILED at 15114 # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":103, please report a bug to PyTorch. 15115 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 15116 # RuntimeError: UNSUPPORTED DTYPE: complex 15117 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 15118 'test_nnc_correctness', dtypes=(torch.complex64, torch.complex128)), 15119 # AssertionError: Tensor-likes are not close! 15120 # break slow tests 15121 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), 15122 ), 15123 supports_expanded_weight=True, 15124 supports_out=False,), 15125 OpInfo('nn.functional.group_norm', 15126 aten_name='group_norm', 15127 aliases=('group_norm',), 15128 ref=reference_group_norm, 15129 dtypes=floating_types_and(torch.float16, torch.bfloat16), 15130 supports_out=False, 15131 supports_forward_ad=True, 15132 supports_fwgrad_bwgrad=True, 15133 error_inputs_func=error_inputs_group_norm, 15134 decorators=[ 15135 # RuntimeError: Cannot insert a Tensor that requires grad as a constant. 15136 # Consider making it a parameter or input, or detaching the gradient 15137 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 15138 DecorateInfo( 15139 toleranceOverride({torch.float32: tol(atol=5e-05, rtol=3e-03)}), 15140 "TestDecomp", 15141 "test_comprehensive", 15142 device_type="cpu" 15143 ), 15144 ], 15145 sample_inputs_func=sample_inputs_group_norm, 15146 reference_inputs_func=reference_inputs_group_norm, 15147 supports_expanded_weight=True,), 15148 OpInfo('nn.functional.instance_norm', 15149 # no ref because instance_norm will often have numerical instability (large numbers or nan) 15150 dtypes=floating_types_and(torch.float16, torch.bfloat16), 15151 supports_out=False, 15152 supports_forward_ad=True, 15153 supports_fwgrad_bwgrad=True, 15154 allow_cow_input_materialize_forward=['running_mean', 'running_var'], 15155 decorators=[ 15156 # RuntimeError: Cannot insert a Tensor that requires grad as a constant. 15157 # Consider making it a parameter or input, or detaching the gradient 15158 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 15159 ], 15160 sample_inputs_func=sample_inputs_instance_norm, 15161 supports_expanded_weight=True,), 15162 OpInfo('nn.functional.layer_norm', 15163 aten_name='layer_norm', 15164 aten_backward_name='layer_norm_backward', 15165 aliases=('layer_norm',), 15166 ref=reference_layer_norm, 15167 dtypes=floating_types_and(torch.half, torch.bfloat16), 15168 supports_out=False, 15169 supports_forward_ad=True, 15170 supports_fwgrad_bwgrad=True, 15171 assert_jit_shape_analysis=True, 15172 decorators=[ 15173 DecorateInfo( 15174 toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}), 15175 'TestCommon', 'test_numpy_refs' 15176 ), 15177 DecorateInfo(unittest.skip("Bug in MPS backend!"), 'TestCommon', 'test_numpy_ref_mps'), 15178 ], 15179 sample_inputs_func=sample_inputs_layer_norm, 15180 supports_expanded_weight=True,), 15181 OpInfo('nn.functional.rms_norm', 15182 aten_name='rms_norm', 15183 aliases=('rms_norm',), 15184 ref=reference_rms_norm, 15185 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 15186 supports_out=False, 15187 supports_forward_ad=True, 15188 supports_fwgrad_bwgrad=True, 15189 sample_inputs_func=sample_inputs_rms_norm, 15190 error_inputs_func=error_inputs_rms_norm,), 15191 OpInfo('nn.functional.local_response_norm', 15192 dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), 15193 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 15194 supports_out=False, 15195 supports_forward_ad=True, 15196 supports_fwgrad_bwgrad=True, 15197 decorators=[ 15198 # RuntimeError: falseINTERNAL ASSERT FAILED at 15199 # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. 15200 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 15201 ], 15202 sample_inputs_func=sample_inputs_local_response_norm,), 15203 OpInfo('constant_pad_nd', 15204 supports_forward_ad=True, 15205 supports_fwgrad_bwgrad=True, 15206 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), 15207 sample_inputs_func=sample_inputs_constant_pad_nd, 15208 supports_out=False, 15209 skips=( 15210 # bool can't be passed to Scalar arguments in JIT tracer because 15211 # BoolType is not a subtype of ScalarType. 15212 DecorateInfo( 15213 unittest.expectedFailure, 'TestNNCOpInfo', 15214 'test_nnc_correctness', dtypes=(torch.bool,)), 15215 )), 15216 OpInfo('nn.functional.pad', 15217 variant_test_name='constant', 15218 aten_name='constant_pad_nd', 15219 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 15220 gradcheck_fast_mode=True, 15221 supports_forward_ad=True, 15222 supports_fwgrad_bwgrad=True, 15223 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), 15224 sample_inputs_func=partial(sample_inputs_nn_pad, mode='constant'), 15225 supports_out=False), 15226 OpInfo('nn.functional.pad', 15227 variant_test_name='reflect', 15228 supports_forward_ad=True, 15229 supports_fwgrad_bwgrad=True, 15230 dtypes=all_types_and_complex_and(torch.bfloat16), 15231 dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), 15232 sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'), 15233 skips=( 15234 # Doesn't have a corresponding aten operator. 15235 # RuntimeError: falseINTERNAL ASSERT FAILED at 15236 # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. 15237 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 15238 ), 15239 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15240 supports_out=False), 15241 OpInfo('nn.functional.pad', 15242 variant_test_name='replicate', 15243 supports_forward_ad=True, 15244 supports_fwgrad_bwgrad=True, 15245 dtypes=all_types_and_complex_and(torch.bfloat16), 15246 dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), 15247 sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'), 15248 skips=( 15249 # Doesn't have a corresponding aten operator. 15250 # RuntimeError: falseINTERNAL ASSERT FAILED at 15251 # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. 15252 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 15253 ), 15254 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15255 supports_out=False), 15256 OpInfo('nn.functional.pad', 15257 variant_test_name='replicate_negative', 15258 supports_forward_ad=True, 15259 supports_fwgrad_bwgrad=True, 15260 dtypes=all_types_and_complex_and(torch.bfloat16), 15261 dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), 15262 sample_inputs_func=sample_inputs_nn_pad_replicate_negative, 15263 skips=( 15264 # Doesn't have a corresponding aten operator. 15265 # RuntimeError: falseINTERNAL ASSERT FAILED at 15266 # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. 15267 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 15268 # Some negative padding cases cause a segfault on MPS 15269 DecorateInfo(unittest.skip("Not fully supported on MPS"), 'TestConsistency'), 15270 ), 15271 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15272 supports_out=False), 15273 OpInfo('nn.functional.pad', 15274 variant_test_name='circular', 15275 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), 15276 sample_inputs_func=partial(sample_inputs_nn_pad, mode='circular'), 15277 supports_forward_ad=True, 15278 supports_fwgrad_bwgrad=True, 15279 check_batched_grad=False, 15280 # https://github.com/pytorch/pytorch/issues/66357 15281 check_batched_forward_grad=False, 15282 skips=( 15283 # Doesn't have a corresponding aten operator. 15284 # RuntimeError: falseINTERNAL ASSERT FAILED at 15285 # "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, please report a bug to PyTorch. 15286 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), 15287 # Difference from <type> is larger with decomposition new_empty_strided.default than original on output 0 15288 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 'TestDecomp', 'test_comprehensive'), 15289 ), 15290 supports_out=False), 15291 OpInfo('nn.functional.hardswish', 15292 aten_name="hardswish", 15293 aten_backward_name='hardswish_backward', 15294 supports_autograd=True, 15295 assert_autodiffed=True, 15296 sample_inputs_func=sample_inputs_hardswish, 15297 dtypes=floating_types_and(torch.bfloat16, torch.half), 15298 supports_gradgrad=True, 15299 supports_forward_ad=True, 15300 supports_fwgrad_bwgrad=True, 15301 supports_out=False, 15302 autodiff_nonfusible_nodes=["aten::hardswish"]), 15303 OpInfo('nn.functional.unfold', 15304 aten_name='im2col', 15305 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), 15306 dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), 15307 sample_inputs_func=sample_inputs_nn_unfold, 15308 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 15309 gradcheck_fast_mode=True, 15310 supports_forward_ad=True, 15311 supports_fwgrad_bwgrad=True, 15312 supports_out=False, 15313 skips=( 15314 # NOTE: this failure may not reproduce consistently on different systems 15315 # false INTERNAL ASSERT FAILED at "...torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185 15316 DecorateInfo(unittest.skip("Internal assert failed!"), 'TestJit', 'test_variant_consistency_jit'), 15317 )), 15318 OpInfo('nn.functional.interpolate', 15319 aten_name="interpolate", 15320 variant_test_name='nearest', 15321 supports_autograd=True, 15322 supports_fwgrad_bwgrad=True, 15323 supports_forward_ad=True, 15324 dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), 15325 sample_inputs_func=partial(sample_inputs_interpolate, 'nearest'), 15326 skips=( 15327 # RuntimeError: false 15328 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, 15329 # please report a bug to PyTorch. 15330 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 15331 ), 15332 supports_out=False), 15333 OpInfo('nn.functional.interpolate', 15334 aten_name="interpolate", 15335 variant_test_name='nearest-exact', 15336 supports_autograd=True, 15337 supports_fwgrad_bwgrad=True, 15338 supports_forward_ad=True, 15339 dtypes=floating_types_and(torch.half, torch.bfloat16, torch.uint8), 15340 sample_inputs_func=partial(sample_inputs_interpolate, 'nearest-exact'), 15341 skips=( 15342 # RuntimeError: false 15343 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, 15344 # please report a bug to PyTorch. 15345 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 15346 # RuntimeError: aten::_upsample_nearest_exact*d hit the vmap fallback which is currently disabled 15347 DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'), 15348 DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'), 15349 DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'), 15350 # NotImplementedError: The operator 'aten::_upsample_nearest_exact3d.out' is not currently implemented 15351 # for the MPS device. 15352 DecorateInfo(unittest.expectedFailure, 'TestConsistency'), 15353 ), 15354 supports_out=False), 15355 OpInfo('nn.functional.interpolate', 15356 aten_name="interpolate", 15357 variant_test_name='linear', 15358 supports_autograd=True, 15359 supports_fwgrad_bwgrad=True, 15360 supports_forward_ad=True, 15361 dtypes=floating_types_and(torch.half, torch.bfloat16), 15362 sample_inputs_func=partial(sample_inputs_interpolate, 'linear'), 15363 skips=( 15364 # RuntimeError: false 15365 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, 15366 # please report a bug to PyTorch. 15367 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 15368 ), 15369 supports_out=False), 15370 OpInfo('nn.functional.interpolate', 15371 aten_name="interpolate", 15372 variant_test_name='bilinear', 15373 supports_fwgrad_bwgrad=True, 15374 supports_autograd=True, 15375 supports_forward_ad=True, 15376 dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), 15377 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), 15378 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15379 sample_inputs_func=partial(sample_inputs_interpolate, 'bilinear'), 15380 reference_inputs_func=partial(reference_inputs_interpolate, 'bilinear'), 15381 skips=( 15382 # RuntimeError: false 15383 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, 15384 # please report a bug to PyTorch. 15385 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 15386 ), 15387 supports_out=False), 15388 OpInfo('nn.functional.interpolate', 15389 aten_name="interpolate", 15390 variant_test_name='bicubic', 15391 supports_autograd=True, 15392 supports_forward_ad=True, 15393 supports_fwgrad_bwgrad=True, 15394 dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), 15395 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), 15396 sample_inputs_func=partial(sample_inputs_interpolate, 'bicubic'), 15397 reference_inputs_func=partial(reference_inputs_interpolate, 'bicubic'), 15398 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15399 skips=( 15400 # RuntimeError: false 15401 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, 15402 # please report a bug to PyTorch. 15403 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 15404 ), 15405 supports_out=False), 15406 OpInfo('nn.functional.interpolate', 15407 aten_name="interpolate", 15408 variant_test_name='trilinear', 15409 supports_autograd=True, 15410 supports_forward_ad=True, 15411 supports_fwgrad_bwgrad=True, 15412 dtypes=floating_types_and(torch.half, torch.bfloat16), 15413 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15414 sample_inputs_func=partial(sample_inputs_interpolate, 'trilinear'), 15415 skips=( 15416 # RuntimeError: false 15417 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, 15418 # please report a bug to PyTorch. 15419 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 15420 ), 15421 supports_out=False), 15422 OpInfo('nn.functional.interpolate', 15423 aten_name="interpolate", 15424 variant_test_name='area', 15425 supports_autograd=True, 15426 supports_forward_ad=True, 15427 supports_fwgrad_bwgrad=True, 15428 dtypes=floating_types_and(torch.half, torch.bfloat16), 15429 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), 15430 sample_inputs_func=partial(sample_inputs_interpolate, 'area'), 15431 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15432 skips=( 15433 # RuntimeError: false 15434 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, 15435 # please report a bug to PyTorch. 15436 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 15437 ), 15438 supports_out=False), 15439 OpInfo('nn.functional.upsample_bilinear', 15440 supports_autograd=True, 15441 supports_forward_ad=True, 15442 supports_fwgrad_bwgrad=True, 15443 dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), 15444 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), 15445 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15446 sample_inputs_func=partial(sample_inputs_upsample, 'bilinear'), 15447 reference_inputs_func=partial(reference_inputs_upsample, 'bilinear'), 15448 skips=( 15449 # RuntimeError: false 15450 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, 15451 # please report a bug to PyTorch. 15452 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 15453 ), 15454 supports_out=False), 15455 OpInfo('_upsample_bilinear2d_aa', 15456 op=torch.ops.aten._upsample_bilinear2d_aa, 15457 aten_name='_upsample_bilinear2d_aa', 15458 supports_autograd=True, 15459 supports_forward_ad=True, 15460 supports_fwgrad_bwgrad=True, 15461 dtypes=floating_types_and(torch.uint8), 15462 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), 15463 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15464 sample_inputs_func=partial(sample_inputs_upsample_aa, 'bilinear'), 15465 supports_out=False, 15466 skips=( 15467 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 15468 DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), 15469 DecorateInfo(unittest.expectedFailure, 'TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive'), 15470 DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), 15471 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 15472 )), 15473 OpInfo( 15474 "nn.functional.soft_margin_loss", 15475 dtypes=floating_types_and(torch.half, torch.bfloat16), 15476 supports_out=False, 15477 supports_forward_ad=True, 15478 # doesn't support grad on target 15479 sample_inputs_func=partial(sample_inputs_loss, rhs_requires_grad=False), 15480 error_inputs_func=error_inputs_soft_margin_loss, 15481 ), 15482 OpInfo('nn.functional.upsample_nearest', 15483 supports_autograd=True, 15484 supports_forward_ad=True, 15485 supports_fwgrad_bwgrad=True, 15486 dtypes=floating_types_and(torch.uint8, torch.half, torch.bfloat16), 15487 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15488 sample_inputs_func=partial(sample_inputs_upsample, 'nearest'), 15489 skips=( 15490 # RuntimeError: false 15491 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":185, 15492 # please report a bug to PyTorch. 15493 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 15494 ), 15495 supports_out=False), 15496 OpInfo( 15497 "nn.functional.margin_ranking_loss", 15498 dtypes=all_types_and(torch.half, torch.bfloat16), 15499 supports_out=False, 15500 sample_inputs_func=sample_inputs_margin_ranking_loss, 15501 error_inputs_func=error_inputs_margin_ranking_loss, 15502 reference_inputs_func=reference_inputs_margin_ranking_loss, 15503 supports_forward_ad=True, 15504 supports_fwgrad_bwgrad=True), 15505 OpInfo( 15506 "nn.functional.multi_margin_loss", 15507 dtypes=floating_types(), 15508 dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), 15509 supports_out=False, 15510 supports_gradgrad=False, 15511 sample_inputs_func=sample_inputs_multi_margin_loss, 15512 reference_inputs_func=reference_inputs_multi_margin_loss, 15513 error_inputs_func=error_inputs_multi_margin_loss, 15514 decorators=( 15515 DecorateInfo( 15516 toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), 15517 "TestJit", 15518 "test_variant_consistency_jit", 15519 ), 15520 ), 15521 ), 15522 OpInfo( 15523 "nn.functional.multilabel_margin_loss", 15524 dtypes=floating_types(), 15525 dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), 15526 supports_out=False, 15527 supports_gradgrad=False, 15528 sample_inputs_func=sample_inputs_multilabel_margin_loss, 15529 reference_inputs_func=reference_inputs_multilabel_margin_loss, 15530 error_inputs_func=error_inputs_multilabel_margin_loss, 15531 ), 15532 OpInfo('nn.functional.leaky_relu', 15533 aliases=None, 15534 aten_name="leaky_relu", 15535 aten_backward_name='leaky_relu_backward', 15536 sample_inputs_func=sample_inputs_leaky_relu, 15537 dtypes=floating_types_and(torch.bfloat16, torch.float16), 15538 inplace_variant=lambda x, negative_slope=0.01: 15539 torch.nn.functional.leaky_relu(x, negative_slope, inplace=True), 15540 supports_autograd=True, 15541 assert_autodiffed=True, 15542 supports_gradgrad=True, 15543 supports_out=False, 15544 supports_forward_ad=True, 15545 supports_fwgrad_bwgrad=True, 15546 autodiff_nonfusible_nodes=["aten::leaky_relu"]), 15547 OpInfo( 15548 "nn.functional.multilabel_soft_margin_loss", 15549 supports_out=False, 15550 dtypes=floating_types_and(torch.half, torch.bfloat16), 15551 sample_inputs_func=sample_inputs_multilabel_soft_margin_loss, 15552 supports_forward_ad=True, 15553 supports_fwgrad_bwgrad=True, 15554 decorators=( 15555 DecorateInfo( 15556 toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), 15557 "TestJit", 15558 "test_variant_consistency_jit", 15559 ), 15560 DecorateInfo( 15561 toleranceOverride({torch.float16: tol(atol=4e-3, rtol=1.3e-3)}), 15562 "TestInductorOpInfo", 15563 "test_comprehensive", 15564 device_type="cuda" 15565 ), 15566 ), 15567 skips=( 15568 # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 4096 15569 # __main__.TestJitCUDA.test_variant_consistency_jit_nn_functional_multilabel_soft_margin_loss_cuda_float32 15570 # leaked 4096 bytes CUDA memory on device 0 15571 DecorateInfo( 15572 # Skip instead of expectedFailure because this fails 15573 # locally for me but passes in CI. 15574 unittest.skip("Skipped!"), 15575 "TestJit", 15576 "test_variant_consistency_jit", 15577 device_type="cuda", 15578 ), 15579 ), 15580 ), 15581 OpInfo('nn.functional.avg_pool2d', 15582 aten_name='avg_pool2d', 15583 supports_autograd=True, 15584 supports_forward_ad=True, 15585 supports_fwgrad_bwgrad=True, 15586 dtypes=floating_types_and(torch.int64, torch.float16, torch.bfloat16), 15587 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 15588 error_inputs_func=error_inputs_avg_pool2d, 15589 sample_inputs_func=sample_inputs_avgpool2d, 15590 skips=( 15591 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'), 15592 )), 15593 OpInfo('nn.functional.fractional_max_pool2d', 15594 supports_autograd=True, 15595 supports_out=False, 15596 supports_forward_ad=True, 15597 supports_fwgrad_bwgrad=True, 15598 op=lambda input, *args, **kwargs: 15599 wrapper_set_seed(torch.nn.functional.fractional_max_pool2d, input, *args, **kwargs), 15600 # vmap does not support random operations 15601 check_batched_forward_grad=False, 15602 dtypes=floating_types_and(torch.bfloat16, torch.float16), 15603 test_neg_view=False, 15604 sample_inputs_func=sample_inputs_fractional_max_pool2d, 15605 decorators=( 15606 # FIXME: AssertionError: False is not true : Tensors failed to compare as equal! 15607 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 15608 # RuntimeError: input->type()->kind() == TypeKind::OptionalType 15609 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 15610 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')), 15611 skips=( 15612 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)), 15613 OpInfo('nn.functional.fractional_max_pool3d', 15614 supports_autograd=True, 15615 supports_out=False, 15616 supports_forward_ad=True, 15617 supports_fwgrad_bwgrad=True, 15618 op=lambda input, *args, **kwargs: 15619 wrapper_set_seed(torch.nn.functional.fractional_max_pool3d, input, *args, **kwargs), 15620 # vmap does not support random operations 15621 check_batched_forward_grad=False, 15622 dtypes=floating_types_and(torch.bfloat16, torch.float16), 15623 test_neg_view=False, 15624 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15625 sample_inputs_func=sample_inputs_fractional_max_pool3d, 15626 decorators=( 15627 # FIXME: both derivatives are implemented incorrectly 15628 # https://github.com/pytorch/pytorch/issues/69322 15629 # FIXME: AssertionError: False is not true : Tensors failed to compare as equal! 15630 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 15631 # RuntimeError: input->type()->kind() == TypeKind::OptionalType 15632 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270 15633 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit')), 15634 skips=( 15635 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),)), 15636 OpInfo('nn.functional.max_pool1d', 15637 aten_name='max_pool1d', 15638 supports_autograd=True, 15639 supports_out=False, 15640 supports_forward_ad=True, 15641 supports_fwgrad_bwgrad=True, 15642 # got: Batching rule not implemented for aten::flatten.using_ints 15643 check_batched_forward_grad=False, 15644 # TODO: add shape checks 15645 assert_jit_shape_analysis=False, 15646 dtypes=floating_types_and(torch.bfloat16, torch.float16), 15647 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 15648 skips=( 15649 # Pre-existing condition; Needs to be fixed 15650 DecorateInfo(unittest.skip("Works on some configs"), 'TestNNCOpInfo', 15651 'test_nnc_correctness', dtypes=(torch.bfloat16,)), 15652 # RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. 15653 # Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() 15654 # to actually allocate memory 15655 DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'), 15656 ), 15657 error_inputs_func=error_inputs_max_pool1d, 15658 sample_inputs_func=sample_inputs_max_pool), 15659 OpInfo('nn.functional.max_pool2d', 15660 aten_name='max_pool2d', 15661 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 15662 gradcheck_fast_mode=True, 15663 # Vmap is not happy with non-contiguous (channels_last) inputs 15664 check_batched_gradgrad=False, 15665 supports_out=False, 15666 supports_forward_ad=True, 15667 supports_fwgrad_bwgrad=True, 15668 # got: Batching rule not implemented for aten::flatten.using_ints 15669 check_batched_forward_grad=False, 15670 assert_jit_shape_analysis=True, 15671 dtypes=all_types_and(torch.float16, torch.bfloat16), 15672 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 15673 error_inputs_func=error_inputs_max_pool2d, 15674 sample_inputs_func=sample_inputs_max_pool), 15675 OpInfo('max_pool2d_with_indices_backward', 15676 op=max_pool2d_backward, 15677 # We've defined a custom op, so there's no corresponding aten op 15678 aten_name=None, 15679 method_variant=None, 15680 inplace_variant=None, 15681 operator_variant=None, 15682 inplace_operator_variant=None, 15683 check_batched_gradgrad=False, 15684 supports_out=False, 15685 supports_forward_ad=True, 15686 supports_fwgrad_bwgrad=True, 15687 check_batched_forward_grad=False, 15688 assert_jit_shape_analysis=False, 15689 dtypes=floating_types_and(torch.bfloat16, torch.float16), 15690 sample_inputs_func=sample_inputs_max_pool, 15691 skips=( 15692 # We've defined a custom op here, and we don't handle the case where we receive an out kwarg 15693 DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"), 15694 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 15695 # FX failed to normalize op - add the op to the op_skip list. 15696 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 15697 # object has no attribute max_pool2d_with_indices_backward (It's not available on torch -- so expected) 15698 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit') 15699 )), 15700 OpInfo('nn.functional.max_pool3d', 15701 aten_name='max_pool3d', 15702 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 15703 gradcheck_fast_mode=True, 15704 supports_out=False, 15705 supports_forward_ad=True, 15706 supports_fwgrad_bwgrad=True, 15707 # got: Batching rule not implemented for aten::flatten.using_ints 15708 check_batched_forward_grad=False, 15709 # TODO: add shape checks 15710 assert_jit_shape_analysis=False, 15711 dtypes=all_types_and(torch.bfloat16, torch.float16), 15712 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 15713 # TODO: investigate nondeterminism 15714 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15715 error_inputs_func=error_inputs_max_pool3d, 15716 sample_inputs_func=sample_inputs_max_pool), 15717 OpInfo('nn.functional.max_unpool1d', 15718 aten_name='max_unpool1d', 15719 supports_autograd=True, 15720 supports_forward_ad=True, 15721 supports_fwgrad_bwgrad=True, 15722 supports_out=False, 15723 assert_jit_shape_analysis=False, 15724 dtypes=floating_types_and(torch.float16, torch.bfloat16), 15725 sample_inputs_func=sample_inputs_max_unpool, 15726 skips=( 15727 # Gradients are tested in `variant_test_name=grad` below. 15728 # We skip tests here because there is non-determinism in backward 15729 # with gather, when there are writes into the same memory location, 15730 # and if there are several indices pointing to the same memory, 15731 # gradcheck is oblivious about that and cannot perturb them all at once 15732 # (see sample_inputs_max_unpool_grad to find out more). 15733 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), 15734 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), 15735 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', 15736 active_if=(not IS_MACOS)), 15737 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad', 15738 device_type='cpu'), 15739 )), 15740 OpInfo('nn.functional.max_unpool1d', 15741 variant_test_name='grad', 15742 aten_name='max_unpool1d', 15743 supports_autograd=True, 15744 supports_forward_ad=True, 15745 supports_fwgrad_bwgrad=True, 15746 supports_out=False, 15747 assert_jit_shape_analysis=False, 15748 dtypes=floating_types_and(torch.float16, torch.bfloat16), 15749 sample_inputs_func=sample_inputs_max_unpool_grad), 15750 OpInfo('nn.functional.max_unpool2d', 15751 aten_name='max_unpool2d', 15752 supports_autograd=True, 15753 supports_forward_ad=True, 15754 supports_fwgrad_bwgrad=True, 15755 supports_out=False, 15756 assert_jit_shape_analysis=False, 15757 dtypes=floating_types_and(torch.float16, torch.bfloat16), 15758 sample_inputs_func=sample_inputs_max_unpool, 15759 skips=( 15760 # Gradients are tested in `variant_test_name=grad` below. 15761 # We skip tests here because there is non-determinism in backward 15762 # with gather, when there are writes into the same memory location, 15763 # and if there are several indices pointing to the same memory, 15764 # gradcheck is oblivious about that and cannot perturb them all at once 15765 # (see sample_inputs_max_unpool_grad to find out more). 15766 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', 15767 active_if=(not IS_MACOS)), 15768 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), 15769 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), 15770 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), 15771 )), 15772 OpInfo('nn.functional.max_unpool2d', 15773 variant_test_name='grad', 15774 aten_name='max_unpool2d', 15775 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 15776 gradcheck_fast_mode=True, 15777 supports_forward_ad=True, 15778 supports_fwgrad_bwgrad=True, 15779 # Vmap is not happy with non-contiguous (channels_last) inputs 15780 check_batched_grad=False, 15781 supports_out=False, 15782 assert_jit_shape_analysis=False, 15783 dtypes=floating_types_and(torch.float16, torch.bfloat16), 15784 sample_inputs_func=sample_inputs_max_unpool_grad), 15785 OpInfo('nn.functional.max_unpool3d', 15786 aten_name='max_unpool3d', 15787 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 15788 gradcheck_fast_mode=True, 15789 supports_forward_ad=True, 15790 supports_fwgrad_bwgrad=True, 15791 supports_out=False, 15792 assert_jit_shape_analysis=False, 15793 dtypes=floating_types_and(torch.float16, torch.bfloat16), 15794 sample_inputs_func=sample_inputs_max_unpool, 15795 skips=( 15796 # Gradients are tested in `variant_test_name=grad` below. 15797 # We skip tests here because there is non-determinism in backward 15798 # with gather, when there are writes into the same memory location, 15799 # and if there are several indices pointing to the same memory, 15800 # gradcheck is oblivious about that and cannot perturb them all at once 15801 # (see sample_inputs_max_unpool_grad to find out more). 15802 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', 15803 active_if=(not IS_MACOS)), 15804 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), 15805 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), 15806 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), 15807 )), 15808 OpInfo('nn.functional.max_unpool3d', 15809 variant_test_name='grad', 15810 aten_name='max_unpool3d', 15811 supports_autograd=True, 15812 supports_forward_ad=True, 15813 supports_fwgrad_bwgrad=True, 15814 supports_out=False, 15815 assert_jit_shape_analysis=False, 15816 dtypes=floating_types_and(torch.float16, torch.bfloat16), 15817 sample_inputs_func=sample_inputs_max_unpool_grad), 15818 OpInfo('nn.functional.linear', 15819 aten_name='linear', 15820 supports_autograd=True, 15821 supports_gradgrad=True, 15822 sample_inputs_func=sample_inputs_linear, 15823 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 15824 dtypesIfROCM=floating_and_complex_types_and(torch.float16, torch.bfloat16), 15825 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), 15826 backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), 15827 # linear calls mm under the hood which is nondeterministic on CUDA 15828 # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms 15829 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 15830 supports_forward_ad=True, 15831 supports_fwgrad_bwgrad=True, 15832 # See https://github.com/pytorch/pytorch/issues/66357 15833 check_batched_forward_grad=False, 15834 supports_expanded_weight=True, 15835 decorators=( 15836 # Strides are not the same! 15837 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 15838 )), 15839 OpInfo('nn.functional.bilinear', 15840 aten_name='bilinear', 15841 supports_autograd=True, 15842 sample_inputs_func=sample_inputs_bilinear, 15843 dtypes=all_types_and(torch.float16, torch.bfloat16), 15844 dtypesIfCUDA=floating_types_and(torch.float16, 15845 *[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else []), 15846 decorators=( 15847 DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-03, rtol=1.3e-03)}), 15848 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'), 15849 ), 15850 skips=( 15851 # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 15852 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), 15853 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bfloat16,)), 15854 ), 15855 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 15856 gradcheck_fast_mode=True, 15857 supports_forward_ad=True, 15858 supports_fwgrad_bwgrad=True, 15859 supports_out=False), 15860 OpInfo('nn.functional.glu', 15861 aten_name='glu', 15862 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 15863 gradcheck_fast_mode=True, 15864 sample_inputs_func=sample_inputs_glu, 15865 dtypes=floating_types_and(torch.bfloat16, torch.float16), 15866 supports_forward_ad=True, 15867 supports_fwgrad_bwgrad=True, 15868 supports_out=False), 15869 UnaryUfuncInfo( 15870 'nn.functional.elu', 15871 aten_backward_name='elu_backward', 15872 ref=lambda x, alpha=1.0, inplace=False: 15873 np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x) - 1)), 15874 dtypes=floating_types_and(torch.bfloat16, torch.float16), 15875 supports_forward_ad=True, 15876 supports_fwgrad_bwgrad=True, 15877 supports_autograd=True, 15878 assert_autodiffed=False, 15879 supports_gradgrad=True, 15880 supports_out=False, 15881 sample_kwargs=lambda device, dtype, input: 15882 ({'alpha': 0.8}, {'alpha': 0.8}), 15883 inplace_variant=lambda x, alpha=1.0: 15884 torch.nn.functional.elu(x, alpha, inplace=True), 15885 decorators=[ 15886 DecorateInfo( 15887 toleranceOverride({ 15888 torch.float16: tol(atol=1e-03, rtol=1.2e-03), 15889 torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) 15890 }), 15891 'TestUnaryUfuncs', device_type='cuda', 15892 ), ], 15893 ), 15894 # Marked as a Unary function because it has some rather odd broadcasting semantics in its 15895 # second argument 15896 UnaryUfuncInfo( 15897 'nn.functional.prelu', 15898 aten_backward_name='_prelu_kernel_backward', 15899 ref=lambda x, weight: 15900 np.maximum(0., x) + np.minimum(0., x) * 15901 (weight if x.ndim == 1 else weight.reshape([weight.size if i == 1 else 1 for i in range(0, x.ndim)])), 15902 dtypes=floating_types_and(torch.bfloat16, torch.float16), 15903 supports_forward_ad=True, 15904 supports_fwgrad_bwgrad=True, 15905 supports_autograd=True, 15906 assert_autodiffed=False, 15907 supports_gradgrad=True, 15908 supports_out=False, 15909 # test_reference_numerics only tests the case when the weight tensor is a scalar 15910 sample_kwargs=sample_kwargs_prelu_scalar_weight, 15911 error_inputs_func=error_inputs_prelu, 15912 sample_inputs_func=sample_inputs_prelu, 15913 reference_inputs_func=reference_inputs_prelu, 15914 decorators=[ 15915 # RuntimeError: Cannot insert a Tensor that requires grad as a constant. 15916 # Consider making it a parameter or input, or detaching the gradient 15917 # https://github.com/pytorch/pytorch/issues/68752 15918 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), ], 15919 ), 15920 UnaryUfuncInfo( 15921 'nn.functional.celu', 15922 ref=lambda x, alpha=1.0, inplace=False: 15923 np.maximum(0., x) + np.minimum(0., alpha * (np.exp(x / alpha) - 1)), 15924 dtypes=floating_types_and(torch.bfloat16, torch.float16), 15925 supports_forward_ad=True, 15926 supports_fwgrad_bwgrad=True, 15927 supports_autograd=True, 15928 assert_autodiffed=False, 15929 supports_gradgrad=True, 15930 supports_out=False, 15931 sample_kwargs=lambda device, dtype, input: 15932 ({'alpha': 0.8}, {'alpha': 0.8}), 15933 inplace_variant=lambda x, alpha=1.0: 15934 torch.nn.functional.celu(x, alpha, inplace=True), 15935 decorators=[ 15936 DecorateInfo( 15937 toleranceOverride({ 15938 torch.float16: tol(atol=1e-03, rtol=1.2e-03), 15939 torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) 15940 }), 15941 'TestUnaryUfuncs', device_type='cuda', 15942 ), ], 15943 ), 15944 UnaryUfuncInfo( 15945 'nn.functional.rrelu', 15946 aten_backward_name='rrelu_with_noise_backward', 15947 op=lambda input, *args, **kwargs: 15948 wrapper_set_seed(torch.nn.functional.rrelu, input, *args, **kwargs), 15949 inplace_variant=lambda input, *args, **kwargs: 15950 wrapper_set_seed(torch.nn.functional.rrelu, input, *args, inplace=True, **kwargs), 15951 dtypes=floating_types_and(torch.bfloat16), 15952 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 15953 gradcheck_wrapper=wrapper_set_seed, 15954 supports_forward_ad=True, 15955 supports_fwgrad_bwgrad=True, 15956 supports_out=False, 15957 sample_kwargs=lambda device, dtype, input: 15958 (dict(lower=0., upper=1., training=True), dict(lower=0., upper=1., training=True)), 15959 sample_inputs_func=sample_inputs_rrelu, 15960 error_inputs_func=error_inputs_rrelu, 15961 decorators=( 15962 DecorateInfo( 15963 toleranceOverride({ 15964 torch.float16: tol(atol=1e-03, rtol=1.2e-03), 15965 torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) 15966 }), 15967 'TestUnaryUfuncs', device_type='cuda', 15968 ),), 15969 skips=( 15970 # lambda impl 15971 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 15972 # lambda impl 15973 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 15974 # In-place operations do not play well with forward AD 15975 # https://github.com/pytorch/pytorch/issues/77447 15976 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 15977 'test_inplace_forward_mode_AD'), 15978 # The noise vector that's generated in these tests is not the same elementwise 15979 DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_batch_vs_slicing'), 15980 DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'), 15981 DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_non_contig_expand'), 15982 DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'), 15983 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))), 15984 UnaryUfuncInfo( 15985 'nn.functional.selu', 15986 ref=lambda x, inplace=False: 15987 1.0507009873554804934193349852946 * ( 15988 np.maximum(0., x) + np.minimum(0., 1.6732632423543772848170429916717 * (np.exp(x) - 1)) 15989 ), 15990 dtypes=floating_types_and(torch.bfloat16, torch.float16), 15991 supports_forward_ad=True, # depends on 'elu' 15992 supports_fwgrad_bwgrad=True, 15993 supports_autograd=True, 15994 assert_autodiffed=False, 15995 supports_gradgrad=True, 15996 supports_out=False, 15997 inplace_variant=lambda x: torch.nn.functional.selu(x, inplace=True), 15998 decorators=[ 15999 DecorateInfo( 16000 toleranceOverride({ 16001 torch.float16: tol(atol=1e-2, rtol=1.8e-2), 16002 torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2) 16003 }), 16004 'TestUnaryUfuncs', device_type='cuda', 16005 ), ], 16006 ), 16007 OpInfo( 16008 'torch._scaled_mm', 16009 sample_inputs_func=sample_inputs_scaled_mm, 16010 dtypes=empty_types(), 16011 dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,), 16012 supports_out=True, 16013 supports_forward_ad=False, 16014 supports_autograd=False, 16015 decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0')], 16016 skips=( 16017 # Sample inputs isn't really parametrized on dtype 16018 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', 16019 device_type='cuda'), 16020 # "mul_cuda" not implemented for float8_e4m3fn 16021 # https://github.com/pytorch/pytorch/issues/107256 16022 DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', 16023 dtypes=(torch.float8_e4m3fn,)), 16024 ) 16025 ), 16026 OpInfo( 16027 'torch.ops.aten._safe_softmax.default', 16028 dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool), 16029 sample_inputs_func=sample_inputs_safe_softmax, 16030 assert_jit_shape_analysis=True, 16031 assert_autodiffed=True, 16032 supports_forward_ad=True, 16033 supports_fwgrad_bwgrad=True, 16034 supports_out=False, 16035 supports_cow_input_no_materialize_backward=False, 16036 decorators=[], 16037 skips=( 16038 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 16039 ), 16040 ), 16041 OpInfo( 16042 'nn.functional.scaled_dot_product_attention', 16043 op=lambda *args, **kwargs: 16044 wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs), 16045 sample_inputs_func=sample_inputs_scaled_dot_product_attention, 16046 dtypes=floating_types_and(torch.float16, torch.bfloat16), 16047 supports_out=False, 16048 supports_forward_ad=False, 16049 supports_fwgrad_bwgrad=True, 16050 check_batched_forward_grad=False, 16051 decorators=[DecorateInfo(toleranceOverride( 16052 {torch.float32: tol(atol=5e-05, rtol=5e-6)}), 'TestCommon',), ], 16053 skips=( 16054 # When attn mask is a composite tensor this fails backward by returning a none 16055 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', device_type='cuda'), 16056 # This is only failing on Linux Bionic 3.10 Cuda 11.6 16057 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', 16058 device_type='cuda', active_if=_get_torch_cuda_version() >= (11, 6)), 16059 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', 16060 dtypes=(torch.float32,)), 16061 # AssertionError: JIT Test does not execute any logic 16062 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 16063 # Forward works for dtype=float64 which is the math path 16064 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), 16065 # Not implemented for Forward AD 16066 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', 16067 device_type='cpu'), 16068 # Not implemented for backward derivative 16069 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad', 16070 device_type='cpu'), 16071 # CPU and CUDA have inconsistencies for intermediate outputs 16072 DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace', 16073 device_type='cpu'), 16074 DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', 16075 device_type='cpu'), 16076 # When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false 16077 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', 16078 device_type='cpu'), 16079 # OpInfo was implemented with a lambda 16080 DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 16081 # TODO Need to understand what this is testing and why it doesn't work 16082 DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'), 16083 DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'), 16084 # TODO skip this for now since we can't skip on runtime arch support 16085 DecorateInfo(unittest.skip('This is '), 'TestInductorOpInfo', 'test_comprehensive'), 16086 # skip for sm < 80 16087 DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', 16088 device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), 16089 # FIXME 16090 DecorateInfo(unittest.skip('test_cow_input does not work with efficient attention on ROCM'), 16091 'TestCompositeCompliance', 'test_cow_input', 16092 device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), 16093 active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), 16094 DecorateInfo(unittest.skip('test_fake_crossref_backward_amp does not work with efficient attention on ROCM'), 16095 'TestFakeTensor', 'test_fake_crossref_backward_amp', 16096 device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), 16097 active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), 16098 DecorateInfo(unittest.skip('test_fake_crossref_backward_no_amp does not work with efficient attention on ROCM'), 16099 'TestFakeTensor', 'test_fake_crossref_backward_no_amp', 16100 device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), 16101 active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), 16102 # for element 1, was torch.Size([4, 4, 0]) but real shape was torch.Size([16, 3, 0]) 16103 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda", 16104 dtypes=[torch.float16, torch.bfloat16, torch.float32], 16105 active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), 16106 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda", 16107 dtypes=[torch.float16, torch.bfloat16, torch.float32], 16108 active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), 16109 # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11]) 16110 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", 16111 device_type="cuda", dtypes=[torch.float32], 16112 active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),), 16113 ), 16114 OpInfo( 16115 'torch.ops.aten._flash_attention_forward', 16116 sample_inputs_func=sample_inputs_flash_attention_forward, 16117 dtypes=empty_types(), 16118 dtypesIfCUDA=custom_types(torch.float16) 16119 if not SM80OrLater 16120 else custom_types(torch.float16, torch.bfloat16), 16121 supports_out=False, 16122 supports_autograd=True, 16123 supports_fwgrad_bwgrad=False, 16124 supports_forward_ad=False, 16125 check_batched_forward_grad=False, 16126 decorators=[skipCUDAIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "This platform doesn't support Flash Attention")], 16127 skips=( 16128 # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11]) 16129 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", device_type="cuda", 16130 dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), 16131 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda", 16132 dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), 16133 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda", 16134 dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), 16135 # Checking the scalar value of the philox seed and offset 16136 # Checking the scalar value of the philox seed and offset 16137 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), 16138 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), 16139 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'), 16140 # None Mismatch Tensor 16141 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'), 16142 ) 16143 ), 16144 OpInfo( 16145 'torch.ops.aten._efficient_attention_forward', 16146 sample_inputs_func=sample_inputs_efficient_attention_forward, 16147 dtypes=empty_types(), 16148 dtypesIfCUDA=custom_types(torch.float16, torch.float32) 16149 if not SM80OrLater 16150 else custom_types(torch.float16, torch.float32, torch.bfloat16), 16151 supports_out=False, 16152 supports_autograd=True, 16153 supports_fwgrad_bwgrad=False, 16154 supports_forward_ad=False, 16155 check_batched_forward_grad=False, 16156 # TODO: Skip because it produces a CUDA illegal memory access for some reason 16157 skip_cow_input_backward=True, 16158 # FIXME: mask_type == 2 (LowerRight) 16159 decorators=[ 16160 skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"), 16161 skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")], 16162 skips=( 16163 # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11]) 16164 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", device_type="cuda", 16165 dtypes=[torch.float16, torch.bfloat16, torch.float32], 16166 active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), 16167 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda", 16168 dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), 16169 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda", 16170 dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), 16171 # Checking the scaler value of the philox seed and offset 16172 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), 16173 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), 16174 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', device_type='cuda'), 16175 # None Mismatch Tensor 16176 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward', device_type='cuda'), 16177 ) 16178 ), 16179 UnaryUfuncInfo( 16180 'nn.functional.silu', 16181 aten_backward_name='silu_backward', 16182 ref=lambda x, inplace=False: x / (1 + np.exp(-x)), 16183 dtypes=floating_types_and(torch.bfloat16, torch.float16), 16184 supports_forward_ad=True, 16185 supports_autograd=True, 16186 supports_fwgrad_bwgrad=True, 16187 assert_autodiffed=True, 16188 supports_out=False, 16189 inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True), 16190 decorators=[ 16191 DecorateInfo( 16192 toleranceOverride({ 16193 torch.float16: tol(atol=1e-3, rtol=1e-3), 16194 torch.bfloat16: tol(atol=1e-4, rtol=1e-4) 16195 }), 16196 'TestUnaryUfuncs', device_type='cuda', 16197 ), ], 16198 skips=( 16199 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', 16200 dtypes=(torch.cfloat,), device_type='cpu'), 16201 ), 16202 autodiff_nonfusible_nodes=["aten::silu"], 16203 ), 16204 # TODO: combine this with the nn.functional.silu OpInfo when 16205 # complex autodiff for silu is supported or when 16206 # the forward bug is fixed 16207 # Note: silu errors when given inputs that require grad 16208 # but it doesn't support grad in their dtype 16209 # This is why the dtypes list above passes test_dtypes, 16210 # because it's getting lucky and failing in forward 16211 # because test_dtypes sets requires_grad to True 16212 # THIS IS A BUG 16213 UnaryUfuncInfo( 16214 'nn.functional.silu', 16215 variant_test_name='complex', 16216 ref=lambda x, inplace=False: 16217 x / (1 + np.exp(-x)), 16218 dtypes=complex_types(), 16219 dtypesIfCUDA=complex_types(), 16220 supports_forward_ad=False, 16221 supports_autograd=False, 16222 assert_autodiffed=False, 16223 supports_out=False, 16224 inplace_variant=lambda x: torch.nn.functional.silu(x, inplace=True), 16225 decorators=[ 16226 DecorateInfo( 16227 toleranceOverride({ 16228 torch.float16: tol(atol=1e-3, rtol=1e-3), 16229 torch.bfloat16: tol(atol=1e-4, rtol=1e-4) 16230 }), 16231 'TestUnaryUfuncs', device_type='cuda', 16232 ), ], 16233 skips=( 16234 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', 16235 dtypes=(torch.cfloat,)), 16236 # FIXME: intentionally misreports dtypes 16237 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), 16238 # FIXME: numpy reference diverges: Comparing (nan+nanj) and (-0+0j) 16239 DecorateInfo(unittest.skip("Skipped!"), 16240 'TestUnaryUfuncs', 'test_reference_numerics_large', 16241 dtypes=(torch.complex64, torch.cdouble)), 16242 DecorateInfo(unittest.skip("Skipped!"), 16243 'TestUnaryUfuncs', 'test_reference_numerics_small', 16244 dtypes=(torch.complex64,)), 16245 DecorateInfo(unittest.skip("Skipped!"), 16246 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 16247 dtypes=(torch.complex64,)))), 16248 UnaryUfuncInfo( 16249 'nn.functional.hardsigmoid', 16250 aten_backward_name='hardsigmoid_backward', 16251 ref=reference_hardsigmoid, 16252 dtypes=floating_types_and(torch.bfloat16, torch.float16), 16253 supports_autograd=True, 16254 assert_autodiffed=False, 16255 supports_gradgrad=False, 16256 supports_forward_ad=True, 16257 supports_out=False, 16258 inplace_variant=partial(torch.nn.functional.hardsigmoid, inplace=True), 16259 decorators=[ 16260 DecorateInfo( 16261 toleranceOverride({torch.float16: tol(atol=1e-04, rtol=0.001)}), 'TestUnaryUfuncs', device_type='cuda',), ], 16262 skips=[ 16263 # still want to test that first derivative works though second derivative isn't supported 16264 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', "test_inplace_gradgrad"), 16265 # produces 0 instead of nan on ROCM 16266 DecorateInfo(unittest.expectedFailure, 16267 'TestUnaryUfuncs', "test_reference_numerics_extremal", 16268 device_type='cuda', 16269 active_if=(TEST_WITH_ROCM)), ] 16270 ), 16271 UnaryUfuncInfo( 16272 'nn.functional.logsigmoid', 16273 aten_name="log_sigmoid", 16274 aten_backward_name='log_sigmoid_backward', 16275 ref=reference_logsigmoid, 16276 dtypes=floating_types_and(torch.half, torch.bfloat16), 16277 supports_autograd=True, 16278 assert_autodiffed=False, 16279 supports_forward_ad=True, 16280 supports_fwgrad_bwgrad=True, 16281 supports_gradgrad=True, 16282 # autodiff_nonfusible_nodes=["aten::log_sigmoid"], 16283 decorators=[ 16284 DecorateInfo( 16285 precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), 16286 'TestUnaryUfuncs', 'test_reference_numerics_small'), 16287 DecorateInfo( 16288 precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), 16289 'TestUnaryUfuncs', 'test_reference_numerics_large'), 16290 DecorateInfo( 16291 precisionOverride({torch.float16: 1e-2, torch.bfloat16: 5e-3}), 16292 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), 16293 ], 16294 skips=( 16295 # Resized a non-empty tensor but did not warn about it. 16296 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cpu'), 16297 ), 16298 ), 16299 UnaryUfuncInfo( 16300 'nn.functional.mish', 16301 aten_backward_name='mish_backward', 16302 ref=lambda x: x * np.tanh(reference_softplus(x)), 16303 dtypes=floating_types_and(torch.bfloat16, torch.float16), 16304 supports_forward_ad=True, 16305 supports_fwgrad_bwgrad=True, 16306 supports_autograd=True, 16307 assert_autodiffed=False, 16308 supports_gradgrad=True, 16309 supports_out=False, 16310 inplace_variant=partial(torch.nn.functional.mish, inplace=True), 16311 decorators=[ 16312 DecorateInfo( 16313 toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), 'TestUnaryUfuncs',), ], 16314 ), 16315 UnaryUfuncInfo( 16316 'nn.functional.softsign', 16317 ref=lambda x: x / (np.abs(x) + 1), 16318 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 16319 dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), 16320 supports_forward_ad=True, 16321 supports_fwgrad_bwgrad=True, 16322 supports_autograd=True, 16323 assert_autodiffed=False, 16324 supports_gradgrad=True, 16325 supports_out=False, 16326 decorators=[ 16327 DecorateInfo( 16328 toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1.3e-04)}), 'TestUnaryUfuncs',), ], 16329 skips=( 16330 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', 16331 dtypes=(torch.int, torch.int8)),), 16332 ), 16333 UnaryUfuncInfo( 16334 'nn.functional.tanhshrink', 16335 ref=lambda x: x - np.tanh(x), 16336 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 16337 supports_forward_ad=True, 16338 supports_fwgrad_bwgrad=True, 16339 supports_autograd=True, 16340 assert_autodiffed=False, 16341 supports_gradgrad=True, 16342 supports_out=False, 16343 decorators=[ 16344 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', 16345 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 16346 DecorateInfo( 16347 toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), 'TestUnaryUfuncs',), 16348 DecorateInfo(toleranceOverride({torch.complex64: tol(atol=6e-04, rtol=1e-05), 16349 torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02)}), 16350 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), 16351 ], 16352 skips=( 16353 # in each case, pytorch will produce a nan while numpy will not 16354 DecorateInfo(unittest.skip("Fails on some jobs works on others!"), 16355 'TestUnaryUfuncs', "test_reference_numerics_large", 16356 dtypes=(torch.complex64, torch.complex128), active_if=(IS_MACOS)), 16357 DecorateInfo(unittest.skip("Fails on some jobs works on others!"), 16358 'TestUnaryUfuncs', "test_reference_numerics_extremal", 16359 dtypes=(torch.complex64, torch.complex128), device_type='cpu', 16360 active_if=(IS_MACOS or IS_WINDOWS)), 16361 ), 16362 # tan(j * pi/2 * odd_number) is nan which also make tanhshrink nan. 16363 reference_numerics_filter=NumericsFilter( 16364 condition=lambda x: (close_to_int(x / (math.pi * 0.5j)) 16365 if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), 16366 safe_val=0) 16367 ), 16368 UnaryUfuncInfo( 16369 'nn.functional.threshold', 16370 ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype), 16371 dtypes=all_types_and(torch.half, torch.bfloat16), 16372 inplace_variant=lambda x, threshold, value: 16373 torch.nn.functional.threshold(x, threshold, value, inplace=True), 16374 supports_forward_ad=True, 16375 supports_fwgrad_bwgrad=True, 16376 assert_autodiffed=False, 16377 supports_gradgrad=True, 16378 supports_out=False, 16379 sample_kwargs=lambda device, dtype, input: ({'threshold': float.fromhex('0x1.3ap-3'), 16380 'value': -9}, 16381 {'threshold': float.fromhex('0x1.3ap-3'), 16382 'value': -9}), 16383 # TODO(whc) should not need sample_inputs_func, but without it 16384 # kwargs aren't being hooked up properly 16385 sample_inputs_func=sample_inputs_threshold, 16386 ), 16387 OpInfo( 16388 "nn.functional.triplet_margin_loss", 16389 sample_inputs_func=sample_inputs_triplet_margin_loss, 16390 error_inputs_func=error_inputs_triplet_margin_loss, 16391 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 16392 supports_out=False, 16393 supports_forward_ad=True, 16394 supports_fwgrad_bwgrad=True, 16395 ), 16396 OpInfo( 16397 "nn.functional.triplet_margin_with_distance_loss", 16398 sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True), 16399 error_inputs_func=error_inputs_triplet_margin_loss, 16400 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 16401 supports_out=False, 16402 supports_forward_ad=True, 16403 supports_fwgrad_bwgrad=True, 16404 skips=( 16405 # This test cannot handle a callable passed to `distance_function`. If we would use 16406 # `distance_function=None`, the test would pass fine. 16407 DecorateInfo( 16408 unittest.expectedFailure, 16409 "TestJit", 16410 "test_variant_consistency_jit", 16411 ), 16412 DecorateInfo( 16413 unittest.expectedFailure, 16414 "TestNormalizeOperators", 16415 "test_normalize_operator_exhaustive", 16416 ), 16417 ), 16418 ), 16419 BinaryUfuncInfo('nextafter', 16420 dtypes=floating_types_and(torch.bfloat16, torch.half), 16421 dtypesIfCUDA=floating_types_and(torch.bfloat16), 16422 supports_autograd=False, 16423 supports_rhs_python_scalar=False), 16424 OpInfo( 16425 "to", 16426 op=lambda x, *args, **kwargs: x.to(*args, **kwargs), 16427 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), 16428 supports_forward_ad=True, 16429 supports_fwgrad_bwgrad=True, 16430 supports_out=False, 16431 sample_inputs_func=sample_inputs_to, 16432 skips=( 16433 # RuntimeError: undefined value cpu 16434 DecorateInfo( 16435 unittest.skip("Skipped!"), 16436 "TestJit", 16437 "test_variant_consistency_jit", 16438 device_type="cpu", 16439 ), 16440 # NotImplementedError: Cannot copy out of meta tensor; no data! 16441 DecorateInfo( 16442 unittest.skip("Skipped!"), 16443 "TestMeta", 16444 "test_meta_outplace", 16445 ), 16446 # https://github.com/pytorch/pytorch/issues/84335 16447 DecorateInfo( 16448 unittest.skip("Skipped!"), 16449 "TestProxyTensorOpInfo", 16450 "test_make_fx_symbolic_exhaustive", 16451 ), 16452 DecorateInfo( 16453 unittest.skip("Skipped!"), 16454 "TestNormalizeOperators", 16455 "test_normalize_operator_exhaustive", 16456 ), 16457 ), 16458 ), 16459 OpInfo('topk', 16460 dtypes=all_types_and(torch.bfloat16, torch.float16), 16461 supports_forward_ad=True, 16462 supports_fwgrad_bwgrad=True, 16463 assert_jit_shape_analysis=True, 16464 sample_inputs_func=sample_inputs_topk), 16465 # Multiple variants for batch_norm to test with and without cuDNN disabled 16466 # See https://github.com/pytorch/pytorch/pull/63218#discussion_r688549391 for more details 16467 OpInfo('nn.functional.batch_norm', 16468 aten_name='batch_norm', 16469 dtypes=floating_types_and(torch.float16, torch.bfloat16), 16470 supports_out=False, 16471 supports_forward_ad=True, 16472 supports_fwgrad_bwgrad=True, 16473 assert_jit_shape_analysis=True, 16474 allow_cow_input_materialize_forward=[1, 2], 16475 allow_cow_input_materialize_backward=[1, 2], 16476 sample_inputs_func=sample_inputs_batch_norm, 16477 skips=( 16478 # see https://github.com/pytorch/pytorch/issues/71286 16479 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), 16480 DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', 16481 device_type='cpu', dtypes=(torch.bfloat16, torch.float16)), 16482 DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-05, rtol=1e-05)}), 16483 'TestCompositeCompliance', 'test_forward_ad', device_type="cpu"), 16484 )), 16485 # This variant tests batch_norm with cuDNN disabled only on CUDA devices 16486 OpInfo('nn.functional.batch_norm', 16487 variant_test_name='without_cudnn', 16488 aten_name='batch_norm', 16489 dtypes=empty_types(), 16490 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 16491 supports_out=False, 16492 supports_forward_ad=True, 16493 supports_fwgrad_bwgrad=True, 16494 allow_cow_input_materialize_forward=[1, 2], 16495 allow_cow_input_materialize_backward=[1, 2], 16496 decorators=[onlyCUDA, disablecuDNN], 16497 skips=( 16498 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-04)}), 16499 'TestJit', 'test_variant_consistency_jit'), 16500 ), 16501 sample_inputs_func=sample_inputs_batch_norm), 16502 OpInfo( 16503 "nn.functional.binary_cross_entropy", 16504 aten_backward_name='binary_cross_entropy_backward', 16505 sample_inputs_func=sample_inputs_binary_cross_entropy, 16506 dtypes=floating_types_and(torch.float16, torch.bfloat16), 16507 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 16508 supports_out=False, 16509 gradcheck_fast_mode=False, 16510 supports_autograd=True, 16511 supports_forward_ad=True, 16512 supports_fwgrad_bwgrad=True, 16513 decorators=( 16514 # RuntimeError: expected int at position 0, but got: Tensor 16515 DecorateInfo( 16516 unittest.skip("Skipped!"), 16517 "TestCudaFuserOpInfo", 16518 ), 16519 # RuntimeError: expected int at position 0, but got: Tensor 16520 DecorateInfo( 16521 unittest.skip("Skipped!"), 16522 "TestNNCOpInfo", 16523 "test_nnc_correctness", 16524 ), 16525 # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120783 16526 DecorateInfo( 16527 unittest.skip("Skipped!"), 16528 "TestCompositeCompliance", 16529 "test_cow_input", 16530 device_type='cuda', 16531 ), 16532 DecorateInfo( 16533 toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}), 16534 "TestJit", 16535 "test_variant_consistency_jit", 16536 ), 16537 # RuntimeError: output with shape [] doesn't match the broadcast shape [5, 5] 16538 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'), 16539 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), 16540 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), 16541 ), 16542 skips=( 16543 # RuntimeError: expected int at position 0, but got: Tensor 16544 DecorateInfo( 16545 unittest.expectedFailure, 16546 "TestJit", 16547 "test_variant_consistency_jit", 16548 ), 16549 ), 16550 ), 16551 # We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the 16552 # standard entry, second is to run gradcheck tests on the second argument. 16553 BinaryUfuncInfo('igamma', 16554 dtypes=floating_types_and(torch.bfloat16, torch.float16), 16555 aliases=('torch.special.gammainc',), 16556 dtypesIfCUDA=floating_types(), 16557 # TODO: FIXME 16558 supports_rhs_python_scalar=False, 16559 supports_autograd=False, 16560 skips=( 16561 # FIXME: incorrectly tries to pass a rhs scalar 16562 DecorateInfo(unittest.expectedFailure, 'TestJit', 16563 'test_jit_alias_remapping'), 16564 )), 16565 # TODO: FIXME, ideally by implemented grad for both inputs 16566 # BinaryUfuncInfo('igamma', 16567 # variant_test_name='grad_other', 16568 # # Since autograd formula is implemented only for other and 16569 # # gradcheck test verifies the formula for input in SampleInput, 16570 # # we permute the arguments. 16571 # op=lambda self, other, **kwargs: torch.igamma(other, self, **kwargs), 16572 # inplace_variant=None, 16573 # method_variant=None, 16574 # supports_rhs_python_scalar=False, 16575 # rhs_make_tensor_kwargs=dict(requires_grad=False), 16576 # dtypes=floating_types_and(torch.bfloat16, torch.float16), 16577 # backward_dtypesIfCPU=floating_types_and(torch.bfloat16), 16578 # dtypesIfCUDA=floating_types(), 16579 # backward_dtypesIfCUDA=floating_types(), 16580 # supports_inplace_autograd=False, 16581 # skips=( 16582 # # Derivative wrt first tensor not implemented 16583 # DecorateInfo(unittest.expectedFailure, "TestCommon", 16584 # "test_floating_inputs_are_differentiable"),"), 16585 # # test does not work with passing lambda for op 16586 # # AssertionError: False is not true : Tensors failed to compare as equal! 16587 # DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 16588 # # test fails are we permute the arguments function variant 16589 # # but not for inplace or method. 16590 # DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), 16591 # # TypeError: igamma(): argument 'input' (position 1) must be Tensor, not float 16592 # DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'), 16593 # )), 16594 BinaryUfuncInfo('igammac', 16595 dtypes=floating_types_and(torch.bfloat16, torch.float16), 16596 aliases=('torch.special.gammaincc',), 16597 dtypesIfCUDA=floating_types(), 16598 supports_autograd=False, 16599 supports_rhs_python_scalar=False, 16600 skips=( 16601 # FIXME: incorrectly tries to pass a rhs scalar 16602 DecorateInfo(unittest.expectedFailure, 'TestJit', 16603 'test_jit_alias_remapping'), 16604 )), 16605 # TODO: FIXME, ideally by implementing grad for both inputs 16606 # BinaryUfuncInfo('igammac', 16607 # variant_test_name='grad_other', 16608 # # Since autograd formula is implemented only for other and 16609 # # gradcheck test verifies the formula for input in SampleInput, 16610 # # we permute the arguments 16611 # op=lambda self, other, **kwargs: torch.igammac(other, self, **kwargs), 16612 # inplace_variant=None, 16613 # method_variant=None, 16614 # supports_rhs_python_scalar=False, 16615 # rhs_make_tensor_kwargs=dict(requires_grad=False), 16616 # dtypes=floating_types_and(torch.bfloat16, torch.float16), 16617 # backward_dtypesIfCPU=floating_types_and(torch.bfloat16), 16618 # dtypesIfCUDA=floating_types(), 16619 # backward_dtypesIfCUDA=floating_types(), 16620 # supports_inplace_autograd=False, 16621 # decorators=[ 16622 # # Derivative wrt first tensor not implemented 16623 # DecorateInfo(unittest.expectedFailure, "TestCommon", 16624 # "test_floating_inputs_are_differentiable"), 16625 # ], 16626 # skips=( 16627 # # test does not work with passing lambda for op 16628 # # AssertionError: False is not true : Tensors failed to compare as equal! 16629 # DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 16630 # # test fails are we permute the arguments function variant 16631 # # but not for inplace or method. 16632 # DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), 16633 # # TypeError: igammac(): argument 'input' (position 1) must be Tensor, not float 16634 # DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs'), 16635 # )), 16636 UnaryUfuncInfo('nn.functional.softshrink', 16637 aten_name="softshrink", 16638 aten_backward_name='softshrink_backward', 16639 dtypes=floating_types_and(torch.bfloat16, torch.float16), 16640 supports_forward_ad=True, 16641 supports_fwgrad_bwgrad=True, 16642 assert_autodiffed=False, 16643 sample_inputs_func=sample_inputs_softshrink, 16644 error_inputs_func=error_inputs_softshrink), 16645 UnaryUfuncInfo('nn.functional.hardshrink', 16646 aten_name="hardshrink", 16647 aten_backward_name='hardshrink_backward', 16648 dtypes=floating_types_and(torch.bfloat16, torch.float16), 16649 assert_autodiffed=True, 16650 sample_inputs_func=sample_inputs_hardshrink, 16651 supports_forward_ad=True, 16652 supports_fwgrad_bwgrad=True, 16653 autodiff_nonfusible_nodes=["aten::hardshrink"]), 16654 UnaryUfuncInfo('nn.functional.hardtanh', 16655 aten_name="hardtanh", 16656 aten_backward_name='hardtanh_backward', 16657 dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.half, torch.bfloat16), 16658 backward_dtypes=all_types_and(torch.half, torch.bfloat16), 16659 backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 16660 assert_autodiffed=True, 16661 sample_inputs_func=sample_inputs_hardtanh, 16662 error_inputs_func=error_inputs_hardtanh, 16663 supports_out=False, 16664 supports_forward_ad=True, 16665 supports_fwgrad_bwgrad=True, 16666 autodiff_nonfusible_nodes=["aten::hardtanh"]), 16667 OpInfo('nn.functional.gelu', 16668 aten_name="gelu", 16669 aten_backward_name='gelu_backward', 16670 ref=reference_gelu if TEST_SCIPY else None, 16671 error_inputs_func=error_inputs_gelu, 16672 supports_autograd=True, 16673 assert_autodiffed=True, 16674 sample_inputs_func=sample_inputs_gelu, 16675 dtypes=floating_types_and(torch.bfloat16, torch.half), 16676 supports_gradgrad=True, 16677 supports_forward_ad=True, 16678 supports_fwgrad_bwgrad=True, 16679 autodiff_nonfusible_nodes=["aten::gelu"], 16680 skips=( 16681 # AssertionError: Tensor-likes are not close! 16682 # May not replicate in CI 16683 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), 16684 DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), 16685 )), 16686 UnaryUfuncInfo('nn.functional.relu6', 16687 aten_name="relu6", 16688 dtypes=all_types_and(torch.half, torch.bfloat16), 16689 backward_dtypes=floating_types_and(torch.half, torch.bfloat16), 16690 assert_autodiffed=True, 16691 supports_out=False, 16692 supports_forward_ad=True, 16693 supports_fwgrad_bwgrad=True, 16694 autodiff_nonfusible_nodes=["aten::relu6"]), 16695 OpInfo('mm', 16696 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 16697 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), 16698 assert_autodiffed=True, 16699 supports_forward_ad=True, 16700 supports_fwgrad_bwgrad=True, 16701 sample_inputs_func=sample_inputs_mm, 16702 skips=( 16703 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 16704 DecorateInfo( 16705 unittest.skip("Skipped!"), 16706 'TestSchemaCheckModeOpInfo', 16707 'test_schema_correctness', 16708 dtypes=(torch.complex64, torch.complex128)), 16709 )), 16710 OpInfo('mode', 16711 op=torch.mode, 16712 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 16713 supports_forward_ad=True, 16714 supports_fwgrad_bwgrad=True, 16715 skips=( 16716 # Resized a non-empty tensor but did not warn about it 16717 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 16718 # FIXME: 16719 # Expected 2114 but got 1123. 16720 # Absolute difference: 991 (up to 0.001 allowed) 16721 # Relative difference: 0.46877956480605487 (up to 0.001 allowed) 16722 DecorateInfo( 16723 unittest.skip("Skipped!"), 16724 "TestCommon", 16725 "test_compare_cpu", 16726 dtypes=(torch.float32,), 16727 device_type="cuda", 16728 ), 16729 ), 16730 sample_inputs_func=sample_inputs_mode,), 16731 make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_1', 16732 domain=(1, None), 16733 skips=skips_mvlgamma(), 16734 sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})), 16735 make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_3', 16736 domain=(2, None), 16737 skips=skips_mvlgamma(), 16738 sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})), 16739 make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_5', 16740 domain=(3, None), 16741 skips=skips_mvlgamma(), 16742 sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})), 16743 BinaryUfuncInfo('ne', 16744 ref=np.not_equal, 16745 aliases=('not_equal',), 16746 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), 16747 always_returns_bool=True, 16748 supports_autograd=False, 16749 skips=( 16750 )), 16751 OpInfo('narrow', 16752 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 16753 supports_out=False, 16754 supports_forward_ad=True, 16755 supports_fwgrad_bwgrad=True, 16756 sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=True), 16757 reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=True), 16758 error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=False), 16759 skips=( 16760 # Use of .item() 16761 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), 16762 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), 16763 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), 16764 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 16765 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 16766 )), 16767 OpInfo('narrow_copy', 16768 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 16769 supports_out=True, 16770 supports_forward_ad=False, 16771 supports_fwgrad_bwgrad=False, 16772 supports_autograd=False, 16773 # https://github.com/pytorch/pytorch/issues/86931 16774 sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=False), 16775 reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=False), 16776 error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=False), 16777 skips=( 16778 # https://github.com/pytorch/pytorch/issues/84577 16779 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 16780 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 16781 # Could not run 'aten::narrow_copy.out' with arguments from the 'CUDA' backend 16782 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace', 16783 device_type='cuda'), 16784 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace', 16785 device_type='cuda'), 16786 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace', 16787 device_type='cuda'), 16788 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), 16789 )), 16790 OpInfo('view_copy', 16791 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), 16792 ref=lambda x, newshape: np.reshape(x, newshape).copy(), 16793 supports_out=True, 16794 supports_forward_ad=True, 16795 supports_fwgrad_bwgrad=True, 16796 supports_autograd=True, 16797 sample_inputs_func=sample_inputs_view_reshape, 16798 error_inputs_func=error_inputs_view_reshape, 16799 skips=( 16800 # RuntimeError: view size is not compatible with input tensor's size and stride 16801 # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. 16802 DecorateInfo( 16803 unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides" 16804 ), 16805 )), 16806 UnaryUfuncInfo('neg', 16807 aliases=('negative', ), 16808 ref=np.negative, 16809 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), 16810 error_inputs_func=error_inputs_neg, 16811 supports_forward_ad=True, 16812 supports_fwgrad_bwgrad=True, 16813 supports_sparse=True, 16814 supports_sparse_csr=True, 16815 supports_sparse_csc=True, 16816 supports_sparse_bsr=True, 16817 supports_sparse_bsc=True, 16818 assert_autodiffed=True), 16819 OpInfo('dist', 16820 op=torch.dist, 16821 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 16822 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 16823 gradcheck_fast_mode=True, 16824 supports_out=False, 16825 supports_forward_ad=True, 16826 # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: 16827 # Could not allocate memory to change Tensor SizesAndStrides! 16828 check_batched_forward_grad=False, 16829 supports_fwgrad_bwgrad=True, 16830 sample_inputs_func=sample_inputs_dist), 16831 OpInfo('outer', 16832 op=torch.outer, 16833 aliases=('ger', ), 16834 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 16835 supports_forward_ad=True, 16836 supports_fwgrad_bwgrad=True, 16837 # See https://github.com/pytorch/pytorch/pull/78358 16838 check_batched_forward_grad=False, 16839 sample_inputs_func=sample_inputs_outer,), 16840 OpInfo('ormqr', 16841 op=torch.ormqr, 16842 dtypes=floating_and_complex_types(), 16843 # https://github.com/pytorch/pytorch/issues/80411 16844 gradcheck_fast_mode=True, 16845 supports_forward_ad=False, 16846 supports_fwgrad_bwgrad=False, 16847 sample_inputs_func=sample_inputs_ormqr, 16848 error_inputs_func=error_inputs_ormqr, 16849 decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack], 16850 skips=( 16851 # Strides are not the same! 16852 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 16853 )), 16854 OpInfo('permute', 16855 ref=np.transpose, 16856 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 16857 supports_out=False, 16858 assert_autodiffed=True, 16859 autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused 16860 autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused 16861 assert_jit_shape_analysis=True, 16862 supports_forward_ad=True, 16863 supports_fwgrad_bwgrad=True, 16864 supports_varargs=True, 16865 sample_inputs_func=sample_inputs_permute, 16866 reference_inputs_func=reference_inputs_permute), 16867 BinaryUfuncInfo('pow', 16868 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 16869 dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), 16870 ref=np.power, 16871 # Due to AVX2 currently not being fully supported for Float16, log_vml_cpu can't be enabled 16872 # for Float16, causing this test to fail. pow's autograd for Float16 is thus currently 16873 # unsupported on CPU. 16874 backward_dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 16875 backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf), 16876 # https://github.com/pytorch/pytorch/issues/80411 16877 gradcheck_fast_mode=True, 16878 supports_inplace_autograd=False, 16879 supports_forward_ad=True, 16880 supports_fwgrad_bwgrad=True, 16881 assert_autodiffed=True, 16882 supports_one_python_scalar=True, 16883 # Integer types do not support negative exponentes 16884 rhs_make_tensor_kwargs=dict(low=0), 16885 # Raising negative real numbers to fractional powers is not supported 16886 lhs_make_tensor_kwargs=dict(low=0), 16887 decorators=( 16888 DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}), 16889 'TestBinaryUfuncs', 'test_reference_numerics'), 16890 DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), 16891 torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), 16892 'TestBinaryUfuncs', 'test_scalar_support'), 16893 ), 16894 skips=( 16895 # Skipping integers because they are being raised to negative powers causing an error 16896 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_small_values', 16897 dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]), 16898 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_reference_numerics_large_values', 16899 dtypes=[torch.int16, torch.int32, torch.int64]), 16900 # FIXME Complex values error with: Greatest absolute difference: nan at index 16901 # Ref: https://github.com/pytorch/pytorch/issues/76853 16902 # For `chalf`, reference computation in `numpy` is computed in `cfloat`. 16903 # Output of `chalf` saturates to `inf` quicker than reference due to its small range 16904 # which leads to failure of this test. 16905 DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick', 16906 dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM), 16907 # FIXME: 16908 # Mismatched elements: 1 / 500 (0.2%) 16909 # Greatest absolute difference: nan at index (7, 9, 0) (up to 1e-05 allowed) 16910 # Greatest relative difference: nan at index (7, 9, 0) (up to 0.001 allowed) 16911 DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive', 16912 dtypes=(torch.complex32,)), 16913 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing', 16914 dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM), 16915 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_batch_vs_slicing', 16916 dtypes=(torch.complex32,)), 16917 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_non_contig', 16918 dtypes=(torch.complex32,)), 16919 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics', 16920 dtypes=(torch.complex32,)), 16921 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', 16922 dtypes=(torch.complex32, torch.complex64, torch.complex128)), 16923 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values', 16924 dtypes=(torch.complex32, torch.complex64, torch.complex128)), 16925 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', 16926 dtypes=(torch.complex32, torch.complex64, torch.complex128)), 16927 )), 16928 BinaryUfuncInfo('float_power', 16929 ref=np.float_power, 16930 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 16931 promotes_int_to_float=True, 16932 # https://github.com/pytorch/pytorch/issues/80411 16933 gradcheck_fast_mode=True, 16934 supports_forward_ad=True, 16935 supports_fwgrad_bwgrad=True, 16936 supports_one_python_scalar=True, 16937 # Integer types do not support negative exponentes 16938 rhs_make_tensor_kwargs=dict(low=0), 16939 # Raising negative real numbers to fractional powers is not supported 16940 lhs_make_tensor_kwargs=dict(low=0), 16941 decorators=( 16942 DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), 16943 torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), 16944 'TestBinaryUfuncs', 'test_scalar_support'), 16945 ), 16946 skips=( 16947 # FIXME 16948 # AssertionError: Object comparison failed: torch.float64 != torch.float32 16949 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), 16950 # -3.43399e+38 is outside the range of representable values of type 'float' 16951 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 16952 # Complex values error with: Greatest absolute difference: nan at index 16953 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_small_values', 16954 dtypes=[torch.complex64, torch.complex128]), 16955 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_large_values', 16956 dtypes=[torch.complex64, torch.complex128]), 16957 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', 16958 dtypes=[torch.complex64, torch.complex128]), 16959 # Inplace always promotes to double and thus other floating dtypes are not supported 16960 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', 16961 dtypes=[torch.bfloat16, torch.float16, torch.float32]), 16962 )), 16963 OpInfo('qr', 16964 op=torch.qr, 16965 dtypes=floating_and_complex_types(), 16966 sample_inputs_func=sample_inputs_linalg_qr_geqrf, 16967 supports_forward_ad=True, 16968 supports_fwgrad_bwgrad=True, 16969 # In-place ops 16970 check_batched_gradgrad=False, 16971 decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack]), 16972 UnaryUfuncInfo('rad2deg', 16973 ref=np.degrees, 16974 decorators=(precisionOverride({torch.bfloat16: 7e-1, 16975 torch.float16: 7e-1}),), 16976 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 16977 supports_forward_ad=True, 16978 supports_fwgrad_bwgrad=True, 16979 supports_sparse=True, 16980 supports_sparse_csr=True, 16981 supports_sparse_csc=True, 16982 supports_sparse_bsr=True, 16983 supports_sparse_bsc=True, 16984 promotes_int_to_float=True), 16985 UnaryUfuncInfo('real', 16986 ref=np.real, 16987 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), 16988 supports_out=False, 16989 supports_forward_ad=True, 16990 supports_fwgrad_bwgrad=True, 16991 # See https://github.com/pytorch/pytorch/issues/66357 16992 check_batched_forward_grad=False, 16993 skips=( 16994 # Skip since real and imag don't have out variants. 16995 DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_out_arg_all_dtypes'), 16996 )), 16997 OpInfo( 16998 "roll", 16999 ref=np.roll, 17000 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), 17001 error_inputs_func=error_inputs_roll, 17002 supports_out=False, 17003 supports_forward_ad=True, 17004 supports_fwgrad_bwgrad=True, 17005 sample_inputs_func=sample_inputs_roll, 17006 decorators=(onlyNativeDeviceTypes,), 17007 ), 17008 OpInfo( 17009 "rot90", 17010 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), 17011 error_inputs_func=error_inputs_rot90, 17012 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 17013 gradcheck_fast_mode=True, 17014 supports_out=False, 17015 supports_forward_ad=True, 17016 supports_fwgrad_bwgrad=True, 17017 sample_inputs_func=sample_inputs_rot90, 17018 ), 17019 # To test reference numerics against multiple values of argument `decimals`, 17020 # we make multiple OpInfo entries with each entry corresponding to different value of decimals. 17021 UnaryUfuncInfo('round', 17022 ref=np.round, 17023 aliases=('special.round',), 17024 dtypes=all_types_and(torch.half, torch.bfloat16), 17025 supports_forward_ad=True, 17026 supports_fwgrad_bwgrad=True, 17027 skips=( 17028 DecorateInfo(unittest.expectedFailure, 17029 'TestNNCOpInfo', 17030 'test_nnc_correctness', 17031 dtypes=tuple(t for t in integral_types() if t != torch.uint8)), 17032 DecorateInfo(unittest.skip("Skipped!"), 17033 'TestNNCOpInfo', 17034 'test_nnc_correctness', 17035 dtypes=(torch.bfloat16,)), 17036 ), 17037 supports_sparse=True, 17038 supports_sparse_csr=True, 17039 supports_sparse_csc=True, 17040 supports_sparse_bsr=True, 17041 supports_sparse_bsc=True, 17042 assert_autodiffed=True, 17043 ), 17044 UnaryUfuncInfo('round', 17045 ref=np.round, 17046 variant_test_name='decimals_0', 17047 aliases=('special.round',), 17048 dtypes=floating_types_and(torch.half, torch.bfloat16), 17049 sample_kwargs=lambda device, dtype, input: ({'decimals': 0}, {'decimals': 0}), 17050 sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 0}), 17051 supports_forward_ad=True, 17052 supports_fwgrad_bwgrad=True, 17053 assert_autodiffed=False, 17054 supports_sparse_csr=False), 17055 UnaryUfuncInfo('round', 17056 ref=np.round, 17057 variant_test_name='decimals_3', 17058 aliases=('special.round',), 17059 dtypes=floating_types_and(torch.bfloat16), 17060 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), 17061 sample_kwargs=lambda device, dtype, input: ({'decimals': 3}, {'decimals': 3}), 17062 sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': 3}), 17063 skips=( 17064 # test_ops already tested for this overload with `decimals_0` opinfo entry 17065 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), 17066 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), 17067 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), 17068 DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), 17069 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'), 17070 DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), 17071 "TestUnaryUfuncs", "test_reference_numerics_extremal", 17072 device_type="cuda"), 17073 DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), 17074 "TestUnaryUfuncs", "test_reference_numerics_normal", 17075 device_type="cuda"), 17076 ), 17077 supports_forward_ad=True, 17078 supports_fwgrad_bwgrad=True, 17079 assert_autodiffed=False, 17080 supports_sparse_csr=False), 17081 UnaryUfuncInfo('round', 17082 ref=np.round, 17083 variant_test_name='decimals_neg_3', 17084 aliases=('special.round',), 17085 dtypes=floating_types_and(torch.bfloat16), 17086 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), 17087 sample_kwargs=lambda device, dtype, input: ({'decimals': -3}, {'decimals': -3}), 17088 sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'decimals': -3}), 17089 skips=( 17090 # test_ops already tested for this overload with `decimals_0` opinfo entry 17091 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), 17092 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), 17093 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), 17094 DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), 17095 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits'), 17096 ), 17097 supports_forward_ad=True, 17098 supports_fwgrad_bwgrad=True, 17099 assert_autodiffed=False, 17100 supports_sparse_csr=False), 17101 UnaryUfuncInfo('sin', 17102 ref=np.sin, 17103 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 17104 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 17105 assert_autodiffed=True, 17106 handles_large_floats=False, 17107 supports_sparse=True, 17108 supports_sparse_csr=True, 17109 supports_sparse_csc=True, 17110 supports_sparse_bsr=True, 17111 supports_sparse_bsc=True, 17112 supports_forward_ad=True, 17113 supports_fwgrad_bwgrad=True, 17114 promotes_int_to_float=True, 17115 skips=( 17116 # Fails on CUDA but passes on ROCm 17117 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17118 dtypes=(torch.cdouble,), device_type='cuda'), 17119 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17120 dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), 17121 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17122 dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), 17123 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 17124 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 17125 ), 17126 decorators=(precisionOverride({torch.bfloat16: 1e-2}),)), 17127 UnaryUfuncInfo('sinc', 17128 ref=np_sinc_with_fp16_as_fp32, 17129 aliases=('special.sinc',), 17130 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 17131 handles_large_floats=False, 17132 supports_forward_ad=True, 17133 supports_fwgrad_bwgrad=True, 17134 promotes_int_to_float=True), 17135 UnaryUfuncInfo('sinh', 17136 ref=np_unary_ufunc_integer_promotion_wrapper(np.sinh), 17137 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 17138 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 17139 assert_autodiffed=True, 17140 supports_forward_ad=True, 17141 supports_fwgrad_bwgrad=True, 17142 supports_sparse=True, 17143 supports_sparse_csr=True, 17144 supports_sparse_csc=True, 17145 supports_sparse_bsr=True, 17146 supports_sparse_bsc=True, 17147 promotes_int_to_float=True, 17148 decorators=(precisionOverride({torch.float16: 1e-2}),), 17149 skips=( 17150 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17151 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 17152 active_if=(IS_MACOS or IS_WINDOWS)), 17153 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17154 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 17155 active_if=(IS_MACOS or IS_WINDOWS)), 17156 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17157 dtypes=(torch.cdouble,)), 17158 # Reference: https://github.com/pytorch/pytorch/issues/48641 17159 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17160 device_type='cpu', dtypes=[torch.int8]), 17161 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 17162 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 17163 )), 17164 UnaryUfuncInfo('sign', 17165 ref=reference_sign, 17166 dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), 17167 dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.half), 17168 supports_forward_ad=True, 17169 supports_fwgrad_bwgrad=True, 17170 supports_sparse=True, 17171 supports_sparse_csr=True, 17172 supports_sparse_csc=True, 17173 supports_sparse_bsr=True, 17174 supports_sparse_bsc=True, 17175 skips=( 17176 # Reference: https://github.com/pytorch/pytorch/issues/41245 17177 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17178 dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]), 17179 )), 17180 UnaryUfuncInfo('sgn', 17181 ref=reference_sgn, 17182 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), 17183 backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), 17184 backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16, torch.half, torch.chalf), 17185 supports_forward_ad=True, 17186 supports_fwgrad_bwgrad=True, 17187 supports_sparse=True, 17188 supports_sparse_csr=True, 17189 supports_sparse_csc=True, 17190 supports_sparse_bsr=True, 17191 supports_sparse_bsc=True, 17192 skips=( 17193 # Reference: https://github.com/pytorch/pytorch/issues/41245 17194 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17195 dtypes=[torch.bfloat16, torch.float16, torch.float32, torch.float64]), 17196 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 17197 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 17198 )), 17199 OpInfo('split', 17200 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), 17201 sample_inputs_func=partial(sample_inputs_split, list_args=False), 17202 supports_forward_ad=True, 17203 supports_fwgrad_bwgrad=True, 17204 supports_out=False, 17205 autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused 17206 autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused 17207 assert_autodiffed=True), 17208 OpInfo('split', 17209 # Cannot declare this aten_name because of 17210 # test_variant_consistency_jit_split_list_args_cpu_float32 17211 decomp_aten_name='split_with_sizes', 17212 variant_test_name='list_args', 17213 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), 17214 sample_inputs_func=partial(sample_inputs_split, list_args=True), 17215 supports_forward_ad=True, 17216 supports_fwgrad_bwgrad=True, 17217 supports_out=False), 17218 # `unsafe_split` supports only `int` for split_size argument 17219 OpInfo('unsafe_split', 17220 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), 17221 sample_inputs_func=partial(sample_inputs_split, list_args=False), 17222 supports_forward_ad=True, 17223 supports_fwgrad_bwgrad=True, 17224 supports_out=False, 17225 autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused 17226 autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused 17227 assert_autodiffed=True, 17228 check_batched_forward_grad=False), 17229 OpInfo('split_with_sizes', 17230 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), 17231 sample_inputs_func=sample_inputs_split_with_sizes, 17232 autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused 17233 autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused 17234 supports_out=False, 17235 supports_forward_ad=True, 17236 supports_fwgrad_bwgrad=True, 17237 assert_autodiffed=True), 17238 OpInfo('split_with_sizes_copy', 17239 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), 17240 sample_inputs_func=sample_inputs_split_with_sizes, 17241 supports_out=True, 17242 supports_forward_ad=True, 17243 supports_fwgrad_bwgrad=True, 17244 skips=( 17245 # No error raised 17246 DecorateInfo(unittest.expectedFailure, "TestCommon", "test_out_requires_grad_error"), 17247 )), 17248 BinaryUfuncInfo('__radd__', 17249 op=torch.Tensor.__radd__, 17250 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), 17251 supports_out=False, 17252 skips=( 17253 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 17254 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), 17255 17256 ), 17257 assert_autodiffed=True, 17258 supports_forward_ad=True, 17259 supports_fwgrad_bwgrad=True, 17260 autodiff_nonfusible_nodes=['aten::add'],), 17261 BinaryUfuncInfo('__rdiv__', 17262 op=torch.Tensor.__rdiv__, 17263 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), 17264 promotes_int_to_float=True, 17265 lhs_make_tensor_kwargs={'exclude_zero': True}, 17266 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 17267 gradcheck_fast_mode=True, 17268 supports_out=False, 17269 skips=( 17270 # https://github.com/pytorch/pytorch/issues/76806 17271 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), 17272 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 17273 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), 17274 ), 17275 supports_forward_ad=True, 17276 supports_fwgrad_bwgrad=True, 17277 assert_autodiffed=True, 17278 autodiff_nonfusible_nodes=['aten::mul', 'aten::reciprocal'],), 17279 BinaryUfuncInfo('__rmul__', 17280 op=torch.Tensor.__rmul__, 17281 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool), 17282 supports_out=False, 17283 skips=( 17284 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 17285 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), 17286 ), 17287 assert_autodiffed=True, 17288 supports_forward_ad=True, 17289 supports_fwgrad_bwgrad=True, 17290 autodiff_nonfusible_nodes=['aten::mul'],), 17291 BinaryUfuncInfo('__rand__', 17292 op=torch.Tensor.__rand__, 17293 dtypes=integral_types_and(torch.bool), 17294 supports_out=False, 17295 supports_autograd=False, 17296 supports_forward_ad=True, 17297 skips=( 17298 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 17299 )), 17300 BinaryUfuncInfo('__ror__', 17301 op=torch.Tensor.__ror__, 17302 dtypes=integral_types_and(torch.bool), 17303 supports_out=False, 17304 supports_autograd=False, 17305 supports_forward_ad=True, 17306 skips=( 17307 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 17308 )), 17309 BinaryUfuncInfo('__rxor__', 17310 op=torch.Tensor.__rxor__, 17311 dtypes=integral_types_and(torch.bool), 17312 supports_out=False, 17313 supports_autograd=False, 17314 supports_forward_ad=True, 17315 skips=( 17316 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 17317 )), 17318 OpInfo('__rmatmul__', 17319 op=torch.Tensor.__rmatmul__, 17320 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), 17321 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, 17322 *[torch.bfloat16] 17323 if SM53OrLater or TEST_WITH_ROCM else []), 17324 assert_autodiffed=True, 17325 sample_inputs_func=partial(sample_inputs_matmul, is_rmatmul=True), 17326 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 17327 gradcheck_fast_mode=True, 17328 supports_out=False, 17329 supports_forward_ad=True, 17330 supports_fwgrad_bwgrad=True, 17331 check_batched_forward_grad=False, 17332 decorators=( 17333 # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 17334 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=not SM53OrLater), 17335 DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1.2e-03)}), 17336 'TestMathBits', 'test_conj_view'), 17337 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1.2e-03)}), 17338 'TestCommon', 'test_noncontiguous_samples'), 17339 DecorateInfo(toleranceOverride({torch.complex64: tol(atol=1e-05, rtol=1e-05)}), 17340 "TestDecomp", "test_comprehensive", device_type="cuda", 17341 active_if=TEST_WITH_ROCM), 17342 ), 17343 skips=( 17344 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 17345 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), 17346 # https://github.com/pytorch/pytorch/issues/67470 17347 DecorateInfo(unittest.skip("67470!"), 17348 'TestCommon', 'test_noncontiguous_samples', 17349 device_type='cpu', dtypes=(torch.long,)), 17350 # Fails on XLA. 17351 # AssertionError: False is not true : Tensors failed to compare as equal 17352 DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)), 17353 # https://github.com/pytorch/pytorch/issues/71774 17354 DecorateInfo(unittest.skip('Skipped!'), 'TestNNCOpInfo', 'test_nnc_correctness', 17355 device_type='cpu', dtypes=(torch.long,)), 17356 )), 17357 BinaryUfuncInfo('__rmod__', 17358 op=torch.Tensor.__rmod__, 17359 dtypes=floating_types_and(torch.bfloat16, torch.half,), 17360 dtypesIfCUDA=all_types_and(torch.bfloat16, torch.half), 17361 # https://github.com/pytorch/pytorch/issues/80411 17362 gradcheck_fast_mode=True, 17363 supports_out=False, 17364 supports_forward_ad=True, 17365 supports_fwgrad_bwgrad=True, 17366 supports_one_python_scalar=True, 17367 skips=( 17368 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 17369 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), 17370 ), 17371 # Support autograd after torch.remainder(Tensor, Tensor) supports 17372 # autograd of the second argument. 17373 # https://github.com/pytorch/pytorch/pull/58476/files#r637167630 17374 # supports_autograd=False, 17375 assert_autodiffed=True, 17376 autodiff_nonfusible_nodes=['aten::remainder'],), 17377 BinaryUfuncInfo('__rpow__', 17378 op=torch.Tensor.__rpow__, 17379 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), 17380 # Reference: https://github.com/pytorch/pytorch/issues/54774 17381 # "log2" "_vml_cpu" not implemented for Half 17382 backward_dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), 17383 supports_out=False, 17384 supports_forward_ad=True, 17385 supports_fwgrad_bwgrad=True, 17386 supports_one_python_scalar=True, 17387 skips=( 17388 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 17389 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), 17390 # TODO: FIXME tolerance is too high 17391 DecorateInfo(unittest.skip('Skipped!'), 'TestFwdGradients'), 17392 DecorateInfo(unittest.skip('Skipped!'), 'TestBwdGradients'), 17393 ), 17394 assert_autodiffed=True, 17395 autodiff_nonfusible_nodes=['aten::pow'],), 17396 BinaryUfuncInfo('__rsub__', 17397 op=torch.Tensor.__rsub__, 17398 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), 17399 supports_forward_ad=True, 17400 supports_fwgrad_bwgrad=True, 17401 supports_out=False, 17402 supports_one_python_scalar=True, 17403 skips=( 17404 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 17405 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',), 17406 ), 17407 assert_autodiffed=True, 17408 autodiff_nonfusible_nodes=['aten::rsub'],), 17409 BinaryUfuncInfo('rsub', 17410 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), 17411 supports_forward_ad=True, 17412 supports_fwgrad_bwgrad=True, 17413 supports_out=False, 17414 supports_inplace_autograd=False, 17415 assert_autodiffed=None, 17416 sample_inputs_func=sample_inputs_add_sub), 17417 OpInfo('select', 17418 aten_backward_name='select_backward', 17419 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), 17420 sample_inputs_func=sample_inputs_select, 17421 assert_jit_shape_analysis=True, 17422 supports_forward_ad=True, 17423 supports_fwgrad_bwgrad=True, 17424 supports_out=False), 17425 OpInfo('select_scatter', 17426 dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool), 17427 sample_inputs_func=sample_inputs_select_scatter, 17428 supports_forward_ad=True, 17429 supports_fwgrad_bwgrad=True, 17430 supports_out=False), 17431 OpInfo('slice', 17432 op=torch.ops.aten.slice.Tensor, 17433 dtypes=all_types_and_complex_and(torch.bfloat16, torch.half, torch.bool, torch.chalf), 17434 sample_inputs_func=sample_inputs_slice, 17435 gradcheck_fast_mode=True, 17436 supports_forward_ad=True, 17437 supports_fwgrad_bwgrad=True, 17438 supports_scripting=False, 17439 supports_inplace_autograd=False, 17440 supports_out=False), 17441 OpInfo('slice_scatter', 17442 dtypes=all_types_and(torch.bfloat16, torch.half, torch.bool), 17443 sample_inputs_func=sample_inputs_slice_scatter, 17444 # https://github.com/pytorch/pytorch/issues/80411 17445 gradcheck_fast_mode=True, 17446 supports_forward_ad=True, 17447 supports_fwgrad_bwgrad=True, 17448 supports_out=True), 17449 UnaryUfuncInfo('signbit', 17450 ref=np.signbit, 17451 dtypes=all_types_and(torch.bool, torch.bfloat16, torch.half), 17452 supports_sparse=True, 17453 supports_sparse_csr=True, 17454 supports_sparse_csc=True, 17455 supports_sparse_bsr=True, 17456 supports_sparse_bsc=True, 17457 supports_autograd=False,), 17458 UnaryUfuncInfo('tan', 17459 ref=np.tan, 17460 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 17461 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 17462 decorators=(DecorateInfo( 17463 toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}), 17464 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17465 device_type='cuda'),), 17466 assert_autodiffed=True, 17467 supports_forward_ad=True, 17468 supports_fwgrad_bwgrad=True, 17469 supports_sparse=True, 17470 supports_sparse_csr=True, 17471 supports_sparse_csc=True, 17472 supports_sparse_bsr=True, 17473 supports_sparse_bsc=True, 17474 promotes_int_to_float=True, 17475 skips=( 17476 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17477 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 17478 active_if=(IS_MACOS or IS_WINDOWS)), 17479 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17480 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 17481 active_if=(IS_MACOS or IS_WINDOWS)), 17482 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 17483 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 17484 # FIXME: 17485 # Mismatched elements: 2 / 400 (0.5%) 17486 # Greatest absolute difference: inf at index (7, 16) (up to 1e-05 allowed) 17487 # Greatest relative difference: nan at index (7, 16) (up to 0.001 allowed) 17488 DecorateInfo( 17489 unittest.skip("Skipped!"), 17490 "TestInductorOpInfo", 17491 "test_comprehensive", 17492 dtypes=(torch.float16,), 17493 device_type="cuda", 17494 ), 17495 ), 17496 # tan(pi/2 * odd_number) is nan 17497 reference_numerics_filter=NumericsFilter( 17498 condition=lambda x: close_to_int(x / (math.pi * 0.5)), safe_val=math.pi)), 17499 UnaryUfuncInfo('tanh', 17500 ref=np.tanh, 17501 aten_backward_name='tanh_backward', 17502 aliases=('nn.functional.tanh',), 17503 decorators=(precisionOverride({torch.bfloat16: 1e-2}), 17504 DecorateInfo( 17505 toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}), 17506 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17507 device_type='cuda'),), 17508 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 17509 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 17510 assert_autodiffed=True, 17511 assert_jit_shape_analysis=True, 17512 supports_forward_ad=True, 17513 supports_fwgrad_bwgrad=True, 17514 supports_sparse=True, 17515 supports_sparse_csr=True, 17516 supports_sparse_csc=True, 17517 supports_sparse_bsr=True, 17518 supports_sparse_bsc=True, 17519 promotes_int_to_float=True, 17520 skips=( 17521 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17522 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 17523 active_if=(IS_MACOS or IS_WINDOWS)), 17524 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17525 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 17526 active_if=(IS_MACOS or IS_WINDOWS)), 17527 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 17528 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 17529 ), 17530 # tan(j * pi/2 * odd_number) is nan 17531 reference_numerics_filter=NumericsFilter( 17532 condition=lambda x: (close_to_int(x / (math.pi * 0.5j)) 17533 if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), 17534 safe_val=0)), 17535 OpInfo('tensor_split', 17536 ref=np.array_split, 17537 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), 17538 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), 17539 supports_out=False, 17540 supports_forward_ad=True, 17541 supports_fwgrad_bwgrad=True, 17542 skips=( 17543 # Pre-existing condition; Needs to be fixed 17544 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), 17545 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), 17546 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), 17547 ), 17548 sample_inputs_func=sample_inputs_tensor_split,), 17549 OpInfo('hsplit', 17550 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), 17551 supports_out=False, 17552 supports_forward_ad=True, 17553 supports_fwgrad_bwgrad=True, 17554 # See https://github.com/pytorch/pytorch/pull/78358 17555 check_batched_forward_grad=False, 17556 sample_inputs_func=sample_inputs_hsplit, 17557 error_inputs_func=error_inputs_hsplit,), 17558 OpInfo('vsplit', 17559 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), 17560 supports_out=False, 17561 supports_forward_ad=True, 17562 supports_fwgrad_bwgrad=True, 17563 # See https://github.com/pytorch/pytorch/pull/78358 17564 check_batched_forward_grad=False, 17565 sample_inputs_func=sample_inputs_vsplit, 17566 error_inputs_func=error_inputs_vsplit,), 17567 OpInfo('dsplit', 17568 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.bfloat16, torch.float16), 17569 supports_out=False, 17570 supports_forward_ad=True, 17571 supports_fwgrad_bwgrad=True, 17572 # See https://github.com/pytorch/pytorch/pull/78358 17573 check_batched_forward_grad=False, 17574 sample_inputs_func=sample_inputs_dsplit, 17575 error_inputs_func=error_inputs_dsplit,), 17576 OpInfo('triangular_solve', 17577 op=torch.triangular_solve, 17578 dtypes=floating_and_complex_types(), 17579 sample_inputs_func=sample_inputs_legacy_solve, 17580 check_batched_gradgrad=False, 17581 supports_forward_ad=True, 17582 supports_fwgrad_bwgrad=True, 17583 gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs), 17584 decorators=[ 17585 skipCUDAIfNoMagma, 17586 skipCPUIfNoLapack, 17587 DecorateInfo( 17588 toleranceOverride({torch.float32: tol(atol=3e-5, rtol=3e-6)}), 17589 'TestConsistency', 'test_output_match', device_type='cpu', 17590 ), 17591 ], 17592 skips=( 17593 # AssertionError: Scalars are not equal! 17594 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 17595 # Gradcheck fails 17596 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', 17597 dtypes=floating_and_complex_types()), 17598 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', 17599 device_type='mps', dtypes=[torch.float32]), 17600 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', 17601 device_type='mps', dtypes=[torch.float32]), 17602 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', 17603 device_type='mps', dtypes=[torch.float32]), 17604 )), 17605 UnaryUfuncInfo('trunc', 17606 aliases=('fix', ), 17607 ref=np.trunc, 17608 dtypes=all_types_and(torch.half, torch.bfloat16), 17609 supports_forward_ad=True, 17610 supports_fwgrad_bwgrad=True, 17611 supports_sparse=True, 17612 skips=( 17613 DecorateInfo(unittest.expectedFailure, 17614 'TestNNCOpInfo', 17615 'test_nnc_correctness', 17616 dtypes=tuple(t for t in integral_types() if t != torch.uint8)), 17617 ), 17618 supports_sparse_csr=True, 17619 supports_sparse_csc=True, 17620 supports_sparse_bsr=True, 17621 supports_sparse_bsc=True, 17622 assert_autodiffed=True), 17623 UnaryUfuncInfo('exp2', 17624 aliases=('special.exp2', ), 17625 ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2), 17626 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 17627 supports_forward_ad=True, 17628 supports_fwgrad_bwgrad=True, 17629 promotes_int_to_float=True, 17630 skips=( 17631 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17632 dtypes=[torch.cdouble]), 17633 # Reference: https://github.com/pytorch/pytorch/issues/48010 17634 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17635 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 17636 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17637 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 17638 )), 17639 UnaryUfuncInfo('expm1', 17640 aliases=('special.expm1', ), 17641 ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1), 17642 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 17643 supports_forward_ad=True, 17644 supports_fwgrad_bwgrad=True, 17645 supports_sparse=True, 17646 supports_sparse_csr=True, 17647 supports_sparse_csc=True, 17648 supports_sparse_bsr=True, 17649 supports_sparse_bsc=True, 17650 promotes_int_to_float=True, 17651 assert_autodiffed=True, 17652 skips=( 17653 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17654 device_type='cuda', dtypes=[torch.complex128]), 17655 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 17656 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 17657 )), 17658 UnaryUfuncInfo('nan_to_num', 17659 ref=np.nan_to_num, 17660 dtypes=all_types_and(torch.half, torch.bool, torch.bfloat16), 17661 dtypesIfCUDA=all_types_and(torch.half, torch.bool, torch.bfloat16), 17662 supports_forward_ad=True, 17663 supports_fwgrad_bwgrad=True, 17664 supports_sparse=True, 17665 skips=( 17666 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 17667 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 17668 ), 17669 # Passing numpy_kwargs via sample_kwargs, as numpy does comparison 17670 # with BFloat16 in float, since it currently doesn't support BFloat16. 17671 # Ref: https://github.com/pytorch/pytorch/issues/57982#issuecomment-839150556 17672 sample_kwargs=lambda device, dtype, input: ({}, 17673 {'posinf': torch.finfo(torch.bfloat16).max, 17674 'neginf': torch.finfo(torch.bfloat16).min}) 17675 if dtype is torch.bfloat16 else ({}, {})), 17676 UnaryUfuncInfo('reciprocal', 17677 ref=np_unary_ufunc_integer_promotion_wrapper(np.reciprocal), 17678 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 17679 assert_autodiffed=True, 17680 supports_forward_ad=True, 17681 supports_fwgrad_bwgrad=True, 17682 promotes_int_to_float=True, 17683 skips=( 17684 # Reference: https://github.com/pytorch/pytorch/issues/45690 17685 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17686 dtypes=[torch.cfloat, torch.cdouble]), 17687 )), 17688 UnaryUfuncInfo('rsqrt', 17689 ref=lambda x: np.reciprocal(np.sqrt(x)), 17690 domain=(0, None), 17691 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 17692 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 17693 decorators=(precisionOverride({torch.half: 5e-2}),), 17694 assert_autodiffed=True, 17695 supports_forward_ad=True, 17696 supports_fwgrad_bwgrad=True, 17697 promotes_int_to_float=True, 17698 skips=( 17699 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17700 dtypes=(torch.cfloat, torch.cdouble)), 17701 # AssertionError: Tensor-likes are not close! 17702 # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed) 17703 # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) 17704 DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', 17705 dtypes=(torch.chalf,)), 17706 )), 17707 UnaryUfuncInfo('sqrt', 17708 ref=np.sqrt, 17709 supports_sparse=True, 17710 domain=(0, None), 17711 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 17712 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 17713 assert_autodiffed=True, 17714 supports_forward_ad=True, 17715 supports_sparse_csr=True, 17716 supports_sparse_csc=True, 17717 supports_sparse_bsr=True, 17718 supports_sparse_bsc=True, 17719 supports_fwgrad_bwgrad=True, 17720 promotes_int_to_float=True, 17721 decorators=( 17722 precisionOverride({torch.bfloat16: 7e-2}), 17723 DecorateInfo( 17724 toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), 17725 'TestUnaryUfuncs', 'test_reference_numerics_large'), 17726 ), 17727 skips=( 17728 # Reference: https://github.com/pytorch/pytorch/issues/47358 17729 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17730 device_type='cpu', dtypes=(torch.cfloat, torch.cdouble), 17731 active_if=IS_MACOS), 17732 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 17733 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 17734 )), 17735 UnaryUfuncInfo('square', 17736 ref=np.square, 17737 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 17738 decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),), 17739 supports_forward_ad=True, 17740 supports_fwgrad_bwgrad=True, 17741 skips=( 17742 # Reference: https://github.com/pytorch/pytorch/issues/52549 17743 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 17744 dtypes=[torch.cfloat, torch.cdouble]), 17745 # >>> t = torch.tensor(complex(-0.01, float("inf"))) 17746 # >>> np.square(t.numpy()) 17747 # (-inf-infj) 17748 # >>> t.square() 17749 # tensor(-inf-infj) 17750 # >>> t.cuda().square() 17751 # tensor(inf+nanj, device='cuda:0') 17752 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 17753 device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), 17754 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', 17755 dtypes=[torch.bool]), 17756 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace', 17757 dtypes=[torch.bool]), 17758 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_inplace', 17759 dtypes=[torch.bool]), 17760 ),), 17761 OpInfo('lerp', 17762 dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), 17763 dtypesIfCUDA=floating_and_complex_types_and(torch.chalf, torch.half, torch.bfloat16), 17764 sample_inputs_func=sample_inputs_lerp, 17765 supports_forward_ad=True, 17766 supports_fwgrad_bwgrad=True, 17767 assert_autodiffed=True), 17768 UnaryUfuncInfo('angle', 17769 ref=np.angle, 17770 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), 17771 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool), 17772 decorators=(precisionOverride({torch.float16: 1e-2, 17773 torch.bfloat16: 1e-2}),), 17774 backward_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), 17775 backward_dtypesIfCUDA=floating_and_complex_types_and(torch.chalf), 17776 supports_forward_ad=True, 17777 supports_fwgrad_bwgrad=True, 17778 supports_sparse_csr=True, 17779 supports_sparse_csc=True, 17780 supports_sparse_bsr=True, 17781 supports_sparse_bsc=True, 17782 supports_complex_to_float=True, 17783 skips=( 17784 # Ref: https://github.com/pytorch/pytorch/issues/78413 17785 DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_small', 17786 dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64),), 17787 )), 17788 UnaryUfuncInfo('isfinite', 17789 ref=np.isfinite, 17790 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 17791 supports_out=False, 17792 supports_autograd=False), 17793 UnaryUfuncInfo('isinf', 17794 ref=np.isinf, 17795 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 17796 supports_out=False, 17797 supports_sparse=True, 17798 supports_sparse_csr=True, 17799 supports_sparse_csc=True, 17800 supports_sparse_bsr=True, 17801 supports_sparse_bsc=True, 17802 supports_autograd=False), 17803 UnaryUfuncInfo('isposinf', 17804 ref=np.isposinf, 17805 dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), 17806 supports_sparse=True, 17807 supports_sparse_csr=True, 17808 supports_sparse_csc=True, 17809 supports_sparse_bsr=True, 17810 supports_sparse_bsc=True, 17811 supports_autograd=False), 17812 UnaryUfuncInfo('isneginf', 17813 ref=np.isneginf, 17814 dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), 17815 supports_sparse=True, 17816 supports_sparse_csr=True, 17817 supports_sparse_csc=True, 17818 supports_sparse_bsr=True, 17819 supports_sparse_bsc=True, 17820 supports_autograd=False), 17821 UnaryUfuncInfo('isreal', 17822 ref=np.isreal, 17823 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 17824 supports_out=False, 17825 supports_autograd=False), 17826 UnaryUfuncInfo('isnan', 17827 ref=np.isnan, 17828 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), 17829 supports_out=False, 17830 supports_sparse=True, 17831 supports_sparse_csr=True, 17832 supports_sparse_csc=True, 17833 supports_sparse_bsr=True, 17834 supports_sparse_bsc=True, 17835 supports_autograd=False), 17836 OpInfo('einsum', 17837 # we need this lambda because SampleInput expects tensor input as the first argument 17838 # TODO(@heitorschueroff) update SampleInput to handle such cases 17839 op=lambda tensors, equation: torch.einsum(equation, tensors), 17840 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 17841 dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), 17842 backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), 17843 supports_out=False, 17844 supports_forward_ad=True, 17845 supports_fwgrad_bwgrad=True, 17846 check_batched_forward_grad=False, 17847 # See https://github.com/pytorch/pytorch/issues/66357 17848 sample_inputs_func=sample_inputs_einsum, 17849 skips=( 17850 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 17851 # test does not work with passing lambda for op 17852 # there's a test `test_einsum` in `test_jit.py` to handle this case 17853 # AssertionError: JIT Test does not execute any logic 17854 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 17855 )), 17856 OpInfo('svd', 17857 op=torch.svd, 17858 dtypes=floating_and_complex_types(), 17859 sample_inputs_func=sample_inputs_svd, 17860 # Runs very slowly on slow-gradcheck - alternatively reduce input sizes 17861 gradcheck_fast_mode=True, 17862 supports_forward_ad=True, 17863 supports_fwgrad_bwgrad=True, 17864 check_batched_forward_grad=False, 17865 # We're using at::allclose, which does not have a batching rule 17866 check_batched_grad=False, 17867 check_batched_gradgrad=False, 17868 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], 17869 skips=( 17870 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 17871 DecorateInfo( 17872 unittest.skip("Skipped!"), 17873 'TestSchemaCheckModeOpInfo', 17874 'test_schema_correctness', 17875 dtypes=(torch.complex64, torch.complex128)), 17876 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', 17877 device_type='mps', dtypes=[torch.float32]), 17878 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', 17879 device_type='mps', dtypes=[torch.float32]), 17880 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', 17881 device_type='mps', dtypes=[torch.float32]), 17882 )), 17883 OpInfo('svd_lowrank', 17884 op=lambda *args, **kwargs: wrapper_set_seed( 17885 lambda a, b, **kwargs: torch.svd_lowrank(a @ b.mT, **kwargs), 17886 *args, **kwargs 17887 ), 17888 dtypes=floating_and_complex_types(), 17889 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 17890 gradcheck_fast_mode=True, 17891 supports_out=False, 17892 # Due to the use of randomness 17893 check_batched_grad=False, 17894 check_batched_gradgrad=False, 17895 check_batched_forward_grad=False, 17896 supports_fwgrad_bwgrad=True, 17897 supports_forward_ad=True, 17898 sample_inputs_func=sample_inputs_svd_lowrank, 17899 decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, 17900 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), 17901 torch.complex64: tol(atol=1e-02, rtol=1e-02)}), 17902 'TestCommon', 'test_noncontiguous_samples'), 17903 # FIXME This should be the following, but the toleranceOverride does not seem to do anything! 17904 # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), 17905 # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), 17906 DecorateInfo(unittest.skip("See comment above"), 17907 'TestFwdGradients', 17908 'test_fn_fwgrad_bwgrad', 17909 dtypes=[torch.complex128]), 17910 ], 17911 skips=( 17912 # test does not work with passing lambda for op 17913 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 17914 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 17915 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 17916 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 17917 DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', 17918 dtypes=(torch.complex64, torch.complex128)), 17919 DecorateInfo(slowTest, 'TestCompositeCompliance', 'test_forward_ad'), 17920 )), 17921 OpInfo('pca_lowrank', 17922 op=lambda *args, **kwargs: wrapper_set_seed( 17923 lambda a, b, **kwargs: torch.pca_lowrank(a @ b.mT, **kwargs), 17924 *args, **kwargs 17925 ), 17926 dtypes=floating_and_complex_types(), 17927 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 17928 gradcheck_fast_mode=True, 17929 supports_out=False, 17930 check_batched_forward_grad=False, 17931 check_batched_grad=False, 17932 check_batched_gradgrad=False, 17933 supports_forward_ad=True, 17934 supports_fwgrad_bwgrad=True, 17935 sample_inputs_func=sample_inputs_pca_lowrank, 17936 decorators=[skipCUDAIfNoCusolver, skipCPUIfNoLapack, with_tf32_off, 17937 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03), 17938 torch.complex64: tol(atol=4e-02, rtol=4e-02)}), 17939 'TestCommon', 'test_noncontiguous_samples'), 17940 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-05, rtol=5e-05)}), 17941 'TestOperators', 'test_grad'), 17942 # FIXME This should be the following, but the toleranceOverride does not seem to do anything! 17943 # DecorateInfo(toleranceOverride({torch.complex128: tol(atol=1e-04, rtol=1e-04)}), 17944 # 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), 17945 DecorateInfo(unittest.skip("See comment above"), 17946 'TestFwdGradients', 17947 'test_fn_fwgrad_bwgrad', 17948 dtypes=[torch.complex128]), 17949 DecorateInfo( 17950 toleranceOverride({torch.float32: tol(atol=3e-5, rtol=1e-3)}), 17951 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'), 17952 ], 17953 skips=( 17954 # test does not work with passing lambda for op 17955 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 17956 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 17957 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 17958 DecorateInfo(unittest.expectedFailure, 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', 17959 dtypes=(torch.complex64, torch.complex128)), 17960 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 17961 )), 17962 BinaryUfuncInfo('polar', 17963 dtypes=floating_types(), 17964 # this function is undefined if 'abs' values are <0 17965 supports_forward_ad=True, 17966 lhs_make_tensor_kwargs=dict(low=0), 17967 supports_rhs_python_scalar=False, 17968 skips=( 17969 # RuntimeError: Expected object of scalar type Float but got scalar type Double for second argument 17970 DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 'test_type_promotion'), 17971 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), 17972 # GradcheckError: Jacobian computed with forward mode mismatch for output 0 with respect to input 0 17973 # Numerical: 17974 # tensor([[0.]], dtype=torch.float64) 17975 # Analytical: 17976 # tensor([[-0.0047]], dtype=torch.float64, grad_fn=<CopySlices>) 17977 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), 17978 )), 17979 # TODO(@kshitij12345): Refactor similar to `mvlgamma` entries. 17980 # To test reference numerics against multiple values of argument `n`, 17981 # we make multiple OpInfo entries with each entry corresponding to different value of n (currently 0 to 4). 17982 # We run the op tests from test_ops.py only for `n=0` to avoid redundancy in testing. 17983 UnaryUfuncInfo('polygamma', 17984 op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs), 17985 variant_test_name='polygamma_n_0', 17986 ref=reference_polygamma if TEST_SCIPY else None, 17987 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 17988 dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), 17989 supports_forward_ad=True, 17990 supports_fwgrad_bwgrad=True, 17991 promotes_int_to_float=True, 17992 sample_inputs_func=sample_inputs_polygamma, 17993 skips=( 17994 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 17995 ), 17996 sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0}), 17997 # polygamma functions have multiple singularities at x having non-positive integer value 17998 reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), 17999 safe_val=1)), 18000 *(UnaryUfuncInfo('polygamma', 18001 op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs), 18002 variant_test_name=f'polygamma_n_{n_}', 18003 ref=reference_polygamma if TEST_SCIPY else None, 18004 dtypes=all_types_and(torch.bool, torch.bfloat16), 18005 dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), 18006 supports_forward_ad=True, 18007 supports_fwgrad_bwgrad=True, 18008 promotes_int_to_float=True, 18009 sample_inputs_func=sample_inputs_polygamma, 18010 decorators=( 18011 DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-3)}), 'TestUnaryUfuncs'), 18012 DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e1, rtol=1e-1), 18013 torch.float32: tol(atol=1e-4, rtol=1e-2)}), 18014 'TestUnaryUfuncs', 'test_reference_numerics_normal', 18015 active_if=IS_WINDOWS), 18016 ), 18017 skips=( 18018 # Redundant tests 18019 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), 18020 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), 18021 DecorateInfo(unittest.skip("Skipped!"), 'TestJit'), 18022 DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'), 18023 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), 18024 # Mismatch: https://github.com/pytorch/pytorch/issues/55357 18025 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), 18026 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'), 18027 ), 18028 sample_kwargs=lambda device, dtype, input: ({'n': n_}, {'n': n_}), 18029 # polygamma functions have multiple singularities at x having non-positive integer value 18030 reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), 18031 safe_val=1)) 18032 for n_ in (1, 2, 3, 4)), 18033 OpInfo('ravel', 18034 ref=np.ravel, 18035 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 18036 supports_out=False, 18037 supports_forward_ad=True, 18038 supports_fwgrad_bwgrad=True, 18039 # See https://github.com/pytorch/pytorch/pull/78358 18040 check_batched_forward_grad=False, 18041 sample_inputs_func=sample_inputs_ravel, 18042 ), 18043 OpInfo('unravel_index', 18044 ref=np.unravel_index, 18045 dtypes=integral_types_and(), 18046 supports_out=False, 18047 supports_autograd=False, 18048 sample_inputs_func=sample_inputs_unravel_index, 18049 ), 18050 OpInfo('reshape', 18051 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 18052 sample_inputs_func=sample_inputs_view_reshape, 18053 reference_inputs_func=reference_inputs_view_reshape, 18054 error_inputs_func=error_inputs_view_reshape, 18055 supports_out=False, 18056 supports_forward_ad=True, 18057 supports_fwgrad_bwgrad=True, 18058 ), 18059 OpInfo('reshape_as', 18060 op=lambda x, other: x.reshape_as(other), 18061 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 18062 sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True), 18063 reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True), 18064 error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True), 18065 supports_out=False, 18066 supports_forward_ad=True, 18067 supports_fwgrad_bwgrad=True, 18068 skips=( 18069 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18070 )), 18071 OpInfo('view', 18072 op=lambda x, shape: x.view(shape), 18073 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 18074 supports_out=False, 18075 supports_forward_ad=True, 18076 supports_fwgrad_bwgrad=True, 18077 assert_jit_shape_analysis=True, 18078 sample_inputs_func=sample_inputs_view_reshape, 18079 reference_inputs_func=reference_inputs_view_reshape, 18080 error_inputs_func=error_inputs_view_reshape, 18081 skips=( 18082 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18083 # RuntimeError: view size is not compatible with input tensor's size and stride 18084 # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. 18085 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), 18086 )), 18087 OpInfo('view_as', 18088 op=lambda x, other: x.view_as(other), 18089 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 18090 supports_out=False, 18091 supports_forward_ad=True, 18092 supports_fwgrad_bwgrad=True, 18093 sample_inputs_func=partial(sample_inputs_view_reshape, tensor_arg=True), 18094 reference_inputs_func=partial(reference_inputs_view_reshape, tensor_arg=True), 18095 error_inputs_func=partial(error_inputs_view_reshape, tensor_arg=True), 18096 skips=( 18097 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18098 # RuntimeError: view size is not compatible with input tensor's size and stride 18099 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides") 18100 )), 18101 OpInfo('atleast_1d', 18102 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 18103 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 18104 gradcheck_fast_mode=True, 18105 supports_out=False, 18106 supports_forward_ad=True, 18107 supports_fwgrad_bwgrad=True, 18108 # See https://github.com/pytorch/pytorch/pull/78358 18109 check_batched_forward_grad=False, 18110 sample_inputs_func=sample_inputs_atleast1d2d3d, 18111 skips=( 18112 # JIT does not support variadic tensors. 18113 # RuntimeError: input->type()->kind() == TypeKind::OptionalType 18114 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, 18115 # please report a bug to PyTorch. 18116 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18117 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), 18118 ), 18119 ), 18120 OpInfo('atleast_2d', 18121 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 18122 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 18123 gradcheck_fast_mode=True, 18124 supports_out=False, 18125 supports_forward_ad=True, 18126 supports_fwgrad_bwgrad=True, 18127 # See https://github.com/pytorch/pytorch/pull/78358 18128 check_batched_forward_grad=False, 18129 skips=( 18130 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18131 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), 18132 ), 18133 sample_inputs_func=sample_inputs_atleast1d2d3d, 18134 ), 18135 OpInfo('atleast_3d', 18136 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 18137 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 18138 gradcheck_fast_mode=True, 18139 supports_out=False, 18140 supports_forward_ad=True, 18141 supports_fwgrad_bwgrad=True, 18142 # See https://github.com/pytorch/pytorch/pull/78358 18143 check_batched_forward_grad=False, 18144 skips=( 18145 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18146 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]), 18147 ), 18148 sample_inputs_func=sample_inputs_atleast1d2d3d, 18149 ), 18150 OpInfo('flatten', 18151 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 18152 ref=reference_flatten, 18153 supports_out=False, 18154 supports_forward_ad=True, 18155 supports_fwgrad_bwgrad=True, 18156 # See https://github.com/pytorch/pytorch/pull/78358 18157 check_batched_forward_grad=False, 18158 sample_inputs_func=sample_inputs_flatten, 18159 reference_inputs_func=reference_inputs_flatten, 18160 ), 18161 OpInfo('unflatten', 18162 op=torch.unflatten, 18163 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 18164 supports_out=False, 18165 supports_forward_ad=True, 18166 supports_fwgrad_bwgrad=True, 18167 sample_inputs_func=sample_inputs_unflatten, 18168 ), 18169 OpInfo('column_stack', 18170 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 18171 supports_forward_ad=True, 18172 supports_fwgrad_bwgrad=True, 18173 # See https://github.com/pytorch/pytorch/pull/78358 18174 check_batched_forward_grad=False, 18175 sample_inputs_func=sample_inputs_column_stack,), 18176 OpInfo('pinverse', 18177 op=torch.pinverse, 18178 dtypes=floating_and_complex_types(), 18179 check_batched_grad=False, 18180 check_batched_gradgrad=False, 18181 supports_forward_ad=True, 18182 supports_fwgrad_bwgrad=True, 18183 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 18184 supports_out=False, 18185 sample_inputs_func=sample_inputs_linalg_invertible, 18186 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 18187 skips=( 18188 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', 18189 device_type='mps', dtypes=[torch.float32]), 18190 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', 18191 device_type='mps', dtypes=[torch.float32]), 18192 )), 18193 OpInfo('gather', 18194 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 18195 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 18196 sample_inputs_func=sample_inputs_gather, 18197 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 18198 supports_forward_ad=True, 18199 supports_fwgrad_bwgrad=True, 18200 error_inputs_func=error_inputs_gather, 18201 ), 18202 OpInfo('index_fill', 18203 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), 18204 supports_out=False, 18205 supports_forward_ad=True, 18206 supports_fwgrad_bwgrad=True, 18207 # https://github.com/pytorch/pytorch/issues/66357 18208 check_batched_forward_grad=False, 18209 skips=( 18210 # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal! 18211 DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'), 18212 # RuntimeError: Mismatch on aten._unique.default: Shapes torch.Size([2]) and torch.Size([1]) are not equal! 18213 DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_amp'), 18214 ), 18215 sample_inputs_func=sample_inputs_index, 18216 reference_inputs_func=partial(sample_inputs_index, reference=True)), 18217 OpInfo('index_copy', 18218 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), 18219 supports_out=True, 18220 supports_forward_ad=True, 18221 supports_fwgrad_bwgrad=True, 18222 # https://github.com/pytorch/pytorch/issues/66357 18223 check_batched_forward_grad=False, 18224 sample_inputs_func=sample_inputs_index, 18225 reference_inputs_func=partial(sample_inputs_index, reference=True), 18226 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), 18227 OpInfo('index_select', 18228 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 18229 backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), 18230 sample_inputs_func=sample_inputs_index, 18231 reference_inputs_func=partial(sample_inputs_index, reference=True), 18232 error_inputs_func=error_inputs_index_select, 18233 supports_forward_ad=True, 18234 supports_fwgrad_bwgrad=True, 18235 assert_jit_shape_analysis=True, 18236 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), 18237 OpInfo('index_add', 18238 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 18239 supports_out=True, 18240 supports_forward_ad=True, 18241 supports_fwgrad_bwgrad=True, 18242 # https://github.com/pytorch/pytorch/issues/66357 18243 check_batched_forward_grad=False, 18244 sample_inputs_func=sample_inputs_index, 18245 reference_inputs_func=partial(sample_inputs_index, reference=True), 18246 error_inputs_func=error_inputs_index_add, 18247 skips=( 18248 # boolean alpha not handled properly 18249 DecorateInfo(unittest.expectedFailure, 18250 'TestNNCOpInfo', 18251 'test_nnc_correctness', 18252 dtypes=(torch.bool,)), 18253 ), 18254 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), 18255 *(OpInfo('index_reduce', 18256 variant_test_name=reduction_type, 18257 dtypes=all_types_and(torch.float16, torch.bfloat16), 18258 skips=( 18259 DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-3)}), 18260 'TestInductorOpInfo', 'test_comprehensive'), 18261 ), 18262 supports_out=True, 18263 sample_inputs_func=sample_inputs_index_reduce, 18264 ) for reduction_type in ('mean', 'prod', 'amin', 'amax')), 18265 OpInfo('_unsafe_masked_index', 18266 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), 18267 supports_out=False, 18268 supports_inplace_autograd=False, 18269 supports_scripting=False, 18270 supports_forward_ad=True, 18271 supports_fwgrad_bwgrad=True, 18272 sample_inputs_func=sample_inputs__unsafe_masked_index, 18273 skips=( 18274 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), 18275 DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward', 18276 dtypes=(torch.float64,), active_if=IS_WINDOWS), 18277 ),), 18278 OpInfo('_unsafe_masked_index_put_accumulate', 18279 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), 18280 supports_out=False, 18281 supports_inplace_autograd=False, 18282 supports_scripting=False, 18283 supports_forward_ad=True, 18284 supports_fwgrad_bwgrad=True, 18285 decorators=( 18286 DecorateInfo( 18287 toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-2)}), 18288 'TestInductorOpInfo', 'test_comprehensive', device_type='cpu' 18289 ), 18290 ), 18291 sample_inputs_func=sample_inputs__unsafe_masked_index_put_accumulate, 18292 skips=( 18293 DecorateInfo(slowTest, 'TestDecomp', 'test_quick_core_backward', 18294 dtypes=(torch.float64,), active_if=IS_WINDOWS), 18295 ),), 18296 OpInfo('__getitem__', 18297 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 18298 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 18299 gradcheck_fast_mode=True, 18300 supports_out=False, 18301 supports_forward_ad=True, 18302 supports_fwgrad_bwgrad=True, 18303 supports_inplace_autograd=False, 18304 supports_scripting=False, 18305 op=torch.Tensor.__getitem__, 18306 skips=( 18307 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18308 # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 104448 18309 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', device_type='cuda'),), 18310 sample_inputs_func=sample_inputs_getitem), 18311 OpInfo('index_put', 18312 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 18313 supports_out=False, 18314 supports_inplace_autograd=True, 18315 supports_forward_ad=True, 18316 supports_fwgrad_bwgrad=True, 18317 # https://github.com/pytorch/pytorch/issues/66357 18318 check_batched_forward_grad=False, 18319 test_neg_view=False, 18320 sample_inputs_func=sample_inputs_index_put, 18321 skips=( 18322 DecorateInfo(unittest.skip("Skipped"), 'TestBwdGradients', 'test_fn_grad', dtypes=[torch.float64], 18323 device_type='cuda', active_if=(TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)), 18324 )), 18325 OpInfo('sort', 18326 dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), 18327 dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), 18328 sample_inputs_func=sample_inputs_sort, 18329 supports_forward_ad=True, 18330 supports_fwgrad_bwgrad=True, 18331 skips=( 18332 )), 18333 OpInfo('unique', 18334 dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64), 18335 dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.uint16, torch.uint32, torch.uint64), 18336 sample_inputs_func=sample_inputs_unique, 18337 supports_out=False, 18338 supports_autograd=False, 18339 skips=( 18340 # lambda impl 18341 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18342 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18343 DecorateInfo(unittest.skip('Output order is undefined when sorted=False'), 'TestCommon', 'test_compare_cpu'), 18344 )), 18345 OpInfo('unique_consecutive', 18346 dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), 18347 dtypesIfCUDA=all_types_and(torch.bool, torch.float16), 18348 sample_inputs_func=sample_inputs_unique_consecutive, 18349 supports_out=False, 18350 supports_autograd=False, 18351 skips=( 18352 # lambda impl 18353 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18354 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18355 )), 18356 OpInfo('put', 18357 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 18358 supports_out=False, 18359 supports_forward_ad=True, 18360 supports_fwgrad_bwgrad=True, 18361 check_batched_forward_grad=False, 18362 check_batched_gradgrad=False, # vmap complains of the sizes 18363 sample_inputs_func=sample_inputs_put), 18364 OpInfo('take', 18365 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 18366 check_batched_grad=False, # vmap complains of the sizes 18367 supports_forward_ad=True, 18368 supports_fwgrad_bwgrad=True, 18369 sample_inputs_func=sample_inputs_take, 18370 error_inputs_func=error_inputs_take), 18371 OpInfo('scatter', 18372 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 18373 supports_forward_ad=True, 18374 supports_fwgrad_bwgrad=True, 18375 sample_inputs_func=sample_inputs_scatter, 18376 error_inputs_func=error_inputs_scatter_and_scatter_add), 18377 UnaryUfuncInfo( 18378 'bfloat16', 18379 op=lambda x, *args, **kwargs: x.bfloat16(*args, **kwargs), 18380 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18381 supports_out=False, 18382 sample_inputs_func=sample_inputs_conversion, 18383 skips=( 18384 # autograd tests don't handle operators that change dtype 18385 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), 18386 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), 18387 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18388 # RuntimeError: attribute lookup is not defined on builtin 18389 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18390 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), 18391 )), 18392 UnaryUfuncInfo( 18393 'bool', 18394 op=lambda x, *args, **kwargs: x.bool(*args, **kwargs), 18395 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18396 supports_out=False, 18397 sample_inputs_func=sample_inputs_conversion, 18398 supports_autograd=False, 18399 skips=( 18400 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18401 # RuntimeError: attributis not defined on builtin 18402 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18403 )), 18404 UnaryUfuncInfo( 18405 'byte', 18406 op=lambda x, *args, **kwargs: x.byte(*args, **kwargs), 18407 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 18408 supports_out=False, 18409 sample_inputs_func=sample_inputs_byte, 18410 # The autograd test runner cannot handle functions that change dtype 18411 supports_autograd=False, 18412 skips=( 18413 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18414 # RuntimeError: attribute lookup is not defined on builtin 18415 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18416 DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), 18417 )), 18418 UnaryUfuncInfo( 18419 'char', 18420 op=lambda x, *args, **kwargs: x.char(*args, **kwargs), 18421 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18422 supports_out=False, 18423 sample_inputs_func=sample_inputs_conversion, 18424 # The autograd test runner cannot handle functions that change dtype 18425 supports_autograd=False, 18426 skips=( 18427 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18428 # RuntimeError: attribute lookup is not defined on builtin 18429 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18430 DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), 18431 )), 18432 UnaryUfuncInfo( 18433 'double', 18434 op=lambda x, *args, **kwargs: x.double(*args, **kwargs), 18435 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18436 supports_out=False, 18437 sample_inputs_func=sample_inputs_conversion, 18438 supports_forward_ad=True, 18439 supports_fwgrad_bwgrad=True, 18440 skips=( 18441 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18442 # RuntimeError: attribute lookup is not defined on builtin 18443 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18444 )), 18445 UnaryUfuncInfo( 18446 'float', 18447 op=lambda x, *args, **kwargs: x.float(*args, **kwargs), 18448 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18449 supports_out=False, 18450 sample_inputs_func=sample_inputs_conversion, 18451 skips=( 18452 # autograd tests don't handle operators that change dtype 18453 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), 18454 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), 18455 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18456 # RuntimeError: attribute lookup is not defined on builtin 18457 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18458 )), 18459 UnaryUfuncInfo( 18460 'half', 18461 op=lambda x, *args, **kwargs: x.half(*args, **kwargs), 18462 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 18463 supports_out=False, 18464 sample_inputs_func=sample_inputs_conversion, 18465 supports_autograd=True, 18466 skips=( 18467 # autograd tests don't handle operators that change dtype 18468 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), 18469 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), 18470 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18471 # RuntimeError: attribute lookup is not defined on builtin 18472 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18473 )), 18474 UnaryUfuncInfo( 18475 'int', 18476 op=lambda x, *args, **kwargs: x.int(*args, **kwargs), 18477 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 18478 supports_out=False, 18479 sample_inputs_func=sample_inputs_conversion, 18480 supports_autograd=False, 18481 skips=( 18482 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18483 # RuntimeError: attribute lookup is not defined on builtin 18484 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18485 DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), 18486 )), 18487 UnaryUfuncInfo( 18488 'long', 18489 op=lambda x, *args, **kwargs: x.long(*args, **kwargs), 18490 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18491 supports_out=False, 18492 sample_inputs_func=sample_inputs_conversion, 18493 supports_autograd=False, 18494 skips=( 18495 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18496 # RuntimeError: attribute lookup is not defined on builtin 18497 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18498 DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), 18499 )), 18500 UnaryUfuncInfo( 18501 'short', 18502 op=lambda x, *args, **kwargs: x.short(*args, **kwargs), 18503 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 18504 supports_out=False, 18505 sample_inputs_func=sample_inputs_conversion, 18506 supports_autograd=False, 18507 skips=( 18508 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18509 # RuntimeError: attribute lookup is not defined on builtin 18510 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18511 DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), 18512 )), 18513 UnaryUfuncInfo( 18514 'cdouble', 18515 op=torch.Tensor.cdouble, 18516 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18517 supports_out=False, 18518 sample_inputs_func=sample_inputs_conversion, 18519 supports_forward_ad=True, 18520 supports_fwgrad_bwgrad=True, 18521 skips=( 18522 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18523 # RuntimeError: attribute lookup is not defined on builtin 18524 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18525 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), 18526 )), 18527 UnaryUfuncInfo( 18528 'cfloat', 18529 op=torch.Tensor.cfloat, 18530 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18531 supports_out=False, 18532 sample_inputs_func=sample_inputs_conversion, 18533 skips=( 18534 # autograd tests don't handle operators that change dtype 18535 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), 18536 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), 18537 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18538 # RuntimeError: attribute lookup is not defined on builtin 18539 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18540 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), 18541 )), 18542 UnaryUfuncInfo( 18543 'chalf', 18544 op=lambda x, *args, **kwargs: x.chalf(*args, **kwargs), 18545 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18546 supports_out=False, 18547 sample_inputs_func=sample_inputs_conversion, 18548 skips=( 18549 # autograd tests don't handle operators that change dtype 18550 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), 18551 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), 18552 # use of lambda doesn't work with test_normalize_operator_exhaustive 18553 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 18554 # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' 18555 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', 18556 device_type='cpu'), 18557 # TypeError: 'int' object is not iterable 18558 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18559 # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' 18560 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view', 18561 device_type='cpu'), 18562 # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' 18563 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view', 18564 device_type='cpu'), 18565 # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' 18566 # RuntimeError: "neg_conj_cuda" not implemented for 'ComplexHalf' 18567 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 18568 ) 18569 ), 18570 OpInfo('empty_like', 18571 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18572 supports_out=False, 18573 sample_inputs_func=sample_inputs_like_fns, 18574 reference_inputs_func=reference_inputs_like_fns, 18575 supports_autograd=False, 18576 skips=( 18577 # Empty tensor data is garbage so it's hard to make comparisons with it. 18578 DecorateInfo(unittest.skip("Skipped!"), 18579 "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18580 # Empty tensor data is garbage so it's hard to make comparisons with it. 18581 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), 18582 # Empty tensor data is garbage so it's hard to make comparisons with it. 18583 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 18584 # Empty tensor data is garbage so it's hard to make comparisons with it. 18585 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), 18586 # Empty tensor data is garbage so it's hard to make comparisons with it. 18587 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), 18588 # Empty tensor data is garbage so it's hard to make comparisons with it. 18589 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), 18590 # Empty tensor data is garbage so it's hard to make comparisons with it. 18591 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), 18592 # Empty tensor data is garbage so it's hard to make comparisons with it. 18593 DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), 18594 # Empty tensor data is garbage so it's hard to make comparisons with it. 18595 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing'), 18596 # Empty tensor data is garbage so it's hard to make comparisons with it. 18597 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), 18598 DecorateInfo(unittest.skip("Expected: empty_like is not comparable"), 'TestCompositeCompliance', 18599 'test_operator'), 18600 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 18601 )), 18602 OpInfo('zeros_like', 18603 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18604 supports_out=False, 18605 sample_inputs_func=sample_inputs_like_fns, 18606 supports_autograd=False, 18607 error_inputs_sparse_func=error_inputs_sparse_like_fns, 18608 sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo), 18609 sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr), 18610 sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc), 18611 sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr), 18612 sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc), 18613 skips=( 18614 )), 18615 OpInfo('ones_like', 18616 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18617 supports_out=False, 18618 sample_inputs_func=sample_inputs_like_fns, 18619 supports_autograd=False, 18620 skips=( 18621 )), 18622 OpInfo('randn', 18623 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32), 18624 op=lambda *args, **kwargs: wrapper_set_seed(torch.randn, *args, **kwargs), 18625 supports_out=True, 18626 sample_inputs_func=sample_inputs_randn, 18627 supports_autograd=False, 18628 skips=( 18629 # Tests that assume input is a tensor or sequence of tensors 18630 DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), 18631 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), 18632 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), 18633 # CPU randn generates different values based on the strides of out tensor 18634 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'), 18635 # randn fails to warn when resizing its out tensor 18636 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 18637 # FX failed to normalize op - add the op to the op_skip list. 18638 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 18639 # Tests that assume input tensor has a meaningful effect on output tensor 18640 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 18641 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 18642 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 18643 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 18644 # AssertionError: JIT Test does not execute any logic 18645 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18646 DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick'), 18647 )), 18648 OpInfo('randn_like', 18649 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.complex32), 18650 op=lambda inp, *args, **kwargs: 18651 wrapper_set_seed(torch.randn_like, inp, *args, **kwargs), 18652 supports_out=False, 18653 sample_inputs_func=sample_inputs_like_fns, 18654 supports_autograd=False, 18655 error_inputs_sparse_func=error_inputs_sparse_like_fns, 18656 sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_coo), 18657 sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csr), 18658 sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_csc), 18659 sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsr), 18660 sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_like_fns, layout=torch.sparse_bsc), 18661 skips=( 18662 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18663 # AssertionError: JIT Test does not execute any logic 18664 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18665 DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"), 18666 'TestCommon', 'test_complex_half_reference_testing'), 18667 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 18668 )), 18669 OpInfo('rand_like', 18670 dtypes=floating_types_and(torch.half, torch.bfloat16, torch.complex32, torch.complex64, torch.complex128), 18671 op=lambda inp, *args, **kwargs: 18672 wrapper_set_seed(torch.randn_like, inp, *args, **kwargs), 18673 supports_out=False, 18674 sample_inputs_func=sample_inputs_like_fns, 18675 supports_autograd=False, 18676 skips=( 18677 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18678 # AssertionError: JIT Test does not execute any logic 18679 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18680 DecorateInfo(unittest.skip("Expected: randn_like is not comparable between dtypes"), 18681 'TestCommon', 'test_complex_half_reference_testing'), 18682 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 18683 )), 18684 OpInfo('randint', 18685 dtypes=all_types_and(torch.half, torch.bfloat16), 18686 op=lambda *args, **kwargs: 18687 wrapper_set_seed(torch.randint, *args, **kwargs), 18688 supports_out=False, 18689 sample_inputs_func=sample_inputs_randint, 18690 supports_autograd=False, 18691 skips=( 18692 # Tests that assume input is a tensor or sequence of tensors 18693 DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), 18694 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), 18695 DecorateInfo(unittest.skip("Test expects tensor input"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), 18696 # CPU randint generates different values based on the strides of out tensor 18697 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 18698 # randint fails to warn when resizing its out tensor 18699 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 18700 # FX failed to normalize op - add the op to the op_skip list. 18701 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 18702 # Tests that assume input tensor has a meaningful effect on output tensor 18703 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 18704 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 18705 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 18706 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 18707 # AssertionError: JIT Test does not execute any logic 18708 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18709 # Might need to skip until ROCm5.5 18710 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_multiple_devices', 18711 dtypes=[torch.float32, torch.int64], active_if=TEST_WITH_ROCM), 18712 )), 18713 OpInfo('randint_like', 18714 dtypes=all_types_and(torch.half, torch.bfloat16), 18715 op=lambda inp, *args, **kwargs: 18716 wrapper_set_seed(torch.randint_like, inp, *args, **kwargs), 18717 supports_out=False, 18718 sample_inputs_func=sample_inputs_randint_like, 18719 supports_autograd=False, 18720 skips=( 18721 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18722 # AssertionError: JIT Test does not execute any logic 18723 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18724 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 18725 )), 18726 OpInfo('full_like', 18727 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 18728 supports_out=False, 18729 sample_inputs_func=sample_inputs_full_like, 18730 supports_autograd=False, 18731 skips=( 18732 )), 18733 OpInfo('new_zeros', 18734 op=lambda x, *args, **kwargs: x.new_zeros(*args, **kwargs), 18735 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18736 supports_out=False, 18737 sample_inputs_func=sample_inputs_new_fns, 18738 skips=( 18739 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18740 ), 18741 supports_autograd=False), 18742 OpInfo('new_ones', 18743 op=lambda x, *args, **kwargs: x.new_ones(*args, **kwargs), 18744 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18745 supports_out=False, 18746 sample_inputs_func=sample_inputs_new_fns, 18747 skips=( 18748 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18749 ), 18750 supports_autograd=False), 18751 OpInfo('ones', 18752 op=torch.ones, 18753 supports_autograd=False, 18754 supports_varargs=True, 18755 is_factory_function=True, 18756 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18757 supports_out=True, 18758 sample_inputs_func=sample_inputs_ones_zeros, 18759 skips=( 18760 # Tests that assume input is a tensor or sequence of tensors 18761 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 18762 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 18763 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 18764 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 18765 18766 # Same failure as arange: cannot find linspace in captured graph 18767 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 18768 18769 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 18770 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 18771 )), 18772 OpInfo('zeros', 18773 op=torch.zeros, 18774 supports_autograd=False, 18775 is_factory_function=True, 18776 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18777 supports_out=True, 18778 sample_inputs_func=sample_inputs_ones_zeros, 18779 skips=( 18780 # Tests that assume input is a tensor or sequence of tensors 18781 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 18782 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 18783 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 18784 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 18785 18786 # Same failure as arange: cannot find linspace in captured graph 18787 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 18788 18789 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 18790 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 18791 )), 18792 OpInfo('full', 18793 op=torch.full, 18794 supports_autograd=False, 18795 is_factory_function=True, 18796 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18797 supports_out=True, 18798 sample_inputs_func=sample_inputs_full, 18799 skips=( 18800 # Tests that assume input is a tensor or sequence of tensors 18801 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), 18802 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 18803 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 18804 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 18805 # Same failure as arange: cannot find linspace in captured graph 18806 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 18807 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 18808 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 18809 # RuntimeError: UNSUPPORTED DTYPE: bool 18810 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bool,)), 18811 )), 18812 OpInfo('new_empty', 18813 op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs), 18814 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18815 supports_out=False, 18816 sample_inputs_func=sample_inputs_new_fns, 18817 skips=( 18818 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18819 # Empty tensor data is garbage so it's hard to make comparisons with it. 18820 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 18821 # Empty tensor data is garbage so it's hard to make comparisons with it. 18822 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), 18823 # Empty tensor data is garbage so it's hard to make comparisons with it. 18824 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), 18825 # Empty tensor data is garbage so it's hard to make comparisons with it. 18826 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), 18827 # Empty tensor data is garbage so it's hard to make comparisons with it. 18828 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), 18829 # Empty tensor data is garbage so it's hard to make comparisons with it. 18830 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), 18831 # Empty tensor data is garbage so it's hard to make comparisons with it. 18832 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), 18833 # Empty tensor data is garbage so it's hard to make comparisons with it. 18834 DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), 18835 # Empty tensor data is garbage so it's hard to make comparisons with it. 18836 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), 18837 DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), 'TestCompositeCompliance', 18838 'test_operator'), 18839 DecorateInfo(unittest.skip("Expected: new_empty is not comparable"), 18840 'TestCommon', 'test_complex_half_reference_testing'), 18841 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 18842 ), 18843 supports_autograd=False), 18844 OpInfo('new_empty_strided', 18845 op=lambda x, *args, **kwargs: x.new_empty_strided(*args, **kwargs), 18846 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18847 supports_out=False, 18848 sample_inputs_func=partial(sample_inputs_new_fns, is_strided=True), 18849 supports_autograd=False, 18850 skips=( 18851 # FX failed to normalize op 18852 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18853 # Lazy tensor failures 18854 DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness'), 18855 DecorateInfo(unittest.skip("Skipped!"), 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), 18856 # Empty tensor data is garbage so it's hard to make comparisons with it. 18857 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18858 'TestCommon', 'test_variant_consistency_eager'), 18859 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18860 'TestCommon', 'test_noncontiguous_samples'), 18861 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18862 'TestMathBits', 'test_conj_view'), 18863 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18864 'TestMathBits', 'test_neg_view'), 18865 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18866 'TestMathBits', 'test_neg_conj_view'), 18867 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18868 'TestCommon', 'test_non_standard_bool_values'), 18869 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18870 'TestCommon', 'test_complex_half_reference_testing'), 18871 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18872 'TestCompositeCompliance', 'test_operator'), 18873 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18874 'TestDecomp', 'test_comprehensive'), 18875 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18876 'TestDecomp', 'test_quick'), 18877 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18878 'TestJit', 'test_variant_consistency_jit'), 18879 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18880 'TestProxyTensorOpInfo', 'test_make_fx_exhaustive'), 18881 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18882 'TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive'), 18883 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18884 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'), 18885 DecorateInfo(unittest.skip("Expected: new_empty_strided is not comparable"), 18886 'TestNNCOpInfo', 'test_nnc_correctness'), 18887 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 18888 )), 18889 OpInfo('empty_strided', 18890 op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.empty_strided, inp, *args, **kwargs), 18891 dtypes=all_types_and_complex_and(torch.bfloat16, torch.bool, torch.half), 18892 supports_out=False, 18893 supports_autograd=False, 18894 sample_inputs_func=sample_inputs_empty_strided, 18895 skips=( 18896 # FX failed to normalize op - add the op to the op_skip list. 18897 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 18898 # AssertionError: JIT Test does not execute any logic 18899 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18900 # Empty tensor data is garbage so it's hard to make comparisons with it. 18901 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), 18902 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), 18903 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), 18904 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), 18905 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), 18906 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), 18907 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), 18908 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', 'test_operator'), 18909 # Lazy tensor failures 18910 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestLazyOpInfo'), 18911 # RuntimeError: unsupported operation: more than one element of the written-to tensor refers to a single 18912 # memory location. Please clone() the tensor before performing the operation. 18913 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace'), 18914 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace'), 18915 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'), 18916 )), 18917 OpInfo('empty', 18918 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18919 sample_inputs_func=sample_inputs_empty, 18920 supports_autograd=False, 18921 skips=( 18922 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18923 # Empty tensor data is garbage so it's hard to make comparisons with it. 18924 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 18925 # Empty tensor data is garbage so it's hard to make comparisons with it. 18926 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), 18927 # Empty tensor data is garbage so it's hard to make comparisons with it. 18928 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), 18929 # Empty tensor data is garbage so it's hard to make comparisons with it. 18930 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), 18931 # Empty tensor data is garbage so it's hard to make comparisons with it. 18932 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), 18933 # Empty tensor data is garbage so it's hard to make comparisons with it. 18934 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), 18935 # Empty tensor data is garbage so it's hard to make comparisons with it. 18936 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), 18937 # Empty tensor data is garbage so it's hard to make comparisons with it. 18938 DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), 18939 # Empty tensor data is garbage so it's hard to make comparisons with it. 18940 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), 18941 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 'TestCompositeCompliance', 18942 'test_operator'), 18943 # requires_grad doesn't exist in the jit schema 18944 DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), 18945 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 18946 'TestCommon', 18947 'test_out'), 18948 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 18949 'TestCommon', 18950 'test_out_warning'), 18951 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 18952 'TestLazyOpInfo'), 18953 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 18954 'TestCommon', 'test_complex_half_reference_testing'), 18955 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 18956 )), 18957 OpInfo('eye', 18958 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 18959 sample_inputs_func=sample_inputs_eye, 18960 error_inputs_func=error_inputs_eye, 18961 supports_out=True, 18962 supports_autograd=False, 18963 skips=( 18964 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 18965 # TODO: same as this? 18966 # https://github.com/pytorch/pytorch/issues/81774 18967 # also see: arange, new_full 18968 # fails to match any schemas despite working in the interpreter 18969 DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), 18970 # fails to match any schemas despite working in the interpreter 18971 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 18972 # skip these tests since we have non tensor input 18973 DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), 18974 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), 18975 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), 18976 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), 18977 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), 18978 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 18979 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 18980 )), 18981 OpInfo('empty_permuted', 18982 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 18983 sample_inputs_func=sample_inputs_empty_permuted, 18984 error_inputs_func=error_inputs_empty_permuted, 18985 supports_out=False, 18986 supports_autograd=False, 18987 skips=( 18988 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 18989 # Empty tensor data is garbage so it's hard to make comparisons with it. 18990 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 18991 # Empty tensor data is garbage so it's hard to make comparisons with it. 18992 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), 18993 # Empty tensor data is garbage so it's hard to make comparisons with it. 18994 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), 18995 # Empty tensor data is garbage so it's hard to make comparisons with it. 18996 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), 18997 # Empty tensor data is garbage so it's hard to make comparisons with it. 18998 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), 18999 # Empty tensor data is garbage so it's hard to make comparisons with it. 19000 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), 19001 # Empty tensor data is garbage so it's hard to make comparisons with it. 19002 DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), 19003 # Empty tensor data is garbage so it's hard to make comparisons with it. 19004 DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), 19005 # Empty tensor data is garbage so it's hard to make comparisons with it. 19006 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'), 19007 DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 'TestCompositeCompliance', 19008 'test_operator'), 19009 # requires_grad doesn't exist in the jit schema 19010 DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), 19011 DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 19012 'TestCommon', 19013 'test_out'), 19014 DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 19015 'TestCommon', 19016 'test_out_warning'), 19017 DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 19018 'TestLazyOpInfo'), 19019 DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 19020 'TestCommon', 'test_complex_half_reference_testing'), 19021 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 19022 )), 19023 OpInfo('scalar_tensor', 19024 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 19025 sample_inputs_func=sample_inputs_scalar_tensor, 19026 supports_autograd=False, 19027 supports_out=False, 19028 skips=( 19029 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 19030 # fails to match any schemas despite working in the interpreter 19031 DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), 19032 # fails to match any schemas despite working in the interpreter 19033 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 19034 # skip these tests since we have non tensor input 19035 DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), 19036 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), 19037 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), 19038 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), 19039 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), 19040 )), 19041 OpInfo('new_full', 19042 op=lambda x, *args, **kwargs: x.new_full(*args, **kwargs), 19043 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 19044 supports_out=False, 19045 sample_inputs_func=sample_inputs_new_full, 19046 skips=( 19047 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 19048 ), 19049 supports_autograd=False), 19050 OpInfo('multinomial', 19051 op=lambda inp, *args, **kwargs: 19052 wrapper_set_seed(torch.multinomial, inp, *args, **kwargs), 19053 method_variant=lambda inp, *args, **kwargs: 19054 wrapper_set_seed(torch.Tensor.multinomial, inp, *args, **kwargs), 19055 dtypes=floating_types_and(torch.bfloat16, torch.half), 19056 supports_out=True, 19057 sample_inputs_func=sample_inputs_multinomial, 19058 error_inputs_func=error_inputs_multinomial, 19059 skips=( 19060 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 19061 # Strides are not the same! 19062 # This may not be reproducible in CI 19063 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), 19064 # AssertionError: JIT Test does not execute any logic 19065 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 19066 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 19067 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 19068 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), 19069 supports_autograd=False), 19070 OpInfo('normal', 19071 op=lambda inp, *args, **kwargs: 19072 wrapper_set_seed(torch.normal, inp, *args, **kwargs), 19073 # The inplace variant (Tensor.normal_) is different from torch.normal 19074 inplace_variant=None, 19075 dtypes=floating_types_and(torch.bfloat16, torch.half), 19076 dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half), 19077 supports_out=True, 19078 sample_inputs_func=sample_inputs_normal_tensor_first, 19079 skips=( 19080 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 19081 # Tensor-likes are not close! 19082 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 19083 # AssertionError: JIT Test does not execute any logic 19084 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 19085 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 19086 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 19087 # Computed gradient is incorrect -- would be an exfail but gradgrad somehow passes 19088 DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestFwdGradients'), 19089 DecorateInfo(unittest.skip("Gradients are incorrect!"), 'TestBwdGradients'), 19090 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 19091 # RuntimeError: Difference from {dtype} is larger with decomposition 19092 DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'), 19093 DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'), 19094 # The inplace variant (Tensor.normal_) is different from torch.normal 19095 # inplace varaint Tensor.normal_ is decomposed using randn_like() 19096 DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'))), 19097 OpInfo('normal', 19098 # This has its own variant b/c OpInfos assume the first arg is a Tensor but it is not here 19099 variant_test_name='number_mean', 19100 op=lambda std, mean, *args, **kwargs: 19101 wrapper_set_seed(torch.normal, mean, std, *args, **kwargs), 19102 # The inplace variant (Tensor.normal_) is different from torch.normal 19103 inplace_variant=None, 19104 dtypes=floating_types_and(torch.bfloat16, torch.half), 19105 dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half), 19106 supports_out=True, 19107 sample_inputs_func=sample_inputs_normal_tensor_second, 19108 skips=( 19109 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 19110 # AssertionError: JIT Test does not execute any logic 19111 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 19112 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'), 19113 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'), 19114 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), 19115 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'), 19116 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'), 19117 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), 19118 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'), 19119 DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'), 19120 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_compare_cpu'), 19121 DecorateInfo(unittest.skip("Skipped!"), 'TestEagerFusionOpInfo'), 19122 DecorateInfo(unittest.skip("Skipped!"), 'TestOperators'), 19123 # AssertionError 19124 DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive'), 19125 # AssertionError 19126 DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick'), 19127 # AssertionError in CUDA variant 19128 DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', device_type='cuda'), 19129 DecorateInfo(unittest.skip("Skipped!"), 'TestDeviceUtils', 'test_device_mode_ops'))), 19130 OpInfo('bernoulli', 19131 op=lambda inp, *args, **kwargs: 19132 wrapper_set_seed(torch.bernoulli, inp, *args, **kwargs), 19133 # The inplace variant (Tensor.bernoulli_) is different from torch.bernoulli 19134 inplace_variant=None, 19135 method_variant=lambda inp, *args, **kwargs: 19136 wrapper_set_seed(torch.Tensor.bernoulli, inp, *args, **kwargs), 19137 dtypes=floating_types_and(torch.bfloat16, torch.half), 19138 supports_out=True, 19139 supports_forward_ad=True, 19140 supports_fwgrad_bwgrad=True, 19141 sample_inputs_func=sample_inputs_bernoulli, 19142 error_inputs_func=error_inputs_bernoulli, 19143 skips=( 19144 # vmap: We do not yet support calling random operations inside of vmap 19145 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), 19146 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 19147 # AssertionError: JIT Test does not execute any logic 19148 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 19149 # Expected RuntimeError when doing an unsafe cast from a result of 19150 # dtype torch.float32 into an out= with dtype torch.lon 19151 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 19152 # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 19153 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 19154 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))), 19155 OpInfo('scatter_add', 19156 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 19157 sample_inputs_func=sample_inputs_scatter_add, 19158 error_inputs_func=error_inputs_scatter_and_scatter_add, 19159 supports_forward_ad=True, 19160 supports_fwgrad_bwgrad=True, 19161 ), 19162 OpInfo('stack', 19163 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 19164 sample_inputs_func=sample_inputs_stack, 19165 assert_autodiffed=True, 19166 supports_forward_ad=True, 19167 supports_fwgrad_bwgrad=True, 19168 skips=( 19169 # https://github.com/pytorch/pytorch/issues/77046 19170 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 19171 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 19172 ), 19173 ), 19174 OpInfo('_chunk_cat', 19175 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 19176 sample_inputs_func=sample_inputs_chunk_cat, 19177 error_inputs_func=error_inputs_chunk_cat, 19178 supports_autograd=False, 19179 supports_out=True, 19180 ), 19181 OpInfo('hstack', 19182 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 19183 sample_inputs_func=sample_inputs_hstack_dstack_vstack, 19184 error_inputs_func=error_inputs_hstack_dstack_vstack, 19185 supports_forward_ad=True, 19186 supports_fwgrad_bwgrad=True, 19187 ), 19188 BinaryUfuncInfo('hypot', 19189 dtypes=floating_types_and(torch.bfloat16, torch.half), 19190 dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), 19191 supports_forward_ad=True, 19192 supports_fwgrad_bwgrad=True, 19193 supports_rhs_python_scalar=False), 19194 OpInfo('histogram', 19195 dtypes=floating_types(), 19196 dtypesIfCUDA=_dispatch_dtypes(), # histogram is only implemented on CPU 19197 sample_inputs_func=sample_inputs_histogram, 19198 supports_autograd=False, 19199 skips=( 19200 # JIT tests don't work with Tensor keyword arguments 19201 # https://github.com/pytorch/pytorch/issues/58507 19202 # RuntimeError: 19203 # undefined value tensor: 19204 # File "<string>", line 3 19205 # def the_method(i0): 19206 # return torch.histogram(i0, 1, weight=tensor(-0.5735, dtype=torch.float32), density=False) 19207 # ~~~~~~ <--- HERE 19208 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 19209 # Not Implemented on XLA. 19210 DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla'), 19211 )), 19212 OpInfo('histogramdd', 19213 dtypes=floating_types(), 19214 dtypesIfCUDA=_dispatch_dtypes(), # histogramdd is only implemented on CPU 19215 sample_inputs_func=sample_inputs_histogramdd, 19216 error_inputs_func=error_inputs_histogramdd, 19217 supports_autograd=False, 19218 skips=( 19219 # Not implemented on CUDA 19220 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='cuda'), 19221 # JIT tests don't work with Tensor keyword arguments 19222 # https://github.com/pytorch/pytorch/issues/58507 19223 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 19224 )), 19225 OpInfo('histc', 19226 dtypes=floating_types_and(torch.bfloat16, torch.float16), 19227 dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64), 19228 sample_inputs_func=sample_inputs_histc, 19229 supports_out=True, 19230 supports_autograd=False, 19231 skips=( 19232 # CUDA histc returns a float tensor but does not correctly warn when passed an integral out tensor 19233 # "AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast 19234 # from a result of dtype torch.float32 into an out= with dtype torch.long" 19235 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cuda'), 19236 )), 19237 OpInfo('bincount', 19238 dtypes=integral_types_and(), 19239 sample_inputs_func=sample_inputs_bincount, 19240 supports_out=False, 19241 supports_autograd=False, 19242 skips=( 19243 # JIT tests don't work with Tensor keyword arguments 19244 # https://github.com/pytorch/pytorch/issues/58507 19245 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 19246 )), 19247 OpInfo('bucketize', 19248 dtypes=all_types_and(torch.float16, torch.bfloat16), 19249 dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16), 19250 sample_inputs_func=sample_inputs_bucketize, 19251 reference_inputs_func=reference_inputs_bucketize, 19252 error_inputs_func=error_inputs_bucketize, 19253 supports_autograd=False, 19254 skips=( 19255 # JIT tests don't work with Tensor keyword arguments 19256 DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'), 19257 )), 19258 OpInfo('searchsorted', 19259 dtypes=all_types_and(torch.bfloat16, torch.float16), 19260 dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16), 19261 sample_inputs_func=sample_inputs_searchsorted, 19262 supports_autograd=False, 19263 ref=reference_searchsorted, 19264 skips=( 19265 # JIT tests don't work with Tensor keyword arguments 19266 # https://github.com/pytorch/pytorch/issues/58507 19267 DecorateInfo(unittest.skip("Expected failure!"), 'TestJit', 'test_variant_consistency_jit'), 19268 )), 19269 OpInfo('cat', 19270 ref=_cat_np, 19271 aliases=('concat', 'concatenate'), 19272 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), 19273 sample_inputs_func=sample_inputs_cat_concat, 19274 reference_inputs_func=reference_inputs_cat, 19275 error_inputs_func=error_inputs_cat, 19276 # https://github.com/pytorch/pytorch/issues/80411 19277 gradcheck_fast_mode=True, 19278 supports_forward_ad=True, 19279 supports_fwgrad_bwgrad=True, 19280 # See https://github.com/pytorch/pytorch/issues/66357 19281 check_batched_forward_grad=False, 19282 assert_autodiffed=True, 19283 skips=( 19284 # https://github.com/pytorch/pytorch/issues/89353 19285 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), 19286 # RuntimeError: Arguments for call not valid. 19287 # Expected a value of type 'List[Tensor]' for argument 19288 # 'tensors' but instead found type 'Tensor (inferred)'. 19289 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'), 19290 # see https://github.com/pytorch/pytorch/issues/71286 19291 DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'), 19292 # see https://github.com/pytorch/pytorch/issues/99806 19293 # RuntimeError: The size of tensor a (25) must match the size of tensor b (0) at non-singleton dimension 0. 19294 DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), 19295 )), 19296 OpInfo('unbind', 19297 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 19298 ref=reference_unbind, 19299 sample_inputs_func=sample_inputs_unbind, 19300 error_inputs_func=error_inputs_unbind, 19301 supports_forward_ad=True, 19302 supports_fwgrad_bwgrad=True, 19303 supports_gradgrad=True, 19304 supports_out=False, 19305 ), 19306 OpInfo('vstack', 19307 aliases=('row_stack',), 19308 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 19309 sample_inputs_func=sample_inputs_hstack_dstack_vstack, 19310 error_inputs_func=error_inputs_hstack_dstack_vstack, 19311 supports_forward_ad=True, 19312 supports_fwgrad_bwgrad=True, 19313 skips=( 19314 # RuntimeError: _fn() Expected a value of type 19315 # 'Tensor (inferred)' for argument 't0' but instead found type 'tuple'. 19316 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_jit_alias_remapping'),)), 19317 OpInfo('dstack', 19318 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 19319 sample_inputs_func=sample_inputs_hstack_dstack_vstack, 19320 error_inputs_func=error_inputs_hstack_dstack_vstack, 19321 supports_forward_ad=True, 19322 supports_fwgrad_bwgrad=True, 19323 # See https://github.com/pytorch/pytorch/pull/78358 19324 check_batched_forward_grad=False, 19325 ), 19326 OpInfo('unfold', 19327 op=lambda x, *args: x.unfold(*args), 19328 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 19329 backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 19330 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 19331 gradcheck_fast_mode=True, 19332 supports_out=False, 19333 supports_forward_ad=True, 19334 supports_fwgrad_bwgrad=True, 19335 check_batched_gradgrad=False, 19336 # See https://github.com/pytorch/pytorch/issues/66357 19337 check_batched_forward_grad=False, 19338 skips=( 19339 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 19340 # Skip operator schema test because this is a functional and not an operator 19341 DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), 19342 ), 19343 sample_inputs_func=sample_inputs_unfold), 19344 OpInfo('unfold_copy', 19345 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 19346 backward_dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 19347 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 19348 gradcheck_fast_mode=True, 19349 supports_out=True, 19350 supports_forward_ad=True, 19351 supports_fwgrad_bwgrad=True, 19352 check_batched_gradgrad=False, 19353 # See https://github.com/pytorch/pytorch/issues/66357 19354 check_batched_forward_grad=False, 19355 sample_inputs_func=sample_inputs_unfold), 19356 OpInfo('msort', 19357 dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), 19358 dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), 19359 check_batched_gradgrad=False, 19360 supports_forward_ad=True, 19361 supports_fwgrad_bwgrad=True, 19362 sample_inputs_func=sample_inputs_msort, 19363 skips=( 19364 )), 19365 OpInfo('movedim', 19366 aliases=('moveaxis',), 19367 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 19368 supports_out=False, 19369 supports_forward_ad=True, 19370 supports_fwgrad_bwgrad=True, 19371 # See https://github.com/pytorch/pytorch/pull/78358 19372 check_batched_forward_grad=False, 19373 sample_inputs_func=sample_movedim_moveaxis, 19374 reference_inputs_func=reference_movedim_moveaxis, 19375 error_inputs_func=error_movedim_moveaxis), 19376 OpInfo('renorm', 19377 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 19378 sample_inputs_func=sample_inputs_renorm, 19379 error_inputs_func=error_inputs_renorm, 19380 supports_forward_ad=True, 19381 supports_fwgrad_bwgrad=True, 19382 skips=( 19383 # RuntimeError: Difference from float64 is larger with decomposition 19384 # linalg_vector_norm.default than original on output 0. 19385 # Original max diff: 2.560596747969157e-07, 19386 # Decomp max diff: 1.8187482915266173e-06 19387 DecorateInfo(unittest.skip("Inconsistent accuracy"), 'TestDecomp', 'test_comprehensive', 19388 device_type='cpu', dtypes=(torch.float16,)), 19389 )), 19390 ShapeFuncInfo('repeat', 19391 op=lambda x, dims: x.repeat(dims), 19392 ref=np.tile, 19393 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 19394 # https://github.com/pytorch/pytorch/issues/80411 19395 gradcheck_fast_mode=True, 19396 supports_out=False, 19397 supports_forward_ad=True, 19398 supports_fwgrad_bwgrad=True, 19399 sample_inputs_func=sample_repeat_tile, 19400 skips=( 19401 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 19402 )), 19403 OpInfo('squeeze', 19404 ref=_squeeze_ref, 19405 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 19406 supports_out=False, 19407 assert_autodiffed=True, 19408 autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused 19409 autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused 19410 assert_jit_shape_analysis=True, 19411 supports_forward_ad=True, 19412 supports_fwgrad_bwgrad=True, 19413 # vmap does not support inplace views 19414 check_inplace_batched_forward_grad=False, 19415 # https://github.com/pytorch/pytorch/issues/66357 19416 check_batched_forward_grad=False, 19417 sample_inputs_func=sample_inputs_squeeze), 19418 OpInfo('squeeze', 19419 ref=_squeeze_ref, 19420 variant_test_name="multiple", 19421 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 19422 supports_out=False, 19423 assert_autodiffed=True, 19424 autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused 19425 autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused 19426 supports_forward_ad=True, 19427 supports_fwgrad_bwgrad=True, 19428 # vmap does not support inplace views 19429 check_inplace_batched_forward_grad=False, 19430 # https://github.com/pytorch/pytorch/issues/66357 19431 check_batched_forward_grad=False, 19432 sample_inputs_func=sample_inputs_squeeze_multiple), 19433 UnaryUfuncInfo( 19434 'fill', 19435 ref=_fill_np, 19436 method_variant=None, 19437 sample_kwargs=_fill_sample_kwargs, 19438 sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'value': True}), 19439 supports_forward_ad=True, 19440 supports_fwgrad_bwgrad=True, 19441 # https://github.com/pytorch/pytorch/issues/66357 19442 check_batched_forward_grad=False, 19443 dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), 19444 supports_out=False, 19445 skips=( 19446 # JIT has issue when op is passed as lambda 19447 # AssertionError: JIT Test does not execute any logic 19448 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 19449 DecorateInfo(unittest.skip("No fill_ op"), 'TestCudaFuserOpInfo'), 19450 DecorateInfo(unittest.skip("No fill_ op"), 'TestNNCOpInfo'), 19451 )), 19452 OpInfo('resize_', 19453 op=lambda x, shape: x.clone().resize_(shape), 19454 method_variant=None, 19455 inplace_variant=torch.Tensor.resize_, 19456 # the test fails because resize_ doesn't work with imag views as expected by the test 19457 # https://github.com/pytorch/pytorch/issues/65945 19458 test_neg_view=False, 19459 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 19460 supports_out=False, 19461 supports_autograd=False, 19462 skips=( 19463 # Cannot resize variables that require grad 19464 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), 19465 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 19466 DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'), 19467 ), 19468 sample_inputs_func=sample_inputs_resize_ops), 19469 OpInfo('resize_as_', 19470 op=lambda x, other: torch.resize_as_(x.clone(), other), 19471 method_variant=None, 19472 inplace_variant=torch.Tensor.resize_as_, 19473 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 19474 supports_out=False, 19475 supports_autograd=False, 19476 skips=( 19477 # Cannot resize variables that require grad 19478 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes'), 19479 DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), 19480 ), 19481 sample_inputs_func=sample_inputs_resize_ops), 19482 OpInfo('take_along_dim', 19483 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 19484 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 19485 supports_inplace_autograd=False, 19486 supports_forward_ad=True, 19487 supports_fwgrad_bwgrad=True, 19488 # See https://github.com/pytorch/pytorch/pull/78358 19489 check_batched_forward_grad=False, 19490 sample_inputs_func=sample_inputs_take_along_dim, 19491 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 19492 decorators=( 19493 # RuntimeError: view size is not compatible with input tensor's size and stride 19494 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), 19495 )), 19496 ShapeFuncInfo('tile', 19497 ref=np.tile, 19498 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 19499 # https://github.com/pytorch/pytorch/issues/80411 19500 gradcheck_fast_mode=True, 19501 supports_out=False, 19502 supports_forward_ad=True, 19503 supports_fwgrad_bwgrad=True, 19504 sample_inputs_func=sample_repeat_tile), 19505 OpInfo('trapz', # TODO: in the future, 'trapz' should be made a proper alias of 'trapezoid' 19506 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 19507 supports_out=False, 19508 supports_forward_ad=True, 19509 supports_fwgrad_bwgrad=True, 19510 # See https://github.com/pytorch/pytorch/pull/78358 19511 check_batched_forward_grad=False, 19512 decorators=[ 19513 DecorateInfo( 19514 toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}), 19515 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' 19516 ), 19517 ], 19518 sample_inputs_func=sample_trapezoid), 19519 OpInfo('trapezoid', 19520 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 19521 supports_out=False, 19522 supports_forward_ad=True, 19523 supports_fwgrad_bwgrad=True, 19524 # See https://github.com/pytorch/pytorch/pull/78358 19525 check_batched_forward_grad=False, 19526 decorators=[ 19527 DecorateInfo( 19528 toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}), 19529 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda' 19530 ), 19531 ], 19532 sample_inputs_func=sample_trapezoid), 19533 OpInfo('cumulative_trapezoid', 19534 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), 19535 supports_forward_ad=True, 19536 supports_fwgrad_bwgrad=True, 19537 # See https://github.com/pytorch/pytorch/pull/78358 19538 check_batched_forward_grad=False, 19539 supports_out=False, 19540 decorators=( 19541 DecorateInfo( 19542 toleranceOverride({torch.float16: tol(atol=4e-3, rtol=4e-3)}), 19543 'TestInductorOpInfo', 'test_comprehensive', 19544 ), 19545 ), 19546 sample_inputs_func=sample_cumulative_trapezoid,), 19547 OpInfo('unsqueeze', 19548 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 19549 supports_out=False, 19550 supports_forward_ad=True, 19551 supports_fwgrad_bwgrad=True, 19552 # See https://github.com/pytorch/pytorch/pull/78358 19553 check_batched_forward_grad=False, 19554 # vmap does not support inplace views 19555 check_inplace_batched_forward_grad=False, 19556 assert_jit_shape_analysis=True, 19557 assert_autodiffed=True, 19558 autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused 19559 autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused 19560 sample_inputs_func=sample_unsqueeze), 19561 OpInfo('unsqueeze_copy', 19562 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 19563 supports_out=True, 19564 supports_forward_ad=True, 19565 supports_fwgrad_bwgrad=True, 19566 # See https://github.com/pytorch/pytorch/pull/78358 19567 check_batched_forward_grad=False, 19568 # vmap does not support inplace views 19569 check_inplace_batched_forward_grad=False, 19570 assert_jit_shape_analysis=True, 19571 assert_autodiffed=True, 19572 autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused 19573 autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused 19574 sample_inputs_func=sample_unsqueeze, 19575 skips=( 19576 DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), 19577 DecorateInfo( 19578 unittest.expectedFailure, 19579 'TestJit', 19580 'test_variant_consistency_jit', 19581 dtypes=(torch.float32,), 19582 ), 19583 )), 19584 BinaryUfuncInfo('xlogy', 19585 aliases=('special.xlogy',), 19586 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 19587 promotes_int_to_float=True, 19588 supports_forward_ad=True, 19589 supports_fwgrad_bwgrad=True, 19590 supports_one_python_scalar=True, 19591 # We don't test 0 as the gradient will be NaN and it'll break 19592 rhs_make_tensor_kwargs=dict(low=0.01)), 19593 OpInfo('zero_', 19594 op=lambda x: torch.zero_(x.clone()), 19595 method_variant=None, 19596 inplace_variant=torch.Tensor.zero_, 19597 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 19598 # https://github.com/pytorch/pytorch/issues/80411 19599 gradcheck_fast_mode=True, 19600 supports_out=False, 19601 supports_forward_ad=True, 19602 supports_fwgrad_bwgrad=True, 19603 supports_gradgrad=True, 19604 skips=( 19605 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 19606 ), 19607 sample_inputs_func=sample_inputs_zero_), 19608 OpInfo('logsumexp', 19609 aliases=('special.logsumexp',), 19610 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 19611 assert_autodiffed=True, 19612 supports_forward_ad=True, 19613 supports_fwgrad_bwgrad=True, 19614 gradcheck_fast_mode=False, 19615 sample_inputs_func=sample_inputs_logsumexp, 19616 reference_inputs_func=reference_inputs_logsumexp), 19617 OpInfo('trace', 19618 dtypes=all_types_and_complex(), 19619 dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16), 19620 error_inputs_func=error_inputs_trace, 19621 supports_inplace_autograd=False, 19622 supports_out=False, 19623 supports_forward_ad=True, 19624 supports_fwgrad_bwgrad=True, 19625 sample_inputs_func=sample_inputs_trace), 19626 OpInfo('transpose', 19627 ref=_numpy_ref_transpose, 19628 aliases=('swapdims', 'swapaxes'), 19629 assert_jit_shape_analysis=True, 19630 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), 19631 supports_out=False, 19632 supports_forward_ad=True, 19633 supports_fwgrad_bwgrad=True, 19634 # vmap does not support inplace views 19635 check_inplace_batched_forward_grad=False, 19636 sample_inputs_func=sample_inputs_transpose_swapdims), 19637 OpInfo('T', 19638 op=lambda x: x.T, 19639 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), 19640 supports_out=False, 19641 supports_forward_ad=True, 19642 supports_fwgrad_bwgrad=True, 19643 skips=( 19644 # lambda impl 19645 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 19646 DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), 19647 sample_inputs_func=sample_inputs_T, 19648 error_inputs_func=error_inputs_T), 19649 OpInfo('H', 19650 op=lambda x: x.H, 19651 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), 19652 supports_out=False, 19653 supports_forward_ad=True, 19654 supports_fwgrad_bwgrad=True, 19655 # See https://github.com/pytorch/pytorch/pull/78358 19656 check_batched_forward_grad=False, 19657 skips=( 19658 # lambda impl 19659 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 19660 DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), 19661 sample_inputs_func=sample_inputs_T), 19662 OpInfo('mT', 19663 op=lambda x: x.mT, 19664 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), 19665 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 19666 gradcheck_fast_mode=True, 19667 supports_out=False, 19668 supports_forward_ad=True, 19669 supports_fwgrad_bwgrad=True, 19670 skips=( 19671 # lambda impl 19672 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 19673 DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), 19674 sample_inputs_func=sample_inputs_adjoint), 19675 OpInfo('mH', 19676 op=lambda x: x.mH, 19677 aliases=('adjoint',), 19678 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), 19679 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 19680 gradcheck_fast_mode=True, 19681 supports_out=False, 19682 supports_forward_ad=True, 19683 supports_fwgrad_bwgrad=True, 19684 # See https://github.com/pytorch/pytorch/pull/78358 19685 check_batched_forward_grad=False, 19686 skips=( 19687 # lambda impl 19688 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 19689 DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),), 19690 sample_inputs_func=sample_inputs_adjoint), 19691 OpInfo('tril', 19692 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 19693 supports_forward_ad=True, 19694 supports_fwgrad_bwgrad=True, 19695 error_inputs_func=error_inputs_tril_triu, 19696 sample_inputs_func=sample_inputs_tril_triu), 19697 OpInfo('triu', 19698 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), 19699 supports_forward_ad=True, 19700 supports_fwgrad_bwgrad=True, 19701 error_inputs_func=error_inputs_tril_triu, 19702 sample_inputs_func=sample_inputs_tril_triu), 19703 OpInfo('triu_indices', 19704 dtypes=_dispatch_dtypes((torch.int32, torch.int64)), 19705 sample_inputs_func=sample_inputs_trilu_indices, 19706 ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.triu_indices(h, ofs, w), dtype=dtype), 19707 supports_out=False, 19708 supports_autograd=False, 19709 skips=( 19710 # skip these tests since we have non tensor input 19711 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), 19712 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), 19713 DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), 19714 DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), 19715 )), 19716 OpInfo('tril_indices', 19717 dtypes=_dispatch_dtypes((torch.int32, torch.int64)), 19718 sample_inputs_func=sample_inputs_trilu_indices, 19719 ref=lambda h, w, ofs=0, dtype=torch.long, device='cpu' : np.array(np.tril_indices(h, ofs, w), dtype=dtype), 19720 supports_out=False, 19721 supports_autograd=False, 19722 skips=( 19723 # skip these tests since we have non tensor input 19724 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), 19725 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), 19726 DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), 19727 DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), 19728 )), 19729 OpInfo('kron', 19730 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 19731 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), 19732 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 19733 gradcheck_fast_mode=True, 19734 supports_inplace_autograd=False, 19735 supports_forward_ad=True, 19736 supports_fwgrad_bwgrad=True, 19737 sample_inputs_func=sample_inputs_kron, 19738 decorators=( 19739 # RuntimeError: view size is not compatible with input tensor's size and stride 19740 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), 19741 )), 19742 OpInfo('inner', 19743 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 19744 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), 19745 dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16), 19746 supports_forward_ad=True, 19747 supports_fwgrad_bwgrad=True, 19748 # See https://github.com/pytorch/pytorch/pull/78358 19749 check_batched_forward_grad=False, 19750 sample_inputs_func=sample_inputs_inner, 19751 ), 19752 OpInfo('tensordot', 19753 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 19754 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), 19755 dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16), 19756 supports_forward_ad=True, 19757 supports_fwgrad_bwgrad=True, 19758 # See https://github.com/pytorch/pytorch/pull/78358 19759 check_batched_forward_grad=False, 19760 sample_inputs_func=sample_inputs_tensordot, 19761 skips=( 19762 # Skip operator schema test because this is a functional and not an operator. 19763 # Reference: https://github.com/pytorch/pytorch/issues/54574 19764 DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), 19765 ) 19766 ), 19767 OpInfo('to_sparse', 19768 op=lambda x, *args: x.to_sparse(*args), 19769 sample_inputs_func=sample_inputs_to_sparse, 19770 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 19771 backward_dtypes=floating_types(), 19772 backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 19773 supports_out=False, 19774 supports_sparse_csr=True, 19775 supports_sparse_csc=True, 19776 check_batched_grad=False, 19777 check_batched_gradgrad=False, 19778 skips=( 19779 # NotImplementedError: Could not run 'aten::normal_' with arguments from the 'SparseCPU' backend 19780 DecorateInfo(unittest.skip(""), 'TestCommon', 'test_noncontiguous_samples'), 19781 # TODO: FIXME: complex inputs requiring grad error in forward 19782 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), 19783 # lambda impl 19784 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 19785 # Allowed exception: sparse tensors don't have strides 19786 DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_operator'), 19787 DecorateInfo(unittest.skip("Allowed exception"), 'TestCompositeCompliance', 'test_backward'), 19788 DecorateInfo(unittest.skip("Allowed exception"), 'TestTags', 'test_tags'), 19789 # TODO: implement csr.to_sparse(sample_dim) where sampled_dim is 1. 19790 DecorateInfo(unittest.skip("csr.to_sparse(1) not implemented. Skipped!"), 19791 'TestSparseCSR', 'test_sparse_csr_consistency'), 19792 # Compiler issue on ROCm. Might need to skip until ROCm5.5 19793 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', 19794 dtypes=[torch.bool], active_if=TEST_WITH_ROCM), 19795 ) 19796 ), 19797 OpInfo('logcumsumexp', 19798 dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half), 19799 backward_dtypes=floating_and_complex_types_and(torch.bfloat16), 19800 backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16), 19801 supports_forward_ad=True, 19802 supports_fwgrad_bwgrad=True, 19803 skips=( 19804 # AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it. 19805 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='cuda'), 19806 # RuntimeError: "max_values_cpu" not implemented for 'ComplexDouble' 19807 # Falling back to non-numerically stablized exp, causing nan in the results. 19808 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD', dtypes=[torch.complex128]), 19809 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', dtypes=[torch.complex128]), 19810 DecorateInfo( 19811 toleranceOverride({ 19812 torch.float16: tol(atol=7e-5, rtol=6e-3), 19813 }), 19814 "TestInductorOpInfo", 19815 "test_comprehensive", 19816 device_type="cuda" 19817 ), 19818 ), 19819 sample_inputs_func=sample_inputs_logcumsumexp, 19820 error_inputs_func=error_inputs_logcumsumexp), 19821 UnaryUfuncInfo('sigmoid', 19822 aliases=('special.expit', 'nn.functional.sigmoid'), 19823 aten_backward_name='sigmoid_backward', 19824 ref=reference_sigmoid if TEST_SCIPY else None, 19825 decorators=(precisionOverride({torch.float16: 1e-2, 19826 torch.complex64: 1e-1, 19827 torch.bfloat16: 1e-2}),), 19828 skips=( 19829 # Reference: https://github.com/pytorch/pytorch/issues/56012 19830 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 19831 dtypes=[torch.complex64, torch.cdouble]), 19832 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 19833 dtypes=[torch.chalf, torch.complex64, torch.cdouble])), 19834 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 19835 dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16), 19836 supports_forward_ad=True, 19837 supports_fwgrad_bwgrad=True, 19838 promotes_int_to_float=True, 19839 assert_autodiffed=True, 19840 # sigmoid(z) = 1 / (1 + exp(-z)), at z = j * pi * odd_number, the denominator is zero 19841 reference_numerics_filter=NumericsFilter( 19842 condition=lambda x: (close_to_int(x / (math.pi * 1j)) 19843 if x.is_complex() else x.new_tensor(False, dtype=torch.bool)), 19844 safe_val=0)), 19845 UnaryUfuncInfo('digamma', 19846 ref=scipy.special.digamma if TEST_SCIPY else None, 19847 aliases=('special.psi', 'special.digamma',), 19848 decorators=(precisionOverride({torch.float16: 5e-1}),), 19849 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 19850 dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), 19851 supports_forward_ad=True, 19852 supports_fwgrad_bwgrad=True, 19853 promotes_int_to_float=True), 19854 UnaryUfuncInfo('erf', 19855 ref=scipy.special.erf if TEST_SCIPY else None, 19856 aliases=('special.erf', ), 19857 decorators=(precisionOverride({torch.float16: 1e-2, 19858 torch.bfloat16: 1e-2}),), 19859 skips=( 19860 DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 19861 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), 19862 19863 ), 19864 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 19865 assert_autodiffed=True, 19866 assert_jit_shape_analysis=True, 19867 supports_sparse=True, 19868 supports_sparse_csr=True, 19869 supports_sparse_csc=True, 19870 supports_sparse_bsr=True, 19871 supports_sparse_bsc=True, 19872 supports_forward_ad=True, 19873 supports_fwgrad_bwgrad=True, 19874 promotes_int_to_float=True), 19875 UnaryUfuncInfo('erfc', 19876 ref=scipy.special.erfc if TEST_SCIPY else None, 19877 aliases=('special.erfc', ), 19878 decorators=(precisionOverride({torch.float16: 1e-2, 19879 torch.bfloat16: 1e-2}),), 19880 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 19881 assert_autodiffed=True, 19882 supports_forward_ad=True, 19883 supports_fwgrad_bwgrad=True, 19884 promotes_int_to_float=True), 19885 UnaryUfuncInfo('erfinv', 19886 ref=scipy.special.erfinv if TEST_SCIPY else None, 19887 aliases=('special.erfinv', ), 19888 decorators=(precisionOverride({torch.float16: 1e-2, 19889 torch.bfloat16: 1e-2, 19890 torch.float32: 1e-4}),), 19891 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 19892 dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), 19893 supports_sparse_csr=True, 19894 supports_sparse_csc=True, 19895 supports_sparse_bsr=True, 19896 supports_sparse_bsc=True, 19897 supports_forward_ad=True, 19898 supports_fwgrad_bwgrad=True, 19899 promotes_int_to_float=True, 19900 domain=(-1, 1), 19901 skips=( 19902 # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 19903 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 19904 active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), 19905 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 19906 active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), 19907 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', 19908 active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), 19909 )), 19910 OpInfo("nn.functional.smooth_l1_loss", 19911 ref=reference_smooth_l1_loss, 19912 sample_inputs_func=sample_inputs_smooth_l1_loss, 19913 dtypes=floating_types_and(torch.float16, torch.bfloat16), 19914 backward_dtypes=floating_types_and(torch.bfloat16), 19915 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 19916 backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 19917 supports_out=False, 19918 supports_forward_ad=True, 19919 supports_fwgrad_bwgrad=True, 19920 skips=( 19921 # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED 19922 # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch. 19923 DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),)), 19924 OpInfo( 19925 "nn.functional.l1_loss", 19926 ref=loss_reference_reduction_wrapper(lambda input, target: np.abs(input - target)), 19927 sample_inputs_func=sample_inputs_l1_loss, 19928 error_inputs_func=error_inputs_l1_loss, 19929 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 19930 supports_out=False, 19931 supports_forward_ad=True, 19932 supports_fwgrad_bwgrad=True, 19933 skips=( 19934 # RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED 19935 # at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch. 19936 DecorateInfo( 19937 unittest.expectedFailure, 19938 "TestJit", 19939 "test_variant_consistency_jit", 19940 dtypes=(torch.float32,), 19941 ), 19942 ), 19943 ), 19944 UnaryUfuncInfo('lgamma', 19945 ref=reference_lgamma if TEST_SCIPY else None, 19946 aliases=('special.gammaln', ), 19947 decorators=(precisionOverride({torch.float16: 7e-1}),), 19948 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 19949 dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), 19950 supports_forward_ad=True, 19951 supports_fwgrad_bwgrad=True, 19952 promotes_int_to_float=True, 19953 skips=( 19954 # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214 19955 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 19956 dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), 19957 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 19958 dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), 19959 ), 19960 # lgamma have multiple singularities at x <= 0 19961 reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)), 19962 OpInfo( 19963 'logdet', 19964 dtypes=floating_and_complex_types(), 19965 supports_out=False, 19966 supports_forward_ad=True, 19967 supports_fwgrad_bwgrad=True, 19968 sample_inputs_func=sample_inputs_linalg_det_logdet_slogdet, 19969 decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]), 19970 # `log_softmax` supports different dtypes based on whether `dtype` argument, 19971 # is passed or not. Hence two OpInfo entries, one with dtype and other without. 19972 OpInfo( 19973 'log_softmax', 19974 aliases=('special.log_softmax', 'nn.functional.log_softmax'), 19975 supports_out=True, 19976 aten_backward_name='_log_softmax_backward_data', 19977 dtypes=floating_types_and(torch.float16, torch.bfloat16), 19978 sample_inputs_func=sample_inputs_softmax_variant, 19979 supports_forward_ad=True, 19980 supports_fwgrad_bwgrad=True, 19981 assert_autodiffed=True), 19982 OpInfo( 19983 'log_softmax', 19984 variant_test_name='with_dtype', 19985 aliases=('special.log_softmax', 'nn.functional.log_softmax'), 19986 supports_out=True, 19987 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 19988 sample_inputs_func=partial(sample_inputs_softmax_variant, with_dtype=True), 19989 supports_forward_ad=True, 19990 supports_fwgrad_bwgrad=True, 19991 assert_autodiffed=True), 19992 UnaryUfuncInfo('logit', 19993 aten_backward_name='logit_backward', 19994 ref=scipy.special.logit if TEST_SCIPY else None, 19995 domain=(0, 1), 19996 aliases=('special.logit', ), 19997 supports_forward_ad=True, 19998 supports_fwgrad_bwgrad=True, 19999 promotes_int_to_float=True, 20000 decorators=(precisionOverride({torch.bfloat16: 5e-1, 20001 torch.float16: 5e-1}),), 20002 dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), 20003 sample_inputs_func=sample_inputs_logit), 20004 OpInfo('where', 20005 # Currently only the `input` is tested in gradcheck. 20006 # If we pass `condition` first, none of the input which supports 20007 # autograd will be tested. Hence the following lambda. 20008 op=lambda self, condition, other, **kwargs: torch.where(condition, self, other, **kwargs), 20009 ref=lambda self, condition, other: np.where(condition, self, other), 20010 sample_inputs_func=sample_inputs_where, 20011 reference_inputs_func=reference_inputs_where, 20012 error_inputs_func=error_inputs_where, 20013 supports_forward_ad=True, 20014 supports_fwgrad_bwgrad=True, 20015 decorators=( 20016 DecorateInfo(onlyCUDA, "TestCommon", 'test_errors'),), 20017 skips=( 20018 # lambda impl 20019 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 20020 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 20021 ), 20022 dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf)), 20023 OpInfo('nonzero', 20024 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 20025 sample_inputs_func=sample_inputs_nonzero, 20026 supports_autograd=False, 20027 skips=( 20028 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 20029 # nonzero(): argument 'out' must be Tensor, not tuple 20030 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 20031 # https://github.com/pytorch/pytorch/issues/67458 20032 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20033 # nonzero is not raising a warning when the out is resized 20034 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 20035 # Can't find schemas for this operator for some reason 20036 DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), 20037 # Compiler issue on ROCm. Might need to skip until ROCm5.5 20038 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', 20039 dtypes=[torch.bool], active_if=TEST_WITH_ROCM), 20040 )), 20041 OpInfo('nonzero_static', 20042 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), 20043 sample_inputs_func=sample_inputs_nonzero_static, 20044 supports_out=False, 20045 supports_autograd=False, 20046 decorators=[onlyCPU], 20047 skips=( 20048 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 20049 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), 20050 DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), 20051 DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), 20052 DecorateInfo(unittest.expectedFailure, 'TestVmapOperatorsOpInfo', 'test_op_has_batch_rule'), 20053 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values', 20054 dtypes=[torch.bool], active_if=TEST_WITH_ROCM), 20055 )), 20056 # Following tests are for jiterator's python interface 20057 # Jiterator can be used to author elementwise CUDA kernel 20058 # jiterator._create_jit_fn returns a callable that behaves like a regular pytorch op 20059 # See create_jit_fn in jiterator.py for more information 20060 UnaryUfuncInfo( 20061 'jiterator_unary', 20062 op=torch.cuda.jiterator._create_jit_fn("template <typename T> T unary(T x) { return x * x + x; }"), 20063 ref=lambda x: x * x + x, 20064 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), 20065 supports_out=False, 20066 supports_autograd=False, # jiterator ops doesn't have backward defined 20067 decorators=[ 20068 onlyCUDA, 20069 DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), 20070 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), 20071 DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), 20072 'TestUnaryUfuncs', 'test_reference_numerics_hard'), 20073 DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), 20074 'TestUnaryUfuncs', 'test_reference_numerics_normal'), 20075 DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), 20076 'TestUnaryUfuncs', 'test_reference_numerics_small'), 20077 ], 20078 skips=( 20079 # Jiterator ops doesn't support neg or conj view 20080 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 20081 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 20082 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 20083 # Jiterator ops doesn't support CompositeCompliantTensor 20084 # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped 20085 DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), 20086 # Skip reference_numerics tests for bool type, as the defined function doesn't work for bool 20087 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 20088 dtypes=[torch.bool]), 20089 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_hard', 20090 dtypes=[torch.bool]), 20091 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal', 20092 dtypes=[torch.bool]), 20093 # ROCm generates -inf+infj instead of nan+infj for complex64 for some of the results 20094 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', 20095 dtypes=[torch.complex64], active_if=TEST_WITH_ROCM), 20096 # Expected failure: torch.jiterator_unary is not a valid op 20097 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20098 # Skip Nvfuser 20099 DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), 20100 ) 20101 ), 20102 BinaryUfuncInfo( 20103 'jiterator_binary', 20104 op=torch.cuda.jiterator._create_jit_fn( 20105 "template <typename T> T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1), 20106 ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \ 20107 else np.add(input, np.multiply(alpha, other)), 20108 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), 20109 sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14), 20110 supports_out=False, 20111 supports_autograd=False, # jiterator ops doesn't have backward defined 20112 supports_rhs_python_scalar=False, 20113 decorators=[onlyCUDA], 20114 skips=( 20115 # Jiterator ops doesn't support neg or conj view 20116 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 20117 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 20118 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 20119 # Jiterator ops doesn't support CompositeCompliantTensor 20120 # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped 20121 DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), 20122 # Expected failure: torch.jiterator_binary is not a valid op 20123 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20124 # Skip Nvfuser 20125 DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), 20126 ) 20127 ), 20128 OpInfo( 20129 'jiterator_4inputs_with_extra_args', 20130 op=torch.cuda.jiterator._create_jit_fn( 20131 "template <typename T> T binary(T i0, T i1, T i2, T i3, T alpha, T beta) { return alpha * i0 + beta * i1 + i2 + i3; }", 20132 alpha=1, beta=1), 20133 ref=lambda i0, i1, i2, i3, *, alpha=1, beta=1: alpha * i0 + beta * i1 + i2 + i3, 20134 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), 20135 sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=4, alpha=3.14, beta=-4.20), 20136 supports_out=False, 20137 supports_autograd=False, # jiterator ops doesn't have backward defined 20138 decorators=[onlyCUDA], 20139 skips=( 20140 # Jiterator ops doesn't support neg or conj view 20141 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 20142 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 20143 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 20144 # Jiterator ops doesn't support CompositeCompliantTensor 20145 # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped 20146 DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), 20147 # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op 20148 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20149 # Skip Nvfuser 20150 DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), 20151 ) 20152 ), 20153 BinaryUfuncInfo( 20154 'jiterator_binary_return_by_ref', 20155 op=torch.cuda.jiterator._create_multi_output_jit_fn( 20156 """ 20157 template <typename T> 20158 void binary_return_by_ref(T i0, T i1, T& out0) { 20159 out0 = i0 + i1; 20160 } 20161 """, 20162 num_outputs=1), 20163 ref=operator.add, 20164 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), 20165 sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-0.42), 20166 supports_out=False, 20167 supports_autograd=False, # jiterator ops doesn't have backward defined 20168 supports_rhs_python_scalar=False, 20169 decorators=[onlyCUDA], 20170 skips=( 20171 # Jiterator ops doesn't support neg or conj view 20172 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 20173 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 20174 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 20175 # Jiterator ops doesn't support CompositeCompliantTensor 20176 # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped 20177 DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), 20178 # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op 20179 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20180 # Skip Nvfuser 20181 DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), 20182 ) 20183 ), 20184 OpInfo( 20185 'jiterator_2inputs_2outputs', 20186 op=torch.cuda.jiterator._create_multi_output_jit_fn( 20187 """ 20188 template <typename T> 20189 void binary_2outputs(T i0, T i1, T& out0, T& out1) { 20190 out0 = i0 + i1; 20191 out1 = i0 - i1; 20192 } 20193 """, 20194 num_outputs=2), 20195 ref=lambda i0, i1, *, alpha=1: (i0 + i1, i0 - i1), 20196 dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool), 20197 sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2), 20198 supports_out=False, 20199 supports_autograd=False, # jiterator ops doesn't have backward defined 20200 decorators=[onlyCUDA], 20201 skips=( 20202 # Jiterator ops doesn't support neg or conj view 20203 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 20204 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 20205 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 20206 # Jiterator ops doesn't support CompositeCompliantTensor 20207 # Following test should expectedFailure, but it's causing cascading failures in CUDA, thus skipped 20208 DecorateInfo(unittest.skip("skip"), 'TestCompositeCompliance', 'test_operator'), 20209 # Expected failure: torch.jiterator_4inputs_with_extra_args is not a valid op 20210 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20211 # Skip Nvfuser 20212 DecorateInfo(unittest.skip('Skipped!'), 'TestCudaFuserOpInfo'), 20213 ) 20214 ), 20215 # `torch.norm` has multiple code paths depending on the value of `p`. 20216 # These paths have different dtype support. Also JIT supports, 20217 # most variants but not all of them. So we split the OpInfo entries, 20218 # for `norm` based on the code-paths and JIT support. 20219 OpInfo( 20220 "norm", 20221 sample_inputs_func=sample_inputs_norm, 20222 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 20223 # TODO Benchmark again with the new implementation 20224 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 20225 gradcheck_fast_mode=True, 20226 check_batched_forward_grad=False, 20227 supports_forward_ad=True, 20228 supports_fwgrad_bwgrad=True, 20229 skips=( 20230 # Dispatches in Python to vector_norm. Not sure how to make this test happy 20231 # Happens to pass on complex64. Also a mystery 20232 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', 20233 dtypes=(torch.float32,)),) 20234 ), 20235 OpInfo('norm', 20236 variant_test_name='nuc', 20237 sample_inputs_func=sample_inputs_norm_nuc, 20238 decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], 20239 check_batched_gradgrad=False, 20240 # torch.autograd.gradcheck.GradcheckError: While computing batched gradients 20241 # got: Could not allocate memory to change Tensor SizesAndStrides! 20242 check_batched_forward_grad=False, 20243 supports_forward_ad=True, 20244 supports_fwgrad_bwgrad=True, 20245 dtypes=floating_and_complex_types(), 20246 dtypesIfCUDA=floating_and_complex_types(), 20247 skips=( 20248 # Dispatches in Python to matrix_norm. Not sure how to make this test happy 20249 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', 20250 dtypes=(torch.complex64, torch.float32,)),) 20251 ), 20252 OpInfo('norm', 20253 variant_test_name='fro', 20254 sample_inputs_func=sample_inputs_norm_fro, 20255 dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), 20256 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), 20257 supports_forward_ad=True, 20258 # torch.autograd.gradcheck.GradcheckError: While computing batched gradients 20259 # got: Could not allocate memory to change Tensor SizesAndStrides! 20260 check_batched_forward_grad=False, 20261 supports_fwgrad_bwgrad=True, 20262 skips=( 20263 # MPS has some mild accuracy issues for float16. We divide the tolerances by 10 20264 DecorateInfo( 20265 toleranceOverride({torch.float16: tol(atol=1e-4, rtol=0.01)}), 20266 'TestConsistency', 20267 'test_output_match', 20268 20269 ), 20270 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 20271 DecorateInfo( 20272 unittest.skip("Skipped!"), 20273 'TestSchemaCheckModeOpInfo', 20274 'test_schema_correctness', 20275 dtypes=(torch.complex64, torch.complex128)), 20276 # Dispatches in Python to vector_norm. Not sure how to make this test happy 20277 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', 20278 dtypes=(torch.complex64, torch.float32,)),) 20279 ), 20280 OpInfo( 20281 "norm", 20282 variant_test_name="inf", 20283 sample_inputs_func=sample_inputs_norm_inf, 20284 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 20285 supports_forward_ad=True, 20286 check_batched_forward_grad=False, 20287 supports_fwgrad_bwgrad=True, 20288 # fast gradcheck produces NaNs 20289 gradcheck_fast_mode=False, 20290 skips=( 20291 DecorateInfo( 20292 toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}), 20293 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda', 20294 ), 20295 # Dispatches in Python to vector_norm. Not sure how to make this test happy 20296 # Happens to pass on complex64. Also a mystery 20297 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', 20298 dtypes=(torch.float32,)) 20299 ), 20300 ), 20301 OpInfo('t', 20302 sample_inputs_func=sample_inputs_t, 20303 supports_out=False, 20304 supports_forward_ad=True, 20305 supports_fwgrad_bwgrad=True, 20306 # See https://github.com/pytorch/pytorch/pull/78358 20307 check_batched_forward_grad=False, 20308 # vmap does not support inplace views 20309 check_inplace_batched_forward_grad=False, 20310 autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused 20311 autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused 20312 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 20313 assert_autodiffed=True, 20314 error_inputs_func=error_inputs_t), 20315 OpInfo('t_copy', 20316 sample_inputs_func=sample_inputs_t, 20317 supports_out=True, 20318 supports_forward_ad=True, 20319 supports_fwgrad_bwgrad=True, 20320 # See https://github.com/pytorch/pytorch/pull/78358 20321 check_batched_forward_grad=False, 20322 # vmap does not support inplace views 20323 check_inplace_batched_forward_grad=False, 20324 autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused 20325 autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused 20326 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 20327 assert_autodiffed=True, 20328 error_inputs_func=error_inputs_t), 20329 OpInfo( 20330 "nn.functional.dropout", 20331 op=lambda input, *args, **kwargs: 20332 wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs), 20333 dtypes=floating_types_and(torch.float16, torch.bfloat16), 20334 skips=( 20335 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 20336 # Probably because we have used lambda for the op here 20337 # AssertionError: JIT Test does not execute any logic 20338 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20339 # inplace variant dispatches to dropout kernel, while on CUDA 20340 # the op dispatches to _fused_dropout (with a few more conditions) 20341 # hence, different values and this skip here 20342 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'), 20343 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), 20344 supports_forward_ad=True, 20345 supports_fwgrad_bwgrad=True, 20346 # https://github.com/pytorch/pytorch/issues/66357 20347 check_batched_forward_grad=False, 20348 supports_out=False, 20349 sample_inputs_func=sample_inputs_dropout, 20350 inplace_variant=lambda input, *args, **kwargs: 20351 wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)), 20352 OpInfo( 20353 "native_dropout_backward", 20354 op=torch.ops.aten.native_dropout_backward.default, 20355 aten_name="native_dropout_backward", 20356 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 20357 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 20358 supports_out=False, 20359 sample_inputs_func=sample_inputs_dropout_backward, 20360 skips=( 20361 DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), 20362 # Lazy tensor failures 20363 DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'), 20364 # These tests fail only when built with ASAN 20365 DecorateInfo(unittest.skip("Fails with ASAN"), 'TestLazyOpInfo', 'test_correctness', active_if=TEST_WITH_ASAN), 20366 DecorateInfo( 20367 unittest.skip("Fails with ASAN"), 20368 'TestLazyOpInfo', 20369 'test_correctness_with_reusing_ir', 20370 active_if=TEST_WITH_ASAN 20371 ), 20372 ), 20373 ), 20374 OpInfo( 20375 "nn.functional.dropout2d", 20376 op=lambda input, *args, **kwargs: 20377 wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs), 20378 dtypes=floating_types_and(torch.float16, torch.bfloat16), 20379 skips=( 20380 # lambda impl 20381 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 20382 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20383 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), 20384 supports_forward_ad=True, 20385 supports_fwgrad_bwgrad=True, 20386 supports_out=False, 20387 check_batched_forward_grad=False, 20388 # As per the docs, valid input dims are (3, 4) 20389 sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(3, 4)), 20390 inplace_variant=lambda input, *args, **kwargs: 20391 wrapper_set_seed(torch.nn.functional.dropout2d, input, *args, **kwargs, inplace=True)), 20392 OpInfo( 20393 "nn.functional.dropout3d", 20394 op=lambda input, *args, **kwargs: 20395 wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs), 20396 dtypes=floating_types_and(torch.float16, torch.bfloat16), 20397 skips=( 20398 # lambda impl 20399 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 20400 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20401 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), 20402 supports_forward_ad=True, 20403 supports_fwgrad_bwgrad=True, 20404 supports_out=False, 20405 check_batched_forward_grad=False, 20406 # As per the docs, valid input dims are (4, 5) 20407 sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(4, 5)), 20408 inplace_variant=lambda input, *args, **kwargs: 20409 wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs, inplace=True)), 20410 OpInfo( 20411 "nn.functional.alpha_dropout", 20412 op=lambda input, *args, **kwargs: 20413 wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs), 20414 dtypes=floating_types_and(torch.float16, torch.bfloat16), 20415 gradcheck_wrapper=wrapper_set_seed, 20416 supports_forward_ad=True, 20417 supports_fwgrad_bwgrad=True, 20418 supports_out=False, 20419 sample_inputs_func=sample_inputs_dropout, 20420 check_batched_forward_grad=False, 20421 inplace_variant=lambda input, *args, **kwargs: 20422 wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs, inplace=True), 20423 skips=( 20424 # lambda impl 20425 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 20426 # AssertionError: Tensor-likes are not close! 20427 # Fails in cuda11.7 20428 # Error Log: https://github.com/pytorch/pytorch/actions/runs/3440108478/jobs/5738475757 20429 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'), 20430 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),), 20431 # In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype 20432 # unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases 20433 OpInfo( 20434 "nn.functional.feature_alpha_dropout", 20435 op=lambda input, *args, **kwargs: 20436 wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs), 20437 variant_test_name="with_train", 20438 dtypes=floating_types_and(torch.float16, torch.bfloat16), 20439 skips=( 20440 # lambda impl 20441 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 20442 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20443 # torch.autograd.gradcheck.GradcheckError: While computing batched gradients, got: 20444 # vmap: We do not yet support calling random operations inside of vmap. 20445 # Please perform random operations outside of vmap as a workaround 20446 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_forward_mode_AD"), 20447 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', "test_inplace_forward_mode_AD"), 20448 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')), 20449 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 20450 gradcheck_fast_mode=True, 20451 supports_forward_ad=True, 20452 supports_fwgrad_bwgrad=True, 20453 supports_out=False, 20454 # As per the docs, valid input dims are (4, 5) 20455 sample_inputs_func=partial(sample_inputs_dropout, train=True, valid_input_dim=(4, 5)), 20456 inplace_variant=lambda input, *args, **kwargs: 20457 wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)), 20458 OpInfo( 20459 "nn.functional.feature_alpha_dropout", 20460 op=lambda input, *args, **kwargs: 20461 wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs), 20462 variant_test_name="without_train", 20463 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 20464 skips=( 20465 # lambda impl 20466 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 20467 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),), 20468 gradcheck_wrapper=wrapper_set_seed, 20469 supports_forward_ad=True, 20470 supports_fwgrad_bwgrad=True, 20471 supports_out=False, 20472 sample_inputs_func=partial(sample_inputs_dropout, train=False), 20473 inplace_variant=lambda input, *args, **kwargs: 20474 wrapper_set_seed(torch.nn.functional.feature_alpha_dropout, input, *args, **kwargs, inplace=True)), 20475 OpInfo( 20476 "nn.functional.one_hot", 20477 ref=reference_one_hot, 20478 supports_out=False, 20479 dtypes=_dispatch_dtypes((torch.int64,)), 20480 sample_inputs_func=sample_inputs_one_hot, 20481 ), 20482 OpInfo( 20483 "nn.functional.embedding", 20484 aten_backward_name="embedding_dense_backward", 20485 # We use lambda to reshuffle the positional arguments. 20486 # This is because currently only the `input` field of SampleInput 20487 # is tested in gradient tests. 20488 op=lambda weight, idx, **kwargs: torch.nn.functional.embedding(idx, weight, **kwargs), 20489 dtypes=floating_types_and(torch.bfloat16, torch.float16), 20490 sample_inputs_func=sample_inputs_embedding, 20491 allow_cow_input_materialize_forward=[0], 20492 error_inputs_func=error_inputs_embedding, 20493 supports_forward_ad=True, 20494 supports_fwgrad_bwgrad=True, 20495 skips=( 20496 # lambda impl 20497 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20498 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 20499 # Fails on CI https://github.com/pytorch/pytorch/issues/85377 20500 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_compare_cpu'), 20501 # Reference: https://github.com/pytorch/pytorch/issues/67084 20502 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'), 20503 # Not a problem: embedding does weird stuff to its input (it renormalizes) 20504 DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), 20505 # Fails due to non-determinism (see issue #74679) 20506 # TODO: Investigate why more granular skips in the test don't work in CI 20507 DecorateInfo(unittest.skip('Skipped!'), 20508 'TestExpandedWeightFunctional', 20509 'test_expanded_weight_forward'), 20510 ), 20511 supports_expanded_weight=True, 20512 supports_out=False, 20513 ), 20514 OpInfo( 20515 "nn.functional.embedding_bag", 20516 # We use lambda to reshuffle the positional arguments. 20517 # This is because currently only the `input` field of SampleInput 20518 # is tested in gradient tests. 20519 op=lambda weight, idx, **kwargs: torch.nn.functional.embedding_bag(idx, weight, **kwargs), 20520 dtypes=floating_types_and(torch.bfloat16, torch.float16), 20521 dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), 20522 # backward is not supported for mode `max` and dtype `bfloat16` 20523 backward_dtypesIfCUDA=floating_types_and(torch.float16), 20524 sample_inputs_func=sample_inputs_embedding_bag, 20525 skips=( 20526 # lambda impl 20527 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20528 DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), 20529 # Not a problem: embedding_bag does weird stuff to its input (it renormalizes) 20530 DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'), 20531 ), 20532 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 20533 supports_out=False, 20534 supports_gradgrad=False, 20535 allow_cow_input_materialize_forward=[0], 20536 ), 20537 OpInfo( 20538 "nn.functional.multi_head_attention_forward", 20539 op=lambda input, *args, **kwargs: 20540 wrapper_set_seed(torch.nn.functional.multi_head_attention_forward, input, *args, **kwargs), 20541 dtypes=floating_types_and(torch.bfloat16, torch.float16), 20542 sample_inputs_func=sample_inputs_multi_head_attention_forward, 20543 skips=( 20544 # Tensor-likes are not close 20545 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', dtypes=(torch.float32,)), 20546 DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-3, rtol=0)}), 'TestDecomp', 'test_comprehensive'), 20547 20548 # TODO skip this for now since we can't skip on runtime arch support (taken from scaled_dot_product_attention) 20549 DecorateInfo(unittest.skip("Skipped!"), 'TestInductorOpInfo', 'test_comprehensive'), 20550 # randomness 20551 DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), 20552 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 20553 # lambda impl 20554 # AssertionError: JIT Test does not execute any logic 20555 DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), 20556 DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), 20557 # tests running very slowly break slow tests, so we skip them instead of using `slowTest`. 20558 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), 20559 DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'), 20560 DecorateInfo( 20561 unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"), 20562 'TestDecomp', 20563 'test_comprehensive', 20564 dtypes=(torch.bfloat16, torch.float16), 20565 ), 20566 DecorateInfo( 20567 unittest.skip("Skipped - baddbmm decomp does not have enough precision for 16-bit float"), 20568 'TestDecomp', 20569 'test_quick', 20570 dtypes=(torch.bfloat16, torch.float16))), 20571 supports_out=False, 20572 supports_gradgrad=True, 20573 supports_forward_ad=True, 20574 supports_fwgrad_bwgrad=True, 20575 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 20576 gradcheck_fast_mode=True, 20577 ), 20578 UnaryUfuncInfo( 20579 "nn.functional.softplus", 20580 aten_backward_name='softplus_backward', 20581 ref=reference_softplus, 20582 sample_kwargs=lambda device, dtype, input: ({'beta': 3, 'threshold': .2}, {'beta': 3, 'threshold': .2}), 20583 sample_inputs_func=partial(sample_inputs_elementwise_unary, op_kwargs={'beta': 3, 'threshold': .2}), 20584 supports_forward_ad=True, 20585 supports_fwgrad_bwgrad=True, 20586 dtypes=floating_types_and(torch.bfloat16, torch.float16), 20587 decorators=( 20588 DecorateInfo( 20589 toleranceOverride 20590 ({ 20591 torch.half: tol(atol=1e-2, rtol=1e-2), 20592 torch.bfloat16: tol(atol=1e-2, rtol=1e-2), 20593 }), 20594 'TestUnaryUfuncs'), 20595 ), 20596 ), 20597 OpInfo( 20598 "nn.functional.mse_loss", 20599 aten_backward_name='mse_loss_backward', 20600 ref=loss_reference_reduction_wrapper(lambda input, target: (input - target) ** 2), 20601 sample_inputs_func=sample_inputs_loss, 20602 supports_out=False, 20603 supports_forward_ad=True, 20604 supports_fwgrad_bwgrad=True, 20605 dtypes=floating_types_and(torch.float16), 20606 backward_dtypes=floating_types(), 20607 dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), 20608 backward_dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), 20609 skips=( 20610 # RuntimeError: input->type()->kind() == TypeKind::OptionalType 20611 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252, 20612 # please report a bug to PyTorch. 20613 DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), 20614 ), 20615 ), 20616 OpInfo( 20617 "nn.functional.grid_sample", 20618 dtypes=floating_types(), 20619 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 20620 supports_out=False, 20621 sample_inputs_func=sample_inputs_grid_sample, 20622 reference_inputs_func=reference_inputs_grid_sample, 20623 supports_gradgrad=False, 20624 gradcheck_nondet_tol=1e-15), 20625 # TODO: delete this OpInfo once we add meta support for grid_sampler_3d 20626 OpInfo( 20627 "grid_sampler_2d", 20628 dtypes=floating_types(), 20629 dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), 20630 supports_out=False, 20631 sample_inputs_func=sample_inputs_grid_sampler_2d, 20632 supports_gradgrad=False, 20633 gradcheck_nondet_tol=1e-15, 20634 skips=( 20635 DecorateInfo(slowTest, 'TestDecomp', 'test_comprehensive', dtypes=(torch.float32, torch.float64), 20636 active_if=IS_WINDOWS), 20637 ),), 20638 OpInfo( 20639 "argwhere", 20640 ref=np.argwhere, 20641 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 20642 supports_out=False, 20643 supports_autograd=False, 20644 sample_inputs_func=sample_inputs_argwhere, 20645 skips=( 20646 # Compiler issue on ROCm. Might need to skip until ROCm5.5 20647 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', 20648 dtypes=[torch.bool], active_if=TEST_WITH_ROCM), 20649 ), 20650 ), 20651 ReductionOpInfo( 20652 'all', 20653 identity=True, 20654 supports_autograd=False, 20655 result_dtype=torch.bool, 20656 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 20657 ref=reference_reduction_numpy(np.all), 20658 skips=( 20659 # FIXME: uint8 input returns uint8 instead of bool 20660 DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), 20661 ), 20662 ), 20663 ReductionOpInfo( 20664 'any', 20665 identity=False, 20666 supports_autograd=False, 20667 result_dtype=torch.bool, 20668 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 20669 ref=reference_reduction_numpy(np.any), 20670 skips=( 20671 # FIXME: uint8 input returns uint8 instead of bool 20672 DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), 20673 ), 20674 ), 20675 ReductionOpInfo( 20676 'amax', 20677 nan_policy='propagate', 20678 supports_forward_ad=True, 20679 check_batched_forward_grad=False, 20680 supports_fwgrad_bwgrad=True, 20681 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 20682 ref=reference_reduction_numpy(np.amax), 20683 skips=( 20684 # FIXME: reduces all dimensions when dim=[] 20685 DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), 20686 DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), 20687 ), 20688 error_inputs_func=error_inputs_aminmax_amax_amin, 20689 ), 20690 ReductionOpInfo( 20691 'amin', 20692 nan_policy='propagate', 20693 supports_forward_ad=True, 20694 check_batched_forward_grad=False, 20695 supports_fwgrad_bwgrad=True, 20696 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 20697 ref=reference_reduction_numpy(np.amin), 20698 skips=( 20699 # FIXME: reduces all dimensions when dim=[] 20700 DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), 20701 DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), 20702 ), 20703 error_inputs_func=error_inputs_aminmax_amax_amin, 20704 ), 20705 ReductionOpInfo( 20706 'argmax', 20707 supports_multiple_dims=False, 20708 supports_autograd=False, 20709 assert_jit_shape_analysis=True, 20710 result_dtype=torch.int64, 20711 dtypes=all_types_and(torch.float16, torch.bfloat16), 20712 ref=reference_reduction_numpy(np.argmax, supports_keepdims=False), 20713 ), 20714 ReductionOpInfo( 20715 'argmin', 20716 supports_multiple_dims=False, 20717 supports_autograd=False, 20718 result_dtype=torch.int64, 20719 dtypes=all_types_and(torch.float16, torch.bfloat16), 20720 ref=reference_reduction_numpy(np.argmin, supports_keepdims=False), 20721 ), 20722 ReductionOpInfo( 20723 'count_nonzero', 20724 identity=0, 20725 supports_out=False, 20726 supports_autograd=False, 20727 result_dtype=torch.int64, 20728 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 20729 sample_inputs_func=sample_inputs_reduction_count_nonzero, 20730 ref=reference_reduction_numpy(np.count_nonzero), 20731 skips=( 20732 # FIXME: count_nonzero does not accept keepdim kwarg 20733 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), 20734 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), 20735 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'), 20736 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), 20737 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'), 20738 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_unsorted_keepdim'), 20739 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_offbounds_keepdim'), 20740 # FIXME: dim=[] reduces all dimensions 20741 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), 20742 ), 20743 ), 20744 ReductionOpInfo( 20745 'mean', 20746 nan_policy='propagate', 20747 supports_forward_ad=True, 20748 supports_fwgrad_bwgrad=True, 20749 # FIXME: mean needs 'dim' parameter when using the 'out' overload. 20750 # Adding it with 'generate_args_kwargs' does not work, since these also get passed 20751 # onto the reference implementations. 20752 supports_out=True, 20753 assert_autodiffed=True, 20754 assert_jit_shape_analysis=True, 20755 promotes_int_to_float=True, 20756 dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), 20757 ref=reference_reduction_numpy(np.mean), 20758 error_inputs_func=error_inputs_mean, 20759 skips=( 20760 # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result 20761 # of dtype torch.float32 into an out= with dtype torch.long 20762 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='cuda', dtypes=[torch.float32]), 20763 # FIXME: mean does not support passing keepdim without passing dim 20764 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), 20765 # FIXME: mean reduces all dimensions when dim=[] 20766 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), 20767 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), 20768 # FIXME: improve precision 20769 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', 20770 dtypes=[torch.float16]), 20771 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values', 20772 device_type='cuda', dtypes=[torch.complex64]), 20773 ), 20774 ), 20775 ReductionOpInfo( 20776 'nanmean', 20777 nan_policy='omit', 20778 assert_autodiffed=True, 20779 promotes_int_to_float=True, 20780 supports_forward_ad=True, 20781 check_batched_forward_grad=False, 20782 supports_fwgrad_bwgrad=True, 20783 dtypes=floating_types_and(torch.float16, torch.bfloat16), 20784 dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), 20785 sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True), 20786 ref=reference_reduction_numpy(np.nanmean), 20787 skips=( 20788 # AssertionError: False is not true : 20789 # Failure in testing nodes' autodifferentiation. 20790 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 20791 # FIXME: prod reduces all dimensions when dim=[] 20792 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), 20793 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), 20794 # FIXME: improve precision 20795 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', 20796 dtypes=[torch.float16]), 20797 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', 20798 device_type='cuda', dtypes=[torch.float16]), 20799 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_extremal_values', 20800 device_type='cuda', dtypes=[torch.complex64]), 20801 ), 20802 ), 20803 ReductionOpInfo( 20804 'std', 20805 nan_policy='propagate', 20806 supports_out=True, 20807 complex_to_real=True, 20808 supports_forward_ad=True, 20809 supports_fwgrad_bwgrad=True, 20810 assert_autodiffed=True, 20811 promotes_int_to_float=True, 20812 check_batched_forward_grad=False, 20813 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 20814 dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), 20815 sample_inputs_func=sample_inputs_std_var, 20816 ref=reference_std_var(np.std), 20817 generate_args_kwargs=generate_std_var_kwargs, 20818 skips=( 20819 # FIXME: cannot specify keepdim without dim 20820 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), 20821 # FIXME: dim=[] reduces all dimensions 20822 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), 20823 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), 20824 # FIXME: improve precision 20825 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', 20826 dtypes=(torch.float16,)), 20827 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', 20828 dtypes=(torch.float16,)), 20829 ), 20830 ), 20831 ReductionOpInfo( 20832 'std', 20833 variant_test_name='unbiased', 20834 nan_policy='propagate', 20835 supports_out=False, 20836 complex_to_real=True, 20837 supports_forward_ad=True, 20838 supports_fwgrad_bwgrad=True, 20839 assert_autodiffed=True, 20840 promotes_int_to_float=True, 20841 check_batched_forward_grad=False, 20842 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 20843 dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), 20844 sample_inputs_func=sample_inputs_std_var_unbiased, 20845 skips=( 20846 # FIXME: dim=[] reduces all dimensions 20847 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), 20848 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), 20849 ), 20850 ), 20851 ReductionOpInfo( 20852 'var', 20853 nan_policy='propagate', 20854 supports_out=True, 20855 assert_autodiffed=True, 20856 promotes_int_to_float=True, 20857 complex_to_real=True, 20858 supports_forward_ad=True, 20859 supports_fwgrad_bwgrad=True, 20860 check_batched_forward_grad=False, 20861 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 20862 dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), 20863 sample_inputs_func=sample_inputs_std_var, 20864 ref=reference_std_var(np.var), 20865 generate_args_kwargs=generate_std_var_kwargs, 20866 skips=( 20867 # FIXME: cannot specify keepdim without dim 20868 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), 20869 # FIXME: dim=[] reduces all dimensions 20870 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), 20871 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), 20872 # FIXME: improve precision 20873 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'), 20874 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'), 20875 # NumPy is giving NaN for this 20876 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'), 20877 ), 20878 ), 20879 ReductionOpInfo( 20880 'var', 20881 variant_test_name='unbiased', 20882 nan_policy='propagate', 20883 supports_out=False, 20884 complex_to_real=True, 20885 supports_forward_ad=True, 20886 supports_fwgrad_bwgrad=True, 20887 assert_autodiffed=True, 20888 promotes_int_to_float=True, 20889 check_batched_forward_grad=False, 20890 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 20891 dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), 20892 sample_inputs_func=sample_inputs_std_var_unbiased, 20893 skips=( 20894 # FIXME: dim=[] reduces all dimensions 20895 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), 20896 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), 20897 ), 20898 ), 20899 ReductionOpInfo( 20900 'prod', 20901 identity=1, 20902 nan_policy='propagate', 20903 supports_multiple_dims=False, 20904 # https://github.com/pytorch/pytorch/issues/80411 20905 gradcheck_fast_mode=True, 20906 supports_out=False, 20907 supports_forward_ad=True, 20908 supports_fwgrad_bwgrad=True, 20909 promotes_int_to_int64=True, 20910 gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, 20911 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), 20912 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 20913 sample_inputs_func=sample_inputs_prod, 20914 ref=prod_numpy, 20915 skips=( 20916 # FIXME: prod does not support passing keepdim without passing dim 20917 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), 20918 # FIXME: prod reduces all dimensions when dim=[] 20919 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), 20920 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), 20921 # FIXME: prod does not support passing None to dim 20922 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'), 20923 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), 20924 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', 20925 dtypes=[torch.float16, torch.complex64]), 20926 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', 20927 dtypes=[torch.uint8, torch.float16, torch.complex64]), 20928 # FIXME: ValueError: The data in MaskedTensor a and Tensor b do not match 20929 DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', 20930 dtypes=[torch.float16]), 20931 ), 20932 ), 20933 ReductionOpInfo( 20934 'sum', 20935 identity=0, 20936 nan_policy='propagate', 20937 supports_out=False, 20938 supports_forward_ad=True, 20939 supports_fwgrad_bwgrad=True, 20940 promotes_int_to_int64=True, 20941 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 20942 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 20943 ref=reference_reduction_numpy(np.sum), 20944 error_inputs_sparse_func=error_inputs_sparse_reduction_sum, 20945 sample_inputs_sparse_coo_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_coo), 20946 sample_inputs_sparse_csr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csr), 20947 sample_inputs_sparse_csc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_csc), 20948 sample_inputs_sparse_bsr_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsr), 20949 sample_inputs_sparse_bsc_func=partial(sample_inputs_sparse_reduction_sum, layout=torch.sparse_bsc), 20950 skips=( 20951 # FIXME: sum does not support passing keepdim without passing dim 20952 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), 20953 # FIXME: sum reduces all dimensions when dim=[] 20954 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), 20955 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), 20956 # FIXME: improve precision 20957 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', 20958 dtypes=[torch.float16]), 20959 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values', 20960 dtypes=[torch.float16]), 20961 DecorateInfo(unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', 20962 dtypes=[torch.float32]), 20963 ), 20964 ), 20965 ReductionOpInfo( 20966 'nansum', 20967 identity=0, 20968 nan_policy='omit', 20969 supports_out=True, 20970 promotes_int_to_int64=True, 20971 supports_forward_ad=True, 20972 check_batched_forward_grad=False, 20973 supports_fwgrad_bwgrad=True, 20974 dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), 20975 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 20976 sample_inputs_func=sample_inputs_nan_reduction(supports_multiple_dims=True), 20977 ref=reference_reduction_numpy(np.nansum), 20978 skips=( 20979 # please report a bug to PyTorch. 20980 DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), 20981 # FIXME: nansum reduces all dimensions when dim=[] 20982 DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), 20983 DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), 20984 # FIXME: flaky test so skipped instead of xfailed 20985 # possibly bad low precision reference in numpy 20986 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', 20987 dtypes=[torch.float16]), 20988 ), 20989 ), 20990 OpInfo( 20991 "nn.functional.ctc_loss", 20992 dtypes=floating_types(), 20993 supports_out=False, 20994 sample_inputs_func=sample_inputs_ctc_loss, 20995 skips=( 20996 # https://github.com/pytorch/pytorch/issues/67462 20997 # torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0 20998 DecorateInfo( 20999 unittest.expectedFailure, 21000 "TestBwdGradients", 21001 "test_fn_grad", 21002 dtypes=(torch.float64,), 21003 ), 21004 # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented 21005 DecorateInfo( 21006 unittest.expectedFailure, 21007 "TestBwdGradients", 21008 "test_fn_gradgrad", 21009 dtypes=(torch.float64,), 21010 ), 21011 # RuntimeError: derivative for aten::_ctc_loss_backward is not implemented 21012 DecorateInfo( 21013 unittest.skip("Skipped!"), 21014 "TestJit", 21015 "test_variant_consistency_jit", 21016 dtypes=(torch.float32,), 21017 ), 21018 # Ref: https://github.com/pytorch/pytorch/issues/85231 21019 DecorateInfo(unittest.skip("Fails with ASAN"), 21020 'TestProxyTensorOpInfo', 21021 'test_make_fx_fake_exhaustive', active_if=TEST_WITH_ASAN), 21022 ), 21023 ), 21024 OpInfo( 21025 "nn.functional.cosine_embedding_loss", 21026 dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool), 21027 supports_out=False, 21028 supports_forward_ad=True, 21029 supports_fwgrad_bwgrad=True, 21030 decorators=[ 21031 DecorateInfo( 21032 toleranceOverride({torch.float16: tol(atol=1e-4, rtol=2e-3)}), 21033 'TestInductorOpInfo', 'test_comprehensive', device_type="cuda", 21034 ), 21035 ], 21036 sample_inputs_func=sample_inputs_cosine_embedding_loss, 21037 ), 21038 OpInfo( 21039 "nn.functional.nll_loss", 21040 dtypes=floating_types_and(torch.float16, torch.bfloat16), 21041 supports_out=False, 21042 sample_inputs_func=sample_inputs_nll_loss, 21043 supports_forward_ad=True, 21044 supports_fwgrad_bwgrad=True, 21045 assert_jit_shape_analysis=True, 21046 skips=( 21047 # RuntimeError: 21048 # undefined value tensor: 21049 # File "<string>", line 3 21050 # def the_method(i0, i1): 21051 # return torch.nn.functional.nll_loss(i0, i1, weight=tensor([8.4784, 1.7658, 4.3228], dtype=torch.float32)) 21052 # ~~~~~~ <--- HERE 21053 DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), 21054 # Fails for unknown reason: https://github.com/pytorch/pytorch/issues/120782 21055 DecorateInfo( 21056 unittest.skip("Skipped!"), 21057 "TestCompositeCompliance", 21058 "test_cow_input", 21059 device_type='cuda', 21060 ), 21061 DecorateInfo(unittest.skip("FP16 nll_loss cases have not been enabled on MPS yet"), 21062 dtypes=(torch.half,), device_type="mps"), 21063 21064 ), 21065 ), 21066 OpInfo( 21067 "nn.functional.gaussian_nll_loss", 21068 dtypes=floating_types_and(torch.half, torch.bfloat16), 21069 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 21070 gradcheck_fast_mode=True, 21071 supports_out=False, 21072 supports_forward_ad=True, 21073 supports_fwgrad_bwgrad=True, 21074 sample_inputs_func=sample_inputs_gaussian_nll_loss, 21075 error_inputs_func=error_inputs_gaussian_nll_loss, 21076 skips=( 21077 # Pre-existing condition (calls .item); needs to be fixed 21078 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), 21079 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), 21080 # Pre-existing condition (calls .item); needs to be fixed 21081 DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), 21082 # JIT does not support variadic tensors. 21083 # RuntimeError: input->type()->kind() == TypeKind::OptionalType 21084 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, 21085 # please report a bug to PyTorch. 21086 DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), 21087 ), 21088 ), 21089 OpInfo( 21090 "nn.functional.hinge_embedding_loss", 21091 dtypes=floating_types_and(torch.half, torch.bfloat16), 21092 supports_out=False, 21093 supports_forward_ad=True, 21094 supports_fwgrad_bwgrad=True, 21095 sample_inputs_func=sample_inputs_hinge_embedding_loss, 21096 error_inputs_func=error_inputs_hinge_embedding_loss, 21097 reference_inputs_func=reference_inputs_hinge_embedding_loss, 21098 ), 21099 OpInfo( 21100 "nn.functional.huber_loss", 21101 aten_backward_name='huber_loss_backward', 21102 dtypes=floating_types_and(torch.float16, torch.bfloat16), 21103 supports_out=False, 21104 supports_forward_ad=True, 21105 sample_inputs_func=sample_inputs_huber_loss, 21106 error_inputs_func=error_inputs_huber_loss, 21107 skips=( 21108 # JIT does not support variadic tensors. 21109 # RuntimeError: input->type()->kind() == TypeKind::OptionalType 21110 # INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, 21111 # please report a bug to PyTorch. 21112 DecorateInfo(unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit", dtypes=(torch.float32,),), 21113 ) 21114 ), 21115 OpInfo( 21116 "nn.functional.pdist", 21117 ref=reference_pdist, 21118 sample_inputs_func=sample_inputs_pdist, 21119 dtypes=floating_types(), 21120 supports_out=False, 21121 supports_gradgrad=False, 21122 skips=( 21123 DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'), 21124 ) 21125 ), 21126 OpInfo( 21127 "nn.functional.poisson_nll_loss", 21128 dtypes=all_types_and(torch.half, torch.bfloat16), 21129 supports_out=False, 21130 supports_forward_ad=True, 21131 supports_fwgrad_bwgrad=True, 21132 sample_inputs_func=sample_inputs_poisson_nll_loss, 21133 error_inputs_func=error_inputs_poisson_nll_loss, 21134 ), 21135 OpInfo( 21136 "argsort", 21137 dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), 21138 dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), 21139 sample_inputs_func=sample_inputs_sort, 21140 supports_out=False, 21141 supports_autograd=False, 21142 skips=( 21143 DecorateInfo( 21144 unittest.skip("Skipped!"), 21145 "TestJit", 21146 "test_variant_consistency_jit", 21147 dtypes=(torch.float32,), 21148 ), 21149 ), 21150 ), 21151 OpInfo( 21152 "repeat_interleave", 21153 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), 21154 backward_dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16, torch.chalf), 21155 sample_inputs_func=sample_inputs_repeat_interleave, 21156 supports_out=False, 21157 supports_forward_ad=True, 21158 supports_fwgrad_bwgrad=True, 21159 # See https://github.com/pytorch/pytorch/pull/78358 21160 check_batched_forward_grad=False, 21161 skips=( 21162 DecorateInfo( 21163 unittest.skip("Skipped!"), 21164 "TestJit", 21165 "test_variant_consistency_jit", 21166 dtypes=(torch.float32, torch.complex64), 21167 ), 21168 ), 21169 ), 21170 OpInfo( 21171 "nn.functional.pairwise_distance", 21172 ref=lambda a, b, p=2.0, eps=1e-6, keepdim=False: ( 21173 np.sum(np.abs(a - b + eps) ** p, axis=-1, keepdims=keepdim) ** (1 / p) 21174 ), 21175 sample_inputs_func=sample_inputs_pairwise_distance, 21176 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 21177 supports_out=False, 21178 supports_forward_ad=True, 21179 supports_fwgrad_bwgrad=True, 21180 skips=( 21181 DecorateInfo( 21182 unittest.skip("Skipped!"), 21183 "TestJit", 21184 "test_variant_consistency_jit", 21185 dtypes=(torch.float32, torch.complex64), 21186 ), 21187 ), 21188 ), 21189 OpInfo( 21190 "nn.functional.pixel_shuffle", 21191 sample_inputs_func=sample_inputs_pixel_shuffle, 21192 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 21193 supports_out=False, 21194 supports_forward_ad=True, 21195 supports_fwgrad_bwgrad=True, 21196 skips=( 21197 DecorateInfo( 21198 unittest.skip("Skipped!"), 21199 "TestJit", 21200 "test_variant_consistency_jit", 21201 dtypes=(torch.float32, torch.complex64), 21202 ), 21203 ), 21204 ), 21205 OpInfo( 21206 "nn.functional.pixel_unshuffle", 21207 sample_inputs_func=sample_inputs_pixel_unshuffle, 21208 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 21209 supports_out=False, 21210 supports_forward_ad=True, 21211 supports_fwgrad_bwgrad=True, 21212 skips=( 21213 DecorateInfo( 21214 unittest.skip("Skipped!"), 21215 "TestJit", 21216 "test_variant_consistency_jit", 21217 dtypes=(torch.float32, torch.complex64), 21218 ), 21219 ), 21220 ), 21221 OpInfo( 21222 "nn.functional.channel_shuffle", 21223 sample_inputs_func=sample_inputs_channel_shuffle, 21224 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 21225 supports_out=False, 21226 supports_forward_ad=True, 21227 supports_fwgrad_bwgrad=True, 21228 allow_cow_input_materialize_forward=[0], 21229 allow_cow_input_materialize_backward=[0, 'output grad 0'], 21230 skips=( 21231 # Skip due to NotImplementedError for MPS device. 21232 DecorateInfo(unittest.expectedFailure, 'TestConsistency'), 21233 DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), 21234 DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), 21235 ), 21236 ), 21237 OpInfo( 21238 "nn.functional.kl_div", 21239 sample_inputs_func=sample_inputs_kl_div, 21240 dtypes=floating_types_and(torch.float16, torch.bfloat16), 21241 supports_out=False, 21242 supports_forward_ad=True, 21243 supports_fwgrad_bwgrad=True, 21244 ), 21245 OpInfo( 21246 "diagflat", 21247 ref=lambda input, offset=0: np.diagflat(input, k=offset), 21248 sample_inputs_func=sample_inputs_diagflat, 21249 dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), 21250 dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 21251 supports_out=False, 21252 supports_forward_ad=True, 21253 supports_fwgrad_bwgrad=True, 21254 # See https://github.com/pytorch/pytorch/pull/78358 21255 check_batched_forward_grad=False, 21256 ), 21257 OpInfo( 21258 'scatter_reduce', 21259 variant_test_name='sum', 21260 # complex not added to dtypes as complex gradients are not properly handled 21261 # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet 21262 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 21263 supports_forward_ad=True, 21264 supports_fwgrad_bwgrad=True, 21265 sample_inputs_func=sample_inputs_scatter_reduce, 21266 ), 21267 OpInfo( 21268 'scatter_reduce', 21269 variant_test_name='prod', 21270 # complex not added to dtypes as complex gradients are not properly handled 21271 # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet 21272 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 21273 dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), 21274 sample_inputs_func=sample_inputs_scatter_reduce, 21275 skips=( 21276 # Not implemented 21277 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), 21278 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'), 21279 DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'), 21280 ), 21281 ), 21282 OpInfo( 21283 'scatter_reduce', 21284 variant_test_name='mean', 21285 # complex not added to dtypes as complex gradients are not properly handled 21286 # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet 21287 dtypes=all_types_and(torch.float16, torch.bfloat16), 21288 dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), 21289 supports_forward_ad=True, 21290 supports_fwgrad_bwgrad=True, 21291 sample_inputs_func=sample_inputs_scatter_reduce, 21292 ), 21293 OpInfo( 21294 'scatter_reduce', 21295 variant_test_name='amin', 21296 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 21297 dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), 21298 supports_forward_ad=True, 21299 check_batched_forward_grad=False, 21300 supports_fwgrad_bwgrad=True, 21301 sample_inputs_func=sample_inputs_scatter_reduce, 21302 ), 21303 OpInfo( 21304 'scatter_reduce', 21305 variant_test_name='amax', 21306 dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), 21307 dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), 21308 supports_forward_ad=True, 21309 check_batched_forward_grad=False, 21310 supports_fwgrad_bwgrad=True, 21311 sample_inputs_func=sample_inputs_scatter_reduce, 21312 ), 21313 OpInfo( 21314 '_segment_reduce', 21315 aten_name='segment_reduce', 21316 variant_test_name='lengths', 21317 dtypes=floating_types_and(torch.float16, torch.bfloat16), 21318 supports_out=False, 21319 # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented 21320 supports_gradgrad=False, 21321 sample_inputs_func=sample_inputs_segment_reduce, 21322 skips=( 21323 # FIXME: CUDA driver API confirmed a leak in 21324 # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32 21325 DecorateInfo( 21326 unittest.skip("Skipped!"), 21327 "TestJit", 21328 "test_variant_consistency_jit", 21329 device_type="cuda", 21330 ), 21331 ), 21332 ), 21333 OpInfo( 21334 '_segment_reduce', 21335 aten_name='segment_reduce', 21336 variant_test_name='offsets', 21337 dtypes=floating_types_and(torch.float16, torch.bfloat16), 21338 supports_out=False, 21339 # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented 21340 supports_gradgrad=False, 21341 sample_inputs_func=partial(sample_inputs_segment_reduce, mode='offsets'), 21342 skips=( 21343 # FIXME: CUDA driver API confirmed a leak in 21344 # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32 21345 DecorateInfo( 21346 unittest.skip("Skipped!"), 21347 "TestJit", 21348 "test_variant_consistency_jit", 21349 device_type="cuda", 21350 ), 21351 ), 21352 ), 21353] 21354op_db += opinfo.definitions.op_db 21355 21356 21357# Separate registry for experimental Python Reference OpInfos. 21358python_ref_db = [ 21359 # 21360 # Elementwise Unary OpInfos 21361 # 21362 ElementwiseUnaryPythonRefInfo( 21363 "_refs.abs", 21364 torch_opinfo_name="abs", 21365 skips=( 21366 # Reference: https://github.com/pytorch/pytorch/issues/49224 21367 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21368 'test_reference_numerics_small', 21369 dtypes=[torch.int8], active_if=TEST_WITH_ASAN), 21370 ), 21371 ), 21372 ElementwiseUnaryPythonRefInfo( 21373 "_refs.acos", 21374 torch_opinfo_name="acos", 21375 skips=( 21376 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21377 'test_reference_numerics_normal', 21378 device_type='cuda', dtypes=[torch.cdouble], 21379 active_if=IS_WINDOWS), 21380 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21381 'test_reference_numerics_extremal', 21382 device_type='cuda', dtypes=[torch.cdouble], 21383 active_if=IS_WINDOWS), 21384 # Failing with wrong imaginary sign on at least some Windows jobs 21385 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21386 'test_reference_numerics_small', 21387 device_type='cuda', dtypes=[torch.cdouble], 21388 active_if=IS_WINDOWS), 21389 # Failing with wrong imaginary sign on at least some Windows jobs 21390 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21391 'test_reference_numerics_large', 21392 device_type='cuda', dtypes=[torch.cdouble], 21393 active_if=IS_WINDOWS), 21394 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21395 'test_reference_numerics_large', 21396 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21397 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21398 'test_reference_numerics_extremal', 21399 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21400 ) 21401 ), 21402 ElementwiseUnaryPythonRefInfo( 21403 "_refs.acosh", 21404 torch_opinfo_name="acosh", 21405 skips=( 21406 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21407 'test_reference_numerics_normal', 21408 device_type='cuda', dtypes=[torch.cdouble], 21409 active_if=IS_WINDOWS), 21410 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21411 'test_reference_numerics_extremal', 21412 device_type='cuda', dtypes=[torch.cdouble], 21413 active_if=IS_WINDOWS), 21414 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21415 'test_reference_numerics_extremal', 21416 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21417 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21418 'test_reference_numerics_large', 21419 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21420 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21421 'test_reference_numerics_extremal', 21422 device_type='cuda', dtypes=[torch.cdouble], 21423 active_if=IS_WINDOWS), 21424 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21425 'test_reference_numerics_large', 21426 device_type='cuda', dtypes=[torch.cdouble], 21427 active_if=IS_WINDOWS), 21428 # Failing with wrong imaginary sign on at least some Windows jobs 21429 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21430 'test_reference_numerics_small', 21431 device_type='cuda', dtypes=[torch.cdouble], 21432 active_if=IS_WINDOWS), 21433 ), 21434 ), 21435 ElementwiseUnaryPythonRefInfo( 21436 "_refs.asin", 21437 torch_opinfo_name="asin", 21438 decorators=[ 21439 DecorateInfo( 21440 toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}), 21441 'TestUnaryUfuncs', device_type='cuda'), 21442 precisionOverride({torch.bfloat16: 1e-2}), 21443 ], 21444 skips=( 21445 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21446 'test_reference_numerics_extremal', 21447 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21448 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21449 'test_reference_numerics_large', 21450 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21451 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21452 'test_reference_numerics_extremal', 21453 device_type='cuda', dtypes=[torch.cdouble], 21454 active_if=IS_WINDOWS), 21455 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21456 'test_reference_numerics_large', 21457 device_type='cuda', dtypes=[torch.cdouble], 21458 active_if=IS_WINDOWS), 21459 ), 21460 ), 21461 ElementwiseUnaryPythonRefInfo( 21462 "_refs.asinh", 21463 torch_opinfo_name="asinh", 21464 decorators=(precisionOverride({torch.bfloat16: 5e-2}),), 21465 skips=( 21466 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21467 'test_reference_numerics_extremal', 21468 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21469 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21470 'test_reference_numerics_large', 21471 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21472 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21473 'test_reference_numerics_small', 21474 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21475 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21476 'test_reference_numerics_normal', 21477 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21478 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21479 'test_reference_numerics_extremal', 21480 device_type='cuda', dtypes=[torch.cdouble], 21481 active_if=IS_WINDOWS), 21482 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21483 'test_reference_numerics_large', 21484 device_type='cuda', dtypes=[torch.cdouble], 21485 active_if=IS_WINDOWS), 21486 ), 21487 ), 21488 PythonRefInfo( 21489 "_refs.lerp", 21490 torch_opinfo_name="lerp", 21491 ), 21492 PythonRefInfo( 21493 "_refs.ones", 21494 torch_opinfo_name="ones", 21495 skips=( 21496 # Tests that assume input is a tensor or sequence of tensors 21497 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21498 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 21499 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 21500 ), 21501 ), 21502 PythonRefInfo( 21503 "_refs.zeros", 21504 torch_opinfo_name="zeros", 21505 skips=( 21506 # Tests that assume input is a tensor or sequence of tensors 21507 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21508 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 21509 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 21510 ), 21511 ), 21512 PythonRefInfo( 21513 "_refs.cauchy", 21514 torch_opinfo_name="cauchy", 21515 decorators=( 21516 # TODO: RuntimeError: no _refs support for torch.rand_like 21517 DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), 21518 'TestCommon', 21519 'test_python_ref'), 21520 # AssertionError: Tensor-likes are not close! 21521 DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), 21522 'TestCommon', 21523 'test_out'), 21524 DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), 21525 'TestCommon', 21526 'test_out_warning'), 21527 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), 21528 DecorateInfo(unittest.skip("Expected: cauchy is not comparable"), 21529 'TestCommon', 21530 'test_python_ref_torch_fallback'), 21531 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 21532 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21533 ) 21534 ), 21535 PythonRefInfo( 21536 "_refs.exponential", 21537 torch_opinfo_name="exponential", 21538 supports_out=True, 21539 decorators=( 21540 # dtypes that do not support check_uniform_bounds of rand_like 21541 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', 21542 dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), 21543 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), 21544 21545 # TODO: RuntimeError: no _refs support for torch.rand_like 21546 DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), 21547 'TestCommon', 21548 'test_python_ref'), 21549 21550 # AssertionError: Tensor-likes are not close! 21551 DecorateInfo(unittest.skip("Expected: exponential is not comparable"), 21552 'TestCommon', 21553 'test_out'), 21554 DecorateInfo(unittest.skip("Expected: exponential is not comparable"), 21555 'TestCommon', 21556 'test_out_warning'), 21557 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), 21558 DecorateInfo(unittest.skip("Expected: exponential is not comparable"), 21559 'TestCommon', 21560 'test_python_ref_torch_fallback'), 21561 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 21562 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21563 ) 21564 ), 21565 PythonRefInfo( 21566 "_refs.geometric", 21567 torch_opinfo_name="geometric", 21568 supports_out=True, 21569 decorators=( 21570 # dtypes that do not support check_uniform_bounds of rand_like 21571 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_dtypes'), 21572 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', 21573 dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), 21574 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 21575 dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)), 21576 21577 # TODO: RuntimeError: no _refs support for torch.rand_like 21578 DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), 21579 'TestCommon', 21580 'test_python_ref'), 21581 DecorateInfo(unittest.skip("Expected: geometric is not comparable"), 21582 'TestCommon', 21583 'test_python_ref_executor', device_type='cuda'), 21584 21585 # AssertionError: Tensor-likes are not close! 21586 DecorateInfo(unittest.skip("Expected: geometric is not comparable"), 21587 'TestCommon', 21588 'test_out'), 21589 DecorateInfo(unittest.skip("Expected: geometric is not comparable"), 21590 'TestCommon', 21591 'test_out_warning'), 21592 DecorateInfo(unittest.skip("Expected: geometric is not comparable"), 21593 'TestCommon', 21594 'test_python_ref_torch_fallback'), 21595 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 21596 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21597 ) 21598 ), 21599 PythonRefInfo( 21600 "_refs.log_normal", 21601 torch_opinfo_name="log_normal", 21602 supports_out=True, 21603 decorators=( 21604 # TODO: RuntimeError: no _refs support for torch.rand_like 21605 DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), 21606 'TestCommon', 21607 'test_python_ref'), 21608 DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), 21609 'TestCommon', 21610 'test_python_ref_executor', device_type='cuda'), 21611 21612 # AssertionError: Tensor-likes are not close! 21613 DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), 21614 'TestCommon', 21615 'test_out'), 21616 DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), 21617 'TestCommon', 21618 'test_out_warning'), 21619 DecorateInfo(unittest.skip("Expected: log_normal is not comparable"), 21620 'TestCommon', 21621 'test_python_ref_torch_fallback'), 21622 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 21623 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21624 ) 21625 ), 21626 PythonRefInfo( 21627 "_refs.normal", 21628 torch_opinfo_name="normal", 21629 supports_out=True, 21630 decorators=( 21631 # TODO: RuntimeError: no _refs support for torch.rand_like 21632 DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), 21633 'TestCommon', 21634 'test_python_ref'), 21635 21636 # AssertionError: Tensor-likes are not close! 21637 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 21638 'TestCommon', 21639 'test_out'), 21640 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 21641 'TestCommon', 21642 'test_out_warning'), 21643 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 21644 'TestCommon', 21645 'test_python_ref_torch_fallback'), 21646 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), 21647 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 21648 DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), 21649 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21650 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 21651 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 21652 ) 21653 ), 21654 PythonRefInfo( 21655 "_refs.normal", 21656 torch_opinfo_name="normal", 21657 torch_opinfo_variant_name="number_mean", 21658 supports_out=True, 21659 decorators=( 21660 # TODO: RuntimeError: no _refs support for torch.rand_like 21661 DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), 21662 'TestCommon', 21663 'test_python_ref'), 21664 21665 # AssertionError: Tensor-likes are not close! 21666 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 21667 'TestCommon', 21668 'test_out'), 21669 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 21670 'TestCommon', 21671 'test_out_warning'), 21672 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 21673 'TestCommon', 21674 'test_python_ref_torch_fallback'), 21675 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), 21676 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 21677 DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), 21678 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21679 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 21680 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 21681 ) 21682 ), 21683 PythonRefInfo( 21684 "_refs.normal_", 21685 op=torch.Tensor.normal_, 21686 torch_opinfo_name="normal", 21687 torch_opinfo_variant_name="in_place", 21688 supports_out=False, 21689 decorators=( 21690 # TODO: RuntimeError: no _refs support for torch.rand_like 21691 DecorateInfo(unittest.skip("TODO: RuntimeError: no _refs support for torch.rand_like"), 21692 'TestCommon', 21693 'test_python_ref'), 21694 21695 # AssertionError: Tensor-likes are not close! 21696 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 21697 'TestCommon', 21698 'test_out'), 21699 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 21700 'TestCommon', 21701 'test_out_warning'), 21702 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 21703 'TestCommon', 21704 'test_python_ref_torch_fallback'), 21705 DecorateInfo(unittest.skip("Expected: normal is not comparable"), 'TestDecomp', 'test_comprehensive'), 21706 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 21707 DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 'TestCommon', 'test_python_ref_executor'), 21708 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21709 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 21710 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 21711 ) 21712 ), 21713 PythonRefInfo( 21714 "_refs.arange", 21715 torch_opinfo_name="arange", 21716 skips=( 21717 # Tests that assume input is a tensor or sequence of tensors 21718 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21719 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 21720 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 21721 ), 21722 ), 21723 PythonRefInfo( 21724 "_refs.linspace", 21725 torch_opinfo_name="linspace", 21726 skips=( 21727 # Tests that assume input is a tensor or sequence of tensors 21728 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21729 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 21730 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 21731 21732 # cpu implementation is wrong on some integral types 21733 # https://github.com/pytorch/pytorch/issues/81996 21734 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 21735 dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), 21736 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', 21737 dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), 21738 21739 # cuda implementation is off-by-one on some inputs due to precision issues 21740 # https://github.com/pytorch/pytorch/issues/82230 21741 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 21742 dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), 21743 device_type="cuda"), 21744 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', 21745 dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), 21746 device_type="cuda"), 21747 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', 21748 dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), 21749 device_type="cuda"), 21750 ), 21751 ), 21752 PythonRefInfo( 21753 "_refs.linspace", 21754 torch_opinfo_name="linspace", 21755 torch_opinfo_variant_name="tensor_overload", 21756 skips=( 21757 # TypeError: 'int' object is not subscriptable 21758 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21759 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 21760 21761 # cpu implementation is wrong on some integral types 21762 # https://github.com/pytorch/pytorch/issues/81996 21763 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 21764 dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), 21765 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', 21766 dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), 21767 21768 # cuda implementation is off-by-one on some inputs due to precision issues 21769 # https://github.com/pytorch/pytorch/issues/82230 21770 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 21771 dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), 21772 device_type="cuda"), 21773 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', 21774 dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), 21775 device_type="cuda"), 21776 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', 21777 dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), 21778 device_type="cuda"), 21779 ), 21780 ), 21781 PythonRefInfo( 21782 "_refs.logspace", 21783 torch_opinfo_name="logspace", 21784 skips=( 21785 # Tests that assume input is a tensor or sequence of tensors 21786 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21787 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 21788 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), 21789 21790 # Off-by-one issue when casting floats to ints 21791 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 21792 dtypes=(torch.int16, torch.int32, torch.int64), 21793 device_type="cuda"), 21794 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', 21795 dtypes=(torch.int16, torch.int32, torch.int64), 21796 device_type="cuda"), 21797 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', 21798 dtypes=(torch.int16, torch.int32, torch.int64), 21799 device_type="cuda"), 21800 ), 21801 ), 21802 PythonRefInfo( 21803 "_refs.logspace", 21804 torch_opinfo_name="logspace", 21805 torch_opinfo_variant_name="tensor_overload", 21806 skips=( 21807 # TypeError: 'int' object is not subscriptable 21808 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), 21809 DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), 21810 21811 # Off-by-one issue when casting floats to ints 21812 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 21813 dtypes=(torch.int16, torch.int32, torch.int64), 21814 device_type="cuda"), 21815 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', 21816 dtypes=(torch.int16, torch.int32, torch.int64), 21817 device_type="cuda"), 21818 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', 21819 dtypes=(torch.int16, torch.int32, torch.int64), 21820 device_type="cuda"), 21821 ), 21822 ), 21823 PythonRefInfo( 21824 "_refs.meshgrid", 21825 torch_opinfo_name="meshgrid", 21826 torch_opinfo_variant_name="variadic_tensors", 21827 ), 21828 PythonRefInfo( 21829 "_refs.take_along_dim", 21830 torch_opinfo_name="take_along_dim", 21831 skips=( 21832 DecorateInfo(unittest.expectedFailure, 21833 'TestCommon', 21834 'test_python_ref'), 21835 ), 21836 ), 21837 PythonRefInfo( 21838 "_refs.to", 21839 torch_opinfo_name="to", 21840 ), 21841 PythonRefInfo( 21842 "_refs.triu", 21843 torch_opinfo_name="triu", 21844 ), 21845 PythonRefInfo( 21846 "_refs.tril", 21847 torch_opinfo_name="tril", 21848 ), 21849 PythonRefInfo( 21850 "_refs.triu_indices", 21851 torch_opinfo_name="triu_indices", 21852 # the implementation uses torch.stack that violates view consistency 21853 validate_view_consistency=False, 21854 skips=( 21855 # skip these tests since we have non tensor input 21856 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), 21857 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), 21858 DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), 21859 DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), 21860 )), 21861 PythonRefInfo( 21862 "_refs.tril_indices", 21863 torch_opinfo_name="tril_indices", 21864 # the implementation uses torch.stack that violates view consistency 21865 validate_view_consistency=False, 21866 skips=( 21867 # skip these tests since we have non tensor input 21868 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_noncontiguous_samples'), 21869 DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), 21870 DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), 21871 DecorateInfo(unittest.skip('Skipped!'), 'TestMathBits', 'test_neg_view'), 21872 )), 21873 PythonRefInfo( 21874 "_refs.meshgrid", 21875 torch_opinfo_name="meshgrid", 21876 torch_opinfo_variant_name="list_of_tensors", 21877 ), 21878 PythonRefInfo( 21879 "_refs.movedim", 21880 aliases=('moveaxis',), 21881 torch_opinfo_name="movedim", 21882 ), 21883 PythonRefInfo( 21884 "_refs.bucketize", 21885 torch_opinfo_name="bucketize", 21886 skips=( 21887 # RuntimeError: It appears that you're trying to get value out of a tracing tensor with 21888 # aten._local_scalar_dense.default - erroring out! [...] 21889 # triggered by mid_val = boundaries[mid] 21890 DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref_executor"), 21891 ) 21892 ), 21893 PythonRefInfo( 21894 "_refs.equal", 21895 torch_opinfo_name="equal", 21896 skips=( 21897 # RuntimeError: Cannot cast FakeTensor to number 21898 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',), 21899 ) 21900 ), 21901 ElementwiseUnaryPythonRefInfo( 21902 "_refs.atan", 21903 torch_opinfo_name="atan", 21904 decorators=(precisionOverride({torch.bfloat16: 1e-2}),), 21905 skips=( 21906 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21907 'test_reference_numerics_extremal', 21908 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21909 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21910 'test_reference_numerics_large', 21911 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21912 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21913 'test_reference_numerics_small', 21914 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21915 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21916 'test_reference_numerics_extremal', 21917 device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], 21918 active_if=IS_WINDOWS), 21919 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21920 'test_reference_numerics_large', 21921 device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], 21922 active_if=IS_WINDOWS), 21923 ), 21924 ), 21925 ElementwiseUnaryPythonRefInfo( 21926 "_refs.atanh", 21927 torch_opinfo_name="atanh", 21928 decorators=(precisionOverride({torch.bfloat16: 1e-2}),), 21929 skips=( 21930 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21931 'test_reference_numerics_small', 21932 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21933 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21934 'test_reference_numerics_extremal', 21935 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21936 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21937 'test_reference_numerics_large', 21938 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 21939 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21940 'test_reference_numerics_extremal', 21941 device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], 21942 active_if=IS_WINDOWS), 21943 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21944 'test_reference_numerics_large', 21945 device_type='cuda', dtypes=[torch.cfloat], 21946 active_if=IS_WINDOWS), 21947 ), 21948 ), 21949 ElementwiseUnaryPythonRefInfo( 21950 "_refs.bitwise_not", 21951 torch_opinfo_name="bitwise_not", 21952 ), 21953 ElementwiseUnaryPythonRefInfo( 21954 "_refs.ceil", 21955 torch_opinfo_name="ceil", 21956 # Fails on int32 21957 # https://github.com/pytorch/pytorch/issues/85258 21958 ), 21959 PythonRefInfo( 21960 "_refs.item", 21961 torch_opinfo_name="item", 21962 skips=( 21963 # RuntimeError: Cannot cast FakeTensor(FakeTensor(..., device='meta', size=()), cpu) to number 21964 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'), 21965 # ValueError: Can't convert a tensor with 10 elements to a number! 21966 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),), 21967 ), 21968 ElementwiseUnaryPythonRefInfo( 21969 "_refs.conj_physical", 21970 torch_opinfo_name="conj_physical", 21971 ), 21972 ElementwiseUnaryPythonRefInfo( 21973 "_refs.cos", 21974 torch_opinfo_name="cos", 21975 decorators=(precisionOverride({torch.bfloat16: 1e-2}),), 21976 skips=( 21977 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21978 'test_reference_numerics_large', 21979 dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', 21980 active_if=IS_WINDOWS), 21981 # This fails on CUDA but passes on ROCm 21982 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21983 'test_reference_numerics_large', 21984 dtypes=(torch.cdouble,), device_type='cuda'), 21985 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21986 'test_reference_numerics_extremal', 21987 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), 21988 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 21989 'test_reference_numerics_extremal', 21990 device_type='cpu', 21991 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), 21992 # AssertionError: Tensor-likes are not close! 21993 # Greatest absolute difference: nan at index (700,) (up to 1e-05 allowed) 21994 # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) 21995 DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 21996 'test_reference_numerics_large', 21997 device_type='cuda', 21998 dtypes=(torch.chalf,), active_if=IS_WINDOWS), 21999 ), 22000 ), 22001 ElementwiseUnaryPythonRefInfo( 22002 "_refs.cosh", 22003 torch_opinfo_name="cosh", 22004 skips=( 22005 # Reference: https://github.com/pytorch/pytorch/issues/48641 22006 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22007 'test_reference_numerics_large', 22008 device_type='cpu', dtypes=[torch.int8]), 22009 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22010 'test_reference_numerics_large', 22011 dtypes=[torch.cdouble]), 22012 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22013 'test_reference_numerics_extremal', 22014 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), 22015 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22016 'test_reference_numerics_large', 22017 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), 22018 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22019 'test_reference_numerics_extremal', 22020 device_type='cpu', 22021 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), 22022 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22023 'test_reference_numerics_large', 22024 device_type='cpu', 22025 dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS), 22026 # AssertionError: Tensor-likes are not close! 22027 # Greatest absolute difference: nan at index (6000,) (up to 1e-05 allowed) 22028 # Greatest relative difference: nan at index (6000,) (up to 0.001 allowed) 22029 DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 22030 'test_reference_numerics_large', 22031 device_type='cuda', 22032 dtypes=(torch.chalf,), active_if=IS_WINDOWS), 22033 ), 22034 ), 22035 ElementwiseUnaryPythonRefInfo( 22036 "_refs.digamma", 22037 torch_opinfo_name="digamma", 22038 ), 22039 ElementwiseUnaryPythonRefInfo( 22040 "_refs.erf", 22041 torch_opinfo_name="erf", 22042 ), 22043 ElementwiseUnaryPythonRefInfo( 22044 "_refs.erfinv", 22045 torch_opinfo_name="erfinv", 22046 decorators=(precisionOverride({torch.float16: 1e-2, 22047 torch.bfloat16: 1e-2, 22048 torch.float32: 1e-4}),), 22049 skips=( 22050 # Reference: https://github.com/pytorch/pytorch/pull/49155#issuecomment-742664611 22051 DecorateInfo( 22052 unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22053 'test_reference_numerics_extremal', 22054 active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), 22055 DecorateInfo( 22056 unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22057 'test_reference_numerics_large', 22058 active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), 22059 DecorateInfo( 22060 unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22061 'test_reference_numerics_small', 22062 active_if=TEST_SCIPY and version.parse(scipy.__version__) < version.parse("1.4.0")), 22063 ), 22064 ), 22065 ElementwiseUnaryPythonRefInfo( 22066 "_refs.erfc", 22067 torch_opinfo_name="erfc", 22068 ), 22069 ElementwiseUnaryPythonRefInfo( 22070 "_refs.exp", 22071 torch_opinfo_name="exp", 22072 skips=( 22073 # Reference: https://github.com/pytorch/pytorch/issues/48010 22074 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22075 'test_reference_numerics_extremal', 22076 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 22077 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22078 'test_reference_numerics_large', 22079 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 22080 ), 22081 ), 22082 ElementwiseUnaryPythonRefInfo( 22083 "_refs.expm1", 22084 torch_opinfo_name="expm1", 22085 ), 22086 ElementwiseUnaryPythonRefInfo( 22087 "_refs.exp2", 22088 torch_opinfo_name="exp2", 22089 skips=( 22090 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22091 'test_reference_numerics_large', 22092 dtypes=[torch.cdouble]), 22093 # Reference: https://github.com/pytorch/pytorch/issues/48010 22094 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22095 'test_reference_numerics_extremal', 22096 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 22097 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22098 'test_reference_numerics_large', 22099 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 22100 ), 22101 ), 22102 ElementwiseUnaryPythonRefInfo( 22103 "_refs.fill", 22104 torch_opinfo_name="fill", 22105 supports_out=True, 22106 ), 22107 ElementwiseUnaryPythonRefInfo( 22108 "_refs.floor", 22109 torch_opinfo_name="floor", 22110 # Fails on int32 22111 # https://github.com/pytorch/pytorch/issues/85258 22112 ), 22113 ElementwiseUnaryPythonRefInfo( 22114 "_refs.frexp", 22115 torch_opinfo_name="frexp", 22116 # Skipped due to numerical failures on Windows CI. 22117 # This is also skipped in frexp earlier in the file. 22118 skips=( 22119 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', 22120 active_if=IS_WINDOWS), 22121 ), 22122 ), 22123 ElementwiseUnaryPythonRefInfo( 22124 "_refs.frac", 22125 torch_opinfo_name="frac", 22126 skips=( 22127 DecorateInfo( 22128 unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22129 'test_reference_numerics_extremal', 22130 dtypes=(torch.bfloat16, torch.float16, torch.float32, torch.float64)), 22131 ), 22132 ), 22133 ElementwiseUnaryPythonRefInfo( 22134 "_refs.imag", 22135 torch_opinfo_name="imag", 22136 ), 22137 ElementwiseUnaryPythonRefInfo( 22138 "_refs.isfinite", 22139 torch_opinfo_name="isfinite", 22140 supports_out=True, 22141 ), 22142 ElementwiseUnaryPythonRefInfo( 22143 "_refs.isinf", 22144 torch_opinfo_name="isinf", 22145 supports_out=True, 22146 ), 22147 ElementwiseUnaryPythonRefInfo( 22148 "_refs.isposinf", 22149 torch_opinfo_name="isposinf", 22150 supports_out=True, 22151 ), 22152 ElementwiseUnaryPythonRefInfo( 22153 "_refs.isneginf", 22154 torch_opinfo_name="isneginf", 22155 supports_out=True, 22156 ), 22157 ElementwiseUnaryPythonRefInfo( 22158 "_refs.isnan", 22159 torch_opinfo_name="isnan", 22160 supports_out=True, 22161 ), 22162 ElementwiseUnaryPythonRefInfo( 22163 "_refs.isreal", 22164 torch_opinfo_name="isreal", 22165 supports_out=True, 22166 ), 22167 ElementwiseUnaryPythonRefInfo( 22168 "_refs.i0", 22169 torch_opinfo_name="i0", 22170 decorators=(precisionOverride({torch.bfloat16: 3e-1, 22171 torch.float16: 5e-1}),), 22172 skips=( 22173 DecorateInfo(unittest.skip("Skipped!"), 22174 'TestUnaryUfuncs', 22175 'test_reference_numerics_large', 22176 dtypes=(torch.int8,)), 22177 ), 22178 ), 22179 ElementwiseUnaryPythonRefInfo( 22180 "_refs.lgamma", 22181 torch_opinfo_name="lgamma", 22182 decorators=(precisionOverride({torch.float16: 7e-1}),), 22183 skips=( 22184 # Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214 22185 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22186 'test_reference_numerics_extremal', 22187 dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), 22188 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22189 'test_reference_numerics_large', 22190 dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS), 22191 ), 22192 ), 22193 ElementwiseUnaryPythonRefInfo( 22194 "_refs.special.multigammaln", 22195 torch_opinfo_name="mvlgamma", 22196 torch_opinfo_variant_name="mvlgamma_p_1", 22197 skips=skips_mvlgamma(), 22198 decorators=( 22199 DecorateInfo(torch.testing._internal.common_utils.markDynamoStrictTest, 'TestUnaryUfuncs', 22200 'test_reference_numerics_large'), 22201 DecorateInfo(torch.testing._internal.common_utils.xfailIfTorchDynamo, 'TestUnaryUfuncs', 22202 'test_reference_numerics_large'), 22203 ), 22204 ), 22205 ElementwiseUnaryPythonRefInfo( 22206 "_refs.special.multigammaln", 22207 torch_opinfo_name="mvlgamma", 22208 torch_opinfo_variant_name="mvlgamma_p_3", 22209 skips=skips_mvlgamma(), 22210 ), 22211 ElementwiseUnaryPythonRefInfo( 22212 "_refs.special.multigammaln", 22213 torch_opinfo_name="mvlgamma", 22214 torch_opinfo_variant_name="mvlgamma_p_5", 22215 skips=skips_mvlgamma(), 22216 ), 22217 ElementwiseUnaryPythonRefInfo( 22218 "_refs.log", 22219 torch_opinfo_name="log", 22220 decorators=(precisionOverride({torch.bfloat16: 5e-2}),), 22221 skips=( 22222 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22223 'test_reference_numerics_extremal', 22224 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 22225 active_if=IS_WINDOWS), 22226 ), 22227 ), 22228 ElementwiseUnaryPythonRefInfo( 22229 "_refs.log1p", 22230 torch_opinfo_name="log1p", 22231 ), 22232 ElementwiseUnaryPythonRefInfo( 22233 "_refs.log10", 22234 torch_opinfo_name="log10", 22235 decorators=(precisionOverride({torch.bfloat16: 5e-2}),), 22236 skips=( 22237 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22238 'test_reference_numerics_extremal', 22239 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 22240 active_if=IS_WINDOWS), 22241 ), 22242 ), 22243 ElementwiseUnaryPythonRefInfo( 22244 "_refs.log2", 22245 torch_opinfo_name="log2", 22246 decorators=(precisionOverride({torch.bfloat16: 1e-1}),), 22247 skips=( 22248 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22249 'test_reference_numerics_extremal', 22250 dtypes=[torch.cfloat, torch.cdouble]), 22251 ), 22252 ), 22253 PythonRefInfo( 22254 "_refs.logsumexp", 22255 torch_opinfo_name="logsumexp", 22256 # When keepdim=False logsumexp function uses squeeze operation 22257 # that is not yet exposed in nvFuser's Python API. 22258 ), 22259 PythonRefInfo( 22260 "_refs.log_softmax", 22261 torch_opinfo_name="log_softmax", 22262 torch_opinfo_variant_name="with_dtype", 22263 ), 22264 ElementwiseUnaryPythonRefInfo( 22265 "_refs.nan_to_num", 22266 torch_opinfo_name="nan_to_num", 22267 ), 22268 ElementwiseUnaryPythonRefInfo( 22269 "_refs.neg", 22270 torch_opinfo_name="neg", 22271 ), 22272 ElementwiseUnaryPythonRefInfo( 22273 "_refs.positive", 22274 torch_opinfo_name="positive", 22275 ), 22276 ElementwiseUnaryPythonRefInfo( 22277 "_refs.real", 22278 torch_opinfo_name="real", 22279 ), 22280 ElementwiseUnaryPythonRefInfo( 22281 "_refs.reciprocal", 22282 torch_opinfo_name="reciprocal", 22283 skips=( 22284 # Reference: https://github.com/pytorch/pytorch/issues/45690 22285 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22286 'test_reference_numerics_extremal', 22287 dtypes=[torch.cfloat, torch.cdouble]), 22288 ), 22289 ), 22290 ElementwiseUnaryPythonRefInfo( 22291 "_refs.round", 22292 torch_opinfo_name="round", 22293 # Fails on int32 22294 # https://github.com/pytorch/pytorch/issues/85258 22295 skips=( 22296 DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), 22297 "TestUnaryUfuncs", "test_reference_numerics_extremal", 22298 device_type="cuda"), 22299 DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), 22300 "TestUnaryUfuncs", "test_reference_numerics_normal", 22301 device_type="cuda"), 22302 ), 22303 ), 22304 ElementwiseUnaryPythonRefInfo( 22305 "_refs.rsqrt", 22306 torch_opinfo_name="rsqrt", 22307 decorators=(precisionOverride({torch.half: 5e-2}),), 22308 skips=( 22309 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22310 'test_reference_numerics_extremal', 22311 dtypes=(torch.cfloat, torch.cdouble)), 22312 # AssertionError: Tensor-likes are not close! 22313 # Greatest absolute difference: nan at index (700,) (up to 0.01 allowed) 22314 # Greatest relative difference: nan at index (700,) (up to 0.001 allowed) 22315 DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 22316 'test_reference_numerics_large', 22317 dtypes=(torch.chalf,)), 22318 ), 22319 ), 22320 ElementwiseUnaryPythonRefInfo( 22321 "_refs.sigmoid", 22322 torch_opinfo_name="sigmoid", 22323 aliases=('_refs.special.expit',), 22324 # Reference: https://github.com/pytorch/pytorch/issues/56012 22325 handles_complex_extremal_values=False, 22326 handles_large_floats=False, 22327 decorators=(precisionOverride({torch.float16: 1e-2, 22328 torch.complex64: 1e-1, 22329 torch.bfloat16: 1e-2}),), 22330 skips=( 22331 # Reference: https://github.com/pytorch/pytorch/issues/56012 22332 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22333 'test_reference_numerics_extremal', 22334 dtypes=[torch.complex64, torch.cdouble]), 22335 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22336 'test_reference_numerics_large', 22337 dtypes=[torch.chalf, torch.complex64, torch.cdouble]) 22338 ), 22339 ), 22340 ElementwiseUnaryPythonRefInfo( 22341 "_refs.sign", 22342 torch_opinfo_name="sign", 22343 skips=( 22344 # Reference: https://github.com/pytorch/pytorch/issues/41245 22345 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22346 'test_reference_numerics_extremal', 22347 dtypes=[torch.bfloat16, torch.float16, torch.float32, 22348 torch.float64]), 22349 ), 22350 ), 22351 ElementwiseUnaryPythonRefInfo( 22352 "_refs.sgn", 22353 torch_opinfo_name="sgn", 22354 # This is an issue with the vectorised abs on CPU 22355 handles_complex_extremal_values=False, 22356 handles_large_floats=False, 22357 skips=( 22358 # Reference: https://github.com/pytorch/pytorch/issues/41245 22359 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22360 'test_reference_numerics_extremal', 22361 dtypes=[torch.bfloat16, torch.float16, torch.float32, 22362 torch.float64]), 22363 ), 22364 ), 22365 ElementwiseUnaryPythonRefInfo( 22366 "_refs.signbit", 22367 torch_opinfo_name="signbit", 22368 ), 22369 ElementwiseUnaryPythonRefInfo( 22370 "_refs.sin", 22371 torch_opinfo_name="sin", 22372 decorators=(precisionOverride({torch.bfloat16: 1e-2}),), 22373 skips=( 22374 # Fails on CUDA but passes on ROCm 22375 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22376 'test_reference_numerics_large', 22377 dtypes=(torch.cdouble,), device_type='cuda'), 22378 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22379 'test_reference_numerics_extremal', 22380 dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', 22381 active_if=IS_WINDOWS), 22382 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22383 'test_reference_numerics_large', 22384 dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', 22385 active_if=IS_WINDOWS), 22386 ), 22387 ), 22388 ElementwiseUnaryPythonRefInfo( 22389 "_refs.sinc", 22390 torch_opinfo_name="sinc", 22391 decorators=(precisionOverride({torch.bfloat16: 1e-2, 22392 torch.float16: 1e-2}),), 22393 skips=( 22394 # Reference: https://github.com/pytorch/pytorch/issues/49133 22395 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22396 'test_reference_numerics_small', 22397 dtypes=[torch.cfloat]), 22398 ), 22399 ), 22400 ElementwiseUnaryPythonRefInfo( 22401 "_refs.sinh", 22402 torch_opinfo_name="sinh", 22403 decorators=(precisionOverride({torch.float16: 1e-2}),), 22404 skips=( 22405 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22406 'test_reference_numerics_extremal', 22407 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 22408 active_if=(IS_MACOS or IS_WINDOWS)), 22409 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22410 'test_reference_numerics_large', 22411 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 22412 active_if=(IS_MACOS or IS_WINDOWS)), 22413 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22414 'test_reference_numerics_large', 22415 dtypes=(torch.cdouble,)), 22416 # Reference: https://github.com/pytorch/pytorch/issues/48641 22417 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22418 'test_reference_numerics_large', 22419 device_type='cpu', dtypes=[torch.int8]), 22420 ), 22421 ), 22422 PythonRefInfo( 22423 "_refs.softmax", 22424 torch_opinfo_name="softmax", 22425 torch_opinfo_variant_name="with_dtype", 22426 ), 22427 ElementwiseUnaryPythonRefInfo( 22428 "_refs.sqrt", 22429 torch_opinfo_name="sqrt", 22430 decorators=( 22431 precisionOverride({torch.bfloat16: 7e-2}), 22432 DecorateInfo( 22433 toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), 22434 'TestUnaryUfuncs', 'test_reference_numerics_large'), 22435 ), 22436 skips=( 22437 # Reference: https://github.com/pytorch/pytorch/issues/47358 22438 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22439 'test_reference_numerics_large', 22440 device_type='cpu', dtypes=(torch.cfloat, torch.cdouble), 22441 active_if=IS_MACOS), 22442 # Reference: https://github.com/pytorch/pytorch/pull/47293#issuecomment-721774436 22443 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22444 'test_reference_numerics_large', 22445 dtypes=(torch.bfloat16,)), 22446 ), 22447 ), 22448 ElementwiseUnaryPythonRefInfo( 22449 "_refs.square", 22450 torch_opinfo_name="square", 22451 decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),), 22452 skips=( 22453 # AssertionError: Reference result was farther (2.2417024338305655e-07) from the precise computation 22454 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', dtypes=(torch.complex64,)), 22455 # Reference: https://github.com/pytorch/pytorch/issues/52549 22456 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22457 'test_reference_numerics_large', 22458 dtypes=[torch.cfloat, torch.cdouble]), 22459 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22460 'test_reference_numerics_extremal', 22461 device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), 22462 ), 22463 ), 22464 ElementwiseUnaryPythonRefInfo( 22465 "_refs.tan", 22466 torch_opinfo_name="tan", 22467 decorators=[ 22468 DecorateInfo( 22469 toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=1e-05)}), 22470 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), 22471 ], 22472 skips=( 22473 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22474 'test_reference_numerics_extremal', 22475 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 22476 active_if=(IS_MACOS or IS_WINDOWS)), 22477 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22478 'test_reference_numerics_large', 22479 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 22480 active_if=(IS_MACOS or IS_WINDOWS)), 22481 ) 22482 ), 22483 ElementwiseUnaryPythonRefInfo( 22484 "_refs.tanh", 22485 torch_opinfo_name="tanh", 22486 decorators=[ 22487 DecorateInfo( 22488 toleranceOverride({torch.complex64: tol(atol=1e-04, rtol=2e-05)}), 22489 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), 22490 ], 22491 skips=( 22492 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22493 'test_reference_numerics_extremal', 22494 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 22495 active_if=(IS_MACOS or IS_WINDOWS)), 22496 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22497 'test_reference_numerics_large', 22498 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], 22499 active_if=(IS_MACOS or IS_WINDOWS)), 22500 ), 22501 ), 22502 ElementwiseUnaryPythonRefInfo( 22503 "_refs.trunc", 22504 torch_opinfo_name="trunc", 22505 # Fails on int32 22506 # https://github.com/pytorch/pytorch/issues/85258 22507 ), 22508 PythonRefInfo( 22509 "_refs.special.log_softmax", 22510 torch_opinfo_name="log_softmax", # alias 22511 torch_opinfo_variant_name="with_dtype", 22512 supports_out=False, 22513 ), 22514 PythonRefInfo( 22515 "_refs.special.softmax", 22516 torch_opinfo_name="softmax", # alias 22517 torch_opinfo_variant_name="with_dtype", 22518 supports_out=False, 22519 ), 22520 # 22521 # Elementwise Unary Special OpInfos 22522 # 22523 ElementwiseUnaryPythonRefInfo( 22524 "_refs.special.logit", 22525 torch_opinfo_name="logit", 22526 ), 22527 # 22528 # Elementwise Unary nn.functional OpInfos 22529 # 22530 PythonRefInfo( 22531 "_refs.nn.functional.alpha_dropout", 22532 torch_opinfo_name="nn.functional.alpha_dropout", 22533 decorators=( 22534 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22535 'TestCommon', 22536 'test_python_ref'), 22537 # AssertionError: Tensor-likes are not close! 22538 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22539 'TestCommon', 22540 'test_python_ref_torch_fallback'), 22541 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22542 'TestCommon', 22543 'test_python_ref_executor', device_type='cuda'), 22544 # AssertionError: Tensor-likes are not close! 22545 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22546 'TestMathBits', 22547 'test_neg_view'), 22548 # AssertionError: Tensor-likes are not close! 22549 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22550 'TestCommon', 22551 'test_compare_cpu'), 22552 ) 22553 ), 22554 ElementwiseUnaryPythonRefInfo( 22555 "_refs.nn.functional.celu", 22556 torch_opinfo_name="nn.functional.celu", 22557 supports_out=True, 22558 ), 22559 PythonRefInfo( 22560 "_refs.nn.functional.channel_shuffle", 22561 torch_opinfo_name="nn.functional.channel_shuffle", 22562 supports_out=True, 22563 ), 22564 ElementwiseUnaryPythonRefInfo( 22565 "_refs.nn.functional.threshold", 22566 torch_opinfo_name="nn.functional.threshold", 22567 supports_out=True, 22568 ), 22569 PythonRefInfo( 22570 "_refs.nn.functional.dropout", 22571 torch_opinfo_name="nn.functional.dropout", 22572 decorators=( 22573 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22574 'TestCommon', 22575 'test_python_ref'), 22576 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22577 'TestCommon', 22578 'test_python_ref_torch_fallback'), 22579 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22580 'TestCommon', 22581 'test_out'), 22582 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22583 'TestCommon', 22584 'test_out_warning'), 22585 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22586 'TestMathBits', 22587 'test_conj_view'), 22588 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22589 'TestMathBits', 22590 'test_neg_conj_view'), 22591 DecorateInfo(unittest.skip("Expected: dropout is not comparable"), 22592 'TestMathBits', 22593 'test_neg_view'), 22594 # dropout is not comparable 22595 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), 22596 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 22597 ) 22598 ), 22599 ElementwiseUnaryPythonRefInfo( 22600 "_refs.nn.functional.elu", 22601 torch_opinfo_name="nn.functional.elu", 22602 supports_out=True, 22603 decorators=[ 22604 DecorateInfo( 22605 toleranceOverride({ 22606 torch.float16: tol(atol=1e-03, rtol=1.2e-03), 22607 torch.bfloat16: tol(atol=1e-03, rtol=1.2e-03) 22608 }), 22609 'TestUnaryUfuncs', device_type='cuda', 22610 ), ], 22611 ), 22612 ElementwiseUnaryPythonRefInfo( 22613 "_refs.nn.functional.hardtanh", 22614 torch_opinfo_name="nn.functional.hardtanh", 22615 supports_out=True, 22616 ), 22617 PythonRefInfo( # TODO: Port this to an UnaryOpInfo 22618 "_refs.nn.functional.gelu", 22619 torch_opinfo_name="nn.functional.gelu", 22620 ), 22621 PythonRefInfo( 22622 "_refs.nn.functional.layer_norm", 22623 torch_opinfo_name="nn.functional.layer_norm", 22624 skips=( 22625 # Reference result was farther (3.5762786809723224e-07) from the precise computation 22626 # than the torch result was (2.5068410824946596e-07)! 22627 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', 22628 dtypes=(torch.float32,), device_type='cpu'), 22629 ), 22630 ), 22631 PythonRefInfo( 22632 "_refs.nn.functional.glu", 22633 torch_opinfo_name="nn.functional.glu", 22634 supports_out=True, 22635 ), 22636 PythonRefInfo( 22637 "_refs.nn.functional.pairwise_distance", 22638 torch_opinfo_name="nn.functional.pairwise_distance", 22639 supports_out=True, 22640 ), 22641 PythonRefInfo( 22642 "_refs.nn.functional.pdist", 22643 torch_opinfo_name="nn.functional.pdist", 22644 supports_out=True, 22645 skips=( 22646 # RunTimeError: no _refs support for torch.Tensor.index_select 22647 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), 22648 # Reference result was farther (1.946091651916504e-05) from the precise 22649 # computation than the torch result was (1.1920928955078125e-06)! 22650 DecorateInfo( 22651 unittest.expectedFailure, 22652 'TestCommon', 22653 'test_python_ref_torch_fallback', 22654 dtypes=(torch.float32,), 22655 device_type='cpu', 22656 ), 22657 )), 22658 PythonRefInfo( 22659 "_refs.nn.functional.leaky_relu", 22660 torch_opinfo_name="nn.functional.leaky_relu", 22661 supports_out=True, 22662 ), 22663 PythonRefInfo( 22664 "_refs.nn.functional.log_softmax", 22665 torch_opinfo_name="log_softmax", # alias 22666 torch_opinfo_variant_name="with_dtype", 22667 supports_out=False, 22668 ), 22669 PythonRefInfo( 22670 "_refs.nn.functional.pixel_shuffle", 22671 torch_opinfo_name="nn.functional.pixel_shuffle", 22672 ), 22673 PythonRefInfo( 22674 "_refs.nn.functional.pixel_unshuffle", 22675 torch_opinfo_name="nn.functional.pixel_unshuffle", 22676 ), 22677 PythonRefInfo( 22678 "_refs.nn.functional.poisson_nll_loss", 22679 torch_opinfo_name="nn.functional.poisson_nll_loss", 22680 ), 22681 ElementwiseUnaryPythonRefInfo( 22682 "_refs.nn.functional.prelu", 22683 torch_opinfo_name="nn.functional.prelu", 22684 ), 22685 ElementwiseUnaryPythonRefInfo( 22686 "_refs.nn.functional.relu", 22687 torch_opinfo_name="nn.functional.relu", 22688 supports_out=True, 22689 ), 22690 ElementwiseUnaryPythonRefInfo( 22691 "_refs.nn.functional.relu6", 22692 torch_opinfo_name="nn.functional.relu6", 22693 supports_out=True, 22694 ), 22695 ElementwiseUnaryPythonRefInfo( 22696 "_refs.nn.functional.mish", 22697 torch_opinfo_name="nn.functional.mish", 22698 supports_out=True, 22699 decorators=[ 22700 DecorateInfo( 22701 toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), 22702 'TestUnaryUfuncs',), ], 22703 ), 22704 ElementwiseUnaryPythonRefInfo( 22705 "_refs.nn.functional.selu", 22706 torch_opinfo_name="nn.functional.selu", 22707 supports_out=True, 22708 decorators=[ 22709 DecorateInfo( 22710 toleranceOverride({ 22711 torch.float16: tol(atol=1e-2, rtol=1.8e-2), 22712 torch.bfloat16: tol(atol=1e-2, rtol=1.8e-2) 22713 }), 22714 'TestUnaryUfuncs', device_type='cuda', 22715 ), ], 22716 ), 22717 PythonRefInfo( 22718 "_refs.nn.functional.softmax", 22719 torch_opinfo_name="softmax", # alias 22720 torch_opinfo_variant_name="with_dtype", 22721 supports_out=False, 22722 ), 22723 PythonRefInfo( 22724 "_refs.nn.functional.softmin", 22725 torch_opinfo_name="nn.functional.softmin", 22726 torch_opinfo_variant_name="with_dtype", 22727 supports_out=False, 22728 ), 22729 ElementwiseUnaryPythonRefInfo( 22730 "_refs.nn.functional.softplus", 22731 torch_opinfo_name="nn.functional.softplus", 22732 ), 22733 PythonRefInfo( 22734 "_refs.nn.functional.l1_loss", 22735 torch_opinfo_name="nn.functional.l1_loss", 22736 ), 22737 PythonRefInfo( 22738 "_refs.nn.functional.margin_ranking_loss", 22739 torch_opinfo_name="nn.functional.margin_ranking_loss", 22740 ), 22741 PythonRefInfo( 22742 "_refs.nn.functional.mse_loss", 22743 torch_opinfo_name="nn.functional.mse_loss", 22744 ), 22745 PythonRefInfo( 22746 "_refs.nn.functional.smooth_l1_loss", 22747 torch_opinfo_name="nn.functional.smooth_l1_loss", 22748 ), 22749 PythonRefInfo( 22750 "_refs.nn.functional.hinge_embedding_loss", 22751 torch_opinfo_name="nn.functional.hinge_embedding_loss", 22752 skips=( 22753 # Reference result was farther (0.29562714856322714) from the precise 22754 # computation than the torch result was (0.20437285143677286)! 22755 DecorateInfo( 22756 unittest.expectedFailure, 'TestCommon', 'test_python_ref', 22757 dtypes=(torch.bfloat16,), device_type="cpu" 22758 ), 22759 ), 22760 ), 22761 PythonRefInfo( 22762 "_refs.nn.functional.nll_loss", 22763 torch_opinfo_name="nn.functional.nll_loss", 22764 # The corresponding PyTorch op doesn't support out. But the ref is 22765 # registered as a decomp and ATen has an out variant. 22766 supports_out=True, 22767 # For simpler indexing, we flatten target indices, then reshape the result tensor. 22768 # This creates inconsistent view state with reference impl. 22769 validate_view_consistency=False, 22770 skips=( 22771 # RuntimeError: It appears that you're trying to get value out of a tracing tensor - erroring out! 22772 DecorateInfo( 22773 unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', device_type="cuda" 22774 ), 22775 ), 22776 ), 22777 PythonRefInfo( 22778 "_refs.nn.functional.huber_loss", 22779 torch_opinfo_name="nn.functional.huber_loss", 22780 # The corresponding PyTorch op doesn't support out. But the ref is 22781 # registered as a decomp and ATen has an out variant. 22782 supports_out=True, 22783 ), 22784 ElementwiseUnaryPythonRefInfo( 22785 "_refs.nn.functional.tanhshrink", 22786 torch_opinfo_name="nn.functional.tanhshrink", 22787 decorators=[ 22788 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 22789 'test_reference_numerics_normal', 22790 device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]), 22791 DecorateInfo( 22792 toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1.6e-02), 22793 torch.complex64: tol(atol=6e-04, rtol=1e-05)}), 22794 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda'), 22795 ], 22796 skips=( 22797 # in each case, pytorch will produce a nan while numpy will not 22798 DecorateInfo(unittest.skip("Fails on some jobs works on others!"), 22799 'TestUnaryUfuncs', "test_reference_numerics_large", 22800 dtypes=(torch.complex64, torch.complex128), 22801 active_if=(IS_MACOS)), 22802 DecorateInfo(unittest.skip("Fails on some jobs works on others!"), 22803 'TestUnaryUfuncs', "test_reference_numerics_extremal", 22804 dtypes=(torch.complex64, torch.complex128), 22805 device_type='cpu', 22806 active_if=(IS_MACOS or IS_WINDOWS)), 22807 ), 22808 ), 22809 ElementwiseUnaryPythonRefInfo( 22810 "_refs.nn.functional.hardshrink", 22811 torch_opinfo_name="nn.functional.hardshrink", 22812 ), 22813 ElementwiseUnaryPythonRefInfo( 22814 "_refs.nn.functional.softshrink", 22815 torch_opinfo_name="nn.functional.softshrink", 22816 ), 22817 # 22818 # Elementwise Binary Reference OpInfos 22819 # 22820 ElementwiseBinaryPythonRefInfo( 22821 "_refs.add", 22822 torch_opinfo_name="add", 22823 # https://github.com/pytorch/pytorch/issues/76944 22824 supports_two_python_scalars=True, 22825 supports_one_python_scalar=True, 22826 decorators=( 22827 DecorateInfo( 22828 toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), 22829 'TestBinaryUfuncs', 'test_reference_numerics'), 22830 ), 22831 skips=( 22832 DecorateInfo(unittest.skip("Skipped!"), 22833 'TestBinaryUfuncs', 22834 'test_reference_numerics_extremal_values', 22835 dtypes=(torch.complex64, torch.complex128)), 22836 ), 22837 ), 22838 ElementwiseBinaryPythonRefInfo( 22839 "_refs.atan2", 22840 torch_opinfo_name="atan2", 22841 ), 22842 ElementwiseBinaryPythonRefInfo( 22843 "_refs.bitwise_and", 22844 torch_opinfo_name="bitwise_and", 22845 ), 22846 ElementwiseBinaryPythonRefInfo( 22847 "_refs.bitwise_left_shift", 22848 torch_opinfo_name="bitwise_left_shift", 22849 skips=( 22850 # https://github.com/pytorch/pytorch/issues/70904 22851 DecorateInfo(unittest.skip("Some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), 22852 ), 22853 ), 22854 ElementwiseBinaryPythonRefInfo( 22855 "_refs.bitwise_right_shift", 22856 torch_opinfo_name="bitwise_right_shift", 22857 skips=( 22858 # # https://github.com/pytorch/pytorch/issues/70904 22859 DecorateInfo(unittest.skip("Skipped some inputs produce undefined outputs"), 'TestCommon', 'test_compare_cpu'), 22860 ), 22861 ), 22862 ElementwiseBinaryPythonRefInfo( 22863 "_refs.bitwise_or", 22864 torch_opinfo_name="bitwise_or", 22865 ), 22866 ElementwiseBinaryPythonRefInfo( 22867 "_refs.bitwise_xor", 22868 torch_opinfo_name="bitwise_xor", 22869 ), 22870 ElementwiseBinaryPythonRefInfo( 22871 "_refs.copysign", 22872 torch_opinfo_name="copysign", 22873 skips=( 22874 # RuntimeError: Expected divisor (b) to be on the same device (cuda:0) as dividend (a), but it is found on cpu! 22875 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_type_promotion'), 22876 # FIXME output 0: meta disagrees with real impl 22877 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), 22878 ) 22879 ), 22880 ElementwiseBinaryPythonRefInfo( 22881 "_refs.div", 22882 torch_opinfo_name="div", 22883 torch_opinfo_variant_name="no_rounding_mode", 22884 # https://github.com/pytorch/pytorch/issues/76944 22885 supports_two_python_scalars=True, 22886 supports_one_python_scalar=True, 22887 skips=( 22888 # NotImplementedError: argument of type: <class 'complex'> 22889 DecorateInfo( 22890 unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_executor', 22891 dtypes=(torch.complex32, torch.complex64, torch.complex128,) 22892 ), 22893 # Reference result was farther (0.7433461727239705) from the precise 22894 # computation than the torch result was (nan)! 22895 DecorateInfo( 22896 unittest.expectedFailure, 'TestCommon', 'test_python_ref', 22897 dtypes=(torch.complex32,), device_type="cuda" 22898 ), 22899 # Reference result was farther (0.7433461727239705) from the precise 22900 # computation than the torch result was (nan)! 22901 DecorateInfo( 22902 unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 22903 dtypes=(torch.complex32,), device_type="cuda" 22904 ), 22905 ), 22906 ), 22907 ElementwiseBinaryPythonRefInfo( 22908 "_refs.div", 22909 torch_opinfo_name="div", 22910 torch_opinfo_variant_name="trunc_rounding", 22911 # https://github.com/pytorch/pytorch/issues/76944 22912 supports_two_python_scalars=True, 22913 supports_one_python_scalar=True, 22914 decorators=( 22915 # See https://github.com/pytorch/pytorch/issues/111126 22916 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), 22917 ), 22918 ), 22919 ElementwiseBinaryPythonRefInfo( 22920 "_refs.div", 22921 torch_opinfo_name="div", 22922 torch_opinfo_variant_name="floor_rounding", 22923 # https://github.com/pytorch/pytorch/issues/76944 22924 supports_two_python_scalars=True, 22925 supports_one_python_scalar=True, 22926 decorators=( 22927 # See https://github.com/pytorch/pytorch/issues/111126 22928 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), 22929 # Reference result was farther (nan) from the precise computation than the 22930 # torch result was (inf)! 22931 DecorateInfo( 22932 unittest.expectedFailure, 22933 "TestCommon", 22934 "test_python_ref", 22935 dtypes=(torch.bfloat16,), 22936 device_type="cpu", 22937 ), 22938 ), 22939 ), 22940 ElementwiseBinaryPythonRefInfo( 22941 "_refs.eq", 22942 torch_opinfo_name="eq", 22943 ), 22944 ElementwiseBinaryPythonRefInfo( 22945 "_refs.float_power", 22946 torch_opinfo_name="float_power", 22947 skips=( 22948 # Test doesn't account for float -> double type promotion 22949 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), 22950 # Complex values error with: Greatest absolute difference: nan at index 22951 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 22952 'test_reference_numerics_small_values', 22953 dtypes=[torch.complex64, torch.complex128]), 22954 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 22955 'test_reference_numerics_large_values', 22956 dtypes=[torch.complex64, torch.complex128]), 22957 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 22958 'test_reference_numerics_extremal_values', 22959 dtypes=[torch.complex64, torch.complex128]), 22960 ), 22961 ), 22962 ElementwiseBinaryPythonRefInfo( 22963 "_refs.logaddexp", 22964 torch_opinfo_name="logaddexp", 22965 skips=( 22966 # failure due to mismatch in edge cases, which boils down to what torch.exp(inf + infj) should be 22967 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='cpu', 22968 dtypes=(torch.complex64, torch.complex128)), 22969 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='cpu', 22970 dtypes=(torch.complex64, torch.complex128)), 22971 ), 22972 ), 22973 PythonRefInfo( 22974 "_refs.logaddexp2", 22975 torch_opinfo_name="logaddexp2", 22976 ), 22977 ElementwiseBinaryPythonRefInfo( 22978 "_refs.floor_divide", 22979 torch_opinfo_name="floor_divide", 22980 rhs_make_tensor_kwargs=dict(exclude_zero=True), 22981 # https://github.com/pytorch/pytorch/issues/76944 22982 supports_two_python_scalars=True, 22983 supports_one_python_scalar=True, 22984 # bfloat16 floor_divide compared with a float32 reference works inconsistently 22985 skips=( 22986 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', 22987 dtypes=(torch.bfloat16,)), 22988 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', 22989 dtypes=(torch.bfloat16,)), 22990 # bfloat16 floor_divide compared with a float32 reference works inconsistently 22991 DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 22992 dtypes=(torch.bfloat16,)), 22993 # int8 floor divide has different results for -128 // -1 vs. NumPy 22994 DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 22995 'test_reference_numerics_small_values', 22996 dtypes=(torch.int8,)), 22997 # The following tests fails on some jobs 22998 DecorateInfo(unittest.skip('Skipped!'), 'TestBinaryUfuncs', 22999 'test_reference_numerics_extremal_values', 23000 dtypes=(torch.float16,)), 23001 DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=5e-3)}), 23002 'TestBinaryUfuncs', 'test_reference_numerics'), 23003 # FIXME output 0: meta disagrees with real impl 23004 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), 23005 ), 23006 ), 23007 ElementwiseBinaryPythonRefInfo( 23008 "_refs.fmax", 23009 torch_opinfo_name="fmax", 23010 supports_rhs_python_scalar=False, 23011 ), 23012 ElementwiseBinaryPythonRefInfo( 23013 "_refs.fmin", 23014 torch_opinfo_name="fmin", 23015 supports_rhs_python_scalar=False, 23016 ), 23017 ElementwiseBinaryPythonRefInfo( 23018 "_refs.fmod", 23019 torch_opinfo_name="fmod", 23020 rhs_make_tensor_kwargs={'exclude_zero': True}, 23021 supports_rhs_python_scalar=True, 23022 skips=( 23023 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', 23024 dtypes=(torch.bfloat16,), device_type='cpu'), 23025 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', 23026 dtypes=(torch.bfloat16,), device_type='cpu'), 23027 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 23028 'test_contig_vs_every_other', 23029 dtypes=(torch.bfloat16,)), 23030 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 23031 'test_non_contig', 23032 dtypes=(torch.bfloat16,)), 23033 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 23034 'test_reference_numerics', 23035 dtypes=(torch.bfloat16,)), 23036 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 23037 'test_reference_numerics_small_values', 23038 dtypes=(torch.uint8,)), 23039 ), 23040 ), 23041 ElementwiseBinaryPythonRefInfo( 23042 "_refs.gcd", 23043 torch_opinfo_name="gcd", 23044 skips=( 23045 DecorateInfo(unittest.expectedFailure, 23046 'TestBinaryUfuncs', 23047 'test_reference_numerics_small_values', 23048 dtypes=(torch.int8,)), 23049 ), 23050 ), 23051 ElementwiseBinaryPythonRefInfo( 23052 "_refs.ge", 23053 torch_opinfo_name="ge", 23054 ), 23055 ElementwiseBinaryPythonRefInfo( 23056 "_refs.gt", 23057 torch_opinfo_name="gt", 23058 ), 23059 ElementwiseBinaryPythonRefInfo( 23060 "_refs.heaviside", 23061 torch_opinfo_name="heaviside", 23062 supports_rhs_python_scalar=False, 23063 skips=( 23064 # PyTorch's heaviside does not appear to propagate NaNs 23065 DecorateInfo(unittest.skip("Skipped!"), 23066 'TestBinaryUfuncs', 23067 'test_reference_numerics_extremal_values'), 23068 ), 23069 ), 23070 ElementwiseBinaryPythonRefInfo( 23071 "_refs.hypot", 23072 torch_opinfo_name="hypot", 23073 supports_rhs_python_scalar=False, 23074 ), 23075 ElementwiseBinaryPythonRefInfo( 23076 "_refs.igamma", 23077 torch_opinfo_name="igamma", 23078 ), 23079 ElementwiseBinaryPythonRefInfo( 23080 "_refs.igammac", 23081 torch_opinfo_name="igammac", 23082 ), 23083 ElementwiseBinaryPythonRefInfo( 23084 "_refs.isclose", 23085 torch_opinfo_name="isclose", 23086 skips=( 23087 # Intentional xfail -- isclose does not type promote 23088 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), 23089 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), 23090 DecorateInfo(unittest.skip("Skipped!"), 23091 'TestBinaryUfuncs', 23092 'test_reference_numerics_extremal_values'), 23093 ), 23094 ), 23095 ElementwiseBinaryPythonRefInfo( 23096 "_refs.lcm", 23097 torch_opinfo_name="lcm", 23098 ), 23099 ElementwiseBinaryPythonRefInfo( 23100 "_refs.le", 23101 torch_opinfo_name="le", 23102 ), 23103 ElementwiseBinaryPythonRefInfo( 23104 "_refs.logical_and", 23105 torch_opinfo_name="logical_and", 23106 ), 23107 ElementwiseUnaryPythonRefInfo( 23108 "_refs.logical_not", 23109 torch_opinfo_name="logical_not", 23110 ), 23111 ElementwiseBinaryPythonRefInfo( 23112 "_refs.logical_or", 23113 torch_opinfo_name="logical_or", 23114 ), 23115 ElementwiseBinaryPythonRefInfo( 23116 "_refs.logical_xor", 23117 torch_opinfo_name="logical_xor", 23118 ), 23119 ElementwiseBinaryPythonRefInfo( 23120 "_refs.lt", 23121 torch_opinfo_name="lt", 23122 ), 23123 ElementwiseBinaryPythonRefInfo( 23124 "_refs.maximum", 23125 torch_opinfo_name="maximum", 23126 skips=( 23127 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), 23128 ), 23129 ), 23130 ElementwiseBinaryPythonRefInfo( 23131 "_refs.minimum", 23132 torch_opinfo_name="minimum", 23133 skips=( 23134 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), 23135 ), 23136 ), 23137 ElementwiseBinaryPythonRefInfo( 23138 "_refs.mul", 23139 torch_opinfo_name="mul", 23140 # https://github.com/pytorch/pytorch/issues/76944 23141 supports_two_python_scalars=True, 23142 supports_one_python_scalar=True, 23143 skips=( 23144 # Reference result was farther (0.0) from the precise computation 23145 # than the torch result was (nan)! 23146 DecorateInfo( 23147 unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', 23148 dtypes=(torch.complex32,), 23149 ), 23150 # Reference result was farther (0.0) from the precise computation 23151 # than the torch result was (nan)! 23152 DecorateInfo( 23153 unittest.expectedFailure, 'TestCommon', 'test_python_ref', 23154 dtypes=(torch.complex32,), device_type='cuda' 23155 ), 23156 # Reference result was farther (0.0) from the precise computation 23157 # than the torch result was (nan)! 23158 DecorateInfo( 23159 unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 23160 dtypes=(torch.complex32,), device_type='cuda' 23161 ), 23162 ) 23163 ), 23164 ElementwiseBinaryPythonRefInfo( 23165 "_refs.ne", 23166 torch_opinfo_name="ne", 23167 ), 23168 ElementwiseBinaryPythonRefInfo( 23169 "_refs.nextafter", 23170 torch_opinfo_name="nextafter", 23171 ), 23172 ElementwiseBinaryPythonRefInfo( 23173 "_refs.pow", 23174 torch_opinfo_name="pow", 23175 decorators=( 23176 DecorateInfo( 23177 toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05)}), 23178 'TestBinaryUfuncs', 'test_reference_numerics'), 23179 DecorateInfo( 23180 toleranceOverride({torch.complex64: tol(atol=1e-4, rtol=1.3e-05), 23181 torch.complex128: tol(atol=1e-4, rtol=1.3e-05)}), 23182 'TestBinaryUfuncs', 'test_scalar_support'), 23183 ), 23184 skips=( 23185 # Reference result was farther (inf) from the precise 23186 # computation than the torch result was (nan)! 23187 DecorateInfo( 23188 unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', 23189 dtypes=(torch.complex32,), 23190 ), 23191 # Reference result was farther (inf) from the precise 23192 # computation than the torch result was (nan)! 23193 DecorateInfo( 23194 unittest.expectedFailure, 'TestCommon', 'test_python_ref', 23195 dtypes=(torch.complex32,), device_type="cuda" 23196 ), 23197 # Reference result was farther (inf) from the precise 23198 # computation than the torch result was (nan)! 23199 DecorateInfo( 23200 unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 23201 dtypes=(torch.complex32,), device_type="cuda" 23202 ), 23203 # Skipping integers because they are being raised to negative powers causing an error 23204 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 23205 'test_reference_numerics_small_values', 23206 dtypes=[torch.int8, torch.int16, torch.int32, torch.int64]), 23207 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 23208 'test_reference_numerics_large_values', 23209 dtypes=[torch.int16, torch.int32, torch.int64]), 23210 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 23211 'test_reference_numerics', 23212 dtypes=(torch.complex32,)), 23213 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 23214 'test_reference_numerics_small_values', 23215 dtypes=(torch.complex32, torch.complex64, torch.complex128)), 23216 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 23217 'test_reference_numerics_large_values', 23218 dtypes=(torch.complex32, torch.complex64, torch.complex128)), 23219 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 23220 'test_reference_numerics_extremal_values', 23221 dtypes=(torch.complex32, torch.complex64, torch.complex128)), 23222 ), 23223 ), 23224 ElementwiseBinaryPythonRefInfo( 23225 "_refs.remainder", 23226 torch_opinfo_name="remainder", 23227 skips=( 23228 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', 23229 dtypes=(torch.bfloat16,), device_type='cpu'), 23230 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', 23231 dtypes=(torch.bfloat16,), device_type='cpu'), 23232 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 23233 'test_reference_numerics', 23234 dtypes=(torch.bfloat16,)), 23235 DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 23236 'test_reference_numerics_small_values', 23237 dtypes=(torch.uint8,)), 23238 ), 23239 ), 23240 ElementwiseBinaryPythonRefInfo( 23241 "_refs.rsub", 23242 torch_opinfo_name="rsub", 23243 # https://github.com/pytorch/pytorch/issues/76944 23244 skips=( 23245 # Reference result was farther (nan) from the precise computation than 23246 # the torch result was (nan)! 23247 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', 23248 dtypes=(torch.chalf,), device_type='cpu'), 23249 # Reference result was farther (nan) from the precise computation than 23250 # the torch result was (nan)! 23251 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 23252 dtypes=(torch.chalf,), device_type='cpu'), 23253 ), 23254 ), 23255 ElementwiseBinaryPythonRefInfo( 23256 "_refs.sub", 23257 torch_opinfo_name="sub", 23258 # https://github.com/pytorch/pytorch/issues/76944 23259 supports_two_python_scalars=True, 23260 supports_one_python_scalar=True, 23261 decorators=( 23262 DecorateInfo( 23263 toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0), 23264 torch.bfloat16: tol(atol=1e-5, rtol=5e-3), 23265 torch.complex32: tol(atol=1e-5, rtol=1e-3)}), 23266 'TestBinaryUfuncs', 'test_reference_numerics'), 23267 DecorateInfo( 23268 toleranceOverride({torch.chalf: tol(atol=1e-2, rtol=0)}), 23269 'TestCommon', 'test_complex_half_reference_testing', device_type='cpu'), 23270 DecorateInfo( 23271 toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), 23272 'TestDecomp', 'test_comprehensive', device_type='cpu'), 23273 DecorateInfo( 23274 toleranceOverride({torch.chalf: tol(atol=5e-3, rtol=0)}), 23275 'TestDecomp', 'test_quick', device_type='cpu'), 23276 ), 23277 skips=( 23278 DecorateInfo(unittest.skip("Skipped!"), 23279 'TestBinaryUfuncs', 23280 'test_reference_numerics', 23281 dtypes=(torch.uint8,)), 23282 DecorateInfo(unittest.skip("Skipped!"), 23283 'TestBinaryUfuncs', 23284 'test_reference_numerics_small_values', 23285 dtypes=(torch.uint8,)), 23286 ), 23287 ), 23288 ElementwiseBinaryPythonRefInfo( 23289 "_refs.true_divide", 23290 torch_opinfo_name="true_divide", 23291 # https://github.com/pytorch/pytorch/issues/76944 23292 supports_two_python_scalars=True, 23293 supports_one_python_scalar=True, 23294 skips=( 23295 # Reference result was farther (0.7433461727239705) from the precise 23296 # computation than the torch result was (nan)! 23297 DecorateInfo( 23298 unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', 23299 dtypes=(torch.complex32,), 23300 ), 23301 # Reference result was farther (0.7433461727239705) from the precise 23302 # computation than the torch result was (nan)! 23303 DecorateInfo( 23304 unittest.expectedFailure, 'TestCommon', 'test_python_ref', 23305 dtypes=(torch.complex32,), device_type="cuda" 23306 ), 23307 # Reference result was farther (0.7433461727239705) from the precise 23308 # computation than the torch result was (nan)! 23309 DecorateInfo( 23310 unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', 23311 dtypes=(torch.complex32,), device_type="cuda" 23312 ), 23313 ), 23314 ), 23315 # 23316 # Elementwise Ternary Reference OpInfos 23317 # 23318 PythonRefInfo( 23319 "_refs.addcdiv", 23320 torch_opinfo_name="addcdiv", 23321 ), 23322 PythonRefInfo( 23323 "_refs.addcmul", 23324 torch_opinfo_name="addcmul", 23325 skips=( 23326 # Reference result was farther (1.3343989849090576e-05) 23327 # from the precise computation than the torch result 23328 # was (9.592622518539429e-06)! 23329 # FIXME: enable dtype-based tolerances in test_ops.py:TestCommon._ref_test_helper 23330 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', 23331 dtypes=(torch.float16,), device_type="cpu"), 23332 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', 23333 dtypes=(torch.float16,), device_type="cpu"), 23334 ), 23335 ), 23336 ElementwiseBinaryPythonRefInfo( 23337 "_refs.clamp_min", 23338 torch_opinfo_name="clamp_min", 23339 skips=( 23340 # test error disabled since rhs non-tensor python scalar is supported 23341 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), 23342 ), 23343 ), 23344 ElementwiseBinaryPythonRefInfo( 23345 "_refs.clamp_max", 23346 torch_opinfo_name="clamp_max", 23347 skips=( 23348 # test error disabled since rhs non-tensor python scalar is supported 23349 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), 23350 ), 23351 ), 23352 PythonRefInfo( 23353 "_refs.clamp", 23354 torch_opinfo_name="clamp", 23355 ), 23356 PythonRefInfo( 23357 "_refs.nn.functional.triplet_margin_loss", 23358 torch_opinfo_name="nn.functional.triplet_margin_loss", 23359 supports_out=False, 23360 # TODO: Uses minimum and clamp 23361 skips=( 23362 # AssertionError: Tensor-likes are not close! 23363 # Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed) 23364 # Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed) 23365 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', 23366 dtypes=(torch.uint8,), device_type="cpu"), 23367 ) 23368 ), 23369 ElementwiseBinaryPythonRefInfo( 23370 "_refs.xlogy", 23371 torch_opinfo_name="xlogy", 23372 supports_one_python_scalar=True, 23373 ), 23374 # 23375 # Elementwise Binary Special OpInfos 23376 # 23377 ElementwiseBinaryPythonRefInfo( 23378 "_refs.special.xlog1py", 23379 torch_opinfo_name="special.xlog1py", 23380 supports_one_python_scalar=True, 23381 ), 23382 # 23383 # Data Conversion & Data Movement Opinfos 23384 # 23385 ElementwiseUnaryPythonRefInfo( 23386 "_refs._conversions.bfloat16", 23387 torch_opinfo_name="bfloat16", 23388 # TODO: If self already has the correct dtype and device, then self is 23389 # returned ignoring memory_format. 23390 # https://github.com/pytorch/pytorch/issues/86558 23391 validate_view_consistency=False, 23392 ), 23393 ElementwiseUnaryPythonRefInfo( 23394 "_refs._conversions.bool", 23395 torch_opinfo_name="bool", 23396 # TODO: If self already has the correct dtype and device, then self is 23397 # returned ignoring memory_format. 23398 # https://github.com/pytorch/pytorch/issues/86558 23399 validate_view_consistency=False, 23400 ), 23401 ElementwiseUnaryPythonRefInfo( 23402 "_refs._conversions.byte", 23403 torch_opinfo_name="byte", 23404 # TODO: If self already has the correct dtype and device, then self is 23405 # returned ignoring memory_format. 23406 # https://github.com/pytorch/pytorch/issues/86558 23407 validate_view_consistency=False, 23408 skips=( 23409 DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), 23410 ) 23411 ), 23412 ElementwiseUnaryPythonRefInfo( 23413 "_refs._conversions.char", 23414 torch_opinfo_name="char", 23415 # TODO: If self already has the correct dtype and device, then self is 23416 # returned ignoring memory_format. 23417 # https://github.com/pytorch/pytorch/issues/86558 23418 validate_view_consistency=False, 23419 skips=( 23420 DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), 23421 ) 23422 ), 23423 ElementwiseBinaryPythonRefInfo( 23424 "_refs._conversions.complex", 23425 torch_opinfo_name="complex", 23426 error_inputs_func=partial(error_inputs_complex, is_ref=True), 23427 skips=( 23428 # Tests don't account for complex's type promotion semantics 23429 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), 23430 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), 23431 ) 23432 ), 23433 ElementwiseBinaryPythonRefInfo( 23434 "_refs._conversions.polar", 23435 torch_opinfo_name="polar", 23436 skips=( 23437 # Tests don't account for complex's type promotion semantics 23438 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), 23439 DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_binary_ufuncs_mixed_dtype'), 23440 ) 23441 ), 23442 ElementwiseUnaryPythonRefInfo( 23443 "_refs._conversions.double", 23444 torch_opinfo_name="double", 23445 # TODO: If self already has the correct dtype and device, then self is 23446 # returned ignoring memory_format. 23447 # https://github.com/pytorch/pytorch/issues/86558 23448 validate_view_consistency=False, 23449 ), 23450 ElementwiseUnaryPythonRefInfo( 23451 "_refs._conversions.float", 23452 torch_opinfo_name="float", 23453 # TODO: If self already has the correct dtype and device, then self is 23454 # returned ignoring memory_format. 23455 # https://github.com/pytorch/pytorch/issues/86558 23456 validate_view_consistency=False, 23457 ), 23458 ElementwiseUnaryPythonRefInfo( 23459 "_refs._conversions.half", 23460 torch_opinfo_name="half", 23461 # TODO: If self already has the correct dtype and device, then self is 23462 # returned ignoring memory_format. 23463 # https://github.com/pytorch/pytorch/issues/86558 23464 validate_view_consistency=False, 23465 ), 23466 ElementwiseUnaryPythonRefInfo( 23467 "_refs._conversions.int", 23468 torch_opinfo_name="int", 23469 # TODO: If self already has the correct dtype and device, then self is 23470 # returned ignoring memory_format. 23471 # https://github.com/pytorch/pytorch/issues/86558 23472 validate_view_consistency=False, 23473 skips=( 23474 DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), 23475 ) 23476 ), 23477 ElementwiseUnaryPythonRefInfo( 23478 "_refs._conversions.long", 23479 torch_opinfo_name="long", 23480 # TODO: If self already has the correct dtype and device, then self is 23481 # returned ignoring memory_format. 23482 # https://github.com/pytorch/pytorch/issues/86558 23483 validate_view_consistency=False, 23484 skips=( 23485 DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), 23486 ) 23487 ), 23488 ElementwiseUnaryPythonRefInfo( 23489 "_refs._conversions.short", 23490 torch_opinfo_name="short", 23491 # TODO: If self already has the correct dtype and device, then self is 23492 # returned ignoring memory_format. 23493 # https://github.com/pytorch/pytorch/issues/86558 23494 validate_view_consistency=False, 23495 skips=( 23496 DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), 23497 ) 23498 ), 23499 ElementwiseUnaryPythonRefInfo( 23500 "_refs._conversions.chalf", 23501 torch_opinfo_name="chalf", 23502 # TODO: If self already has the correct dtype and device, then self is 23503 # returned ignoring memory_format. 23504 # https://github.com/pytorch/pytorch/issues/86558 23505 validate_view_consistency=False, 23506 ), 23507 ElementwiseUnaryPythonRefInfo( 23508 "_refs._conversions.cfloat", 23509 torch_opinfo_name="cfloat", 23510 # TODO: If self already has the correct dtype and device, then self is 23511 # returned ignoring memory_format. 23512 # https://github.com/pytorch/pytorch/issues/86558 23513 validate_view_consistency=False, 23514 ), 23515 ElementwiseUnaryPythonRefInfo( 23516 "_refs._conversions.cdouble", 23517 torch_opinfo_name="cdouble", 23518 # TODO: If self already has the correct dtype and device, then self is 23519 # returned ignoring memory_format. 23520 # https://github.com/pytorch/pytorch/issues/86558 23521 validate_view_consistency=False, 23522 ), 23523 PythonRefInfo( 23524 "_refs.clone", 23525 torch_opinfo_name="clone", 23526 ), 23527 # 23528 # View & Shape OpInfos 23529 # 23530 PythonRefInfo( 23531 "_refs.alias_copy", 23532 torch_opinfo_name="alias_copy", 23533 supports_out=True, 23534 ), 23535 PythonRefInfo( 23536 "_refs.atleast_1d", 23537 torch_opinfo_name="atleast_1d", 23538 validate_view_consistency=False, 23539 ), 23540 PythonRefInfo( 23541 "_refs.atleast_2d", 23542 torch_opinfo_name="atleast_2d", 23543 validate_view_consistency=False, 23544 ), 23545 PythonRefInfo( 23546 "_refs.atleast_3d", 23547 torch_opinfo_name="atleast_3d", 23548 validate_view_consistency=False, 23549 ), 23550 PythonRefInfo( 23551 "_refs.as_strided", 23552 torch_opinfo_name="as_strided", 23553 # FIXME: doesn't support chalf 23554 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 23555 skips=( 23556 # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED 23557 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), 23558 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), 23559 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), 23560 ), 23561 ), 23562 PythonRefInfo( 23563 "_refs.as_strided_copy", 23564 torch_opinfo_name="as_strided_copy", 23565 supports_out=True, 23566 # FIXME: doesn't support chalf 23567 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 23568 skips=( 23569 # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED 23570 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), 23571 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), 23572 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), 23573 # The view function this decompose into does not have a ref 23574 DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"), 23575 ), 23576 ), 23577 PythonRefInfo( 23578 "_refs.as_strided", 23579 torch_opinfo_name="as_strided", 23580 torch_opinfo_variant_name="partial_views", 23581 # FIXME: doesn't support chalf 23582 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 23583 skips=( 23584 # cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED 23585 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'), 23586 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'), 23587 DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'), 23588 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'), 23589 ), 23590 ), 23591 PythonRefInfo( 23592 "_refs.as_strided_scatter", 23593 torch_opinfo_name="as_strided_scatter", 23594 # returns a view of an intermediate tensor (as_strided) 23595 validate_view_consistency=False, 23596 ), 23597 PythonRefInfo( 23598 "_refs.block_diag", 23599 torch_opinfo_name="block_diag", 23600 ), 23601 PythonRefInfo( 23602 "_refs.broadcast_shapes", 23603 torch_opinfo_name="broadcast_shapes", 23604 ), 23605 PythonRefInfo( 23606 "_refs.broadcast_tensors", 23607 torch_opinfo_name="broadcast_tensors", 23608 ), 23609 PythonRefInfo( 23610 "_refs.broadcast_to", 23611 torch_opinfo_name="broadcast_to", 23612 ), 23613 PythonRefInfo( 23614 "_refs.cat", 23615 torch_opinfo_name="cat", 23616 skips=( 23617 # FIXME: AssertionError: RuntimeError not raised 23618 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), 23619 ), 23620 ), 23621 PythonRefInfo( 23622 "_refs.chunk", 23623 torch_opinfo_name="chunk", 23624 ), 23625 PythonRefInfo( 23626 "_refs.column_stack", 23627 torch_opinfo_name="column_stack", 23628 ), 23629 ElementwiseUnaryPythonRefInfo( 23630 "_refs.conj", 23631 torch_opinfo_name="conj", 23632 ), 23633 PythonRefInfo( 23634 "_refs.constant_pad_nd", 23635 torch_opinfo_name="constant_pad_nd", 23636 ), 23637 PythonRefInfo( 23638 "_refs.contiguous", 23639 torch_opinfo_name="contiguous", 23640 ), 23641 ElementwiseUnaryPythonRefInfo( 23642 "_refs.deg2rad", 23643 torch_opinfo_name="deg2rad", 23644 decorators=(precisionOverride({torch.bfloat16: 7e-1, 23645 torch.float16: 7e-1}),), 23646 ), 23647 PythonRefInfo( 23648 "_refs.dsplit", 23649 torch_opinfo_name="dsplit", 23650 ), 23651 PythonRefInfo( 23652 "_refs.diag", 23653 torch_opinfo_name="diag", 23654 ), 23655 PythonRefInfo( 23656 "_refs.diagonal", 23657 torch_opinfo_name="diagonal", 23658 ), 23659 PythonRefInfo( 23660 "_refs.diagonal_copy", 23661 torch_opinfo_name="diagonal_copy", 23662 supports_out=True, 23663 ), 23664 PythonRefInfo( 23665 "_refs.diagonal_scatter", 23666 torch_opinfo_name="diagonal_scatter", 23667 supports_out=True, 23668 # returns a view of an intermediate tensor (as_strided) 23669 validate_view_consistency=False, 23670 ), 23671 PythonRefInfo( 23672 "_refs.diag_embed", 23673 torch_opinfo_name="diag_embed", 23674 supports_out=True, 23675 ), 23676 PythonRefInfo( 23677 "_refs.dstack", 23678 torch_opinfo_name="dstack", 23679 skips=( 23680 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), 23681 ), 23682 ), 23683 PythonRefInfo( 23684 "_refs.expand", 23685 torch_opinfo_name="expand", 23686 ), 23687 PythonRefInfo( 23688 "_refs.expand_as", 23689 torch_opinfo_name="expand_as", 23690 ), 23691 PythonRefInfo( 23692 "_refs.expand_copy", 23693 torch_opinfo_name="expand_copy", 23694 supports_out=True, 23695 ), 23696 PythonRefInfo( 23697 "_refs.flatten", 23698 torch_opinfo_name="flatten", 23699 ), 23700 PythonRefInfo( 23701 "_refs.flip", 23702 torch_opinfo_name="flip", 23703 ), 23704 PythonRefInfo( 23705 "_refs.fliplr", 23706 torch_opinfo_name="fliplr", 23707 ), 23708 PythonRefInfo( 23709 "_refs.flipud", 23710 torch_opinfo_name="flipud", 23711 ), 23712 PythonRefInfo( 23713 "_refs.hstack", 23714 torch_opinfo_name="hstack", 23715 skips=( 23716 # https://github.com/pytorch/pytorch/issues/78613 23717 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), 23718 ), 23719 ), 23720 PythonRefInfo( 23721 "_refs.narrow", 23722 torch_opinfo_name="narrow", 23723 error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=True), 23724 ), 23725 PythonRefInfo( 23726 "_refs.narrow_copy", 23727 torch_opinfo_name="narrow_copy", 23728 supports_out=True, 23729 error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True), 23730 skips=( 23731 # The view function this decompose into does not have a ref 23732 DecorateInfo(unittest.expectedFailure, "TestCommon", "test_python_ref"), 23733 ), 23734 ), 23735 PythonRefInfo( 23736 "_refs.nn.functional.group_norm", 23737 torch_opinfo_name="nn.functional.group_norm", 23738 validate_view_consistency=False, 23739 ), 23740 PythonRefInfo( 23741 "_refs.native_layer_norm", 23742 torch_opinfo_name="native_layer_norm", 23743 skips=( 23744 DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref", 23745 device_type="cpu", dtypes=(torch.float32,)), 23746 DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_python_ref_torch_fallback", 23747 device_type="cpu", dtypes=(torch.float32,)), 23748 ), 23749 ), 23750 PythonRefInfo( 23751 "_refs.permute", 23752 torch_opinfo_name="permute", 23753 ), 23754 ElementwiseUnaryPythonRefInfo( 23755 "_refs.rad2deg", 23756 torch_opinfo_name="rad2deg", 23757 decorators=(precisionOverride({torch.bfloat16: 7e-1, 23758 torch.float16: 7e-1}),), 23759 ), 23760 PythonRefInfo( 23761 "_refs.ravel", 23762 torch_opinfo_name="ravel", 23763 ), 23764 PythonRefInfo( 23765 "_refs.renorm", 23766 torch_opinfo_name="renorm", 23767 ), 23768 PythonRefInfo( 23769 "_refs.repeat", 23770 torch_opinfo_name="repeat", 23771 validate_view_consistency=False, 23772 ), 23773 PythonRefInfo( 23774 "_refs.reshape", 23775 torch_opinfo_name="reshape", 23776 ), 23777 PythonRefInfo( 23778 "_refs.reshape_as", 23779 torch_opinfo_name="reshape_as", 23780 ), 23781 PythonRefInfo( 23782 "_refs.roll", 23783 torch_opinfo_name="roll", 23784 validate_view_consistency=False, 23785 ), 23786 PythonRefInfo( 23787 "_refs.rot90", 23788 torch_opinfo_name="rot90", 23789 validate_view_consistency=False, 23790 ), 23791 PythonRefInfo( 23792 "_refs.select_scatter", 23793 torch_opinfo_name="select_scatter", 23794 ), 23795 PythonRefInfo( 23796 "_refs.stack", 23797 torch_opinfo_name="stack", 23798 validate_view_consistency=False, 23799 ), 23800 PythonRefInfo( 23801 "_refs.squeeze", 23802 torch_opinfo_name="squeeze", 23803 ), 23804 PythonRefInfo( 23805 "_refs.squeeze", 23806 torch_opinfo_name="squeeze", 23807 torch_opinfo_variant_name="multiple", 23808 ), 23809 PythonRefInfo( 23810 "_refs.tensor_split", 23811 torch_opinfo_name="tensor_split", 23812 skips=( 23813 # RuntimeError: no _refs support for torch.Tensor.tolist 23814 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), 23815 ), 23816 ), 23817 PythonRefInfo( 23818 "_refs.hsplit", 23819 torch_opinfo_name="hsplit", 23820 ), 23821 PythonRefInfo( 23822 "_refs.vsplit", 23823 torch_opinfo_name="vsplit", 23824 ), 23825 PythonRefInfo( 23826 "_refs.dot", 23827 torch_opinfo_name="dot", 23828 error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True), 23829 # .conj() does not set ._is_view() correctly in ATen 23830 validate_view_consistency=False, 23831 skips=( 23832 # RuntimeError: no _refs support for torch.Tensor.is_conj 23833 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]), 23834 ), 23835 ), 23836 PythonRefInfo( 23837 "_refs.vdot", 23838 torch_opinfo_name="vdot", 23839 error_inputs_func=partial(error_inputs_dot_vdot, is_ref=True), 23840 # .conj() does not set ._is_view() correctly in ATen 23841 validate_view_consistency=False, 23842 skips=( 23843 # RuntimeError: no _refs support for torch.Tensor.is_conj 23844 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', dtypes=[torch.complex64, torch.complex128]), 23845 ), 23846 ), 23847 PythonRefInfo( 23848 "_refs.transpose", 23849 torch_opinfo_name="transpose", 23850 ), 23851 PythonRefInfo( 23852 "_refs.t", 23853 torch_opinfo_name="t", 23854 ), 23855 PythonRefInfo( 23856 "_refs.t_copy", 23857 torch_opinfo_name="t_copy", 23858 supports_out=True, 23859 ), 23860 PythonRefInfo( 23861 "_refs.T", 23862 torch_opinfo_name="T", 23863 error_inputs_func=partial(error_inputs_T, has_ndims_error=True), 23864 ), 23865 PythonRefInfo( 23866 "_refs.unfold", 23867 torch_opinfo_name="unfold", 23868 ), 23869 PythonRefInfo( 23870 "_refs.unfold_copy", 23871 torch_opinfo_name="unfold_copy", 23872 supports_out=True, 23873 ), 23874 PythonRefInfo( 23875 "_refs.unsqueeze", 23876 torch_opinfo_name="unsqueeze", 23877 ), 23878 PythonRefInfo( 23879 "_refs.unsqueeze_copy", 23880 torch_opinfo_name="unsqueeze_copy", 23881 supports_out=True, 23882 ), 23883 PythonRefInfo( 23884 "_refs.view", 23885 torch_opinfo_name="view", 23886 ), 23887 PythonRefInfo( 23888 "_refs.view_as", 23889 torch_opinfo_name="view_as", 23890 ), 23891 PythonRefInfo( 23892 "_refs.view_copy", 23893 torch_opinfo_name="view_copy", 23894 supports_out=True, 23895 ), 23896 PythonRefInfo( 23897 "_refs.vstack", 23898 torch_opinfo_name="vstack", 23899 skips=( 23900 # https://github.com/pytorch/pytorch/issues/78613 23901 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), 23902 ), 23903 ), 23904 PythonRefInfo( 23905 "_refs.unflatten", 23906 torch_opinfo_name="unflatten", 23907 ), 23908 PythonRefInfo( 23909 "_refs.unbind", 23910 torch_opinfo_name="unbind", 23911 ), 23912 # 23913 # Reduction Reference OpInfos 23914 # 23915 ReductionPythonRefInfo( 23916 "_refs.all", 23917 torch_opinfo_name="all", 23918 skips=( 23919 # FIXME: uint8 input returns uint8 instead of bool 23920 DecorateInfo( 23921 unittest.expectedFailure, 'TestReductions', 'test_result_dtype', 23922 dtypes=[torch.uint8]), 23923 ), 23924 ), 23925 ReductionPythonRefInfo( 23926 "_refs.amax", 23927 torch_opinfo_name="amax", 23928 error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True), 23929 skips=( 23930 # FIXME: reduces all dimensions when dim=[] 23931 DecorateInfo( 23932 unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), 23933 DecorateInfo( 23934 unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), 23935 ), 23936 ), 23937 ReductionPythonRefInfo( 23938 "_refs.amin", 23939 torch_opinfo_name="amin", 23940 error_inputs_func=partial(error_inputs_aminmax_amax_amin, is_ref=True), 23941 skips=( 23942 # FIXME: reduces all dimensions when dim=[] 23943 DecorateInfo( 23944 unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), 23945 DecorateInfo( 23946 unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), 23947 ), 23948 ), 23949 ReductionPythonRefInfo( 23950 "_refs.any", 23951 torch_opinfo_name="any", 23952 skips=( 23953 # FIXME: uint8 input returns uint8 instead of bool 23954 DecorateInfo( 23955 unittest.expectedFailure, 'TestReductions', 'test_result_dtype', 23956 dtypes=[torch.uint8]), 23957 ), 23958 ), 23959 ReductionPythonRefInfo( 23960 "_refs.count_nonzero", 23961 torch_opinfo_name="count_nonzero", 23962 skips=( 23963 # FIXME: count_nonzero does not accept keepdim kwarg 23964 DecorateInfo( 23965 unittest.skip("Skipped!"), 'TestReductions', 23966 'test_dim_default_keepdim'), 23967 DecorateInfo( 23968 unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), 23969 DecorateInfo( 23970 unittest.skip("Skipped!"), 'TestReductions', 'test_dim_single_keepdim'), 23971 DecorateInfo( 23972 unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), 23973 DecorateInfo( 23974 unittest.skip("Skipped!"), 'TestReductions', 'test_dim_multi_keepdim'), 23975 DecorateInfo( 23976 unittest.skip("Skipped!"), 'TestReductions', 23977 'test_dim_multi_unsorted_keepdim'), 23978 # FIXME: dim=[] reduces all dimensions 23979 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), 23980 ), 23981 ), 23982 ReductionPythonRefInfo( 23983 "_refs.mean", 23984 torch_opinfo_name="mean", 23985 supports_out=True, 23986 error_inputs_func=partial(error_inputs_mean, is_ref=True), 23987 skips=( 23988 # FIXME: reduces all dimensions when dim=[] 23989 DecorateInfo( 23990 unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), 23991 DecorateInfo( 23992 unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), 23993 ), 23994 ), 23995 ReductionPythonRefInfo( 23996 "_refs.std", 23997 torch_opinfo_name="std", 23998 supports_out=True, 23999 skips=( 24000 # FIXME: reduces all dimensions when dim=[] 24001 DecorateInfo( 24002 unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), 24003 DecorateInfo( 24004 unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), 24005 # FIXME: improve precision 24006 DecorateInfo( 24007 unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', 24008 dtypes=(torch.float16,)), 24009 DecorateInfo( 24010 unittest.skip("Skipped!"), 'TestReductions', 24011 'test_ref_duplicate_values', 24012 dtypes=(torch.float16,)), 24013 ), 24014 ), 24015 # std_mean and var_mean are not ReductionInfos 24016 PythonRefInfo( 24017 "_refs.std_mean", 24018 torch_opinfo_name="std_mean", 24019 ), 24020 ReductionPythonRefInfo( 24021 "_refs.sum", 24022 torch_opinfo_name="sum", 24023 supports_out=True, 24024 skips=( 24025 # FIXME: doesn't test out behavior properly for this operator 24026 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 24027 # FIXME: mean reduces all dimensions when dim=[] 24028 DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), 24029 DecorateInfo( 24030 unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), 24031 # FIXME: improve precision 24032 DecorateInfo( 24033 unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', 24034 dtypes=[torch.float16]), 24035 DecorateInfo( 24036 unittest.skip("Skipped!"), 'TestReductions', 24037 'test_ref_duplicate_values', 24038 dtypes=[torch.float16]), 24039 DecorateInfo( 24040 unittest.skip("Skipped!"), 'TestOperators', 'test_reduction_all', 24041 dtypes=[torch.float32]), 24042 ), 24043 ), 24044 PythonRefInfo( 24045 "_refs.cumsum", 24046 torch_opinfo_name="cumsum", 24047 supports_out=True, 24048 skips=( 24049 # doesn't test out behavior properly for this operator 24050 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 24051 ), 24052 ), 24053 PythonRefInfo( 24054 "_refs.cumprod", 24055 torch_opinfo_name="cumprod", 24056 supports_out=True, 24057 skips=( 24058 # doesn't test out behavior properly for this operator 24059 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 24060 ), 24061 ), 24062 PythonRefInfo( 24063 "_refs.sum_to_size", 24064 torch_opinfo_name="sum_to_size", 24065 validate_view_consistency=False, 24066 ), 24067 ReductionPythonRefInfo( 24068 "_refs.prod", 24069 torch_opinfo_name="prod", 24070 supports_out=True, 24071 supports_multiple_dims=True, 24072 skips=( 24073 # FIXME: doesn't test out behavior properly for this operator 24074 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), 24075 # FIXME: reduces all dimensions when dim=[] 24076 DecorateInfo( 24077 unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), 24078 DecorateInfo( 24079 unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), 24080 # FIXME: improve precision 24081 DecorateInfo( 24082 unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', 24083 dtypes=[torch.float16, torch.complex64]), 24084 ), 24085 ), 24086 ReductionPythonRefInfo( 24087 "_refs.var", 24088 torch_opinfo_name="var", 24089 supports_out=True, 24090 skips=( 24091 # FIXME: reduces all dimensions when dim=[] 24092 DecorateInfo( 24093 unittest.expectedFailure, 'TestReductions', 'test_dim_empty'), 24094 DecorateInfo( 24095 unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'), 24096 # FIXME: improve precision 24097 DecorateInfo( 24098 unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'), 24099 ), 24100 ), 24101 PythonRefInfo( 24102 "_refs.var_mean", 24103 torch_opinfo_name="var_mean", 24104 validate_view_consistency=False, 24105 ), 24106 # 24107 # Linear Algebra Operators 24108 # 24109 PythonRefInfo( 24110 "_refs.addr", 24111 torch_opinfo_name="addr", 24112 decorators=( 24113 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',), 24114 ), 24115 ), 24116 PythonRefInfo( 24117 "_refs.trace", 24118 torch_opinfo_name="trace", 24119 ), 24120 PythonRefInfo( 24121 "_refs.norm", 24122 torch_opinfo_name="norm", 24123 supports_out=True, 24124 # Uses vector_norm inside and vector_norm is affected by 24125 # https://github.com/pytorch/pytorch/issues/77216 24126 validate_view_consistency=False, 24127 ), 24128 # 24129 # Tensor Creation Reference OpInfos 24130 # 24131 PythonRefInfo( 24132 "_refs.empty", 24133 torch_opinfo_name="empty", 24134 skips=( 24135 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24136 'TestCommon', 24137 'test_python_ref'), 24138 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24139 'TestCommon', 24140 'test_python_ref_torch_fallback'), 24141 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24142 'TestCommon', 24143 'test_out'), 24144 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24145 'TestCommon', 24146 'test_out_warning'), 24147 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24148 'TestMathBits', 24149 'test_conj_view'), 24150 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24151 'TestMathBits', 24152 'test_neg_conj_view'), 24153 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24154 'TestMathBits', 24155 'test_neg_view'), 24156 # FIXME: shouldn't check empty results 24157 DecorateInfo(unittest.skip("Can't check result for empty"), 'TestCommon', 'test_python_ref_executor'), 24158 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 24159 ), 24160 ), 24161 PythonRefInfo( 24162 "_refs.empty_like", 24163 torch_opinfo_name="empty_like", 24164 skips=( 24165 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24166 'TestCommon', 24167 'test_python_ref'), 24168 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24169 'TestCommon', 24170 'test_python_ref_torch_fallback'), 24171 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24172 'TestCommon', 24173 'test_out'), 24174 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24175 'TestCommon', 24176 'test_out_warning'), 24177 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24178 'TestMathBits', 24179 'test_conj_view'), 24180 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24181 'TestMathBits', 24182 'test_neg_conj_view'), 24183 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24184 'TestMathBits', 24185 'test_neg_view'), 24186 # FIXME: should not compare results of empty_like 24187 DecorateInfo(unittest.skip("Can't check result for empty_like"), 'TestCommon', 'test_python_ref_executor'), 24188 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 24189 ), 24190 ), 24191 PythonRefInfo( 24192 "_refs.randn", 24193 torch_opinfo_name="randn", 24194 op=lambda *args, **kwargs: wrapper_set_seed(refs.randn, *args, **kwargs), 24195 skips=( 24196 # see https://github.com/pytorch/pytorch/issues/85121 24197 DecorateInfo(unittest.skip("make_traced() doesn't set seed properly!"), 24198 'TestCommon', 24199 'test_python_ref_executor'), 24200 # These tests expect the input to be a tensor or a sequence of tensors 24201 DecorateInfo(unittest.skip("Test expects tensor input"), "TestCommon", "test_noncontiguous_samples"), 24202 DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_view'), 24203 DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_conj_view'), 24204 DecorateInfo(unittest.skip("Test expects tensor input"), 'TestMathBits', 'test_neg_conj_view'), 24205 ), 24206 ), 24207 PythonRefInfo( 24208 "_refs.eye", 24209 torch_opinfo_name="eye", 24210 skips=( 24211 # skip these tests since we have non tensor input 24212 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), 24213 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), 24214 DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), 24215 ), 24216 ), 24217 PythonRefInfo( 24218 "_refs.new_empty", 24219 torch_opinfo_name="new_empty", 24220 skips=( 24221 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24222 'TestCommon', 24223 'test_python_ref'), 24224 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24225 'TestCommon', 24226 'test_python_ref_torch_fallback'), 24227 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24228 'TestCommon', 24229 'test_out'), 24230 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24231 'TestCommon', 24232 'test_out_warning'), 24233 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24234 'TestMathBits', 24235 'test_conj_view'), 24236 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24237 'TestMathBits', 24238 'test_neg_conj_view'), 24239 DecorateInfo(unittest.skip("Expected: empty is not comparable"), 24240 'TestMathBits', 24241 'test_neg_view'), 24242 # FIXME: should not compare results of empty_like 24243 DecorateInfo(unittest.skip("Can't check result for new_empty"), 'TestCommon', 'test_python_ref_executor'), 24244 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 24245 ), 24246 ), 24247 PythonRefInfo( 24248 "_refs.new_empty_strided", 24249 torch_opinfo_name="new_empty_strided", 24250 skips=( 24251 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24252 'TestCommon', 24253 'test_python_ref'), 24254 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24255 'TestCommon', 24256 'test_python_ref_torch_fallback'), 24257 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24258 'TestMathBits', 24259 'test_conj_view'), 24260 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24261 'TestMathBits', 24262 'test_neg_conj_view'), 24263 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24264 'TestMathBits', 24265 'test_neg_view'), 24266 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24267 'TestCommon', 24268 'test_python_ref_executor'), 24269 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 24270 24271 ), 24272 ), 24273 PythonRefInfo( 24274 "_refs.empty_strided", 24275 torch_opinfo_name="empty_strided", 24276 skips=( 24277 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24278 'TestCommon', 24279 'test_python_ref'), 24280 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24281 'TestCommon', 24282 'test_python_ref_torch_fallback'), 24283 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24284 'TestMathBits', 24285 'test_conj_view'), 24286 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24287 'TestMathBits', 24288 'test_neg_conj_view'), 24289 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24290 'TestMathBits', 24291 'test_neg_view'), 24292 DecorateInfo(unittest.skip("Expected: empty_strided is not comparable"), 24293 'TestCommon', 24294 'test_python_ref_executor'), 24295 DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), 24296 ), 24297 ), 24298 PythonRefInfo( 24299 "_refs.new_full", 24300 torch_opinfo_name="new_full", 24301 ), 24302 PythonRefInfo( 24303 "_refs.new_ones", 24304 torch_opinfo_name="new_ones", 24305 ), 24306 PythonRefInfo( 24307 "_refs.new_zeros", 24308 torch_opinfo_name="new_zeros", 24309 ), 24310 # 24311 # Conditional Reference OpInfos 24312 # 24313 PythonRefInfo( 24314 "_refs.masked_fill", 24315 torch_opinfo_name="masked_fill", 24316 skips=( 24317 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), 24318 ), 24319 ), 24320 PythonRefInfo( 24321 "_refs.where", 24322 torch_opinfo_name="where", 24323 op=lambda self, condition, other: refs.where(condition, self, other), 24324 supports_out=False, 24325 skips=( 24326 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors', device_type='cuda'), 24327 ), 24328 ), 24329 PythonRefInfo( 24330 "_refs.index_select", 24331 torch_opinfo_name="index_select", 24332 # empty_strided 24333 skips=( 24334 # no _refs support for Tensor.__setitem__ 24335 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), 24336 # Sample out= with a stride of zero. This _out operation checks that the input has no 24337 # inner overlap 24338 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),) 24339 ), 24340 PythonRefInfo( 24341 "_refs.index_copy", 24342 torch_opinfo_name="index_copy", 24343 # empty_strided 24344 skips=( 24345 # no _refs support for Tensor.__setitem__ 24346 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), 24347 ), 24348 ), 24349 PythonRefInfo( 24350 "_refs.index_add", 24351 torch_opinfo_name="index_add", 24352 # empty_strided 24353 skips=( 24354 # no _refs support for Tensor.__setitem__ 24355 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), 24356 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), 24357 ), 24358 ), 24359 PythonRefInfo( 24360 "_refs.index_fill", 24361 torch_opinfo_name="index_fill", 24362 # empty_strided 24363 skips=( 24364 # no _refs support for Tensor.__setitem__ 24365 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),) 24366 ), 24367 # 24368 # Test-related functions 24369 # 24370 PythonRefInfo( 24371 "_refs.allclose", 24372 torch_opinfo_name="allclose", 24373 ), 24374 # 24375 # Misc functions 24376 # 24377 PythonRefInfo( 24378 "_refs.stft", 24379 torch_opinfo_name="stft", 24380 skips=[ 24381 # RuntimeError: no _refs support for aten.pad 24382 DecorateInfo( 24383 unittest.expectedFailure, 'TestCommon', 'test_python_ref' 24384 ), 24385 ], 24386 ), 24387 PythonRefInfo( 24388 "_refs.istft", 24389 torch_opinfo_name="istft", 24390 skips=[ 24391 # RuntimeError: no _refs support for aten.unfold_backward 24392 DecorateInfo( 24393 unittest.expectedFailure, 'TestCommon', 'test_python_ref' 24394 ), 24395 DecorateInfo( 24396 unittest.skip("Expected: unfold_backward() got an unexpected keyword argument 'input_sizes'"), 24397 'TestCommon', 24398 'test_python_ref_executor', 24399 dtypes=(torch.complex64, torch.complex128), 24400 ), 24401 ], 24402 ), 24403 PythonRefInfo( 24404 "_refs.view_as_complex", 24405 torch_opinfo_name="view_as_complex", 24406 ), 24407] 24408python_ref_db += opinfo.definitions.python_ref_db 24409 24410# Common operator groupings 24411ops_and_refs = op_db + python_ref_db 24412unary_ufuncs = [op for op in ops_and_refs if isinstance(op, UnaryUfuncInfo)] 24413binary_ufuncs = [op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)] 24414binary_ufuncs_and_refs = tuple(op for op in ops_and_refs if isinstance(op, BinaryUfuncInfo)) 24415spectral_funcs = [op for op in ops_and_refs if isinstance(op, SpectralFuncInfo)] 24416sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse] 24417sparse_csr_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse_csr] 24418sparse_reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo) and op.supports_sparse] 24419shape_funcs = [op for op in ops_and_refs if isinstance(op, ShapeFuncInfo)] 24420reduction_ops = [op for op in ops_and_refs if isinstance(op, ReductionOpInfo)] 24421reference_filtered_ops = [op for op in reduction_ops if op.ref is not None] 24422reference_masked_ops = [op for op in reference_filtered_ops if op.name.startswith('masked.')] 24423sparse_masked_reduction_ops = [op for op in sparse_reduction_ops if op.name.startswith('masked.')] 24424 24425# TODO: review porting these to make_tensor 24426def index_variable(shape, max_indices, device=torch.device('cpu')): 24427 if not isinstance(shape, tuple): 24428 shape = (shape,) 24429 index = torch.rand(*shape, dtype=torch.double, device=device).mul_(max_indices).floor_().long() 24430 return index 24431 24432def gather_variable(shape, index_dim, max_indices, duplicate=False, device=torch.device('cpu')): 24433 assert len(shape) == 2 24434 assert index_dim < 2 24435 batch_dim = 1 - index_dim 24436 index = torch.zeros(*shape, dtype=torch.long, device=device) 24437 for i in range(shape[index_dim]): 24438 index.select(index_dim, i).copy_( 24439 torch.randperm(max_indices, device=device)[:shape[batch_dim]]) 24440 if duplicate: 24441 index.select(batch_dim, 0).copy_(index.select(batch_dim, 1)) 24442 return index 24443 24444def bernoulli_scalar(): 24445 return torch.tensor(0, dtype=torch.bool).bernoulli_() 24446 24447def mask_not_all_zeros(shape): 24448 assert len(shape) > 0 24449 while True: 24450 result = torch.randn(shape).gt(0) 24451 if result.sum() > 0: 24452 return result 24453 24454# Copied from functorch 24455def xfail(op_name, variant_name='', *, device_type=None, dtypes=None): 24456 return (op_name, variant_name, device_type, dtypes, True) 24457 24458 24459def skip(op_name, variant_name='', *, device_type=None, dtypes=None): 24460 return (op_name, variant_name, device_type, dtypes, False) 24461 24462 24463def skipOps(test_case_name, base_test_name, to_skip): 24464 all_opinfos = op_db 24465 for xfail in to_skip: 24466 op_name, variant_name, device_type, dtypes, expected_failure = xfail 24467 matching_opinfos = [o for o in all_opinfos 24468 if o.name == op_name and o.variant_test_name == variant_name] 24469 assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}" 24470 for op in matching_opinfos: 24471 decorators = list(op.decorators) 24472 if expected_failure: 24473 decorator = DecorateInfo(unittest.expectedFailure, 24474 test_case_name, base_test_name, 24475 device_type=device_type, dtypes=dtypes) 24476 decorators.append(decorator) 24477 else: 24478 decorator = DecorateInfo(unittest.skip("Skipped!"), 24479 test_case_name, base_test_name, 24480 device_type=device_type, dtypes=dtypes) 24481 decorators.append(decorator) 24482 op.decorators = tuple(decorators) 24483 24484 # This decorator doesn't modify fn in any way 24485 def wrapped(fn): 24486 return fn 24487 return wrapped 24488