1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: decompositions"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerimport itertools 5*da0073e9SAndroid Build Coastguard Workerimport re 6*da0073e9SAndroid Build Coastguard Workerimport unittest 7*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 8*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport torch._inductor.decomposition 11*da0073e9SAndroid Build Coastguard Workerimport torch.autograd 12*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 13*da0073e9SAndroid Build Coastguard Workerfrom torch._decomp import core_aten_decompositions, decomposition_table 14*da0073e9SAndroid Build Coastguard Workerfrom torch._dispatch.python import enable_python_dispatcher 15*da0073e9SAndroid Build Coastguard Workerfrom torch._ops import DispatchKey 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import tf32_off 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 19*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 20*da0073e9SAndroid Build Coastguard Worker onlyCPU, 21*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 22*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, 23*da0073e9SAndroid Build Coastguard Worker ops, 24*da0073e9SAndroid Build Coastguard Worker) 25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ( 26*da0073e9SAndroid Build Coastguard Worker op_db, 27*da0073e9SAndroid Build Coastguard Worker skip, 28*da0073e9SAndroid Build Coastguard Worker skipOps, 29*da0073e9SAndroid Build Coastguard Worker xfail, 30*da0073e9SAndroid Build Coastguard Worker) 31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_modules import module_db, modules 32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 33*da0073e9SAndroid Build Coastguard Worker is_iterable_of_tensors, 34*da0073e9SAndroid Build Coastguard Worker run_tests, 35*da0073e9SAndroid Build Coastguard Worker skipIfCrossRef, 36*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 37*da0073e9SAndroid Build Coastguard Worker suppress_warnings, 38*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ASAN, 39*da0073e9SAndroid Build Coastguard Worker TEST_WITH_SLOW, 40*da0073e9SAndroid Build Coastguard Worker TestCase, 41*da0073e9SAndroid Build Coastguard Worker unMarkDynamoStrictTest, 42*da0073e9SAndroid Build Coastguard Worker) 43*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode 45*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_flatten, tree_map, tree_unflatten 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker# TODO: this isn't going to work with non-aten namespaces 52*da0073e9SAndroid Build Coastguard Workerdef overload_to_aten_name(op): 53*da0073e9SAndroid Build Coastguard Worker return op._schema.name.split("::")[1] 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker# All operators that can have decomp tests 57*da0073e9SAndroid Build Coastguard Workerdecomposition_names = { 58*da0073e9SAndroid Build Coastguard Worker overload_to_aten_name(k) 59*da0073e9SAndroid Build Coastguard Worker for k in decomposition_table 60*da0073e9SAndroid Build Coastguard Worker if isinstance(k, torch._ops.OpOverload) 61*da0073e9SAndroid Build Coastguard Worker} 62*da0073e9SAndroid Build Coastguard Workercore_decomposition_names = { 63*da0073e9SAndroid Build Coastguard Worker overload_to_aten_name(k) 64*da0073e9SAndroid Build Coastguard Worker for k in core_aten_decompositions() 65*da0073e9SAndroid Build Coastguard Worker if isinstance(k, torch._ops.OpOverload) 66*da0073e9SAndroid Build Coastguard Worker} 67*da0073e9SAndroid Build Coastguard Worker_decomp_test_ops = [ 68*da0073e9SAndroid Build Coastguard Worker op 69*da0073e9SAndroid Build Coastguard Worker for op in op_db 70*da0073e9SAndroid Build Coastguard Worker if op.aten_name in decomposition_names 71*da0073e9SAndroid Build Coastguard Worker or op.aten_backward_name in decomposition_names 72*da0073e9SAndroid Build Coastguard Worker] 73*da0073e9SAndroid Build Coastguard Worker_decomp_test_ops_core_autograd = [ 74*da0073e9SAndroid Build Coastguard Worker op 75*da0073e9SAndroid Build Coastguard Worker for op in op_db 76*da0073e9SAndroid Build Coastguard Worker if op.aten_name in core_decomposition_names and op.supports_autograd 77*da0073e9SAndroid Build Coastguard Worker] 78*da0073e9SAndroid Build Coastguard Worker_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name] 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Workerdef diff_arg(arg, requires_grad=True): 82*da0073e9SAndroid Build Coastguard Worker def is_differentiable_arg(arg): 83*da0073e9SAndroid Build Coastguard Worker if requires_grad: 84*da0073e9SAndroid Build Coastguard Worker return arg.requires_grad 85*da0073e9SAndroid Build Coastguard Worker else: 86*da0073e9SAndroid Build Coastguard Worker return arg.is_floating_point() or arg.is_complex() 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker if is_iterable_of_tensors(arg): 89*da0073e9SAndroid Build Coastguard Worker if all(is_differentiable_arg(a) for a in arg): 90*da0073e9SAndroid Build Coastguard Worker return True 91*da0073e9SAndroid Build Coastguard Worker if all(not is_differentiable_arg(a) for a in arg): 92*da0073e9SAndroid Build Coastguard Worker return False 93*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("NYI: The test runner can't handle this") 94*da0073e9SAndroid Build Coastguard Worker return isinstance(arg, Tensor) and is_differentiable_arg(arg) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker# Version of autograd.grad with some differences: 98*da0073e9SAndroid Build Coastguard Worker# - pytree inputs is allowed (but leaves of the pytree have to all 99*da0073e9SAndroid Build Coastguard Worker# be tensors) 100*da0073e9SAndroid Build Coastguard Worker# - if an input is not used as part of derivatives, we will return a 101*da0073e9SAndroid Build Coastguard Worker# zero-filled tensor for the result 102*da0073e9SAndroid Build Coastguard Workerdef _autograd_grad( 103*da0073e9SAndroid Build Coastguard Worker outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True 104*da0073e9SAndroid Build Coastguard Worker): 105*da0073e9SAndroid Build Coastguard Worker inputs, inputs_spec = tree_flatten(inputs) 106*da0073e9SAndroid Build Coastguard Worker diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) 107*da0073e9SAndroid Build Coastguard Worker if grad_outputs is None: 108*da0073e9SAndroid Build Coastguard Worker diff_outputs = tuple(out for out in outputs if out.requires_grad) 109*da0073e9SAndroid Build Coastguard Worker else: 110*da0073e9SAndroid Build Coastguard Worker diff_grad_outputs = [ 111*da0073e9SAndroid Build Coastguard Worker (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad 112*da0073e9SAndroid Build Coastguard Worker ] 113*da0073e9SAndroid Build Coastguard Worker if len(diff_grad_outputs) == 0: 114*da0073e9SAndroid Build Coastguard Worker diff_outputs, grad_outputs = (), () 115*da0073e9SAndroid Build Coastguard Worker else: 116*da0073e9SAndroid Build Coastguard Worker diff_outputs, grad_outputs = zip(*diff_grad_outputs) 117*da0073e9SAndroid Build Coastguard Worker grad_inputs = torch.autograd.grad( 118*da0073e9SAndroid Build Coastguard Worker diff_outputs, 119*da0073e9SAndroid Build Coastguard Worker diff_inputs, 120*da0073e9SAndroid Build Coastguard Worker grad_outputs, 121*da0073e9SAndroid Build Coastguard Worker retain_graph=retain_graph, 122*da0073e9SAndroid Build Coastguard Worker create_graph=create_graph, 123*da0073e9SAndroid Build Coastguard Worker allow_unused=True, 124*da0073e9SAndroid Build Coastguard Worker ) 125*da0073e9SAndroid Build Coastguard Worker result = [] 126*da0073e9SAndroid Build Coastguard Worker grad_inputs_iter = iter(grad_inputs) 127*da0073e9SAndroid Build Coastguard Worker for inp in inputs: 128*da0073e9SAndroid Build Coastguard Worker if inp.requires_grad: 129*da0073e9SAndroid Build Coastguard Worker grad_input = next(grad_inputs_iter) 130*da0073e9SAndroid Build Coastguard Worker if grad_input is None: 131*da0073e9SAndroid Build Coastguard Worker result.append(torch.zeros_like(inp)) 132*da0073e9SAndroid Build Coastguard Worker else: 133*da0073e9SAndroid Build Coastguard Worker result.append(grad_input) 134*da0073e9SAndroid Build Coastguard Worker else: 135*da0073e9SAndroid Build Coastguard Worker result.append(torch.zeros_like(inp)) 136*da0073e9SAndroid Build Coastguard Worker return tree_unflatten(result, inputs_spec) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Workerdef _as_tuple(val): 140*da0073e9SAndroid Build Coastguard Worker if isinstance(val, tuple): 141*da0073e9SAndroid Build Coastguard Worker return val 142*da0073e9SAndroid Build Coastguard Worker return (val,) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Workerdef ref_vjp_no_create(f, *primals): 146*da0073e9SAndroid Build Coastguard Worker result = f(*primals) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker def wrapped(cotangents): 149*da0073e9SAndroid Build Coastguard Worker return _autograd_grad( 150*da0073e9SAndroid Build Coastguard Worker _as_tuple(result), 151*da0073e9SAndroid Build Coastguard Worker primals, 152*da0073e9SAndroid Build Coastguard Worker _as_tuple(cotangents), 153*da0073e9SAndroid Build Coastguard Worker create_graph=False, 154*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 155*da0073e9SAndroid Build Coastguard Worker ) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker return result, wrapped 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Workerdtype_precisions = { 161*da0073e9SAndroid Build Coastguard Worker torch.float16: (0.001, 1e-5), 162*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: (0.016, 1e-4), 163*da0073e9SAndroid Build Coastguard Worker torch.float32: (1.3e-6, 1e-5), 164*da0073e9SAndroid Build Coastguard Worker torch.float64: (1e-7, 1e-7), 165*da0073e9SAndroid Build Coastguard Worker torch.complex32: (0.001, 1e-5), 166*da0073e9SAndroid Build Coastguard Worker torch.complex64: (1.3e-6, 1e-5), 167*da0073e9SAndroid Build Coastguard Worker torch.complex128: (1e-7, 1e-7), 168*da0073e9SAndroid Build Coastguard Worker} 169*da0073e9SAndroid Build Coastguard Worker# Returns the "default" rtol and atol for comparing scalars or 170*da0073e9SAndroid Build Coastguard Worker# tensors of the given dtypes. 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Workerdef _getDefaultRtolAndAtol(dtype0, dtype1): 174*da0073e9SAndroid Build Coastguard Worker rtol = max( 175*da0073e9SAndroid Build Coastguard Worker dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0] 176*da0073e9SAndroid Build Coastguard Worker ) 177*da0073e9SAndroid Build Coastguard Worker atol = max( 178*da0073e9SAndroid Build Coastguard Worker dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1] 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker return rtol, atol 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Workerdef op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs): 184*da0073e9SAndroid Build Coastguard Worker assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" 185*da0073e9SAndroid Build Coastguard Worker if orig.numel() == 0 or decomp.numel() == 0: 186*da0073e9SAndroid Build Coastguard Worker assert orig.numel() == decomp.numel() 187*da0073e9SAndroid Build Coastguard Worker return 188*da0073e9SAndroid Build Coastguard Worker assert orig.shape == decomp.shape, f"{i} Operation: {op}" 189*da0073e9SAndroid Build Coastguard Worker tol_table = { 190*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5, 191*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5, 192*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.native_layer_norm_backward.default): 1e-3, 193*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2, 194*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5, 195*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5, 196*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, 197*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, 198*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, 199*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, 200*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4, 201*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4, 202*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.var_mean.correction): 5e-7, 203*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.var_mean.correction): 5e-7, 204*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.var_mean.dim): 5e-7, 205*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.var_mean.dim): 5e-7, 206*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2, 207*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1, 208*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2, 209*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1, 210*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.hardswish.default): 2e-7, 211*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7, 212*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2, 213*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 5e-2, 214*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, 215*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, 216*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3, 217*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3, 218*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3, 219*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3, 220*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3, 221*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2, 222*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/pull/96264 223*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.mv.default): 1e-5, 224*da0073e9SAndroid Build Coastguard Worker (torch.bfloat16, torch.ops.aten.mv.default): 1e-5, 225*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5, 226*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7, 227*da0073e9SAndroid Build Coastguard Worker } 228*da0073e9SAndroid Build Coastguard Worker if ref.is_floating_point(): 229*da0073e9SAndroid Build Coastguard Worker orig_diff = (orig - ref).abs().max() 230*da0073e9SAndroid Build Coastguard Worker decomp_diff = (decomp - ref).abs().max() 231*da0073e9SAndroid Build Coastguard Worker atol = tol_table.get((test_dtype, op), 1e-7) 232*da0073e9SAndroid Build Coastguard Worker if decomp_diff > orig_diff + atol: 233*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 234*da0073e9SAndroid Build Coastguard Worker f"Difference from float64 is larger with decomposition {op.__name__}" 235*da0073e9SAndroid Build Coastguard Worker f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n" 236*da0073e9SAndroid Build Coastguard Worker f"atol = {atol}\n" 237*da0073e9SAndroid Build Coastguard Worker f"args = {args}\n" 238*da0073e9SAndroid Build Coastguard Worker f"kwargs = {kwargs}" 239*da0073e9SAndroid Build Coastguard Worker ) 240*da0073e9SAndroid Build Coastguard Worker else: 241*da0073e9SAndroid Build Coastguard Worker test_case.assertEqual( 242*da0073e9SAndroid Build Coastguard Worker orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}" 243*da0073e9SAndroid Build Coastguard Worker ) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Workerdef op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs): 247*da0073e9SAndroid Build Coastguard Worker test_case.assertEqual( 248*da0073e9SAndroid Build Coastguard Worker orig.dtype, 249*da0073e9SAndroid Build Coastguard Worker decomp.dtype, 250*da0073e9SAndroid Build Coastguard Worker f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}", 251*da0073e9SAndroid Build Coastguard Worker ) 252*da0073e9SAndroid Build Coastguard Worker # Before adding an entry to this table, make sure your decomposition is right :) 253*da0073e9SAndroid Build Coastguard Worker tol_table = { 254*da0073e9SAndroid Build Coastguard Worker # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161 255*da0073e9SAndroid Build Coastguard Worker (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3), 256*da0073e9SAndroid Build Coastguard Worker (torch.float32, torch.ops.aten.native_layer_norm_backward.default): ( 257*da0073e9SAndroid Build Coastguard Worker 1e-3, 258*da0073e9SAndroid Build Coastguard Worker 1e-3, 259*da0073e9SAndroid Build Coastguard Worker ), 260*da0073e9SAndroid Build Coastguard Worker (torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6), 261*da0073e9SAndroid Build Coastguard Worker # This exceeds default tolerances only on CPU, on CUDA it's fine 262*da0073e9SAndroid Build Coastguard Worker (torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5), 263*da0073e9SAndroid Build Coastguard Worker # Exceeds tolerances on CUDA, likely due to fma 264*da0073e9SAndroid Build Coastguard Worker (torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5), 265*da0073e9SAndroid Build Coastguard Worker (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5), 266*da0073e9SAndroid Build Coastguard Worker (torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4), 267*da0073e9SAndroid Build Coastguard Worker (torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4), 268*da0073e9SAndroid Build Coastguard Worker # The decomposition is TOO correct. It computes everything in int64, so sometimes 269*da0073e9SAndroid Build Coastguard Worker # there's an off-by-one error. See 270*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/81996 271*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/82230 272*da0073e9SAndroid Build Coastguard Worker (torch.int8, torch.ops.aten.linspace.default): (0, 1), 273*da0073e9SAndroid Build Coastguard Worker (torch.uint8, torch.ops.aten.linspace.default): (0, 1), 274*da0073e9SAndroid Build Coastguard Worker (torch.int16, torch.ops.aten.linspace.default): (0, 1), 275*da0073e9SAndroid Build Coastguard Worker (torch.int32, torch.ops.aten.linspace.default): (0, 1), 276*da0073e9SAndroid Build Coastguard Worker (torch.int64, torch.ops.aten.linspace.default): (0, 1), 277*da0073e9SAndroid Build Coastguard Worker (torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), 278*da0073e9SAndroid Build Coastguard Worker (torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), 279*da0073e9SAndroid Build Coastguard Worker (torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), 280*da0073e9SAndroid Build Coastguard Worker (torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), 281*da0073e9SAndroid Build Coastguard Worker (torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), 282*da0073e9SAndroid Build Coastguard Worker (torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), 283*da0073e9SAndroid Build Coastguard Worker (torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), 284*da0073e9SAndroid Build Coastguard Worker (torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), 285*da0073e9SAndroid Build Coastguard Worker (torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), 286*da0073e9SAndroid Build Coastguard Worker (torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), 287*da0073e9SAndroid Build Coastguard Worker (torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), 288*da0073e9SAndroid Build Coastguard Worker (torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), 289*da0073e9SAndroid Build Coastguard Worker (torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), 290*da0073e9SAndroid Build Coastguard Worker (torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), 291*da0073e9SAndroid Build Coastguard Worker (torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), 292*da0073e9SAndroid Build Coastguard Worker } 293*da0073e9SAndroid Build Coastguard Worker if (decomp.dtype, op) in tol_table: 294*da0073e9SAndroid Build Coastguard Worker rtol, atol = tol_table[(decomp.dtype, op)] 295*da0073e9SAndroid Build Coastguard Worker else: 296*da0073e9SAndroid Build Coastguard Worker rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype) 297*da0073e9SAndroid Build Coastguard Worker test_case.assertEqual( 298*da0073e9SAndroid Build Coastguard Worker orig, 299*da0073e9SAndroid Build Coastguard Worker decomp, 300*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 301*da0073e9SAndroid Build Coastguard Worker atol=atol, 302*da0073e9SAndroid Build Coastguard Worker msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}", 303*da0073e9SAndroid Build Coastguard Worker ) 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker# Given f, returns an f' such that: 307*da0073e9SAndroid Build Coastguard Worker# - f' takes only positional arguments 308*da0073e9SAndroid Build Coastguard Worker# - All arguments to f' are floating-point Tensors 309*da0073e9SAndroid Build Coastguard Worker# - All outputs of f' are floating-point Tensors 310*da0073e9SAndroid Build Coastguard Workerdef normalize_op_input_output2( 311*da0073e9SAndroid Build Coastguard Worker f, args, kwargs, output_process_fn_grad=None, requires_grad=True 312*da0073e9SAndroid Build Coastguard Worker): 313*da0073e9SAndroid Build Coastguard Worker flat_args, args_spec = tree_flatten(args) 314*da0073e9SAndroid Build Coastguard Worker diff_argnums = tuple( 315*da0073e9SAndroid Build Coastguard Worker i 316*da0073e9SAndroid Build Coastguard Worker for i, arg in enumerate(flat_args) 317*da0073e9SAndroid Build Coastguard Worker if diff_arg(arg, requires_grad=requires_grad) 318*da0073e9SAndroid Build Coastguard Worker ) 319*da0073e9SAndroid Build Coastguard Worker assert len(diff_argnums) > 0 320*da0073e9SAndroid Build Coastguard Worker primals = tuple(flat_args[i] for i in diff_argnums) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker @functools.wraps(f) 323*da0073e9SAndroid Build Coastguard Worker def wrapped(*primals): 324*da0073e9SAndroid Build Coastguard Worker _args = list(flat_args) 325*da0073e9SAndroid Build Coastguard Worker for num, arg in zip(diff_argnums, primals): 326*da0073e9SAndroid Build Coastguard Worker _args[num] = arg 327*da0073e9SAndroid Build Coastguard Worker _args = tree_unflatten(_args, args_spec) 328*da0073e9SAndroid Build Coastguard Worker result = f(*_args, **kwargs) 329*da0073e9SAndroid Build Coastguard Worker if output_process_fn_grad is not None: 330*da0073e9SAndroid Build Coastguard Worker result = output_process_fn_grad(result) 331*da0073e9SAndroid Build Coastguard Worker if isinstance(result, tuple): 332*da0073e9SAndroid Build Coastguard Worker # TODO We should check that the integer outputs also agree 333*da0073e9SAndroid Build Coastguard Worker result = tuple( 334*da0073e9SAndroid Build Coastguard Worker r 335*da0073e9SAndroid Build Coastguard Worker for r in result 336*da0073e9SAndroid Build Coastguard Worker if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex()) 337*da0073e9SAndroid Build Coastguard Worker ) 338*da0073e9SAndroid Build Coastguard Worker assert len(result) > 0 339*da0073e9SAndroid Build Coastguard Worker return result 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker return wrapped, primals 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker# NB: This also upcasts dtype arguments 345*da0073e9SAndroid Build Coastguard Worker# TODO: handle complex correctly 346*da0073e9SAndroid Build Coastguard Workerdef upcast_tensor(x, dtype=torch.float32): 347*da0073e9SAndroid Build Coastguard Worker if isinstance(x, Tensor) and x.dtype.is_floating_point: 348*da0073e9SAndroid Build Coastguard Worker return x.to(dtype=dtype) 349*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, torch.dtype) and x in [ 350*da0073e9SAndroid Build Coastguard Worker torch.float16, 351*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 352*da0073e9SAndroid Build Coastguard Worker torch.float, 353*da0073e9SAndroid Build Coastguard Worker ]: 354*da0073e9SAndroid Build Coastguard Worker return dtype 355*da0073e9SAndroid Build Coastguard Worker else: 356*da0073e9SAndroid Build Coastguard Worker return x 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Workerdef normalize_op_input_output(f, sample, requires_grad=True): 360*da0073e9SAndroid Build Coastguard Worker args = tuple([sample.input] + list(sample.args)) 361*da0073e9SAndroid Build Coastguard Worker return normalize_op_input_output2( 362*da0073e9SAndroid Build Coastguard Worker f, 363*da0073e9SAndroid Build Coastguard Worker args, 364*da0073e9SAndroid Build Coastguard Worker sample.kwargs, 365*da0073e9SAndroid Build Coastguard Worker sample.output_process_fn_grad, 366*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 367*da0073e9SAndroid Build Coastguard Worker ) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard WorkerCROSS_REF_EXCLUDE_SET = { 371*da0073e9SAndroid Build Coastguard Worker # CUBLAS_STATUS_NOT_SUPPORTED when calling 372*da0073e9SAndroid Build Coastguard Worker # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, 373*da0073e9SAndroid Build Coastguard Worker # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, 374*da0073e9SAndroid Build Coastguard Worker # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, 375*da0073e9SAndroid Build Coastguard Worker # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)` 376*da0073e9SAndroid Build Coastguard Worker ("cuda", torch.bfloat16, "nn.functional.bilinear"), 377*da0073e9SAndroid Build Coastguard Worker # randomness 378*da0073e9SAndroid Build Coastguard Worker (None, None, "special.ndtr"), # aten.special_ndtr was not decomposed 379*da0073e9SAndroid Build Coastguard Worker (None, None, "new_empty"), 380*da0073e9SAndroid Build Coastguard Worker (None, None, "empty_like"), 381*da0073e9SAndroid Build Coastguard Worker (None, None, "empty"), 382*da0073e9SAndroid Build Coastguard Worker # AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default. 383*da0073e9SAndroid Build Coastguard Worker (None, None, "item"), 384*da0073e9SAndroid Build Coastguard Worker # It's the only in-place op without an out-of-place equivalent in the Python API 385*da0073e9SAndroid Build Coastguard Worker # Its OpInfo wrongly registers it as `torch.zero_(x.clone())`. 386*da0073e9SAndroid Build Coastguard Worker (None, None, "zero_"), 387*da0073e9SAndroid Build Coastguard Worker # No idea what's going on here 388*da0073e9SAndroid Build Coastguard Worker # In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), []) 389*da0073e9SAndroid Build Coastguard Worker # in the test, but it seems to pass when tested locally and in the logsumexp test 390*da0073e9SAndroid Build Coastguard Worker (None, torch.float32, "masked.logsumexp"), 391*da0073e9SAndroid Build Coastguard Worker (None, torch.float64, "masked.logsumexp"), 392*da0073e9SAndroid Build Coastguard Worker # exp_vml_cpu not implemented for Half 393*da0073e9SAndroid Build Coastguard Worker (torch.cpu, torch.float16, "signal.windows.exponential"), 394*da0073e9SAndroid Build Coastguard Worker (torch.cpu, torch.float16, "signal.windows.gaussian"), 395*da0073e9SAndroid Build Coastguard Worker # sin_vml_cpu not implemented for Half 396*da0073e9SAndroid Build Coastguard Worker (torch.cpu, torch.float16, "signal.windows.cosine"), 397*da0073e9SAndroid Build Coastguard Worker # CompositeAutogradImplicit 398*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/81669 399*da0073e9SAndroid Build Coastguard Worker (None, None, "nn.functional.relu6"), 400*da0073e9SAndroid Build Coastguard Worker # This decomp runs before autograd. 401*da0073e9SAndroid Build Coastguard Worker (None, None, "nn.functional.rrelu"), 402*da0073e9SAndroid Build Coastguard Worker (None, None, "meshgrid"), 403*da0073e9SAndroid Build Coastguard Worker # Decomposition registered as Autograd 404*da0073e9SAndroid Build Coastguard Worker (None, None, "nn.functional.hardshrink"), 405*da0073e9SAndroid Build Coastguard Worker (None, None, "nn.functional.softshrink"), 406*da0073e9SAndroid Build Coastguard Worker # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit) 407*da0073e9SAndroid Build Coastguard Worker (None, None, "diag"), 408*da0073e9SAndroid Build Coastguard Worker # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32 409*da0073e9SAndroid Build Coastguard Worker ("cpu", torch.bfloat16, "_softmax_backward_data"), 410*da0073e9SAndroid Build Coastguard Worker (None, None, "norm"), 411*da0073e9SAndroid Build Coastguard Worker # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise) 412*da0073e9SAndroid Build Coastguard Worker (None, None, "native_batch_norm"), 413*da0073e9SAndroid Build Coastguard Worker (None, None, "_upsample_bilinear2d_aa"), 414*da0073e9SAndroid Build Coastguard Worker (None, None, "empty_strided"), # aten.empty_strided was not decomposed 415*da0073e9SAndroid Build Coastguard Worker} 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard WorkerCROSS_REF_BACKWARD_EXCLUDE_SET = { 418*da0073e9SAndroid Build Coastguard Worker # Decomposed backward formula is not as precise 419*da0073e9SAndroid Build Coastguard Worker ("cpu", torch.bfloat16, "nn.functional.hardswish"), 420*da0073e9SAndroid Build Coastguard Worker ("cuda", torch.float16, "nn.functional.cross_entropy"), 421*da0073e9SAndroid Build Coastguard Worker} 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Workerall_decomposed = set() 424*da0073e9SAndroid Build Coastguard Workerall_called = defaultdict(int) 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker# Helpful snippet for testing coverage 427*da0073e9SAndroid Build Coastguard Worker""" 428*da0073e9SAndroid Build Coastguard Workerimport atexit 429*da0073e9SAndroid Build Coastguard Workerdef check_coverage(): 430*da0073e9SAndroid Build Coastguard Worker print("missing coverage:") 431*da0073e9SAndroid Build Coastguard Worker print("\n".join(map(str, decomposition_table.keys() - all_decomposed))) 432*da0073e9SAndroid Build Coastguard Workeratexit.register(check_coverage) 433*da0073e9SAndroid Build Coastguard Worker""" 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker# Helpful snippet for Horace to create his google sheet :) 436*da0073e9SAndroid Build Coastguard Worker""" 437*da0073e9SAndroid Build Coastguard Workerimport atexit 438*da0073e9SAndroid Build Coastguard Workerdef dump_ops(): 439*da0073e9SAndroid Build Coastguard Worker with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g: 440*da0073e9SAndroid Build Coastguard Worker for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__): 441*da0073e9SAndroid Build Coastguard Worker f.write(f'{op.__name__}\n') 442*da0073e9SAndroid Build Coastguard Worker g.write(f'{count}\n') 443*da0073e9SAndroid Build Coastguard Worker with open('run_decompositions.txt', 'w') as f: 444*da0073e9SAndroid Build Coastguard Worker for op in sorted([i.__name__ for i in all_decomposed]): 445*da0073e9SAndroid Build Coastguard Worker f.write(f'{op}\n') 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Workeratexit.register(dump_ops) 448*da0073e9SAndroid Build Coastguard Worker""" 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Workerdef any_unsupported(args, kwargs): 452*da0073e9SAndroid Build Coastguard Worker def test_unsupported(t): 453*da0073e9SAndroid Build Coastguard Worker if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: 454*da0073e9SAndroid Build Coastguard Worker # These are all things that we haven't coded decompositions 455*da0073e9SAndroid Build Coastguard Worker # to handle correctly. Maybe they should. 456*da0073e9SAndroid Build Coastguard Worker return any( 457*da0073e9SAndroid Build Coastguard Worker [ 458*da0073e9SAndroid Build Coastguard Worker t.is_sparse_csr, 459*da0073e9SAndroid Build Coastguard Worker t.is_sparse, 460*da0073e9SAndroid Build Coastguard Worker t.is_mkldnn, 461*da0073e9SAndroid Build Coastguard Worker t.is_quantized, 462*da0073e9SAndroid Build Coastguard Worker t.is_nested, 463*da0073e9SAndroid Build Coastguard Worker torch._is_functional_tensor(t), 464*da0073e9SAndroid Build Coastguard Worker ] 465*da0073e9SAndroid Build Coastguard Worker ) 466*da0073e9SAndroid Build Coastguard Worker elif torch.overrides.is_tensor_like(t): 467*da0073e9SAndroid Build Coastguard Worker # Decompositions will generally change the behavior of Tensor-like 468*da0073e9SAndroid Build Coastguard Worker # subclasses, so bypass tests in this case too 469*da0073e9SAndroid Build Coastguard Worker return True 470*da0073e9SAndroid Build Coastguard Worker else: 471*da0073e9SAndroid Build Coastguard Worker return False 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker flat_args = pytree.arg_tree_leaves(*args, **kwargs) 474*da0073e9SAndroid Build Coastguard Worker return any(test_unsupported(x) for x in flat_args) 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Workercore_backward_failures = { 478*da0073e9SAndroid Build Coastguard Worker skip("_softmax_backward_data"), # slow: fails with --timeout=360 secs 479*da0073e9SAndroid Build Coastguard Worker xfail("addcdiv"), 480*da0073e9SAndroid Build Coastguard Worker skip("addcmul"), # slow: fails with --timeout=360 secs 481*da0073e9SAndroid Build Coastguard Worker skip("deg2rad"), # slow: fails with --timeout=360 secs 482*da0073e9SAndroid Build Coastguard Worker skip("diag_embed"), # slow: fails with --timeout=360 secs 483*da0073e9SAndroid Build Coastguard Worker skip("frac"), # slow: fails with --timeout=360 secs 484*da0073e9SAndroid Build Coastguard Worker skip("grid_sampler_2d"), # slow: fails with --timeout=360 secs 485*da0073e9SAndroid Build Coastguard Worker xfail("lerp"), 486*da0073e9SAndroid Build Coastguard Worker skip("logaddexp"), # slow: fails with --timeout=360 secs 487*da0073e9SAndroid Build Coastguard Worker skip("native_dropout_backward"), # slow: fails with --timeout=360 secs 488*da0073e9SAndroid Build Coastguard Worker xfail("nn.functional.binary_cross_entropy_with_logits"), 489*da0073e9SAndroid Build Coastguard Worker skip("nn.functional.glu"), # slow: fails with --timeout=360 secs 490*da0073e9SAndroid Build Coastguard Worker xfail("nn.functional.hardshrink"), 491*da0073e9SAndroid Build Coastguard Worker xfail("nn.functional.softshrink"), 492*da0073e9SAndroid Build Coastguard Worker skip("nn.functional.unfold"), # slow: fails with --timeout=360 secs 493*da0073e9SAndroid Build Coastguard Worker xfail("norm"), 494*da0073e9SAndroid Build Coastguard Worker xfail("norm", "fro"), 495*da0073e9SAndroid Build Coastguard Worker xfail("norm", "inf"), 496*da0073e9SAndroid Build Coastguard Worker xfail("norm", "nuc"), 497*da0073e9SAndroid Build Coastguard Worker skip("rad2deg"), # slow: fails with --timeout=360 secs 498*da0073e9SAndroid Build Coastguard Worker skip("renorm"), # slow: fails with --timeout=360 secs 499*da0073e9SAndroid Build Coastguard Worker skip("rot90"), # slow: fails with --timeout=360 secs 500*da0073e9SAndroid Build Coastguard Worker skip("rsub"), # slow: fails with --timeout=360 secs 501*da0073e9SAndroid Build Coastguard Worker skip("sgn"), # slow: fails with --timeout=360 secs 502*da0073e9SAndroid Build Coastguard Worker skip("special.xlog1py"), # slow: fails with --timeout=360 secs 503*da0073e9SAndroid Build Coastguard Worker xfail("stack"), 504*da0073e9SAndroid Build Coastguard Worker skip("tril"), # slow: fails with --timeout=360 secs 505*da0073e9SAndroid Build Coastguard Worker skip("triu"), # slow: fails with --timeout=360 secs 506*da0073e9SAndroid Build Coastguard Worker skip("unfold_copy"), # slow: fails with --timeout=360 secs 507*da0073e9SAndroid Build Coastguard Worker skip("xlogy"), # slow: fails with --timeout=360 secs 508*da0073e9SAndroid Build Coastguard Worker xfail("zero_"), 509*da0073e9SAndroid Build Coastguard Worker} 510*da0073e9SAndroid Build Coastguard Workerif not TEST_WITH_SLOW: 511*da0073e9SAndroid Build Coastguard Worker core_backward_failures.update( 512*da0073e9SAndroid Build Coastguard Worker { 513*da0073e9SAndroid Build Coastguard Worker skip("addr"), # slow: takes 46 sec on A100 514*da0073e9SAndroid Build Coastguard Worker skip("baddbmm"), # slow: takes 800+ sec on A100 515*da0073e9SAndroid Build Coastguard Worker skip("clamp_min"), # slow: takes 800 sec on A100 516*da0073e9SAndroid Build Coastguard Worker skip("clamp_max"), # slow: takes 800 sec on A100 517*da0073e9SAndroid Build Coastguard Worker skip("logit"), # slow: takes 44 sec on A100 518*da0073e9SAndroid Build Coastguard Worker skip("nn.functional.hardswish"), # slow: takes 60 sec on A100 519*da0073e9SAndroid Build Coastguard Worker skip("std_mean"), # slow: takes 170 sec on A100 520*da0073e9SAndroid Build Coastguard Worker skip("split", variant_name="list_args"), # slow: takes 118 sec on A100 521*da0073e9SAndroid Build Coastguard Worker skip("transpose"), # slow: takes 50 sec on A100 522*da0073e9SAndroid Build Coastguard Worker skip("unbind"), # slow: takes 70 sec on A100 523*da0073e9SAndroid Build Coastguard Worker skip("unsafe_split"), # slow: takes 49 sec on A100 524*da0073e9SAndroid Build Coastguard Worker } 525*da0073e9SAndroid Build Coastguard Worker ) 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Workercomprehensive_failures = { 528*da0073e9SAndroid Build Coastguard Worker xfail( 529*da0073e9SAndroid Build Coastguard Worker "nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,) 530*da0073e9SAndroid Build Coastguard Worker ), # off by one error 531*da0073e9SAndroid Build Coastguard Worker xfail( 532*da0073e9SAndroid Build Coastguard Worker "nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,) 533*da0073e9SAndroid Build Coastguard Worker ), # off by one error 534*da0073e9SAndroid Build Coastguard Worker xfail( 535*da0073e9SAndroid Build Coastguard Worker "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,) 536*da0073e9SAndroid Build Coastguard Worker ), # off by one error 537*da0073e9SAndroid Build Coastguard Worker} 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker@unMarkDynamoStrictTest 541*da0073e9SAndroid Build Coastguard Workerclass TestDecomp(TestCase): 542*da0073e9SAndroid Build Coastguard Worker longMessage = True 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker # NB: This actually overlaps with test_comprehensive, but it only 545*da0073e9SAndroid Build Coastguard Worker # runs on things that are definitely decomposed so it's a lot faster 546*da0073e9SAndroid Build Coastguard Worker # to run 547*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 548*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 549*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 550*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 551*da0073e9SAndroid Build Coastguard Worker @ops(_decomp_test_ops) 552*da0073e9SAndroid Build Coastguard Worker def test_quick(self, device, dtype, op): 553*da0073e9SAndroid Build Coastguard Worker self.do_cross_ref(device, dtype, op, run_all=False) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 556*da0073e9SAndroid Build Coastguard Worker @skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures) 557*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 558*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 559*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 560*da0073e9SAndroid Build Coastguard Worker @ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,)) 561*da0073e9SAndroid Build Coastguard Worker def test_quick_core_backward(self, device, dtype, op): 562*da0073e9SAndroid Build Coastguard Worker for sample_input in op.sample_inputs(device, dtype, requires_grad=True): 563*da0073e9SAndroid Build Coastguard Worker aten_name = op.decomp_aten_name or op.aten_name 564*da0073e9SAndroid Build Coastguard Worker args = [sample_input.input] + list(sample_input.args) 565*da0073e9SAndroid Build Coastguard Worker kwargs = sample_input.kwargs 566*da0073e9SAndroid Build Coastguard Worker func = partial(op.get_op(), **kwargs) 567*da0073e9SAndroid Build Coastguard Worker with self.DecompCrossRefMode( 568*da0073e9SAndroid Build Coastguard Worker self, self.precision, self.rel_tol, dtype, run_all=False 569*da0073e9SAndroid Build Coastguard Worker ) as mode, enable_python_dispatcher(): 570*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(func, args) 571*da0073e9SAndroid Build Coastguard Worker self.check_decomposed(aten_name, mode) 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 574*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 575*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 576*da0073e9SAndroid Build Coastguard Worker @skipOps("TestDecomp", "test_comprehensive", comprehensive_failures) 577*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 578*da0073e9SAndroid Build Coastguard Worker @ops(op_db) 579*da0073e9SAndroid Build Coastguard Worker def test_comprehensive(self, device, dtype, op): 580*da0073e9SAndroid Build Coastguard Worker self.do_cross_ref(device, dtype, op, run_all=True) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker def test_uniform(self, device): 583*da0073e9SAndroid Build Coastguard Worker size = (2, 3, 4, 5) 584*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 585*da0073e9SAndroid Build Coastguard Worker x = make_tensor(size, dtype=dtype, device=device) 586*da0073e9SAndroid Build Coastguard Worker low = 0.3 587*da0073e9SAndroid Build Coastguard Worker high = 0.9 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 590*da0073e9SAndroid Build Coastguard Worker ref = torch.ops.aten.uniform(x, low, high) 591*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 592*da0073e9SAndroid Build Coastguard Worker res = torch._decomp.decompositions.uniform(x, low=low, high=high) 593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker def test_broadcasting_index_copy(self, device): 596*da0073e9SAndroid Build Coastguard Worker x = torch.zeros([1, 10], device=device) 597*da0073e9SAndroid Build Coastguard Worker xs = torch.ones([2, 10], device=device) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker def index_copy(xs, x): 600*da0073e9SAndroid Build Coastguard Worker torch._decomp.decompositions.index_copy_( 601*da0073e9SAndroid Build Coastguard Worker xs, 0, torch.tensor(0).to(device), x 602*da0073e9SAndroid Build Coastguard Worker ) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker index_copy(xs, x) 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker xs_two = torch.ones([2, 10], device=device) 607*da0073e9SAndroid Build Coastguard Worker xs_two[0] = x 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xs, xs_two) 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker def test_cat_single_input(self, device): 612*da0073e9SAndroid Build Coastguard Worker decomp_table = torch._inductor.decomposition.select_decomp_table() 613*da0073e9SAndroid Build Coastguard Worker cat_inductor = decomp_table[torch.ops.aten.cat.default] 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([2048, 2048], device=device) 616*da0073e9SAndroid Build Coastguard Worker inps = [inp for _ in range(10)] 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker for dim in (-1, 0, 1): 619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim)) 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker def test_rrelu_with_noise(self, device): 622*da0073e9SAndroid Build Coastguard Worker # rrelu_with_noise behavior depends on a) whether elements in the input 623*da0073e9SAndroid Build Coastguard Worker # are <= 0, and b) whether we're in training mode. Cover all cases: 624*da0073e9SAndroid Build Coastguard Worker dtype = torch.float64 625*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device) 626*da0073e9SAndroid Build Coastguard Worker lower = 1.0 627*da0073e9SAndroid Build Coastguard Worker upper = 4.0 628*da0073e9SAndroid Build Coastguard Worker training = False 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 631*da0073e9SAndroid Build Coastguard Worker noise_ref = torch.zeros(x.shape, dtype=dtype, device=device) 632*da0073e9SAndroid Build Coastguard Worker ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training) 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 635*da0073e9SAndroid Build Coastguard Worker noise_res = torch.zeros(x.shape, dtype=dtype, device=device) 636*da0073e9SAndroid Build Coastguard Worker res = torch._decomp.decompositions.rrelu_with_noise( 637*da0073e9SAndroid Build Coastguard Worker x, 638*da0073e9SAndroid Build Coastguard Worker noise_res, 639*da0073e9SAndroid Build Coastguard Worker lower, 640*da0073e9SAndroid Build Coastguard Worker upper, 641*da0073e9SAndroid Build Coastguard Worker training, 642*da0073e9SAndroid Build Coastguard Worker ) 643*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(noise_ref, noise_res) 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker # Now with training=True: 647*da0073e9SAndroid Build Coastguard Worker training = True 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 650*da0073e9SAndroid Build Coastguard Worker noise_ref = torch.zeros(x.shape, dtype=dtype, device=device) 651*da0073e9SAndroid Build Coastguard Worker ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training) 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 654*da0073e9SAndroid Build Coastguard Worker noise_res = torch.zeros(x.shape, dtype=dtype, device=device) 655*da0073e9SAndroid Build Coastguard Worker res = torch._decomp.decompositions.rrelu_with_noise( 656*da0073e9SAndroid Build Coastguard Worker x, 657*da0073e9SAndroid Build Coastguard Worker noise_res, 658*da0073e9SAndroid Build Coastguard Worker lower, 659*da0073e9SAndroid Build Coastguard Worker upper, 660*da0073e9SAndroid Build Coastguard Worker training, 661*da0073e9SAndroid Build Coastguard Worker ) 662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 663*da0073e9SAndroid Build Coastguard Worker self.assertEqual(noise_ref, noise_res) 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 666*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 667*da0073e9SAndroid Build Coastguard Worker @tf32_off() 668*da0073e9SAndroid Build Coastguard Worker # only tests RNNs since we have py dispsatcher decomps for them 669*da0073e9SAndroid Build Coastguard Worker @modules( 670*da0073e9SAndroid Build Coastguard Worker filter( 671*da0073e9SAndroid Build Coastguard Worker lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), 672*da0073e9SAndroid Build Coastguard Worker module_db, 673*da0073e9SAndroid Build Coastguard Worker ) 674*da0073e9SAndroid Build Coastguard Worker ) 675*da0073e9SAndroid Build Coastguard Worker def test_rnn_decomp_module(self, device, dtype, module_info, training): 676*da0073e9SAndroid Build Coastguard Worker module_cls = module_info.module_cls 677*da0073e9SAndroid Build Coastguard Worker module_inputs = module_info.module_inputs_func( 678*da0073e9SAndroid Build Coastguard Worker module_info, 679*da0073e9SAndroid Build Coastguard Worker device=device, 680*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 681*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 682*da0073e9SAndroid Build Coastguard Worker training=training, 683*da0073e9SAndroid Build Coastguard Worker ) 684*da0073e9SAndroid Build Coastguard Worker for module_input in module_inputs: 685*da0073e9SAndroid Build Coastguard Worker if module_input.forward_input is None: 686*da0073e9SAndroid Build Coastguard Worker continue 687*da0073e9SAndroid Build Coastguard Worker args, kwargs = ( 688*da0073e9SAndroid Build Coastguard Worker module_input.constructor_input.args, 689*da0073e9SAndroid Build Coastguard Worker module_input.constructor_input.kwargs, 690*da0073e9SAndroid Build Coastguard Worker ) 691*da0073e9SAndroid Build Coastguard Worker m = module_cls(*args, **kwargs) 692*da0073e9SAndroid Build Coastguard Worker m.to(device).to(dtype) 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker args, kwargs = ( 695*da0073e9SAndroid Build Coastguard Worker module_input.forward_input.args, 696*da0073e9SAndroid Build Coastguard Worker module_input.forward_input.kwargs, 697*da0073e9SAndroid Build Coastguard Worker ) 698*da0073e9SAndroid Build Coastguard Worker with self.DecompCrossRefMode( 699*da0073e9SAndroid Build Coastguard Worker self, self.precision, self.rel_tol, dtype, run_all=True 700*da0073e9SAndroid Build Coastguard Worker ), enable_python_dispatcher(): 701*da0073e9SAndroid Build Coastguard Worker decomp_out = m(*args, **kwargs) 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker non_decomp_out = m(*args, **kwargs) 704*da0073e9SAndroid Build Coastguard Worker # without this check, incorrect decomps at the python dispatcher level can still pass because 705*da0073e9SAndroid Build Coastguard Worker # they're checking aten decomps at the torch_dispatch level 706*da0073e9SAndroid Build Coastguard Worker self.assertEqual(decomp_out, non_decomp_out) 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_unflatten_weight_bias(self, device): 709*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/100970 710*da0073e9SAndroid Build Coastguard Worker shape = (1, 3, 2, 2) 711*da0073e9SAndroid Build Coastguard Worker input = torch.randn(shape, device=device) 712*da0073e9SAndroid Build Coastguard Worker weight = torch.randn((3, 1, 1, 1), device=device) 713*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(3, device=device) 714*da0073e9SAndroid Build Coastguard Worker mean = torch.randn(3, device=device) 715*da0073e9SAndroid Build Coastguard Worker var = torch.randn(3, device=device) 716*da0073e9SAndroid Build Coastguard Worker res = torch._decomp.decompositions.native_batch_norm( 717*da0073e9SAndroid Build Coastguard Worker input, weight, bias, mean, var, False, 1, 1e-05 718*da0073e9SAndroid Build Coastguard Worker ) 719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, res[0].shape) 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker def test_arange_graph(self, device): 722*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.proxy_tensor import make_fx 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker def func(x, start): 725*da0073e9SAndroid Build Coastguard Worker le = x.shape[-1] 726*da0073e9SAndroid Build Coastguard Worker if start is None: 727*da0073e9SAndroid Build Coastguard Worker a = torch.arange(le, dtype=torch.float32, device=x.device) 728*da0073e9SAndroid Build Coastguard Worker else: 729*da0073e9SAndroid Build Coastguard Worker a = torch.arange(start, le, dtype=torch.float32, device=x.device) 730*da0073e9SAndroid Build Coastguard Worker return a 731*da0073e9SAndroid Build Coastguard Worker 732*da0073e9SAndroid Build Coastguard Worker pattern = r", device = device\(.+\), requires_grad = False" 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker cfunc = make_fx(func, decomposition_table=decomposition_table) 735*da0073e9SAndroid Build Coastguard Worker fx_g = cfunc(torch.rand(10, device=device), None) 736*da0073e9SAndroid Build Coastguard Worker fx_g_code = fx_g.code.strip() 737*da0073e9SAndroid Build Coastguard Worker # Remove device and requires_grad 738*da0073e9SAndroid Build Coastguard Worker fx_g_code = re.sub(pattern, "", fx_g_code) 739*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 740*da0073e9SAndroid Build Coastguard Worker fx_g_code, 741*da0073e9SAndroid Build Coastguard Worker """\ 742*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1, start_1): 743*da0073e9SAndroid Build Coastguard Worker iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64) 744*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.prims.mul.default(iota, 1); iota = None 745*da0073e9SAndroid Build Coastguard Worker add = torch.ops.prims.add.default(mul, 0); mul = None 746*da0073e9SAndroid Build Coastguard Worker convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None 747*da0073e9SAndroid Build Coastguard Worker return convert_element_type""", 748*da0073e9SAndroid Build Coastguard Worker ) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker fx_g = cfunc(torch.rand(10, device=device), 1) 751*da0073e9SAndroid Build Coastguard Worker fx_g_code = fx_g.code.strip() 752*da0073e9SAndroid Build Coastguard Worker # Remove device and requires_grad 753*da0073e9SAndroid Build Coastguard Worker fx_g_code = re.sub(pattern, "", fx_g_code) 754*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 755*da0073e9SAndroid Build Coastguard Worker fx_g_code, 756*da0073e9SAndroid Build Coastguard Worker """\ 757*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1, start_1): 758*da0073e9SAndroid Build Coastguard Worker iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64) 759*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.prims.mul.default(iota, 1); iota = None 760*da0073e9SAndroid Build Coastguard Worker add = torch.ops.prims.add.default(mul, 1); mul = None 761*da0073e9SAndroid Build Coastguard Worker convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None 762*da0073e9SAndroid Build Coastguard Worker return convert_element_type""", 763*da0073e9SAndroid Build Coastguard Worker ) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker def test_masked_fill(self, device): 766*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.proxy_tensor import make_fx 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type not in [ 769*da0073e9SAndroid Build Coastguard Worker "xpu", 770*da0073e9SAndroid Build Coastguard Worker "cuda", 771*da0073e9SAndroid Build Coastguard Worker torch._C._get_privateuse1_backend_name(), 772*da0073e9SAndroid Build Coastguard Worker ]: 773*da0073e9SAndroid Build Coastguard Worker self.skipTest("only runs on XPU and CUDA and PrivateUse1.") 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker def func(scores, mask, value): 776*da0073e9SAndroid Build Coastguard Worker return scores.masked_fill(mask, value) 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker scores_t = torch.tensor([1, 2, 3, 4], device=device) 779*da0073e9SAndroid Build Coastguard Worker mask_t = torch.tensor([True, True, True, True], device=device) 780*da0073e9SAndroid Build Coastguard Worker value_t = torch.tensor(0, dtype=scores_t.dtype) 781*da0073e9SAndroid Build Coastguard Worker cfunc = make_fx(func, decomposition_table=decomposition_table) 782*da0073e9SAndroid Build Coastguard Worker fx_g = cfunc(scores_t, mask_t, value_t) 783*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 784*da0073e9SAndroid Build Coastguard Worker fx_g.code.strip(), 785*da0073e9SAndroid Build Coastguard Worker """\ 786*da0073e9SAndroid Build Coastguard Workerdef forward(self, scores_1, mask_1, value_1): 787*da0073e9SAndroid Build Coastguard Worker where = torch.ops.prims.where.default(mask_1, value_1, scores_1); mask_1 = value_1 = scores_1 = None 788*da0073e9SAndroid Build Coastguard Worker return where""", 789*da0073e9SAndroid Build Coastguard Worker ) 790*da0073e9SAndroid Build Coastguard Worker 791*da0073e9SAndroid Build Coastguard Worker class DecompCrossRefMode(TorchDispatchMode): 792*da0073e9SAndroid Build Coastguard Worker def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all): 793*da0073e9SAndroid Build Coastguard Worker self.test_case = test_case 794*da0073e9SAndroid Build Coastguard Worker self.saved_precision = saved_precision 795*da0073e9SAndroid Build Coastguard Worker self.saved_rel_tol = saved_rel_tol 796*da0073e9SAndroid Build Coastguard Worker self.test_dtype = dtype 797*da0073e9SAndroid Build Coastguard Worker self.run_all = run_all 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Worker # We check the correctness of each decomposition right after running it. 800*da0073e9SAndroid Build Coastguard Worker # So, when we encounter a decomposition, we run the function normally, and 801*da0073e9SAndroid Build Coastguard Worker # then run the decomposition, and ensure they're identical. 802*da0073e9SAndroid Build Coastguard Worker self.called = set() 803*da0073e9SAndroid Build Coastguard Worker self.decomposed = set() 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 806*da0073e9SAndroid Build Coastguard Worker self.test_case.precision = self.saved_precision 807*da0073e9SAndroid Build Coastguard Worker self.test_case.rel_tol = self.saved_rel_tol 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker self.called.add(func) 810*da0073e9SAndroid Build Coastguard Worker all_called[func] += 1 811*da0073e9SAndroid Build Coastguard Worker 812*da0073e9SAndroid Build Coastguard Worker # Stuff we shouldn't bother testing 813*da0073e9SAndroid Build Coastguard Worker # (TODO: remove detach from the decomp table?) 814*da0073e9SAndroid Build Coastguard Worker # N.b. Testing in-place ops would need dedicated logic 815*da0073e9SAndroid Build Coastguard Worker in_place = func.name()[-1] == "_" 816*da0073e9SAndroid Build Coastguard Worker ignored_ops = [ 817*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.detach.default, 818*da0073e9SAndroid Build Coastguard Worker # non-deterministic ops 819*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.empty.memory_format, 820*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.empty_like.default, 821*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.new_empty.default, 822*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.empty_strided.default, 823*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.new_empty_strided.default, 824*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.randn.default, 825*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.native_dropout.default, 826*da0073e9SAndroid Build Coastguard Worker ] 827*da0073e9SAndroid Build Coastguard Worker if ( 828*da0073e9SAndroid Build Coastguard Worker func not in decomposition_table 829*da0073e9SAndroid Build Coastguard Worker or func in ignored_ops 830*da0073e9SAndroid Build Coastguard Worker or torch.Tag.nondeterministic_seeded in func.tags 831*da0073e9SAndroid Build Coastguard Worker or any_unsupported(args, kwargs) 832*da0073e9SAndroid Build Coastguard Worker or in_place 833*da0073e9SAndroid Build Coastguard Worker ): 834*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker self.decomposed.add(func) 837*da0073e9SAndroid Build Coastguard Worker all_decomposed.add(func) 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker # We take 2 main strategies for verifying correctness/numerical stability of decompositions 840*da0073e9SAndroid Build Coastguard Worker # The first one is simply tolerance checking between decomp_out and pytorch_out 841*da0073e9SAndroid Build Coastguard Worker # However, for fp16/bf16 and reductions, this becomes very 842*da0073e9SAndroid Build Coastguard Worker # finicky, as there are not many guarantees we can make. 843*da0073e9SAndroid Build Coastguard Worker # So, for fp16/bf16, we instead compare the difference of 844*da0073e9SAndroid Build Coastguard Worker # {decomp_out, pytorch_out_64} and {pytorch_out, 845*da0073e9SAndroid Build Coastguard Worker # pytorch_out_64}. In other words, we compare how far the 846*da0073e9SAndroid Build Coastguard Worker # decomposition and pytorch are from the "ground truth" (i.e. 847*da0073e9SAndroid Build Coastguard Worker # fp64). If the decomposition results in more error, we error 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Worker # We also decompose the decomposition recursively for 850*da0073e9SAndroid Build Coastguard Worker # further coverage, as some paths not be exercised directly by 851*da0073e9SAndroid Build Coastguard Worker # OpInfos (sadly) but just by other ops 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker decomposition = decomposition_table[func] 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker do_relative_check = self.test_dtype in [torch.float16, torch.bfloat16] 856*da0073e9SAndroid Build Coastguard Worker if self.run_all: 857*da0073e9SAndroid Build Coastguard Worker # Execute recursively via DFS, to find the root of a possible error first 858*da0073e9SAndroid Build Coastguard Worker with self: 859*da0073e9SAndroid Build Coastguard Worker decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) 860*da0073e9SAndroid Build Coastguard Worker else: 861*da0073e9SAndroid Build Coastguard Worker decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker # At this stage we should not be decomposing an in-place op 864*da0073e9SAndroid Build Coastguard Worker # We'd like to have decompositions that decompose out-of-place ops into out-of-place ops 865*da0073e9SAndroid Build Coastguard Worker # because decompositions are run after functionalisation and we would not like them to 866*da0073e9SAndroid Build Coastguard Worker # de-functionalise the graph, as that would break AoTAutograd 867*da0073e9SAndroid Build Coastguard Worker # We run the real function *after* the decomposition to make sure that the 868*da0073e9SAndroid Build Coastguard Worker # decomposition does not modify any of the inputs in-place. If it does 869*da0073e9SAndroid Build Coastguard Worker # real_out should be differen than decom_out so we should catch this 870*da0073e9SAndroid Build Coastguard Worker real_out_unflat = func(*args, **kwargs) 871*da0073e9SAndroid Build Coastguard Worker real_out = pytree.tree_leaves(real_out_unflat) 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker assert len(real_out) == len(decomp_out) 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker if do_relative_check: 876*da0073e9SAndroid Build Coastguard Worker upcast = partial(upcast_tensor, dtype=torch.float64) 877*da0073e9SAndroid Build Coastguard Worker real_out_double, _ = tree_flatten( 878*da0073e9SAndroid Build Coastguard Worker func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) 879*da0073e9SAndroid Build Coastguard Worker ) 880*da0073e9SAndroid Build Coastguard Worker for i, (orig, decomp, ref) in enumerate( 881*da0073e9SAndroid Build Coastguard Worker zip(real_out, decomp_out, real_out_double) 882*da0073e9SAndroid Build Coastguard Worker ): 883*da0073e9SAndroid Build Coastguard Worker if not isinstance(orig, torch.Tensor): 884*da0073e9SAndroid Build Coastguard Worker assert type(orig) == type(decomp) 885*da0073e9SAndroid Build Coastguard Worker assert orig == decomp 886*da0073e9SAndroid Build Coastguard Worker continue 887*da0073e9SAndroid Build Coastguard Worker op_assert_ref( 888*da0073e9SAndroid Build Coastguard Worker self.test_case, 889*da0073e9SAndroid Build Coastguard Worker func, 890*da0073e9SAndroid Build Coastguard Worker self.test_dtype, 891*da0073e9SAndroid Build Coastguard Worker i, 892*da0073e9SAndroid Build Coastguard Worker orig, 893*da0073e9SAndroid Build Coastguard Worker decomp, 894*da0073e9SAndroid Build Coastguard Worker ref, 895*da0073e9SAndroid Build Coastguard Worker args, 896*da0073e9SAndroid Build Coastguard Worker kwargs, 897*da0073e9SAndroid Build Coastguard Worker ) 898*da0073e9SAndroid Build Coastguard Worker else: 899*da0073e9SAndroid Build Coastguard Worker for orig, decomp in zip(real_out, decomp_out): 900*da0073e9SAndroid Build Coastguard Worker if not isinstance(orig, torch.Tensor): 901*da0073e9SAndroid Build Coastguard Worker assert type(orig) == type(decomp) 902*da0073e9SAndroid Build Coastguard Worker assert orig == decomp 903*da0073e9SAndroid Build Coastguard Worker continue 904*da0073e9SAndroid Build Coastguard Worker op_assert_equal( 905*da0073e9SAndroid Build Coastguard Worker self.test_case, 906*da0073e9SAndroid Build Coastguard Worker func, 907*da0073e9SAndroid Build Coastguard Worker self.test_dtype, 908*da0073e9SAndroid Build Coastguard Worker orig, 909*da0073e9SAndroid Build Coastguard Worker decomp, 910*da0073e9SAndroid Build Coastguard Worker args, 911*da0073e9SAndroid Build Coastguard Worker kwargs, 912*da0073e9SAndroid Build Coastguard Worker ) 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker return real_out_unflat 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker def check_decomposed(self, aten_name, mode): 917*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 918*da0073e9SAndroid Build Coastguard Worker any(overload_to_aten_name(c) == aten_name for c in mode.decomposed), 919*da0073e9SAndroid Build Coastguard Worker msg=( 920*da0073e9SAndroid Build Coastguard Worker f"aten.{aten_name} was not decomposed, saw calls for: " 921*da0073e9SAndroid Build Coastguard Worker f"{', '.join(map(str, list(mode.called)))}. If your op is " 922*da0073e9SAndroid Build Coastguard Worker f"CompositeImplicitAutograd you should skip this test " 923*da0073e9SAndroid Build Coastguard Worker f"by updating CROSS_REF_EXCLUDE_SET." 924*da0073e9SAndroid Build Coastguard Worker ), 925*da0073e9SAndroid Build Coastguard Worker ) 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Test does not work with TorchDynamo") 928*da0073e9SAndroid Build Coastguard Worker def do_cross_ref(self, device, dtype, op, *, run_all): 929*da0073e9SAndroid Build Coastguard Worker test_keys = [ 930*da0073e9SAndroid Build Coastguard Worker (torch.device(device).type, dtype, op.name), 931*da0073e9SAndroid Build Coastguard Worker (None, dtype, op.name), 932*da0073e9SAndroid Build Coastguard Worker (None, None, op.name), 933*da0073e9SAndroid Build Coastguard Worker ] 934*da0073e9SAndroid Build Coastguard Worker if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys): 935*da0073e9SAndroid Build Coastguard Worker self.skipTest(f"{op.name} in {dtype} not supported") 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Worker skip_decomp_vjp = any( 938*da0073e9SAndroid Build Coastguard Worker key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys 939*da0073e9SAndroid Build Coastguard Worker ) 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker requires_grad = ( 942*da0073e9SAndroid Build Coastguard Worker op.supports_autograd 943*da0073e9SAndroid Build Coastguard Worker and dtype in op.supported_backward_dtypes(torch.device(device).type) 944*da0073e9SAndroid Build Coastguard Worker # TODO: OpInfo really ought to error out for this case, but it's 945*da0073e9SAndroid Build Coastguard Worker # not exercised in test_ops_gradients atm. The problem is not 946*da0073e9SAndroid Build Coastguard Worker # complex32 per-se (which is supported by data movement only ops) 947*da0073e9SAndroid Build Coastguard Worker # but that when we do backwards we expect other ops like add to work 948*da0073e9SAndroid Build Coastguard Worker and not dtype == torch.complex32 949*da0073e9SAndroid Build Coastguard Worker ) 950*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker aten_name = op.decomp_aten_name or op.aten_name 953*da0073e9SAndroid Build Coastguard Worker 954*da0073e9SAndroid Build Coastguard Worker func = op.get_op() 955*da0073e9SAndroid Build Coastguard Worker 956*da0073e9SAndroid Build Coastguard Worker def run_without_python_dispatcher(mode): 957*da0073e9SAndroid Build Coastguard Worker return any( 958*da0073e9SAndroid Build Coastguard Worker isinstance(op, torch._ops.OpOverload) 959*da0073e9SAndroid Build Coastguard Worker and op.has_kernel_for_dispatch_key( 960*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutograd 961*da0073e9SAndroid Build Coastguard Worker ) 962*da0073e9SAndroid Build Coastguard Worker for op in mode.decomposed.union([func]) 963*da0073e9SAndroid Build Coastguard Worker ) 964*da0073e9SAndroid Build Coastguard Worker 965*da0073e9SAndroid Build Coastguard Worker for sample_input in samples: 966*da0073e9SAndroid Build Coastguard Worker if requires_grad: 967*da0073e9SAndroid Build Coastguard Worker fn, primals = normalize_op_input_output(func, sample_input) 968*da0073e9SAndroid Build Coastguard Worker primals = tree_map( 969*da0073e9SAndroid Build Coastguard Worker lambda x: x if isinstance(x, torch.Tensor) else x, primals 970*da0073e9SAndroid Build Coastguard Worker ) 971*da0073e9SAndroid Build Coastguard Worker 972*da0073e9SAndroid Build Coastguard Worker # Once https://github.com/pytorch/pytorch/pull/75965/ I can 973*da0073e9SAndroid Build Coastguard Worker # store the called list on the mode object instance and no 974*da0073e9SAndroid Build Coastguard Worker # explicit clearing is necessary as I will create a fresh mode 975*da0073e9SAndroid Build Coastguard Worker # for each region 976*da0073e9SAndroid Build Coastguard Worker with self.DecompCrossRefMode( 977*da0073e9SAndroid Build Coastguard Worker self, self.precision, self.rel_tol, dtype, run_all 978*da0073e9SAndroid Build Coastguard Worker ) as mode, enable_python_dispatcher(): 979*da0073e9SAndroid Build Coastguard Worker decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) 980*da0073e9SAndroid Build Coastguard Worker if run_without_python_dispatcher(mode): 981*da0073e9SAndroid Build Coastguard Worker # without this check, incorrect decomps at the python dispatcher level can still pass because 982*da0073e9SAndroid Build Coastguard Worker # they're checking aten decomps at the torch_dispatch level. 983*da0073e9SAndroid Build Coastguard Worker with self.DecompCrossRefMode( 984*da0073e9SAndroid Build Coastguard Worker self, self.precision, self.rel_tol, dtype, run_all 985*da0073e9SAndroid Build Coastguard Worker ) as mode: 986*da0073e9SAndroid Build Coastguard Worker decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) 987*da0073e9SAndroid Build Coastguard Worker if aten_name in decomposition_names: 988*da0073e9SAndroid Build Coastguard Worker self.check_decomposed(aten_name, mode) 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker if not skip_decomp_vjp and ( 991*da0073e9SAndroid Build Coastguard Worker op.aten_backward_name in decomposition_names or run_all 992*da0073e9SAndroid Build Coastguard Worker ): 993*da0073e9SAndroid Build Coastguard Worker cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out) 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker with self.DecompCrossRefMode( 996*da0073e9SAndroid Build Coastguard Worker self, self.precision, self.rel_tol, dtype, run_all 997*da0073e9SAndroid Build Coastguard Worker ) as mode, enable_python_dispatcher(): 998*da0073e9SAndroid Build Coastguard Worker decomp_vjp_fn(cotangents) 999*da0073e9SAndroid Build Coastguard Worker if run_without_python_dispatcher(mode): 1000*da0073e9SAndroid Build Coastguard Worker # without this check, incorrect decomps at the python dispatcher level can still pass because 1001*da0073e9SAndroid Build Coastguard Worker # they're checking aten decomps at the torch_dispatch level. 1002*da0073e9SAndroid Build Coastguard Worker with self.DecompCrossRefMode( 1003*da0073e9SAndroid Build Coastguard Worker self, self.precision, self.rel_tol, dtype, run_all 1004*da0073e9SAndroid Build Coastguard Worker ) as mode: 1005*da0073e9SAndroid Build Coastguard Worker decomp_vjp_fn(cotangents) 1006*da0073e9SAndroid Build Coastguard Worker if not run_all: 1007*da0073e9SAndroid Build Coastguard Worker self.check_decomposed(op.aten_backward_name, mode) 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker elif aten_name in decomposition_names or run_all: 1010*da0073e9SAndroid Build Coastguard Worker args = [sample_input.input] + list(sample_input.args) 1011*da0073e9SAndroid Build Coastguard Worker kwargs = sample_input.kwargs 1012*da0073e9SAndroid Build Coastguard Worker # A failure here might be because the decomposition for the op is wrong or because a 1013*da0073e9SAndroid Build Coastguard Worker # decomposition used by the particular op is wrong. 1014*da0073e9SAndroid Build Coastguard Worker with self.DecompCrossRefMode( 1015*da0073e9SAndroid Build Coastguard Worker self, self.precision, self.rel_tol, dtype, run_all 1016*da0073e9SAndroid Build Coastguard Worker ) as mode, enable_python_dispatcher(): 1017*da0073e9SAndroid Build Coastguard Worker func(*args, **kwargs) 1018*da0073e9SAndroid Build Coastguard Worker 1019*da0073e9SAndroid Build Coastguard Worker if run_without_python_dispatcher(mode): 1020*da0073e9SAndroid Build Coastguard Worker # without this check, incorrect decomps at the python dispatcher level can still pass because 1021*da0073e9SAndroid Build Coastguard Worker # they're checking aten decomps at the torch_dispatch level. 1022*da0073e9SAndroid Build Coastguard Worker with self.DecompCrossRefMode( 1023*da0073e9SAndroid Build Coastguard Worker self, self.precision, self.rel_tol, dtype, run_all 1024*da0073e9SAndroid Build Coastguard Worker ) as mode: 1025*da0073e9SAndroid Build Coastguard Worker func(*args, **kwargs) 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker if not run_all: 1028*da0073e9SAndroid Build Coastguard Worker self.check_decomposed(aten_name, mode) 1029*da0073e9SAndroid Build Coastguard Worker else: 1030*da0073e9SAndroid Build Coastguard Worker assert op.supports_autograd 1031*da0073e9SAndroid Build Coastguard Worker self.skipTest( 1032*da0073e9SAndroid Build Coastguard Worker "only backwards is decomposed, but dtype doesn't support AD" 1033*da0073e9SAndroid Build Coastguard Worker ) 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard Worker 1036*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDecomp, globals()) 1037*da0073e9SAndroid Build Coastguard Worker 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Workerclass DecompOneOffTests(TestCase): 1040*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1041*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1042*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1043*da0073e9SAndroid Build Coastguard Worker def test_contiguous_softmax(self, device): 1044*da0073e9SAndroid Build Coastguard Worker size = (2, 4, 3, 3) 1045*da0073e9SAndroid Build Coastguard Worker stride = (9, 18, 3, 1) 1046*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size, dtype=dtype, device=device) 1049*da0073e9SAndroid Build Coastguard Worker x = torch.as_strided(x, size, stride) 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker ref = torch.ops.aten._softmax(x, -1, False) 1052*da0073e9SAndroid Build Coastguard Worker res = torch._decomp.decompositions._softmax(x, -1, False) 1053*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref.stride(), res.stride()) 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1056*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1057*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1058*da0073e9SAndroid Build Coastguard Worker def test_contiguous_log_softmax(self, device): 1059*da0073e9SAndroid Build Coastguard Worker size = (2, 4, 3, 3) 1060*da0073e9SAndroid Build Coastguard Worker stride = (9, 18, 3, 1) 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 1063*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size, dtype=dtype, device=device) 1064*da0073e9SAndroid Build Coastguard Worker x = torch.as_strided(x, size, stride) 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker ref = torch.ops.aten._log_softmax(x, -1, False) 1067*da0073e9SAndroid Build Coastguard Worker res = torch._decomp.decompositions._log_softmax(x, -1, False) 1068*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref.stride(), res.stride()) 1069*da0073e9SAndroid Build Coastguard Worker 1070*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1071*da0073e9SAndroid Build Coastguard Worker def test_exponential_non_inf(self, device): 1072*da0073e9SAndroid Build Coastguard Worker inp = torch.empty((4, 400, 256), device=device) 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.utils.preserve_rng_state(): 1075*da0073e9SAndroid Build Coastguard Worker exp_ref = inp.exponential_() 1076*da0073e9SAndroid Build Coastguard Worker exp = torch._refs.exponential(inp) 1077*da0073e9SAndroid Build Coastguard Worker 1078*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp, exp_ref) 1079*da0073e9SAndroid Build Coastguard Worker self.assertFalse(exp.isinf().any()) 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1082*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1083*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1084*da0073e9SAndroid Build Coastguard Worker def test_amp_batch_norm_backward(self): 1085*da0073e9SAndroid Build Coastguard Worker device = "cuda" 1086*da0073e9SAndroid Build Coastguard Worker grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) 1087*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) 1088*da0073e9SAndroid Build Coastguard Worker weight = torch.randn((2,), dtype=torch.float32, device=device) 1089*da0073e9SAndroid Build Coastguard Worker rmean = torch.randn((2,), dtype=torch.float32, device=device) 1090*da0073e9SAndroid Build Coastguard Worker rvar = torch.randn((2,), dtype=torch.float32, device=device) 1091*da0073e9SAndroid Build Coastguard Worker mean = torch.randn((0,), dtype=torch.float32, device=device) 1092*da0073e9SAndroid Build Coastguard Worker 1093*da0073e9SAndroid Build Coastguard Worker ref = torch.ops.aten.native_batch_norm_backward( 1094*da0073e9SAndroid Build Coastguard Worker grad_out, 1095*da0073e9SAndroid Build Coastguard Worker x, 1096*da0073e9SAndroid Build Coastguard Worker weight, 1097*da0073e9SAndroid Build Coastguard Worker rmean, 1098*da0073e9SAndroid Build Coastguard Worker rvar, 1099*da0073e9SAndroid Build Coastguard Worker mean, 1100*da0073e9SAndroid Build Coastguard Worker mean, 1101*da0073e9SAndroid Build Coastguard Worker False, 1102*da0073e9SAndroid Build Coastguard Worker 1e-05, 1103*da0073e9SAndroid Build Coastguard Worker [True, True, True], 1104*da0073e9SAndroid Build Coastguard Worker ) 1105*da0073e9SAndroid Build Coastguard Worker res = torch._decomp.decompositions.native_batch_norm_backward( 1106*da0073e9SAndroid Build Coastguard Worker grad_out, 1107*da0073e9SAndroid Build Coastguard Worker x, 1108*da0073e9SAndroid Build Coastguard Worker weight, 1109*da0073e9SAndroid Build Coastguard Worker rmean, 1110*da0073e9SAndroid Build Coastguard Worker rvar, 1111*da0073e9SAndroid Build Coastguard Worker mean, 1112*da0073e9SAndroid Build Coastguard Worker mean, 1113*da0073e9SAndroid Build Coastguard Worker False, 1114*da0073e9SAndroid Build Coastguard Worker 1e-05, 1115*da0073e9SAndroid Build Coastguard Worker [True, True, True], 1116*da0073e9SAndroid Build Coastguard Worker ) 1117*da0073e9SAndroid Build Coastguard Worker for a, b in zip(ref, res): 1118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.stride(), b.stride()) 1119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, b.dtype) 1120*da0073e9SAndroid Build Coastguard Worker 1121*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1122*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1123*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1124*da0073e9SAndroid Build Coastguard Worker def test_elu_backward(self, device): 1125*da0073e9SAndroid Build Coastguard Worker size = (2, 4, 3, 3) 1126*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 1127*da0073e9SAndroid Build Coastguard Worker grad_out = torch.randn(size, dtype=dtype, device=device) 1128*da0073e9SAndroid Build Coastguard Worker out = torch.randn(size, dtype=dtype, device=device) 1129*da0073e9SAndroid Build Coastguard Worker 1130*da0073e9SAndroid Build Coastguard Worker ref = torch.ops.aten.elu_backward(grad_out, 1.0, 1, 1, True, out) 1131*da0073e9SAndroid Build Coastguard Worker res = torch._decomp.decompositions.elu_backward(grad_out, 1.0, 1, 1, True, out) 1132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1135*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1136*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1137*da0073e9SAndroid Build Coastguard Worker def test_threshold_backward_dtype(self, device): 1138*da0073e9SAndroid Build Coastguard Worker grad = torch.randint(10, (4,), device=device) 1139*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.randint(10, (4,), device=device) 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker ref = torch.ops.aten.threshold_backward(grad, input_tensor, 1) 1142*da0073e9SAndroid Build Coastguard Worker res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1) 1143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref.dtype, res.dtype) 1144*da0073e9SAndroid Build Coastguard Worker 1145*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1146*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1147*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1148*da0073e9SAndroid Build Coastguard Worker def test_weight_norm_interface(self, device): 1149*da0073e9SAndroid Build Coastguard Worker g = torch.randn((3, 10, 10), device=device) 1150*da0073e9SAndroid Build Coastguard Worker v = torch.randn((1, 1, 10), device=device) 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker ref = torch.ops.aten._weight_norm_interface(g, v, 2) 1153*da0073e9SAndroid Build Coastguard Worker res = torch._decomp.decompositions._weight_norm_interface(g, v, 2) 1154*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref[0], res[0])) 1155*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ref[1], res[1])) 1156*da0073e9SAndroid Build Coastguard Worker 1157*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([30, 10], device=device) 1158*da0073e9SAndroid Build Coastguard Worker inp2 = torch.rand([30, 1], device=device) 1159*da0073e9SAndroid Build Coastguard Worker 1160*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1161*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._weight_norm_interface(inp, inp2), 1162*da0073e9SAndroid Build Coastguard Worker torch._decomp.decompositions._weight_norm_interface(inp, inp2), 1163*da0073e9SAndroid Build Coastguard Worker ) 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") 1166*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1167*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 1168*da0073e9SAndroid Build Coastguard Worker @skipOps( 1169*da0073e9SAndroid Build Coastguard Worker "DecompOneOffTests", 1170*da0073e9SAndroid Build Coastguard Worker "test_sdpa", 1171*da0073e9SAndroid Build Coastguard Worker [ 1172*da0073e9SAndroid Build Coastguard Worker xfail( 1173*da0073e9SAndroid Build Coastguard Worker "nn.functional.scaled_dot_product_attention", 1174*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.half], 1175*da0073e9SAndroid Build Coastguard Worker ), 1176*da0073e9SAndroid Build Coastguard Worker ], 1177*da0073e9SAndroid Build Coastguard Worker ) 1178*da0073e9SAndroid Build Coastguard Worker @ops(_sdpa_op_info) 1179*da0073e9SAndroid Build Coastguard Worker def test_sdpa(self, device, dtype, op): 1180*da0073e9SAndroid Build Coastguard Worker # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we 1181*da0073e9SAndroid Build Coastguard Worker # add support for float16 over there we should update this test as well. 1182*da0073e9SAndroid Build Coastguard Worker 1183*da0073e9SAndroid Build Coastguard Worker class ScaledDotProductAttention(torch.nn.Module): 1184*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1185*da0073e9SAndroid Build Coastguard Worker super().__init__() 1186*da0073e9SAndroid Build Coastguard Worker 1187*da0073e9SAndroid Build Coastguard Worker def forward( 1188*da0073e9SAndroid Build Coastguard Worker self, query_layer, key_layer, value_layer, mask=None, is_causal=True 1189*da0073e9SAndroid Build Coastguard Worker ): 1190*da0073e9SAndroid Build Coastguard Worker attn_output = op( 1191*da0073e9SAndroid Build Coastguard Worker query_layer, 1192*da0073e9SAndroid Build Coastguard Worker key_layer, 1193*da0073e9SAndroid Build Coastguard Worker value_layer, 1194*da0073e9SAndroid Build Coastguard Worker attn_mask=mask, 1195*da0073e9SAndroid Build Coastguard Worker dropout_p=0.0, 1196*da0073e9SAndroid Build Coastguard Worker is_causal=is_causal, 1197*da0073e9SAndroid Build Coastguard Worker ) 1198*da0073e9SAndroid Build Coastguard Worker return attn_output 1199*da0073e9SAndroid Build Coastguard Worker 1200*da0073e9SAndroid Build Coastguard Worker query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) 1201*da0073e9SAndroid Build Coastguard Worker key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) 1202*da0073e9SAndroid Build Coastguard Worker value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) 1203*da0073e9SAndroid Build Coastguard Worker masks = [None, torch.ones((1, 1, 100, 100), device=device, dtype=torch.bool)] 1204*da0073e9SAndroid Build Coastguard Worker 1205*da0073e9SAndroid Build Coastguard Worker atol, rtol = dtype_precisions[dtype] 1206*da0073e9SAndroid Build Coastguard Worker 1207*da0073e9SAndroid Build Coastguard Worker for mask in masks: 1208*da0073e9SAndroid Build Coastguard Worker is_causal = mask is None 1209*da0073e9SAndroid Build Coastguard Worker attention = ScaledDotProductAttention() 1210*da0073e9SAndroid Build Coastguard Worker decomposed_res = ( 1211*da0073e9SAndroid Build Coastguard Worker torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu( 1212*da0073e9SAndroid Build Coastguard Worker query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask 1213*da0073e9SAndroid Build Coastguard Worker ) 1214*da0073e9SAndroid Build Coastguard Worker ) 1215*da0073e9SAndroid Build Coastguard Worker eager_res = op( 1216*da0073e9SAndroid Build Coastguard Worker query_layer, 1217*da0073e9SAndroid Build Coastguard Worker key_layer, 1218*da0073e9SAndroid Build Coastguard Worker value_layer, 1219*da0073e9SAndroid Build Coastguard Worker attn_mask=mask, 1220*da0073e9SAndroid Build Coastguard Worker dropout_p=0.0, 1221*da0073e9SAndroid Build Coastguard Worker is_causal=is_causal, 1222*da0073e9SAndroid Build Coastguard Worker ) 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1225*da0073e9SAndroid Build Coastguard Worker torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol) 1226*da0073e9SAndroid Build Coastguard Worker ) 1227*da0073e9SAndroid Build Coastguard Worker 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(DecompOneOffTests, globals()) 1230*da0073e9SAndroid Build Coastguard Worker 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Workerclass HasDecompTest(TestCase): 1233*da0073e9SAndroid Build Coastguard Worker def setUp(self): 1234*da0073e9SAndroid Build Coastguard Worker super().setUp() 1235*da0073e9SAndroid Build Coastguard Worker self.maxDiff = None 1236*da0073e9SAndroid Build Coastguard Worker 1237*da0073e9SAndroid Build Coastguard Worker @staticmethod 1238*da0073e9SAndroid Build Coastguard Worker def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool: 1239*da0073e9SAndroid Build Coastguard Worker has_tensor_arg = any( 1240*da0073e9SAndroid Build Coastguard Worker "Tensor" in str(a.type) 1241*da0073e9SAndroid Build Coastguard Worker for a in itertools.chain(op._schema.arguments, op._schema.returns) 1242*da0073e9SAndroid Build Coastguard Worker ) 1243*da0073e9SAndroid Build Coastguard Worker if not has_tensor_arg: 1244*da0073e9SAndroid Build Coastguard Worker return False 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker try: 1247*da0073e9SAndroid Build Coastguard Worker # CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions 1248*da0073e9SAndroid Build Coastguard Worker return not op.has_kernel_for_dispatch_key( 1249*da0073e9SAndroid Build Coastguard Worker DispatchKey.CompositeImplicitAutograd 1250*da0073e9SAndroid Build Coastguard Worker ) 1251*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1252*da0073e9SAndroid Build Coastguard Worker # has_key fails for some jit-registered ops, which shouldn't be 1253*da0073e9SAndroid Build Coastguard Worker # relevant here anyway 1254*da0073e9SAndroid Build Coastguard Worker if "does not exist" in str(e): 1255*da0073e9SAndroid Build Coastguard Worker return False 1256*da0073e9SAndroid Build Coastguard Worker raise 1257*da0073e9SAndroid Build Coastguard Worker 1258*da0073e9SAndroid Build Coastguard Worker def test_has_decomposition(self): 1259*da0073e9SAndroid Build Coastguard Worker def all_aten_overloads(): 1260*da0073e9SAndroid Build Coastguard Worker for name in torch._C._dispatch_get_all_op_names(): 1261*da0073e9SAndroid Build Coastguard Worker if not name.startswith("aten::"): 1262*da0073e9SAndroid Build Coastguard Worker continue 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker name = name[6:] 1265*da0073e9SAndroid Build Coastguard Worker if "." in name: 1266*da0073e9SAndroid Build Coastguard Worker packet_name, overload_name = name.split(".") 1267*da0073e9SAndroid Build Coastguard Worker else: 1268*da0073e9SAndroid Build Coastguard Worker packet_name, overload_name = name, "default" 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker packet = getattr(aten, packet_name) 1271*da0073e9SAndroid Build Coastguard Worker assert isinstance(packet, torch._ops.OpOverloadPacket) 1272*da0073e9SAndroid Build Coastguard Worker op = getattr(packet, overload_name) 1273*da0073e9SAndroid Build Coastguard Worker yield op 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker # This is for operators that are only registered in some CI 1276*da0073e9SAndroid Build Coastguard Worker # configurations, so would cause the test to fail 1277*da0073e9SAndroid Build Coastguard Worker allow_list = {aten.get_gradients.default} 1278*da0073e9SAndroid Build Coastguard Worker 1279*da0073e9SAndroid Build Coastguard Worker overloads_wanting_decomp = { 1280*da0073e9SAndroid Build Coastguard Worker op for op in all_aten_overloads() if self._can_appear_in_trace(op) 1281*da0073e9SAndroid Build Coastguard Worker } 1282*da0073e9SAndroid Build Coastguard Worker ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys() 1283*da0073e9SAndroid Build Coastguard Worker ops_missing_decomp -= allow_list 1284*da0073e9SAndroid Build Coastguard Worker self.assertExpected( 1285*da0073e9SAndroid Build Coastguard Worker "".join(sorted(op.name() + "\n" for op in ops_missing_decomp)) 1286*da0073e9SAndroid Build Coastguard Worker ) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker def test_aten_core_operators(self): 1289*da0073e9SAndroid Build Coastguard Worker # If a decomposition isn't included in the core decompositions, 1290*da0073e9SAndroid Build Coastguard Worker # then it must decompose a core ATen operator. 1291*da0073e9SAndroid Build Coastguard Worker # 1292*da0073e9SAndroid Build Coastguard Worker # See NOTE [Core ATen Ops] 1293*da0073e9SAndroid Build Coastguard Worker # 1294*da0073e9SAndroid Build Coastguard Worker # If this test fails then either: 1295*da0073e9SAndroid Build Coastguard Worker # - Add the decomposition to torch._decomp.core_aten_decompositions, 1296*da0073e9SAndroid Build Coastguard Worker # if decomposition should be used by inductor (not a core operator). 1297*da0073e9SAndroid Build Coastguard Worker # - Run this test again with EXPECTTEST_ACCEPT=1 to update the list of 1298*da0073e9SAndroid Build Coastguard Worker # core ATen operators (and inductor will not use the decomposition). 1299*da0073e9SAndroid Build Coastguard Worker 1300*da0073e9SAndroid Build Coastguard Worker # Some decompositions are registered for CompositeImplicitAutograd 1301*da0073e9SAndroid Build Coastguard Worker # operators, which never appear in AOTAutograd's graph so are never used. 1302*da0073e9SAndroid Build Coastguard Worker useful_decomps = { 1303*da0073e9SAndroid Build Coastguard Worker op 1304*da0073e9SAndroid Build Coastguard Worker for op in decomposition_table.keys() 1305*da0073e9SAndroid Build Coastguard Worker if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op) 1306*da0073e9SAndroid Build Coastguard Worker } 1307*da0073e9SAndroid Build Coastguard Worker core_decomps = torch._decomp.core_aten_decompositions().keys() 1308*da0073e9SAndroid Build Coastguard Worker core_aten_ops = useful_decomps - core_decomps 1309*da0073e9SAndroid Build Coastguard Worker self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops))) 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1313*da0073e9SAndroid Build Coastguard Worker run_tests() 1314