1# mypy: ignore-errors 2 3import unittest 4from collections.abc import Sequence 5from functools import partial 6from typing import List 7 8import numpy as np 9 10import torch 11from torch.testing import make_tensor 12from torch.testing._internal.common_device_type import tol, toleranceOverride 13from torch.testing._internal.common_dtype import ( 14 all_types_and, 15 all_types_and_complex_and, 16 complex_types, 17 floating_and_complex_types_and, 18 floating_types_and, 19 integral_types, 20) 21from torch.testing._internal.opinfo.core import ( 22 DecorateInfo, 23 gradcheck_wrapper_masked_operation, 24 gradcheck_wrapper_masked_pointwise_operation, 25 M, 26 OpInfo, 27 ReductionOpInfo, 28 S, 29 sample_inputs_reduction, 30 SampleInput, 31) 32from torch.testing._internal.opinfo.utils import prod_numpy, reference_reduction_numpy 33 34 35# Used for log_softmax, softmax, softmin 36def sample_inputs_softmax_variant( 37 op_info, 38 device, 39 dtype, 40 requires_grad, 41 with_dtype=False, 42 use_zero_dimensions=True, 43 **kwargs, 44): 45 make_arg = partial( 46 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad 47 ) 48 cases = [ 49 ((S,), (0,)), 50 ((S, S), (0,)), 51 ((S, S), (1,)), 52 ((S, S), (-1,)), 53 ((S, M, S), (2,)), 54 *([((S, 0, 0), (-1,))] if use_zero_dimensions else []), 55 ] 56 kwargs = dict(dtype=torch.float64) if with_dtype else None 57 58 # PyTorch on XLA throws an error when passed with dim argument for 0d tensor. 59 # See https://github.com/pytorch/xla/issues/3061 for more details. 60 if torch.device(device).type != "xla": 61 cases.append(((), (0,))) 62 63 return ( 64 SampleInput(make_arg(shape), args=dim, kwargs=kwargs) for shape, dim in cases 65 ) 66 67 68def _generate_masked_op_mask(input_shape, device, **kwargs): 69 make_arg = partial( 70 make_tensor, dtype=torch.bool, device=device, requires_grad=False 71 ) 72 yield None 73 yield make_arg(input_shape) 74 if len(input_shape) > 2: 75 # broadcast last mask dimension: 76 yield make_arg(input_shape[:-1] + (1,)) 77 # broadcast middle mask dimension: 78 yield make_arg(input_shape[:1] + (1,) + input_shape[2:]) 79 # broadcast first mask dimension: 80 yield make_arg((1,) + input_shape[1:]) 81 # mask.ndim < input.ndim 82 yield make_arg(input_shape[1:]) 83 # mask.ndim == 1 84 yield make_arg(input_shape[-1:]) 85 # masks that require broadcasting of inputs (mask.ndim > 86 # input.ndim) will not be supported, however, we may 87 # reconsider this if there will be demand on this kind of 88 # degenerate cases. 89 90 91def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwargs): 92 """Sample inputs for masked reduction operators. 93 94 Masked reduction operator is a reduction operator with trailing 95 mask optional argument. A mask is a bool tensor with the same 96 shape as input or a shape that is broadcastable to input shape. 97 """ 98 kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims 99 100 for sample_input in sample_inputs_reduction( 101 op_info, device, dtype, requires_grad, **kwargs 102 ): 103 for mask in _generate_masked_op_mask( 104 sample_input.input.shape, device, **kwargs 105 ): 106 sample_input_args, sample_input_kwargs = sample_input.args, dict( 107 mask=mask, **sample_input.kwargs 108 ) 109 yield SampleInput( 110 sample_input.input.detach().requires_grad_(requires_grad), 111 args=sample_input_args, 112 kwargs=sample_input_kwargs, 113 ) 114 if ( 115 not requires_grad 116 and dtype.is_floating_point 117 and sample_input.input.ndim == 2 118 and mask is not None 119 and mask.shape == sample_input.input.shape 120 ): 121 for v in [torch.inf, -torch.inf, torch.nan]: 122 t = sample_input.input.detach() 123 t.diagonal(0, -2, -1).fill_(v) 124 yield SampleInput( 125 t.requires_grad_(requires_grad), 126 args=sample_input_args, 127 kwargs=sample_input_kwargs, 128 ) 129 130 131def sample_inputs_sparse_coo_masked_reduction( 132 op_info, device, dtype, requires_grad, **kwargs 133): 134 """Sample inputs for masked reduction operators that support inputs 135 with sparse coo layouts. 136 """ 137 if op_info.supports_sparse: 138 op_name = op_info.name.replace("masked.", "") 139 for sample_input in sample_inputs_masked_reduction( 140 op_info, device, dtype, requires_grad, **kwargs 141 ): 142 mask = sample_input.kwargs.get("mask") 143 if mask is not None: 144 sample_input_kwargs = sample_input.kwargs.copy() 145 sample_input_kwargs.update(mask=mask.to_sparse()) 146 yield SampleInput( 147 sample_input.input.to_sparse(), 148 args=sample_input.args, 149 kwargs=sample_input_kwargs, 150 ) 151 else: 152 if op_name in {"prod", "amax", "amin"}: 153 # FIXME: for now reductions with non-zero reduction identity and 154 # unspecified mask are not supported for sparse COO 155 # tensors, see torch.masked.prod implementation 156 # for details. 157 continue 158 yield SampleInput( 159 sample_input.input.to_sparse(), 160 args=sample_input.args, 161 kwargs=sample_input.kwargs, 162 ) 163 164 165def sample_inputs_sparse_csr_masked_reduction( 166 op_info, device, dtype, requires_grad, **kwargs 167): 168 """Sample inputs for masked reduction operators that support inputs 169 with sparse csr layouts. 170 """ 171 if op_info.supports_sparse_csr: 172 op_name = op_info.name.replace("masked.", "") 173 for sample_input in sample_inputs_masked_reduction( 174 op_info, device, dtype, requires_grad, **kwargs 175 ): 176 if not ( 177 sample_input.input.ndim == 2 and sample_input.kwargs.get("keepdim") 178 ): 179 # - sparse CSR tensors are always 2-D tensors 180 # - masked reduction on CSR tensors are defined only if keepdim is True. 181 continue 182 mask = sample_input.kwargs.get("mask") 183 if mask is not None: 184 sample_input_kwargs = sample_input.kwargs.copy() 185 sample_input_kwargs.update(mask=mask.to_sparse_csr()) 186 new_sample = SampleInput( 187 sample_input.input.to_sparse_csr(), 188 args=sample_input.args, 189 kwargs=sample_input_kwargs, 190 ) 191 else: 192 if op_name in ["prod", "amax", "amin", "mean"]: 193 # reductions with non-zero reduction identity and 194 # unspecified mask is not supported for sparse CSR 195 # tensors, see torch.masked.prod implementation 196 # for details. 197 continue 198 new_sample = SampleInput( 199 sample_input.input.to_sparse_csr(), 200 args=sample_input.args, 201 kwargs=sample_input.kwargs, 202 ) 203 yield new_sample 204 if sample_input.kwargs["dim"] == 0: 205 # Reductions of CSR tensors use different implementations for 206 # inner and/or outer dimensions. So, as a minimum of testing CSR 207 # implementations the following kwargs must be generated: 208 # dict(dim=0, keepdim=True) 209 # dict(dim=1, keepdim=True) 210 # dict(dim=(0, 1), keepdim=True) 211 # Here we generate the dim=1 case from the dim=0 case. 212 sample_input_kwargs = new_sample.kwargs.copy() 213 sample_input_kwargs.update(dim=1) 214 yield SampleInput( 215 new_sample.input.clone(), 216 args=sample_input.args, 217 kwargs=sample_input_kwargs, 218 ) 219 220 221def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs): 222 """Sample inputs for masked norm.""" 223 for ord in [2.0, 1, float("inf"), float("-inf"), 0]: 224 for sample_input in sample_inputs_masked_reduction( 225 op_info, device, dtype, requires_grad, **kwargs 226 ): 227 sample_input_args, sample_input_kwargs = ( 228 ord, 229 ) + sample_input.args, sample_input.kwargs.copy() 230 yield SampleInput( 231 sample_input.input.clone().requires_grad_(requires_grad), 232 args=sample_input_args, 233 kwargs=sample_input_kwargs, 234 ) 235 236 237def reference_masked_std_var( 238 numpy_fn, 239): 240 ref = reference_reduction_numpy(numpy_fn) 241 242 # Translate unbiased or correction arguments into ddof 243 def func( 244 input, 245 dim=None, 246 unbiased=None, 247 *, 248 correction=None, 249 **kwargs, 250 ): 251 ddof = 1 252 if unbiased is not None: 253 ddof = 1 if unbiased else 0 254 if correction is not None: 255 ddof = correction 256 257 if isinstance(dim, Sequence): 258 dim = tuple(dim) 259 260 return ref(input, dim, ddof=ddof, **kwargs) 261 262 return func 263 264 265def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs): 266 """Sample inputs for masked std/var.""" 267 kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims 268 from torch.testing._internal.common_methods_invocations import sample_inputs_std_var 269 270 def masked_samples(): 271 for sample_input in sample_inputs_std_var( 272 op_info, device, dtype, requires_grad, **kwargs 273 ): 274 if len(sample_input.args) and isinstance(sample_input.args[0], bool): 275 continue # masked.{std, var} doesn't support `.var(unbiased)` 276 277 for mask in _generate_masked_op_mask( 278 sample_input.input.shape, device, **kwargs 279 ): 280 sample_input_args, sample_input_kwargs = sample_input.args, dict( 281 mask=mask, **sample_input.kwargs 282 ) 283 yield SampleInput( 284 sample_input.input.detach().requires_grad_(requires_grad), 285 args=sample_input_args, 286 kwargs=sample_input_kwargs, 287 ) 288 if ( 289 not requires_grad 290 and dtype.is_floating_point 291 and sample_input.input.ndim == 2 292 and mask is not None 293 and mask.shape == sample_input.input.shape 294 ): 295 for v in [torch.inf, -torch.inf, torch.nan]: 296 t = sample_input.input.detach() 297 t.diagonal(0, -2, -1).fill_(v) 298 yield SampleInput( 299 t.requires_grad_(requires_grad), 300 args=sample_input_args, 301 kwargs=sample_input_kwargs, 302 ) 303 304 for sample_input in masked_samples(): 305 correction = sample_input.kwargs.get("correction") 306 if correction is None: 307 correction = int(sample_input.kwargs.get("unbiased", True)) 308 309 dim = sample_input.kwargs.get("dim", None) 310 311 if sample_input.kwargs.get("mask") is None: 312 orig_count = torch.masked.sum( 313 torch.ones(sample_input.input.shape, dtype=torch.int64), 314 dim, 315 keepdim=True, 316 ) 317 else: 318 inmask = torch.masked._input_mask( 319 sample_input.input, *sample_input.args, **sample_input.kwargs 320 ) 321 orig_count = torch.masked.sum( 322 inmask.new_ones(sample_input.input.shape, dtype=torch.int64), 323 dim, 324 keepdim=True, 325 mask=inmask, 326 ) 327 if orig_count.min() <= correction + 1: 328 # Skip samples that lead to nans in var computation 329 continue 330 331 yield sample_input 332 333 334def sample_inputs_masked_softmax( 335 op_info, device, dtype, requires_grad, with_dtype=False, **kwargs 336): 337 """Sample inputs for masked softmax, log_softmax, and softmin. 338 339 Masked normalization operator is a reduction operator with 340 trailing mask optional argument. A mask is a bool tensor with the 341 same shape as input or a shape that is broadcastable to input 342 shape. 343 """ 344 for sample_input in sample_inputs_softmax_variant( 345 op_info, device, dtype, requires_grad, with_dtype=with_dtype, **kwargs 346 ): 347 for mask in _generate_masked_op_mask( 348 sample_input.input.shape, device, **kwargs 349 ): 350 yield SampleInput( 351 sample_input.input.clone().requires_grad_(requires_grad), 352 *sample_input.args, 353 mask=mask, 354 **sample_input.kwargs, 355 ) 356 357 358def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs): 359 """Sample inputs for masked cumsum and cumprod.""" 360 inputs: List[SampleInput] = [] 361 for sample_input in sample_inputs_softmax_variant( 362 op_info, device, dtype, requires_grad, **kwargs 363 ): 364 for mask in _generate_masked_op_mask( 365 sample_input.input.shape, device, **kwargs 366 ): 367 if type(mask) != torch.Tensor: 368 continue 369 sample_input_args, sample_input_kwargs = sample_input.args, dict( 370 mask=mask, **sample_input.kwargs 371 ) 372 if "keepdim" in sample_input_kwargs: 373 sample_input_kwargs.pop("keepdim") 374 # dimension is required 375 if sample_input_args: 376 dim = sample_input.args[0] 377 else: 378 if "dim" not in sample_input_kwargs: 379 continue 380 dim = sample_input_kwargs.pop("dim") 381 sample_input_args = (dim,) 382 yield SampleInput( 383 sample_input.input.clone().requires_grad_(requires_grad), 384 *sample_input_args, 385 **sample_input_kwargs, 386 ) 387 388 389def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwargs): 390 """Sample inputs for masked logaddexp.""" 391 shapes = [(S,), (S, S), (S, M, S)] 392 input_mask_lists = [ 393 list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes 394 ] 395 other_mask_lists = [ 396 list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes 397 ] 398 399 make_arg = partial( 400 make_tensor, dtype=dtype, device=device, requires_grad=requires_grad 401 ) 402 for shape, input_masks, other_masks in zip( 403 shapes, input_mask_lists, other_mask_lists 404 ): 405 for input_mask, other_mask in zip(input_masks, other_masks): 406 yield SampleInput( 407 make_arg(shape), 408 make_arg(shape), 409 input_mask=input_mask, 410 other_mask=other_mask, 411 ) 412 413 414def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs): 415 """Sample inputs for masked normalize.""" 416 for ord in [2.0, 1, float("inf"), float("-inf"), 0]: 417 for sample_input in sample_inputs_softmax_variant( 418 op_info, device, dtype, requires_grad, use_zero_dimensions=False, **kwargs 419 ): 420 yield SampleInput( 421 sample_input.input.clone().requires_grad_(requires_grad), 422 ord, 423 *sample_input.args, 424 **sample_input.kwargs, 425 ) 426 427 428op_db: List[OpInfo] = [ 429 ReductionOpInfo( 430 "masked.sum", 431 ref=reference_reduction_numpy(np.sum), 432 method_variant=None, 433 identity=0, 434 nan_policy="propagate", 435 supports_out=False, 436 supports_forward_ad=True, 437 supports_fwgrad_bwgrad=True, 438 supports_sparse=True, 439 supports_sparse_csr=True, 440 promotes_int_to_int64=True, 441 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 442 skips=( 443 DecorateInfo( 444 unittest.skip("Failing on some jobs"), 445 "TestReductions", 446 "test_reference_masked", 447 dtypes=(torch.bool, torch.int8, torch.int16, torch.int32), 448 ), 449 DecorateInfo( 450 unittest.expectedFailure, 451 "TestNormalizeOperators", 452 "test_normalize_operator_exhaustive", 453 ), 454 # FIXME: sum reduces all dimensions when dim=[] 455 DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), 456 DecorateInfo( 457 unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" 458 ), 459 # RuntimeError: undefined value tensor 460 DecorateInfo( 461 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 462 ), 463 ), 464 decorators=[ 465 DecorateInfo( 466 toleranceOverride( 467 { 468 torch.bfloat16: tol(atol=1e-03, rtol=5e-2), 469 torch.float16: tol(atol=1e-03, rtol=5e-3), 470 } 471 ), 472 "TestReductions", 473 "test_reference_masked", 474 ), 475 DecorateInfo( 476 toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-03)}), 477 "TestReductions", 478 "test_ref_small_input", 479 ), 480 DecorateInfo( 481 toleranceOverride( 482 { 483 torch.bfloat16: tol(atol=0.1, rtol=0.1), 484 torch.float16: tol(atol=5e-3, rtol=5e-3), 485 } 486 ), 487 "TestMasked", 488 "test_mask_layout", 489 ), 490 ], 491 sample_inputs_func=sample_inputs_masked_reduction, 492 sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction, 493 sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, 494 ), 495 ReductionOpInfo( 496 "masked.prod", 497 ref=prod_numpy, 498 method_variant=None, 499 identity=1, 500 nan_policy="propagate", 501 # https://github.com/pytorch/pytorch/issues/80411 502 gradcheck_fast_mode=True, 503 supports_out=False, 504 supports_forward_ad=True, 505 supports_fwgrad_bwgrad=True, 506 supports_sparse=True, 507 supports_sparse_csr=True, 508 promotes_int_to_int64=True, 509 dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), 510 skips=( 511 DecorateInfo( 512 unittest.expectedFailure, 513 "TestNormalizeOperators", 514 "test_normalize_operator_exhaustive", 515 ), 516 DecorateInfo( 517 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 518 ), 519 DecorateInfo( 520 unittest.skip("Failing on some jobs"), 521 "TestReductions", 522 "test_reference_masked", 523 dtypes=(torch.bool, torch.int8, torch.int16, torch.int32), 524 ), 525 DecorateInfo( 526 "TestReductions", 527 "test_ref_small_input", 528 dtypes=(torch.int8, torch.int16, torch.int32), 529 ), 530 # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs) 531 DecorateInfo( 532 unittest.skip("Skipped!"), 533 "TestMasked", 534 "test_mask_layout", 535 device_type="cuda", 536 dtypes=(torch.bool, *integral_types(), *complex_types()), 537 ), 538 ), 539 decorators=[ 540 DecorateInfo( 541 toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-02)}), 542 "TestReductions", 543 "test_reference_masked", 544 ), 545 DecorateInfo( 546 toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}), 547 "TestReductions", 548 "test_ref_duplicate_values", 549 ), 550 DecorateInfo( 551 toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}), 552 "TestReductions", 553 "test_ref_small_input", 554 ), 555 DecorateInfo( 556 toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1.5e-03)}), 557 "TestMasked", 558 "test_mask_layout", 559 device_type="cpu", 560 ), 561 DecorateInfo( 562 toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)}), 563 "TestOperators", 564 "test_jvp", 565 device_type="cuda", 566 ), 567 ], 568 sample_inputs_func=sample_inputs_masked_reduction, 569 sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction, 570 sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, 571 ), 572 OpInfo( 573 "masked.cumsum", 574 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 575 method_variant=None, 576 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 577 gradcheck_fast_mode=True, 578 supports_out=False, 579 supports_forward_ad=True, 580 supports_fwgrad_bwgrad=True, 581 skips=( 582 DecorateInfo( 583 unittest.expectedFailure, 584 "TestNormalizeOperators", 585 "test_normalize_operator_exhaustive", 586 ), 587 # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults 588 DecorateInfo( 589 unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" 590 ), 591 ), 592 # Can reuse the same inputs; dim is required in both 593 sample_inputs_func=sample_inputs_masked_cumops, 594 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 595 ), 596 OpInfo( 597 "masked.cumprod", 598 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 599 method_variant=None, 600 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 601 gradcheck_fast_mode=True, 602 supports_out=False, 603 supports_forward_ad=True, 604 supports_fwgrad_bwgrad=True, 605 skips=( 606 # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults 607 DecorateInfo( 608 unittest.expectedFailure, 609 "TestNormalizeOperators", 610 "test_normalize_operator_exhaustive", 611 ), 612 # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults 613 DecorateInfo( 614 unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" 615 ), 616 DecorateInfo( 617 toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}), 618 "TestCompositeCompliance", 619 "test_backward", 620 device_type="cuda", 621 ), 622 DecorateInfo( 623 toleranceOverride({torch.float16: tol(atol=1e-2, rtol=2.6e-3)}), 624 "TestInductorOpInfo", 625 "test_comprehensive", 626 device_type="cuda", 627 ), 628 ), 629 # Can reuse the same inputs; dim is required in both 630 sample_inputs_func=sample_inputs_masked_cumops, 631 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 632 ), 633 ReductionOpInfo( 634 "masked.amax", 635 nan_policy="propagate", 636 supports_out=False, 637 dtypes=all_types_and(torch.float16, torch.bfloat16), 638 supports_sparse=True, 639 supports_forward_ad=True, 640 supports_fwgrad_bwgrad=True, 641 supports_sparse_csr=True, 642 ref=reference_reduction_numpy(np.amax), 643 skips=( 644 DecorateInfo( 645 unittest.expectedFailure, 646 "TestNormalizeOperators", 647 "test_normalize_operator_exhaustive", 648 ), 649 # FIXME: amax reduces all dimensions when dim=[] 650 DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), 651 DecorateInfo( 652 unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" 653 ), 654 # RuntimeError: Unknown builtin op: aten::iinfo 655 DecorateInfo( 656 unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" 657 ), 658 # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs) 659 # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs) 660 DecorateInfo( 661 unittest.skip("Skipped!"), 662 "TestMasked", 663 "test_mask_layout", 664 dtypes=(torch.bool, *integral_types(), *complex_types()), 665 ), 666 ), 667 sample_inputs_func=sample_inputs_masked_reduction, 668 sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction, 669 sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, 670 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 671 ), 672 ReductionOpInfo( 673 "masked.amin", 674 nan_policy="propagate", 675 supports_out=False, 676 supports_forward_ad=True, 677 supports_fwgrad_bwgrad=True, 678 dtypes=all_types_and(torch.float16, torch.bfloat16), 679 supports_sparse=True, 680 supports_sparse_csr=True, 681 ref=reference_reduction_numpy(np.amin), 682 skips=( 683 DecorateInfo( 684 unittest.expectedFailure, 685 "TestNormalizeOperators", 686 "test_normalize_operator_exhaustive", 687 ), 688 # FIXME: amax reduces all dimensions when dim=[] 689 DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), 690 DecorateInfo( 691 unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" 692 ), 693 # RuntimeError: Unknown builtin op: aten::iinfo 694 DecorateInfo( 695 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 696 ), 697 # FIXME: "cuda_scatter_gather_base_kernel_func" not implemented for ... (used for sparse_coo inputs) 698 # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs) 699 DecorateInfo( 700 unittest.skip("Skipped!"), 701 "TestMasked", 702 "test_mask_layout", 703 dtypes=(torch.bool, *integral_types(), *complex_types()), 704 ), 705 ), 706 sample_inputs_func=sample_inputs_masked_reduction, 707 sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction, 708 sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, 709 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 710 ), 711 ReductionOpInfo( 712 "masked.argmax", 713 supports_out=False, 714 supports_multiple_dims=False, 715 supports_autograd=False, 716 dtypes=all_types_and(torch.float16, torch.bfloat16), 717 ref=reference_reduction_numpy(np.argmax, supports_keepdims=False), 718 skips=( 719 DecorateInfo( 720 unittest.expectedFailure, 721 "TestNormalizeOperators", 722 "test_normalize_operator_exhaustive", 723 ), 724 # initial is not a keyword for argmax 725 DecorateInfo( 726 unittest.expectedFailure, "TestReductions", "test_reference_masked" 727 ), 728 # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults 729 DecorateInfo( 730 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 731 ), 732 ), 733 sample_inputs_func=sample_inputs_masked_reduction, 734 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 735 ), 736 ReductionOpInfo( 737 "masked.argmin", 738 supports_out=False, 739 supports_multiple_dims=False, 740 supports_autograd=False, 741 dtypes=all_types_and(torch.float16, torch.bfloat16), 742 ref=reference_reduction_numpy(np.argmin, supports_keepdims=False), 743 skips=( 744 DecorateInfo( 745 unittest.expectedFailure, 746 "TestNormalizeOperators", 747 "test_normalize_operator_exhaustive", 748 ), 749 # initial is not a keyword for argmin 750 DecorateInfo( 751 unittest.expectedFailure, "TestReductions", "test_reference_masked" 752 ), 753 # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults 754 DecorateInfo( 755 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 756 ), 757 ), 758 sample_inputs_func=sample_inputs_masked_reduction, 759 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 760 ), 761 ReductionOpInfo( 762 "masked.mean", 763 ref=reference_reduction_numpy(np.mean) 764 if np.lib.NumpyVersion(np.__version__) >= "1.20.2" 765 else None, 766 method_variant=None, 767 nan_policy="propagate", 768 supports_out=False, 769 supports_sparse_csr=True, 770 supports_forward_ad=True, 771 supports_fwgrad_bwgrad=True, 772 promotes_int_to_float=True, 773 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), 774 skips=( 775 DecorateInfo( 776 unittest.expectedFailure, 777 "TestReductions", 778 "test_ref_duplicate_values", 779 dtypes=(torch.bool,), 780 ), 781 DecorateInfo( 782 unittest.expectedFailure, 783 "TestReductions", 784 "test_reference_masked", 785 dtypes=(torch.bool,), 786 ), 787 DecorateInfo( 788 unittest.expectedFailure, 789 "TestReductions", 790 "test_ref_small_input", 791 dtypes=(torch.bool,), 792 ), 793 DecorateInfo( 794 unittest.expectedFailure, 795 "TestNormalizeOperators", 796 "test_normalize_operator_exhaustive", 797 ), 798 # FIXME: sum reduces all dimensions when dim=[] 799 DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), 800 DecorateInfo( 801 unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" 802 ), 803 # RuntimeError: undefined value tensor 804 DecorateInfo( 805 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 806 ), 807 # FIXME: "_segment_reduce_lengths_cpu/cuda" not implemented for ... (used for sparse_csr inputs) 808 DecorateInfo( 809 unittest.skip("Skipped!"), 810 "TestMasked", 811 "test_mask_layout", 812 dtypes=(torch.bool, *integral_types(), *complex_types()), 813 ), 814 ), 815 decorators=[ 816 DecorateInfo( 817 toleranceOverride( 818 { 819 torch.bfloat16: tol(atol=1e-03, rtol=0.05), 820 torch.float16: tol(atol=1e-03, rtol=1e-03), 821 } 822 ), 823 "TestReductions", 824 "test_reference_masked", 825 ), 826 DecorateInfo( 827 toleranceOverride({torch.float16: tol(atol=1e-03, rtol=1e-03)}), 828 "TestReductions", 829 "test_ref_small_input", 830 ), 831 DecorateInfo( 832 toleranceOverride({torch.float16: tol(atol=1e-03, rtol=2e-03)}), 833 "TestSparseCompressed", 834 "test_consistency", 835 device_type="cuda", 836 ), 837 ], 838 sample_inputs_func=sample_inputs_masked_reduction, 839 sample_inputs_sparse_csr_func=sample_inputs_sparse_csr_masked_reduction, 840 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 841 ), 842 OpInfo( 843 "masked.median", 844 dtypes=floating_types_and(torch.bfloat16, torch.float16), 845 method_variant=None, 846 supports_out=False, 847 supports_forward_ad=True, 848 supports_fwgrad_bwgrad=True, 849 skips=( 850 DecorateInfo( 851 unittest.expectedFailure, 852 "TestNormalizeOperators", 853 "test_normalize_operator_exhaustive", 854 ), 855 # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults 856 DecorateInfo( 857 unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" 858 ), 859 ), 860 sample_inputs_func=partial( 861 sample_inputs_masked_softmax, use_zero_dimensions=False 862 ), 863 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 864 ), 865 ReductionOpInfo( 866 "masked.norm", 867 identity=0, 868 method_variant=None, 869 nan_policy="propagate", 870 supports_out=False, 871 promotes_int_to_float=True, 872 dtypes=floating_types_and(torch.float16, torch.bfloat16), 873 skips=( 874 DecorateInfo( 875 unittest.expectedFailure, 876 "TestNormalizeOperators", 877 "test_normalize_operator_exhaustive", 878 ), 879 # FIXME: sum reduces all dimensions when dim=[] 880 DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), 881 DecorateInfo( 882 unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" 883 ), 884 # torch.jit.frontend.NotSupportedError: Compiled functions 885 # can't take variable number of arguments or use 886 # keyword-only arguments with defaults 887 DecorateInfo( 888 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 889 ), 890 ), 891 supports_forward_ad=True, 892 supports_fwgrad_bwgrad=True, 893 sample_inputs_func=sample_inputs_masked_norm, 894 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 895 ), 896 ReductionOpInfo( 897 "masked.var", 898 ref=reference_masked_std_var(np.var) 899 if np.lib.NumpyVersion(np.__version__) >= "1.20.2" 900 else None, 901 method_variant=None, 902 nan_policy="propagate", 903 supports_out=False, 904 supports_forward_ad=True, 905 supports_fwgrad_bwgrad=True, 906 # See https://github.com/pytorch/pytorch/pull/78358 907 check_batched_forward_grad=False, 908 promotes_int_to_float=True, 909 dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), 910 skips=( 911 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 912 DecorateInfo( 913 unittest.skip("Skipped!"), 914 "TestSchemaCheckModeOpInfo", 915 "test_schema_correctness", 916 dtypes=(torch.complex64, torch.complex128), 917 ), 918 DecorateInfo( 919 unittest.expectedFailure, 920 "TestNormalizeOperators", 921 "test_normalize_operator_exhaustive", 922 ), 923 # FIXME: sum reduces all dimensions when dim=[] 924 DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), 925 DecorateInfo( 926 unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" 927 ), 928 # RuntimeError: undefined value tensor 929 DecorateInfo( 930 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 931 ), 932 ), 933 decorators=[ 934 DecorateInfo( 935 toleranceOverride( 936 { 937 torch.float16: tol(atol=1e-02, rtol=1e-02), 938 torch.bfloat16: tol(atol=1e-03, rtol=1e-03), 939 } 940 ), 941 "TestReductions", 942 "test_reference_masked", 943 ), 944 DecorateInfo( 945 toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), 946 "TestReductions", 947 "test_ref_small_input", 948 ), 949 DecorateInfo( 950 toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), 951 "TestMasked", 952 "test_reference_masked", 953 ), 954 DecorateInfo( 955 toleranceOverride( 956 { 957 torch.float16: tol(atol=1e-02, rtol=1e-02), 958 torch.bfloat16: tol(atol=1e-03, rtol=1e-03), 959 } 960 ), 961 "TestMasked", 962 "test_reference_masked", 963 ), 964 DecorateInfo( 965 toleranceOverride( 966 { 967 torch.float16: tol(atol=4e-5, rtol=2e-2), 968 } 969 ), 970 "TestInductorOpInfo", 971 "test_comprehensive", 972 device_type="cuda", 973 ), 974 ], 975 sample_inputs_func=sample_inputs_masked_std_var, 976 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 977 check_batched_grad=True, 978 ), 979 ReductionOpInfo( 980 "masked.std", 981 ref=reference_masked_std_var(np.std) 982 if np.lib.NumpyVersion(np.__version__) >= "1.20.2" 983 else None, 984 method_variant=None, 985 nan_policy="propagate", 986 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 987 gradcheck_fast_mode=True, 988 supports_out=False, 989 supports_forward_ad=True, 990 supports_fwgrad_bwgrad=True, 991 # See https://github.com/pytorch/pytorch/pull/78358 992 check_batched_forward_grad=False, 993 promotes_int_to_float=True, 994 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 995 skips=( 996 # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 997 DecorateInfo( 998 unittest.skip("Skipped!"), 999 "TestSchemaCheckModeOpInfo", 1000 "test_schema_correctness", 1001 dtypes=(torch.complex64, torch.complex128), 1002 ), 1003 DecorateInfo( 1004 unittest.expectedFailure, 1005 "TestNormalizeOperators", 1006 "test_normalize_operator_exhaustive", 1007 ), 1008 # FIXME: sum reduces all dimensions when dim=[] 1009 DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), 1010 DecorateInfo( 1011 unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" 1012 ), 1013 # RuntimeError: undefined value tensor 1014 DecorateInfo( 1015 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 1016 ), 1017 ), 1018 decorators=[ 1019 DecorateInfo( 1020 toleranceOverride( 1021 { 1022 torch.bfloat16: tol(atol=1e-02, rtol=1e-02), 1023 torch.float16: tol(atol=1e-02, rtol=1e-02), 1024 } 1025 ), 1026 "TestReductions", 1027 "test_reference_masked", 1028 ), 1029 DecorateInfo( 1030 toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}), 1031 "TestReductions", 1032 "test_ref_small_input", 1033 ), 1034 DecorateInfo( 1035 toleranceOverride( 1036 { 1037 torch.float16: tol(atol=1e-02, rtol=1e-02), 1038 torch.bfloat16: tol(atol=5e-03, rtol=5e-04), 1039 } 1040 ), 1041 "TestMasked", 1042 "test_reference_masked", 1043 ), 1044 ], 1045 sample_inputs_func=sample_inputs_masked_std_var, 1046 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 1047 check_batched_grad=True, 1048 ), 1049 OpInfo( 1050 "masked.softmax", 1051 method_variant=None, 1052 dtypes=floating_types_and(torch.half, torch.bfloat16), 1053 sample_inputs_func=sample_inputs_masked_softmax, 1054 skips=( 1055 DecorateInfo( 1056 unittest.expectedFailure, 1057 "TestNormalizeOperators", 1058 "test_normalize_operator_exhaustive", 1059 ), 1060 DecorateInfo( 1061 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 1062 ), 1063 ), 1064 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 1065 supports_forward_ad=True, 1066 supports_fwgrad_bwgrad=True, 1067 supports_out=False, 1068 ), 1069 OpInfo( 1070 "masked.log_softmax", 1071 method_variant=None, 1072 dtypes=floating_types_and(torch.half, torch.bfloat16), 1073 sample_inputs_func=sample_inputs_masked_softmax, 1074 skips=( 1075 DecorateInfo( 1076 unittest.expectedFailure, 1077 "TestNormalizeOperators", 1078 "test_normalize_operator_exhaustive", 1079 ), 1080 DecorateInfo( 1081 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 1082 ), 1083 ), 1084 decorators=[ 1085 DecorateInfo( 1086 toleranceOverride({torch.bfloat16: tol(atol=1e-02, rtol=1e-02)}), 1087 "TestMasked", 1088 "test_reference_masked", 1089 ), 1090 ], 1091 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 1092 supports_forward_ad=True, 1093 supports_fwgrad_bwgrad=True, 1094 supports_out=False, 1095 ), 1096 OpInfo( 1097 "masked.softmin", 1098 method_variant=None, 1099 dtypes=floating_types_and(torch.half, torch.bfloat16), 1100 sample_inputs_func=sample_inputs_masked_softmax, 1101 skips=( 1102 DecorateInfo( 1103 unittest.expectedFailure, 1104 "TestNormalizeOperators", 1105 "test_normalize_operator_exhaustive", 1106 ), 1107 DecorateInfo( 1108 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 1109 ), 1110 # FIXME: 1111 # Mismatched elements: 2 / 2 (100.0%) 1112 # Greatest absolute difference: nan at index (0,) (up to 0.0001 allowed) 1113 # Greatest relative difference: nan at index (0,) (up to 0.0001 allowed 1114 DecorateInfo( 1115 unittest.skip("Skipped!"), 1116 "TestOperators", 1117 "test_vmapvjpvjp", 1118 device_type="cpu", 1119 ), 1120 ), 1121 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 1122 supports_forward_ad=True, 1123 supports_fwgrad_bwgrad=True, 1124 supports_out=False, 1125 ), 1126 OpInfo( 1127 "masked.normalize", 1128 method_variant=None, 1129 dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), 1130 sample_inputs_func=sample_inputs_masked_normalize, 1131 decorators=[ 1132 DecorateInfo( 1133 toleranceOverride({torch.float16: tol(atol=2e-5, rtol=6e-3)}), 1134 "TestInductorOpInfo", 1135 "test_comprehensive", 1136 device_type="cuda", 1137 ), 1138 ], 1139 skips=( 1140 DecorateInfo( 1141 unittest.expectedFailure, 1142 "TestNormalizeOperators", 1143 "test_normalize_operator_exhaustive", 1144 ), 1145 DecorateInfo( 1146 unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" 1147 ), 1148 ), 1149 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 1150 # Runs very slowly on slow gradcheck - alternatively reduce input sizes 1151 gradcheck_fast_mode=True, 1152 supports_forward_ad=True, 1153 supports_fwgrad_bwgrad=True, 1154 supports_out=False, 1155 ), 1156 OpInfo( 1157 "masked.logaddexp", 1158 dtypes=floating_types_and(torch.float16, torch.bfloat16), 1159 supports_out=False, 1160 supports_forward_ad=True, 1161 supports_fwgrad_bwgrad=True, 1162 check_batched_forward_grad=False, 1163 skips=( 1164 DecorateInfo( 1165 unittest.expectedFailure, 1166 "TestNormalizeOperators", 1167 "test_normalize_operator_exhaustive", 1168 ), 1169 # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults 1170 DecorateInfo( 1171 unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" 1172 ), 1173 DecorateInfo( 1174 unittest.skip("Skipped!"), "TestFwdGradients", "test_fn_gradgrad" 1175 ), 1176 DecorateInfo( 1177 unittest.skip("Skipped!"), "TestBwdGradients", "test_fn_gradgrad" 1178 ), 1179 ), 1180 sample_inputs_func=sample_inputs_masked_logaddexp, 1181 gradcheck_wrapper=gradcheck_wrapper_masked_pointwise_operation, 1182 ), 1183 ReductionOpInfo( 1184 "masked.logsumexp", 1185 dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), 1186 method_variant=None, 1187 nan_policy="propagate", 1188 supports_out=False, 1189 supports_forward_ad=True, 1190 supports_fwgrad_bwgrad=True, 1191 skips=( 1192 DecorateInfo( 1193 unittest.skip("Skipped!"), 1194 "TestNormalizeOperators", 1195 "test_normalize_operator_exhaustive", 1196 ), 1197 # FIXME: reduces all dimensions when dim=[] 1198 DecorateInfo(unittest.skip("Skipped!"), "TestReductions", "test_dim_empty"), 1199 DecorateInfo( 1200 unittest.skip("Skipped!"), "TestReductions", "test_dim_empty_keepdim" 1201 ), 1202 # Identity can't be -torch.inf without overflow 1203 DecorateInfo( 1204 unittest.skip("Skipped!"), 1205 "TestReductions", 1206 "test_empty_tensor_empty_slice", 1207 ), 1208 # NotSupportedError: Compiled functions can't ... use keyword-only arguments with defaults 1209 DecorateInfo( 1210 unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" 1211 ), 1212 # all the values are the same except for -inf vs nan 1213 DecorateInfo(unittest.skip("Skipped!"), "TestDecomp", "test_comprehensive"), 1214 # FIXME: 1215 # Mismatched elements: 2 / 12 (16.7%) 1216 # Greatest absolute difference: 9223372034707292160 at index (0, 0, 0, 0) 1217 # Greatest relative difference: 0.0 at index (0, 0, 0, 1) 1218 DecorateInfo( 1219 unittest.skip("Skipped!"), 1220 "TestInductorOpInfo", 1221 "test_comprehensive", 1222 device_type="cpu", 1223 ), 1224 ), 1225 sample_inputs_func=sample_inputs_masked_reduction, 1226 gradcheck_wrapper=gradcheck_wrapper_masked_operation, 1227 ), 1228] 1229