1# Owner(s): ["module: decompositions"] 2 3import functools 4import itertools 5import re 6import unittest 7from collections import defaultdict 8from functools import partial 9 10import torch._inductor.decomposition 11import torch.autograd 12from torch import Tensor 13from torch._decomp import core_aten_decompositions, decomposition_table 14from torch._dispatch.python import enable_python_dispatcher 15from torch._ops import DispatchKey 16from torch.testing import make_tensor 17from torch.testing._internal.common_cuda import tf32_off 18from torch.testing._internal.common_device_type import ( 19 instantiate_device_type_tests, 20 onlyCPU, 21 onlyCUDA, 22 onlyNativeDeviceTypes, 23 ops, 24) 25from torch.testing._internal.common_methods_invocations import ( 26 op_db, 27 skip, 28 skipOps, 29 xfail, 30) 31from torch.testing._internal.common_modules import module_db, modules 32from torch.testing._internal.common_utils import ( 33 is_iterable_of_tensors, 34 run_tests, 35 skipIfCrossRef, 36 skipIfTorchDynamo, 37 suppress_warnings, 38 TEST_WITH_ASAN, 39 TEST_WITH_SLOW, 40 TestCase, 41 unMarkDynamoStrictTest, 42) 43from torch.utils import _pytree as pytree 44from torch.utils._python_dispatch import TorchDispatchMode 45from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten 46 47 48aten = torch.ops.aten 49 50 51# TODO: this isn't going to work with non-aten namespaces 52def overload_to_aten_name(op): 53 return op._schema.name.split("::")[1] 54 55 56# All operators that can have decomp tests 57decomposition_names = { 58 overload_to_aten_name(k) 59 for k in decomposition_table 60 if isinstance(k, torch._ops.OpOverload) 61} 62core_decomposition_names = { 63 overload_to_aten_name(k) 64 for k in core_aten_decompositions() 65 if isinstance(k, torch._ops.OpOverload) 66} 67_decomp_test_ops = [ 68 op 69 for op in op_db 70 if op.aten_name in decomposition_names 71 or op.aten_backward_name in decomposition_names 72] 73_decomp_test_ops_core_autograd = [ 74 op 75 for op in op_db 76 if op.aten_name in core_decomposition_names and op.supports_autograd 77] 78_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name] 79 80 81def diff_arg(arg, requires_grad=True): 82 def is_differentiable_arg(arg): 83 if requires_grad: 84 return arg.requires_grad 85 else: 86 return arg.is_floating_point() or arg.is_complex() 87 88 if is_iterable_of_tensors(arg): 89 if all(is_differentiable_arg(a) for a in arg): 90 return True 91 if all(not is_differentiable_arg(a) for a in arg): 92 return False 93 raise RuntimeError("NYI: The test runner can't handle this") 94 return isinstance(arg, Tensor) and is_differentiable_arg(arg) 95 96 97# Version of autograd.grad with some differences: 98# - pytree inputs is allowed (but leaves of the pytree have to all 99# be tensors) 100# - if an input is not used as part of derivatives, we will return a 101# zero-filled tensor for the result 102def _autograd_grad( 103 outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True 104): 105 inputs, inputs_spec = tree_flatten(inputs) 106 diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) 107 if grad_outputs is None: 108 diff_outputs = tuple(out for out in outputs if out.requires_grad) 109 else: 110 diff_grad_outputs = [ 111 (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad 112 ] 113 if len(diff_grad_outputs) == 0: 114 diff_outputs, grad_outputs = (), () 115 else: 116 diff_outputs, grad_outputs = zip(*diff_grad_outputs) 117 grad_inputs = torch.autograd.grad( 118 diff_outputs, 119 diff_inputs, 120 grad_outputs, 121 retain_graph=retain_graph, 122 create_graph=create_graph, 123 allow_unused=True, 124 ) 125 result = [] 126 grad_inputs_iter = iter(grad_inputs) 127 for inp in inputs: 128 if inp.requires_grad: 129 grad_input = next(grad_inputs_iter) 130 if grad_input is None: 131 result.append(torch.zeros_like(inp)) 132 else: 133 result.append(grad_input) 134 else: 135 result.append(torch.zeros_like(inp)) 136 return tree_unflatten(result, inputs_spec) 137 138 139def _as_tuple(val): 140 if isinstance(val, tuple): 141 return val 142 return (val,) 143 144 145def ref_vjp_no_create(f, *primals): 146 result = f(*primals) 147 148 def wrapped(cotangents): 149 return _autograd_grad( 150 _as_tuple(result), 151 primals, 152 _as_tuple(cotangents), 153 create_graph=False, 154 retain_graph=True, 155 ) 156 157 return result, wrapped 158 159 160dtype_precisions = { 161 torch.float16: (0.001, 1e-5), 162 torch.bfloat16: (0.016, 1e-4), 163 torch.float32: (1.3e-6, 1e-5), 164 torch.float64: (1e-7, 1e-7), 165 torch.complex32: (0.001, 1e-5), 166 torch.complex64: (1.3e-6, 1e-5), 167 torch.complex128: (1e-7, 1e-7), 168} 169# Returns the "default" rtol and atol for comparing scalars or 170# tensors of the given dtypes. 171 172 173def _getDefaultRtolAndAtol(dtype0, dtype1): 174 rtol = max( 175 dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0] 176 ) 177 atol = max( 178 dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1] 179 ) 180 return rtol, atol 181 182 183def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs): 184 assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" 185 if orig.numel() == 0 or decomp.numel() == 0: 186 assert orig.numel() == decomp.numel() 187 return 188 assert orig.shape == decomp.shape, f"{i} Operation: {op}" 189 tol_table = { 190 (torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5, 191 (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5, 192 (torch.float16, torch.ops.aten.native_layer_norm_backward.default): 1e-3, 193 (torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2, 194 (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5, 195 (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5, 196 (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, 197 (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, 198 (torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, 199 (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, 200 (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4, 201 (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4, 202 (torch.bfloat16, torch.ops.aten.var_mean.correction): 5e-7, 203 (torch.float16, torch.ops.aten.var_mean.correction): 5e-7, 204 (torch.bfloat16, torch.ops.aten.var_mean.dim): 5e-7, 205 (torch.float16, torch.ops.aten.var_mean.dim): 5e-7, 206 (torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2, 207 (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1, 208 (torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2, 209 (torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1, 210 (torch.float16, torch.ops.aten.hardswish.default): 2e-7, 211 (torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7, 212 (torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2, 213 (torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 5e-2, 214 (torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, 215 (torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, 216 (torch.float16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3, 217 (torch.bfloat16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3, 218 (torch.float16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3, 219 (torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3, 220 (torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3, 221 (torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2, 222 # see https://github.com/pytorch/pytorch/pull/96264 223 (torch.float16, torch.ops.aten.mv.default): 1e-5, 224 (torch.bfloat16, torch.ops.aten.mv.default): 1e-5, 225 (torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5, 226 (torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7, 227 } 228 if ref.is_floating_point(): 229 orig_diff = (orig - ref).abs().max() 230 decomp_diff = (decomp - ref).abs().max() 231 atol = tol_table.get((test_dtype, op), 1e-7) 232 if decomp_diff > orig_diff + atol: 233 raise RuntimeError( 234 f"Difference from float64 is larger with decomposition {op.__name__}" 235 f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n" 236 f"atol = {atol}\n" 237 f"args = {args}\n" 238 f"kwargs = {kwargs}" 239 ) 240 else: 241 test_case.assertEqual( 242 orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}" 243 ) 244 245 246def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs): 247 test_case.assertEqual( 248 orig.dtype, 249 decomp.dtype, 250 f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}", 251 ) 252 # Before adding an entry to this table, make sure your decomposition is right :) 253 tol_table = { 254 # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161 255 (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3), 256 (torch.float32, torch.ops.aten.native_layer_norm_backward.default): ( 257 1e-3, 258 1e-3, 259 ), 260 (torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6), 261 # This exceeds default tolerances only on CPU, on CUDA it's fine 262 (torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5), 263 # Exceeds tolerances on CUDA, likely due to fma 264 (torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5), 265 (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5), 266 (torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4), 267 (torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4), 268 # The decomposition is TOO correct. It computes everything in int64, so sometimes 269 # there's an off-by-one error. See 270 # https://github.com/pytorch/pytorch/issues/81996 271 # https://github.com/pytorch/pytorch/issues/82230 272 (torch.int8, torch.ops.aten.linspace.default): (0, 1), 273 (torch.uint8, torch.ops.aten.linspace.default): (0, 1), 274 (torch.int16, torch.ops.aten.linspace.default): (0, 1), 275 (torch.int32, torch.ops.aten.linspace.default): (0, 1), 276 (torch.int64, torch.ops.aten.linspace.default): (0, 1), 277 (torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), 278 (torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), 279 (torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), 280 (torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), 281 (torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), 282 (torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), 283 (torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), 284 (torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), 285 (torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), 286 (torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), 287 (torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), 288 (torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), 289 (torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), 290 (torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), 291 (torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), 292 } 293 if (decomp.dtype, op) in tol_table: 294 rtol, atol = tol_table[(decomp.dtype, op)] 295 else: 296 rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype) 297 test_case.assertEqual( 298 orig, 299 decomp, 300 rtol=rtol, 301 atol=atol, 302 msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}", 303 ) 304 305 306# Given f, returns an f' such that: 307# - f' takes only positional arguments 308# - All arguments to f' are floating-point Tensors 309# - All outputs of f' are floating-point Tensors 310def normalize_op_input_output2( 311 f, args, kwargs, output_process_fn_grad=None, requires_grad=True 312): 313 flat_args, args_spec = tree_flatten(args) 314 diff_argnums = tuple( 315 i 316 for i, arg in enumerate(flat_args) 317 if diff_arg(arg, requires_grad=requires_grad) 318 ) 319 assert len(diff_argnums) > 0 320 primals = tuple(flat_args[i] for i in diff_argnums) 321 322 @functools.wraps(f) 323 def wrapped(*primals): 324 _args = list(flat_args) 325 for num, arg in zip(diff_argnums, primals): 326 _args[num] = arg 327 _args = tree_unflatten(_args, args_spec) 328 result = f(*_args, **kwargs) 329 if output_process_fn_grad is not None: 330 result = output_process_fn_grad(result) 331 if isinstance(result, tuple): 332 # TODO We should check that the integer outputs also agree 333 result = tuple( 334 r 335 for r in result 336 if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex()) 337 ) 338 assert len(result) > 0 339 return result 340 341 return wrapped, primals 342 343 344# NB: This also upcasts dtype arguments 345# TODO: handle complex correctly 346def upcast_tensor(x, dtype=torch.float32): 347 if isinstance(x, Tensor) and x.dtype.is_floating_point: 348 return x.to(dtype=dtype) 349 elif isinstance(x, torch.dtype) and x in [ 350 torch.float16, 351 torch.bfloat16, 352 torch.float, 353 ]: 354 return dtype 355 else: 356 return x 357 358 359def normalize_op_input_output(f, sample, requires_grad=True): 360 args = tuple([sample.input] + list(sample.args)) 361 return normalize_op_input_output2( 362 f, 363 args, 364 sample.kwargs, 365 sample.output_process_fn_grad, 366 requires_grad=requires_grad, 367 ) 368 369 370CROSS_REF_EXCLUDE_SET = { 371 # CUBLAS_STATUS_NOT_SUPPORTED when calling 372 # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, 373 # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, 374 # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, 375 # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)` 376 ("cuda", torch.bfloat16, "nn.functional.bilinear"), 377 # randomness 378 (None, None, "special.ndtr"), # aten.special_ndtr was not decomposed 379 (None, None, "new_empty"), 380 (None, None, "empty_like"), 381 (None, None, "empty"), 382 # AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default. 383 (None, None, "item"), 384 # It's the only in-place op without an out-of-place equivalent in the Python API 385 # Its OpInfo wrongly registers it as `torch.zero_(x.clone())`. 386 (None, None, "zero_"), 387 # No idea what's going on here 388 # In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), []) 389 # in the test, but it seems to pass when tested locally and in the logsumexp test 390 (None, torch.float32, "masked.logsumexp"), 391 (None, torch.float64, "masked.logsumexp"), 392 # exp_vml_cpu not implemented for Half 393 (torch.cpu, torch.float16, "signal.windows.exponential"), 394 (torch.cpu, torch.float16, "signal.windows.gaussian"), 395 # sin_vml_cpu not implemented for Half 396 (torch.cpu, torch.float16, "signal.windows.cosine"), 397 # CompositeAutogradImplicit 398 # See https://github.com/pytorch/pytorch/issues/81669 399 (None, None, "nn.functional.relu6"), 400 # This decomp runs before autograd. 401 (None, None, "nn.functional.rrelu"), 402 (None, None, "meshgrid"), 403 # Decomposition registered as Autograd 404 (None, None, "nn.functional.hardshrink"), 405 (None, None, "nn.functional.softshrink"), 406 # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit) 407 (None, None, "diag"), 408 # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32 409 ("cpu", torch.bfloat16, "_softmax_backward_data"), 410 (None, None, "norm"), 411 # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise) 412 (None, None, "native_batch_norm"), 413 (None, None, "_upsample_bilinear2d_aa"), 414 (None, None, "empty_strided"), # aten.empty_strided was not decomposed 415} 416 417CROSS_REF_BACKWARD_EXCLUDE_SET = { 418 # Decomposed backward formula is not as precise 419 ("cpu", torch.bfloat16, "nn.functional.hardswish"), 420 ("cuda", torch.float16, "nn.functional.cross_entropy"), 421} 422 423all_decomposed = set() 424all_called = defaultdict(int) 425 426# Helpful snippet for testing coverage 427""" 428import atexit 429def check_coverage(): 430 print("missing coverage:") 431 print("\n".join(map(str, decomposition_table.keys() - all_decomposed))) 432atexit.register(check_coverage) 433""" 434 435# Helpful snippet for Horace to create his google sheet :) 436""" 437import atexit 438def dump_ops(): 439 with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g: 440 for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__): 441 f.write(f'{op.__name__}\n') 442 g.write(f'{count}\n') 443 with open('run_decompositions.txt', 'w') as f: 444 for op in sorted([i.__name__ for i in all_decomposed]): 445 f.write(f'{op}\n') 446 447atexit.register(dump_ops) 448""" 449 450 451def any_unsupported(args, kwargs): 452 def test_unsupported(t): 453 if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: 454 # These are all things that we haven't coded decompositions 455 # to handle correctly. Maybe they should. 456 return any( 457 [ 458 t.is_sparse_csr, 459 t.is_sparse, 460 t.is_mkldnn, 461 t.is_quantized, 462 t.is_nested, 463 torch._is_functional_tensor(t), 464 ] 465 ) 466 elif torch.overrides.is_tensor_like(t): 467 # Decompositions will generally change the behavior of Tensor-like 468 # subclasses, so bypass tests in this case too 469 return True 470 else: 471 return False 472 473 flat_args = pytree.arg_tree_leaves(*args, **kwargs) 474 return any(test_unsupported(x) for x in flat_args) 475 476 477core_backward_failures = { 478 skip("_softmax_backward_data"), # slow: fails with --timeout=360 secs 479 xfail("addcdiv"), 480 skip("addcmul"), # slow: fails with --timeout=360 secs 481 skip("deg2rad"), # slow: fails with --timeout=360 secs 482 skip("diag_embed"), # slow: fails with --timeout=360 secs 483 skip("frac"), # slow: fails with --timeout=360 secs 484 skip("grid_sampler_2d"), # slow: fails with --timeout=360 secs 485 xfail("lerp"), 486 skip("logaddexp"), # slow: fails with --timeout=360 secs 487 skip("native_dropout_backward"), # slow: fails with --timeout=360 secs 488 xfail("nn.functional.binary_cross_entropy_with_logits"), 489 skip("nn.functional.glu"), # slow: fails with --timeout=360 secs 490 xfail("nn.functional.hardshrink"), 491 xfail("nn.functional.softshrink"), 492 skip("nn.functional.unfold"), # slow: fails with --timeout=360 secs 493 xfail("norm"), 494 xfail("norm", "fro"), 495 xfail("norm", "inf"), 496 xfail("norm", "nuc"), 497 skip("rad2deg"), # slow: fails with --timeout=360 secs 498 skip("renorm"), # slow: fails with --timeout=360 secs 499 skip("rot90"), # slow: fails with --timeout=360 secs 500 skip("rsub"), # slow: fails with --timeout=360 secs 501 skip("sgn"), # slow: fails with --timeout=360 secs 502 skip("special.xlog1py"), # slow: fails with --timeout=360 secs 503 xfail("stack"), 504 skip("tril"), # slow: fails with --timeout=360 secs 505 skip("triu"), # slow: fails with --timeout=360 secs 506 skip("unfold_copy"), # slow: fails with --timeout=360 secs 507 skip("xlogy"), # slow: fails with --timeout=360 secs 508 xfail("zero_"), 509} 510if not TEST_WITH_SLOW: 511 core_backward_failures.update( 512 { 513 skip("addr"), # slow: takes 46 sec on A100 514 skip("baddbmm"), # slow: takes 800+ sec on A100 515 skip("clamp_min"), # slow: takes 800 sec on A100 516 skip("clamp_max"), # slow: takes 800 sec on A100 517 skip("logit"), # slow: takes 44 sec on A100 518 skip("nn.functional.hardswish"), # slow: takes 60 sec on A100 519 skip("std_mean"), # slow: takes 170 sec on A100 520 skip("split", variant_name="list_args"), # slow: takes 118 sec on A100 521 skip("transpose"), # slow: takes 50 sec on A100 522 skip("unbind"), # slow: takes 70 sec on A100 523 skip("unsafe_split"), # slow: takes 49 sec on A100 524 } 525 ) 526 527comprehensive_failures = { 528 xfail( 529 "nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,) 530 ), # off by one error 531 xfail( 532 "nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,) 533 ), # off by one error 534 xfail( 535 "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,) 536 ), # off by one error 537} 538 539 540@unMarkDynamoStrictTest 541class TestDecomp(TestCase): 542 longMessage = True 543 544 # NB: This actually overlaps with test_comprehensive, but it only 545 # runs on things that are definitely decomposed so it's a lot faster 546 # to run 547 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 548 @onlyNativeDeviceTypes 549 @skipIfCrossRef 550 @suppress_warnings 551 @ops(_decomp_test_ops) 552 def test_quick(self, device, dtype, op): 553 self.do_cross_ref(device, dtype, op, run_all=False) 554 555 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 556 @skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures) 557 @onlyNativeDeviceTypes 558 @skipIfCrossRef 559 @suppress_warnings 560 @ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,)) 561 def test_quick_core_backward(self, device, dtype, op): 562 for sample_input in op.sample_inputs(device, dtype, requires_grad=True): 563 aten_name = op.decomp_aten_name or op.aten_name 564 args = [sample_input.input] + list(sample_input.args) 565 kwargs = sample_input.kwargs 566 func = partial(op.get_op(), **kwargs) 567 with self.DecompCrossRefMode( 568 self, self.precision, self.rel_tol, dtype, run_all=False 569 ) as mode, enable_python_dispatcher(): 570 torch.autograd.gradcheck(func, args) 571 self.check_decomposed(aten_name, mode) 572 573 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 574 @onlyNativeDeviceTypes 575 @skipIfCrossRef 576 @skipOps("TestDecomp", "test_comprehensive", comprehensive_failures) 577 @suppress_warnings 578 @ops(op_db) 579 def test_comprehensive(self, device, dtype, op): 580 self.do_cross_ref(device, dtype, op, run_all=True) 581 582 def test_uniform(self, device): 583 size = (2, 3, 4, 5) 584 dtype = torch.float32 585 x = make_tensor(size, dtype=dtype, device=device) 586 low = 0.3 587 high = 0.9 588 589 torch.manual_seed(123) 590 ref = torch.ops.aten.uniform(x, low, high) 591 torch.manual_seed(123) 592 res = torch._decomp.decompositions.uniform(x, low=low, high=high) 593 self.assertEqual(ref, res) 594 595 def test_broadcasting_index_copy(self, device): 596 x = torch.zeros([1, 10], device=device) 597 xs = torch.ones([2, 10], device=device) 598 599 def index_copy(xs, x): 600 torch._decomp.decompositions.index_copy_( 601 xs, 0, torch.tensor(0).to(device), x 602 ) 603 604 index_copy(xs, x) 605 606 xs_two = torch.ones([2, 10], device=device) 607 xs_two[0] = x 608 609 self.assertEqual(xs, xs_two) 610 611 def test_cat_single_input(self, device): 612 decomp_table = torch._inductor.decomposition.select_decomp_table() 613 cat_inductor = decomp_table[torch.ops.aten.cat.default] 614 615 inp = torch.rand([2048, 2048], device=device) 616 inps = [inp for _ in range(10)] 617 618 for dim in (-1, 0, 1): 619 self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim)) 620 621 def test_rrelu_with_noise(self, device): 622 # rrelu_with_noise behavior depends on a) whether elements in the input 623 # are <= 0, and b) whether we're in training mode. Cover all cases: 624 dtype = torch.float64 625 x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device) 626 lower = 1.0 627 upper = 4.0 628 training = False 629 630 torch.manual_seed(123) 631 noise_ref = torch.zeros(x.shape, dtype=dtype, device=device) 632 ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training) 633 634 torch.manual_seed(123) 635 noise_res = torch.zeros(x.shape, dtype=dtype, device=device) 636 res = torch._decomp.decompositions.rrelu_with_noise( 637 x, 638 noise_res, 639 lower, 640 upper, 641 training, 642 ) 643 self.assertEqual(ref, res) 644 self.assertEqual(noise_ref, noise_res) 645 646 # Now with training=True: 647 training = True 648 649 torch.manual_seed(123) 650 noise_ref = torch.zeros(x.shape, dtype=dtype, device=device) 651 ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training) 652 653 torch.manual_seed(123) 654 noise_res = torch.zeros(x.shape, dtype=dtype, device=device) 655 res = torch._decomp.decompositions.rrelu_with_noise( 656 x, 657 noise_res, 658 lower, 659 upper, 660 training, 661 ) 662 self.assertEqual(ref, res) 663 self.assertEqual(noise_ref, noise_res) 664 665 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 666 @suppress_warnings 667 @tf32_off() 668 # only tests RNNs since we have py dispsatcher decomps for them 669 @modules( 670 filter( 671 lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), 672 module_db, 673 ) 674 ) 675 def test_rnn_decomp_module(self, device, dtype, module_info, training): 676 module_cls = module_info.module_cls 677 module_inputs = module_info.module_inputs_func( 678 module_info, 679 device=device, 680 dtype=dtype, 681 requires_grad=True, 682 training=training, 683 ) 684 for module_input in module_inputs: 685 if module_input.forward_input is None: 686 continue 687 args, kwargs = ( 688 module_input.constructor_input.args, 689 module_input.constructor_input.kwargs, 690 ) 691 m = module_cls(*args, **kwargs) 692 m.to(device).to(dtype) 693 694 args, kwargs = ( 695 module_input.forward_input.args, 696 module_input.forward_input.kwargs, 697 ) 698 with self.DecompCrossRefMode( 699 self, self.precision, self.rel_tol, dtype, run_all=True 700 ), enable_python_dispatcher(): 701 decomp_out = m(*args, **kwargs) 702 703 non_decomp_out = m(*args, **kwargs) 704 # without this check, incorrect decomps at the python dispatcher level can still pass because 705 # they're checking aten decomps at the torch_dispatch level 706 self.assertEqual(decomp_out, non_decomp_out) 707 708 def test_batch_norm_unflatten_weight_bias(self, device): 709 # https://github.com/pytorch/pytorch/issues/100970 710 shape = (1, 3, 2, 2) 711 input = torch.randn(shape, device=device) 712 weight = torch.randn((3, 1, 1, 1), device=device) 713 bias = torch.randn(3, device=device) 714 mean = torch.randn(3, device=device) 715 var = torch.randn(3, device=device) 716 res = torch._decomp.decompositions.native_batch_norm( 717 input, weight, bias, mean, var, False, 1, 1e-05 718 ) 719 self.assertEqual(shape, res[0].shape) 720 721 def test_arange_graph(self, device): 722 from torch.fx.experimental.proxy_tensor import make_fx 723 724 def func(x, start): 725 le = x.shape[-1] 726 if start is None: 727 a = torch.arange(le, dtype=torch.float32, device=x.device) 728 else: 729 a = torch.arange(start, le, dtype=torch.float32, device=x.device) 730 return a 731 732 pattern = r", device = device\(.+\), requires_grad = False" 733 734 cfunc = make_fx(func, decomposition_table=decomposition_table) 735 fx_g = cfunc(torch.rand(10, device=device), None) 736 fx_g_code = fx_g.code.strip() 737 # Remove device and requires_grad 738 fx_g_code = re.sub(pattern, "", fx_g_code) 739 self.assertExpectedInline( 740 fx_g_code, 741 """\ 742def forward(self, x_1, start_1): 743 iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64) 744 mul = torch.ops.prims.mul.default(iota, 1); iota = None 745 add = torch.ops.prims.add.default(mul, 0); mul = None 746 convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None 747 return convert_element_type""", 748 ) 749 750 fx_g = cfunc(torch.rand(10, device=device), 1) 751 fx_g_code = fx_g.code.strip() 752 # Remove device and requires_grad 753 fx_g_code = re.sub(pattern, "", fx_g_code) 754 self.assertExpectedInline( 755 fx_g_code, 756 """\ 757def forward(self, x_1, start_1): 758 iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64) 759 mul = torch.ops.prims.mul.default(iota, 1); iota = None 760 add = torch.ops.prims.add.default(mul, 1); mul = None 761 convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None 762 return convert_element_type""", 763 ) 764 765 def test_masked_fill(self, device): 766 from torch.fx.experimental.proxy_tensor import make_fx 767 768 if torch.device(device).type not in [ 769 "xpu", 770 "cuda", 771 torch._C._get_privateuse1_backend_name(), 772 ]: 773 self.skipTest("only runs on XPU and CUDA and PrivateUse1.") 774 775 def func(scores, mask, value): 776 return scores.masked_fill(mask, value) 777 778 scores_t = torch.tensor([1, 2, 3, 4], device=device) 779 mask_t = torch.tensor([True, True, True, True], device=device) 780 value_t = torch.tensor(0, dtype=scores_t.dtype) 781 cfunc = make_fx(func, decomposition_table=decomposition_table) 782 fx_g = cfunc(scores_t, mask_t, value_t) 783 self.assertExpectedInline( 784 fx_g.code.strip(), 785 """\ 786def forward(self, scores_1, mask_1, value_1): 787 where = torch.ops.prims.where.default(mask_1, value_1, scores_1); mask_1 = value_1 = scores_1 = None 788 return where""", 789 ) 790 791 class DecompCrossRefMode(TorchDispatchMode): 792 def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all): 793 self.test_case = test_case 794 self.saved_precision = saved_precision 795 self.saved_rel_tol = saved_rel_tol 796 self.test_dtype = dtype 797 self.run_all = run_all 798 799 # We check the correctness of each decomposition right after running it. 800 # So, when we encounter a decomposition, we run the function normally, and 801 # then run the decomposition, and ensure they're identical. 802 self.called = set() 803 self.decomposed = set() 804 805 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 806 self.test_case.precision = self.saved_precision 807 self.test_case.rel_tol = self.saved_rel_tol 808 809 self.called.add(func) 810 all_called[func] += 1 811 812 # Stuff we shouldn't bother testing 813 # (TODO: remove detach from the decomp table?) 814 # N.b. Testing in-place ops would need dedicated logic 815 in_place = func.name()[-1] == "_" 816 ignored_ops = [ 817 torch.ops.aten.detach.default, 818 # non-deterministic ops 819 torch.ops.aten.empty.memory_format, 820 torch.ops.aten.empty_like.default, 821 torch.ops.aten.new_empty.default, 822 torch.ops.aten.empty_strided.default, 823 torch.ops.aten.new_empty_strided.default, 824 torch.ops.aten.randn.default, 825 torch.ops.aten.native_dropout.default, 826 ] 827 if ( 828 func not in decomposition_table 829 or func in ignored_ops 830 or torch.Tag.nondeterministic_seeded in func.tags 831 or any_unsupported(args, kwargs) 832 or in_place 833 ): 834 return func(*args, **kwargs) 835 836 self.decomposed.add(func) 837 all_decomposed.add(func) 838 839 # We take 2 main strategies for verifying correctness/numerical stability of decompositions 840 # The first one is simply tolerance checking between decomp_out and pytorch_out 841 # However, for fp16/bf16 and reductions, this becomes very 842 # finicky, as there are not many guarantees we can make. 843 # So, for fp16/bf16, we instead compare the difference of 844 # {decomp_out, pytorch_out_64} and {pytorch_out, 845 # pytorch_out_64}. In other words, we compare how far the 846 # decomposition and pytorch are from the "ground truth" (i.e. 847 # fp64). If the decomposition results in more error, we error 848 849 # We also decompose the decomposition recursively for 850 # further coverage, as some paths not be exercised directly by 851 # OpInfos (sadly) but just by other ops 852 853 decomposition = decomposition_table[func] 854 855 do_relative_check = self.test_dtype in [torch.float16, torch.bfloat16] 856 if self.run_all: 857 # Execute recursively via DFS, to find the root of a possible error first 858 with self: 859 decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) 860 else: 861 decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) 862 863 # At this stage we should not be decomposing an in-place op 864 # We'd like to have decompositions that decompose out-of-place ops into out-of-place ops 865 # because decompositions are run after functionalisation and we would not like them to 866 # de-functionalise the graph, as that would break AoTAutograd 867 # We run the real function *after* the decomposition to make sure that the 868 # decomposition does not modify any of the inputs in-place. If it does 869 # real_out should be differen than decom_out so we should catch this 870 real_out_unflat = func(*args, **kwargs) 871 real_out = pytree.tree_leaves(real_out_unflat) 872 873 assert len(real_out) == len(decomp_out) 874 875 if do_relative_check: 876 upcast = partial(upcast_tensor, dtype=torch.float64) 877 real_out_double, _ = tree_flatten( 878 func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) 879 ) 880 for i, (orig, decomp, ref) in enumerate( 881 zip(real_out, decomp_out, real_out_double) 882 ): 883 if not isinstance(orig, torch.Tensor): 884 assert type(orig) == type(decomp) 885 assert orig == decomp 886 continue 887 op_assert_ref( 888 self.test_case, 889 func, 890 self.test_dtype, 891 i, 892 orig, 893 decomp, 894 ref, 895 args, 896 kwargs, 897 ) 898 else: 899 for orig, decomp in zip(real_out, decomp_out): 900 if not isinstance(orig, torch.Tensor): 901 assert type(orig) == type(decomp) 902 assert orig == decomp 903 continue 904 op_assert_equal( 905 self.test_case, 906 func, 907 self.test_dtype, 908 orig, 909 decomp, 910 args, 911 kwargs, 912 ) 913 914 return real_out_unflat 915 916 def check_decomposed(self, aten_name, mode): 917 self.assertTrue( 918 any(overload_to_aten_name(c) == aten_name for c in mode.decomposed), 919 msg=( 920 f"aten.{aten_name} was not decomposed, saw calls for: " 921 f"{', '.join(map(str, list(mode.called)))}. If your op is " 922 f"CompositeImplicitAutograd you should skip this test " 923 f"by updating CROSS_REF_EXCLUDE_SET." 924 ), 925 ) 926 927 @skipIfTorchDynamo("Test does not work with TorchDynamo") 928 def do_cross_ref(self, device, dtype, op, *, run_all): 929 test_keys = [ 930 (torch.device(device).type, dtype, op.name), 931 (None, dtype, op.name), 932 (None, None, op.name), 933 ] 934 if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys): 935 self.skipTest(f"{op.name} in {dtype} not supported") 936 937 skip_decomp_vjp = any( 938 key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys 939 ) 940 941 requires_grad = ( 942 op.supports_autograd 943 and dtype in op.supported_backward_dtypes(torch.device(device).type) 944 # TODO: OpInfo really ought to error out for this case, but it's 945 # not exercised in test_ops_gradients atm. The problem is not 946 # complex32 per-se (which is supported by data movement only ops) 947 # but that when we do backwards we expect other ops like add to work 948 and not dtype == torch.complex32 949 ) 950 samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) 951 952 aten_name = op.decomp_aten_name or op.aten_name 953 954 func = op.get_op() 955 956 def run_without_python_dispatcher(mode): 957 return any( 958 isinstance(op, torch._ops.OpOverload) 959 and op.has_kernel_for_dispatch_key( 960 DispatchKey.CompositeImplicitAutograd 961 ) 962 for op in mode.decomposed.union([func]) 963 ) 964 965 for sample_input in samples: 966 if requires_grad: 967 fn, primals = normalize_op_input_output(func, sample_input) 968 primals = tree_map( 969 lambda x: x if isinstance(x, torch.Tensor) else x, primals 970 ) 971 972 # Once https://github.com/pytorch/pytorch/pull/75965/ I can 973 # store the called list on the mode object instance and no 974 # explicit clearing is necessary as I will create a fresh mode 975 # for each region 976 with self.DecompCrossRefMode( 977 self, self.precision, self.rel_tol, dtype, run_all 978 ) as mode, enable_python_dispatcher(): 979 decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) 980 if run_without_python_dispatcher(mode): 981 # without this check, incorrect decomps at the python dispatcher level can still pass because 982 # they're checking aten decomps at the torch_dispatch level. 983 with self.DecompCrossRefMode( 984 self, self.precision, self.rel_tol, dtype, run_all 985 ) as mode: 986 decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) 987 if aten_name in decomposition_names: 988 self.check_decomposed(aten_name, mode) 989 990 if not skip_decomp_vjp and ( 991 op.aten_backward_name in decomposition_names or run_all 992 ): 993 cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out) 994 995 with self.DecompCrossRefMode( 996 self, self.precision, self.rel_tol, dtype, run_all 997 ) as mode, enable_python_dispatcher(): 998 decomp_vjp_fn(cotangents) 999 if run_without_python_dispatcher(mode): 1000 # without this check, incorrect decomps at the python dispatcher level can still pass because 1001 # they're checking aten decomps at the torch_dispatch level. 1002 with self.DecompCrossRefMode( 1003 self, self.precision, self.rel_tol, dtype, run_all 1004 ) as mode: 1005 decomp_vjp_fn(cotangents) 1006 if not run_all: 1007 self.check_decomposed(op.aten_backward_name, mode) 1008 1009 elif aten_name in decomposition_names or run_all: 1010 args = [sample_input.input] + list(sample_input.args) 1011 kwargs = sample_input.kwargs 1012 # A failure here might be because the decomposition for the op is wrong or because a 1013 # decomposition used by the particular op is wrong. 1014 with self.DecompCrossRefMode( 1015 self, self.precision, self.rel_tol, dtype, run_all 1016 ) as mode, enable_python_dispatcher(): 1017 func(*args, **kwargs) 1018 1019 if run_without_python_dispatcher(mode): 1020 # without this check, incorrect decomps at the python dispatcher level can still pass because 1021 # they're checking aten decomps at the torch_dispatch level. 1022 with self.DecompCrossRefMode( 1023 self, self.precision, self.rel_tol, dtype, run_all 1024 ) as mode: 1025 func(*args, **kwargs) 1026 1027 if not run_all: 1028 self.check_decomposed(aten_name, mode) 1029 else: 1030 assert op.supports_autograd 1031 self.skipTest( 1032 "only backwards is decomposed, but dtype doesn't support AD" 1033 ) 1034 1035 1036instantiate_device_type_tests(TestDecomp, globals()) 1037 1038 1039class DecompOneOffTests(TestCase): 1040 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1041 @onlyNativeDeviceTypes 1042 @skipIfCrossRef 1043 def test_contiguous_softmax(self, device): 1044 size = (2, 4, 3, 3) 1045 stride = (9, 18, 3, 1) 1046 dtype = torch.float32 1047 1048 x = torch.randn(size, dtype=dtype, device=device) 1049 x = torch.as_strided(x, size, stride) 1050 1051 ref = torch.ops.aten._softmax(x, -1, False) 1052 res = torch._decomp.decompositions._softmax(x, -1, False) 1053 self.assertEqual(ref.stride(), res.stride()) 1054 1055 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1056 @onlyNativeDeviceTypes 1057 @skipIfCrossRef 1058 def test_contiguous_log_softmax(self, device): 1059 size = (2, 4, 3, 3) 1060 stride = (9, 18, 3, 1) 1061 1062 dtype = torch.float32 1063 x = torch.randn(size, dtype=dtype, device=device) 1064 x = torch.as_strided(x, size, stride) 1065 1066 ref = torch.ops.aten._log_softmax(x, -1, False) 1067 res = torch._decomp.decompositions._log_softmax(x, -1, False) 1068 self.assertEqual(ref.stride(), res.stride()) 1069 1070 @onlyCUDA 1071 def test_exponential_non_inf(self, device): 1072 inp = torch.empty((4, 400, 256), device=device) 1073 1074 with torch._dynamo.utils.preserve_rng_state(): 1075 exp_ref = inp.exponential_() 1076 exp = torch._refs.exponential(inp) 1077 1078 self.assertEqual(exp, exp_ref) 1079 self.assertFalse(exp.isinf().any()) 1080 1081 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1082 @skipIfCrossRef 1083 @onlyCUDA 1084 def test_amp_batch_norm_backward(self): 1085 device = "cuda" 1086 grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) 1087 x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) 1088 weight = torch.randn((2,), dtype=torch.float32, device=device) 1089 rmean = torch.randn((2,), dtype=torch.float32, device=device) 1090 rvar = torch.randn((2,), dtype=torch.float32, device=device) 1091 mean = torch.randn((0,), dtype=torch.float32, device=device) 1092 1093 ref = torch.ops.aten.native_batch_norm_backward( 1094 grad_out, 1095 x, 1096 weight, 1097 rmean, 1098 rvar, 1099 mean, 1100 mean, 1101 False, 1102 1e-05, 1103 [True, True, True], 1104 ) 1105 res = torch._decomp.decompositions.native_batch_norm_backward( 1106 grad_out, 1107 x, 1108 weight, 1109 rmean, 1110 rvar, 1111 mean, 1112 mean, 1113 False, 1114 1e-05, 1115 [True, True, True], 1116 ) 1117 for a, b in zip(ref, res): 1118 self.assertEqual(a.stride(), b.stride()) 1119 self.assertEqual(a.dtype, b.dtype) 1120 1121 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1122 @onlyNativeDeviceTypes 1123 @skipIfCrossRef 1124 def test_elu_backward(self, device): 1125 size = (2, 4, 3, 3) 1126 dtype = torch.float32 1127 grad_out = torch.randn(size, dtype=dtype, device=device) 1128 out = torch.randn(size, dtype=dtype, device=device) 1129 1130 ref = torch.ops.aten.elu_backward(grad_out, 1.0, 1, 1, True, out) 1131 res = torch._decomp.decompositions.elu_backward(grad_out, 1.0, 1, 1, True, out) 1132 self.assertEqual(ref, res) 1133 1134 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1135 @onlyNativeDeviceTypes 1136 @skipIfCrossRef 1137 def test_threshold_backward_dtype(self, device): 1138 grad = torch.randint(10, (4,), device=device) 1139 input_tensor = torch.randint(10, (4,), device=device) 1140 1141 ref = torch.ops.aten.threshold_backward(grad, input_tensor, 1) 1142 res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1) 1143 self.assertEqual(ref.dtype, res.dtype) 1144 1145 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1146 @onlyNativeDeviceTypes 1147 @skipIfCrossRef 1148 def test_weight_norm_interface(self, device): 1149 g = torch.randn((3, 10, 10), device=device) 1150 v = torch.randn((1, 1, 10), device=device) 1151 1152 ref = torch.ops.aten._weight_norm_interface(g, v, 2) 1153 res = torch._decomp.decompositions._weight_norm_interface(g, v, 2) 1154 self.assertTrue(torch.allclose(ref[0], res[0])) 1155 self.assertTrue(torch.allclose(ref[1], res[1])) 1156 1157 inp = torch.rand([30, 10], device=device) 1158 inp2 = torch.rand([30, 1], device=device) 1159 1160 self.assertEqual( 1161 torch.ops.aten._weight_norm_interface(inp, inp2), 1162 torch._decomp.decompositions._weight_norm_interface(inp, inp2), 1163 ) 1164 1165 @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1166 @onlyCPU 1167 @skipIfCrossRef 1168 @skipOps( 1169 "DecompOneOffTests", 1170 "test_sdpa", 1171 [ 1172 xfail( 1173 "nn.functional.scaled_dot_product_attention", 1174 dtypes=[torch.half], 1175 ), 1176 ], 1177 ) 1178 @ops(_sdpa_op_info) 1179 def test_sdpa(self, device, dtype, op): 1180 # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we 1181 # add support for float16 over there we should update this test as well. 1182 1183 class ScaledDotProductAttention(torch.nn.Module): 1184 def __init__(self) -> None: 1185 super().__init__() 1186 1187 def forward( 1188 self, query_layer, key_layer, value_layer, mask=None, is_causal=True 1189 ): 1190 attn_output = op( 1191 query_layer, 1192 key_layer, 1193 value_layer, 1194 attn_mask=mask, 1195 dropout_p=0.0, 1196 is_causal=is_causal, 1197 ) 1198 return attn_output 1199 1200 query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) 1201 key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) 1202 value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) 1203 masks = [None, torch.ones((1, 1, 100, 100), device=device, dtype=torch.bool)] 1204 1205 atol, rtol = dtype_precisions[dtype] 1206 1207 for mask in masks: 1208 is_causal = mask is None 1209 attention = ScaledDotProductAttention() 1210 decomposed_res = ( 1211 torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu( 1212 query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask 1213 ) 1214 ) 1215 eager_res = op( 1216 query_layer, 1217 key_layer, 1218 value_layer, 1219 attn_mask=mask, 1220 dropout_p=0.0, 1221 is_causal=is_causal, 1222 ) 1223 1224 self.assertTrue( 1225 torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol) 1226 ) 1227 1228 1229instantiate_device_type_tests(DecompOneOffTests, globals()) 1230 1231 1232class HasDecompTest(TestCase): 1233 def setUp(self): 1234 super().setUp() 1235 self.maxDiff = None 1236 1237 @staticmethod 1238 def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool: 1239 has_tensor_arg = any( 1240 "Tensor" in str(a.type) 1241 for a in itertools.chain(op._schema.arguments, op._schema.returns) 1242 ) 1243 if not has_tensor_arg: 1244 return False 1245 1246 try: 1247 # CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions 1248 return not op.has_kernel_for_dispatch_key( 1249 DispatchKey.CompositeImplicitAutograd 1250 ) 1251 except RuntimeError as e: 1252 # has_key fails for some jit-registered ops, which shouldn't be 1253 # relevant here anyway 1254 if "does not exist" in str(e): 1255 return False 1256 raise 1257 1258 def test_has_decomposition(self): 1259 def all_aten_overloads(): 1260 for name in torch._C._dispatch_get_all_op_names(): 1261 if not name.startswith("aten::"): 1262 continue 1263 1264 name = name[6:] 1265 if "." in name: 1266 packet_name, overload_name = name.split(".") 1267 else: 1268 packet_name, overload_name = name, "default" 1269 1270 packet = getattr(aten, packet_name) 1271 assert isinstance(packet, torch._ops.OpOverloadPacket) 1272 op = getattr(packet, overload_name) 1273 yield op 1274 1275 # This is for operators that are only registered in some CI 1276 # configurations, so would cause the test to fail 1277 allow_list = {aten.get_gradients.default} 1278 1279 overloads_wanting_decomp = { 1280 op for op in all_aten_overloads() if self._can_appear_in_trace(op) 1281 } 1282 ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys() 1283 ops_missing_decomp -= allow_list 1284 self.assertExpected( 1285 "".join(sorted(op.name() + "\n" for op in ops_missing_decomp)) 1286 ) 1287 1288 def test_aten_core_operators(self): 1289 # If a decomposition isn't included in the core decompositions, 1290 # then it must decompose a core ATen operator. 1291 # 1292 # See NOTE [Core ATen Ops] 1293 # 1294 # If this test fails then either: 1295 # - Add the decomposition to torch._decomp.core_aten_decompositions, 1296 # if decomposition should be used by inductor (not a core operator). 1297 # - Run this test again with EXPECTTEST_ACCEPT=1 to update the list of 1298 # core ATen operators (and inductor will not use the decomposition). 1299 1300 # Some decompositions are registered for CompositeImplicitAutograd 1301 # operators, which never appear in AOTAutograd's graph so are never used. 1302 useful_decomps = { 1303 op 1304 for op in decomposition_table.keys() 1305 if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op) 1306 } 1307 core_decomps = torch._decomp.core_aten_decompositions().keys() 1308 core_aten_ops = useful_decomps - core_decomps 1309 self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops))) 1310 1311 1312if __name__ == "__main__": 1313 run_tests() 1314