1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: mps"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerimport platform 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerimport math 7*da0073e9SAndroid Build Coastguard Workerimport random 8*da0073e9SAndroid Build Coastguard Workerimport unittest 9*da0073e9SAndroid Build Coastguard Workerimport warnings 10*da0073e9SAndroid Build Coastguard Workerimport subprocess 11*da0073e9SAndroid Build Coastguard Workerimport tempfile 12*da0073e9SAndroid Build Coastguard Workerimport os 13*da0073e9SAndroid Build Coastguard Workerimport copy 14*da0073e9SAndroid Build Coastguard Workerimport gc 15*da0073e9SAndroid Build Coastguard Workerimport threading 16*da0073e9SAndroid Build Coastguard Workerimport torch 17*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 18*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 19*da0073e9SAndroid Build Coastguard Workerimport itertools 20*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 21*da0073e9SAndroid Build Coastguard Workerfrom torch import inf 22*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Buffer, Parameter 23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import opinfo 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import \ 25*da0073e9SAndroid Build Coastguard Worker (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, IS_CI, 26*da0073e9SAndroid Build Coastguard Worker NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests) 27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import get_all_dtypes, integral_types 29*da0073e9SAndroid Build Coastguard Workerimport torch.backends.mps 30*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import Uniform, Exponential 31*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ( 34*da0073e9SAndroid Build Coastguard Worker op_db, 35*da0073e9SAndroid Build Coastguard Worker DecorateInfo, 36*da0073e9SAndroid Build Coastguard Worker UnaryUfuncInfo, 37*da0073e9SAndroid Build Coastguard Worker ReductionOpInfo, 38*da0073e9SAndroid Build Coastguard Worker SpectralFuncInfo, 39*da0073e9SAndroid Build Coastguard Worker BinaryUfuncInfo, 40*da0073e9SAndroid Build Coastguard Worker) 41*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes 42*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase 43*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel 44*da0073e9SAndroid Build Coastguard Workerimport numpy as np 45*da0073e9SAndroid Build Coastguard Workerimport torch 46*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 47*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 48*da0073e9SAndroid Build Coastguard Workerimport operator 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Workertest_consistency_op_db = copy.deepcopy(op_db) 51*da0073e9SAndroid Build Coastguard Workertest_error_inputs_op_db = copy.deepcopy(op_db) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker# Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref` 54*da0073e9SAndroid Build Coastguard Worker_ref_test_ops = tuple( 55*da0073e9SAndroid Build Coastguard Worker filter( 56*da0073e9SAndroid Build Coastguard Worker lambda op: not isinstance( 57*da0073e9SAndroid Build Coastguard Worker op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo) 58*da0073e9SAndroid Build Coastguard Worker ) 59*da0073e9SAndroid Build Coastguard Worker and op.ref is not None, 60*da0073e9SAndroid Build Coastguard Worker op_db, 61*da0073e9SAndroid Build Coastguard Worker ) 62*da0073e9SAndroid Build Coastguard Worker) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Workerdef xfailIf(condition): 65*da0073e9SAndroid Build Coastguard Worker def wrapper(func): 66*da0073e9SAndroid Build Coastguard Worker if condition: 67*da0073e9SAndroid Build Coastguard Worker return unittest.expectedFailure(func) 68*da0073e9SAndroid Build Coastguard Worker else: 69*da0073e9SAndroid Build Coastguard Worker return func 70*da0073e9SAndroid Build Coastguard Worker return wrapper 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Workerdef xfailIfMacOS14_4Plus(func): 73*da0073e9SAndroid Build Coastguard Worker return unittest.expectedFailure(func) if product_version > 14.3 else func # noqa: F821 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Workerdef mps_ops_grad_modifier(ops): 76*da0073e9SAndroid Build Coastguard Worker XFAILLIST_GRAD = { 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker # precision issues 79*da0073e9SAndroid Build Coastguard Worker 'special.polygammaspecial_polygamma_n_0': [torch.float16], 80*da0073e9SAndroid Build Coastguard Worker 'polygammapolygamma_n_0': [torch.float16], 81*da0073e9SAndroid Build Coastguard Worker 'nn.functional.binary_cross_entropy': [torch.float16], 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker # Unimplemented ops 84*da0073e9SAndroid Build Coastguard Worker '__getitem__': [torch.float16], 85*da0073e9SAndroid Build Coastguard Worker '_segment_reduce': [torch.float16, torch.float32], 86*da0073e9SAndroid Build Coastguard Worker '_chunk_cat': [torch.float16, torch.float32], 87*da0073e9SAndroid Build Coastguard Worker 'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented 88*da0073e9SAndroid Build Coastguard Worker 'unfold': [torch.float16, torch.float32], 89*da0073e9SAndroid Build Coastguard Worker 'sparse.mmreduce': [torch.float32], # csr not supported 90*da0073e9SAndroid Build Coastguard Worker 'unique_consecutive': [torch.float16, torch.float32], 91*da0073e9SAndroid Build Coastguard Worker 'special_modified_bessel_i0': [torch.float16, torch.float32], 92*da0073e9SAndroid Build Coastguard Worker 'scalar_tensor': [torch.float16, torch.float32], 93*da0073e9SAndroid Build Coastguard Worker 'cdist': [torch.float32], 94*da0073e9SAndroid Build Coastguard Worker 'masked.scatter': [torch.float16, torch.float32], 95*da0073e9SAndroid Build Coastguard Worker 'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`. 96*da0073e9SAndroid Build Coastguard Worker 'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`. 97*da0073e9SAndroid Build Coastguard Worker 'aminmax': [torch.float32, torch.float16], 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker # Correctness issues 100*da0073e9SAndroid Build Coastguard Worker 'atanh': [torch.float32], 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker # Random output 103*da0073e9SAndroid Build Coastguard Worker 'exponential': [torch.float16, torch.float32], 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker # CPU errors 106*da0073e9SAndroid Build Coastguard Worker # derivative for aten::nextafter is not implemented on CPU 107*da0073e9SAndroid Build Coastguard Worker 'nextafter': None, 108*da0073e9SAndroid Build Coastguard Worker # derivative for aten::floor_divide is not implemented on CPU 109*da0073e9SAndroid Build Coastguard Worker 'floor_divide': [torch.float16, torch.float32], 110*da0073e9SAndroid Build Coastguard Worker # derivative for aten::narrow_copy is not implemented on CPU 111*da0073e9SAndroid Build Coastguard Worker 'narrow_copy': [torch.float16, torch.float32], 112*da0073e9SAndroid Build Coastguard Worker # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU 113*da0073e9SAndroid Build Coastguard Worker 'histogramdd': [torch.float16, torch.float32], 114*da0073e9SAndroid Build Coastguard Worker # derivative for aten::histogram is not implemented 115*da0073e9SAndroid Build Coastguard Worker 'histogram': [torch.float16, torch.float32], 116*da0073e9SAndroid Build Coastguard Worker # 'bool' object is not iterable 117*da0073e9SAndroid Build Coastguard Worker 'allclose': [torch.float16, torch.float32], 118*da0073e9SAndroid Build Coastguard Worker 'equal': [torch.float16, torch.float32], 119*da0073e9SAndroid Build Coastguard Worker # 'float' object is not iterable 120*da0073e9SAndroid Build Coastguard Worker 'item': [torch.float16, torch.float32], 121*da0073e9SAndroid Build Coastguard Worker # "mse_backward_cpu_out" not implemented for 'Half' 122*da0073e9SAndroid Build Coastguard Worker 'nn.functional.mse_loss': [torch.float16], 123*da0073e9SAndroid Build Coastguard Worker # "smooth_l1_backward_cpu_out" not implemented for 'Half' 124*da0073e9SAndroid Build Coastguard Worker 'nn.functional.smooth_l1_loss': [torch.float16], 125*da0073e9SAndroid Build Coastguard Worker # cpu error: grad requires non-empty inputs 126*da0073e9SAndroid Build Coastguard Worker 'randn': [torch.float16, torch.float32], 127*da0073e9SAndroid Build Coastguard Worker 'signal.windows.bartlett': [torch.float32], 128*da0073e9SAndroid Build Coastguard Worker 'signal.windows.blackman': [torch.float32], 129*da0073e9SAndroid Build Coastguard Worker 'signal.windows.cosine': [torch.float32], 130*da0073e9SAndroid Build Coastguard Worker 'signal.windows.exponential': [torch.float32], 131*da0073e9SAndroid Build Coastguard Worker 'signal.windows.gaussian': [torch.float32], 132*da0073e9SAndroid Build Coastguard Worker 'signal.windows.general_cosine': [torch.float32], 133*da0073e9SAndroid Build Coastguard Worker 'signal.windows.general_hamming': [torch.float32], 134*da0073e9SAndroid Build Coastguard Worker 'signal.windows.hamming': [torch.float32], 135*da0073e9SAndroid Build Coastguard Worker 'signal.windows.hann': [torch.float32], 136*da0073e9SAndroid Build Coastguard Worker 'signal.windows.kaiser': [torch.float32], 137*da0073e9SAndroid Build Coastguard Worker 'signal.windows.nuttall': [torch.float32], 138*da0073e9SAndroid Build Coastguard Worker 'eye': [torch.float16, torch.float32], 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker # trunc_tensor not working properly for float16 141*da0073e9SAndroid Build Coastguard Worker 'divtrunc_rounding': [torch.float16], 142*da0073e9SAndroid Build Coastguard Worker 'fmod': [torch.float16], 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker # round not working properly for float16 145*da0073e9SAndroid Build Coastguard Worker 'round': [torch.float16], 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker # atomic operation in backward pass 148*da0073e9SAndroid Build Coastguard Worker '_unsafe_masked_index': [torch.float16], 149*da0073e9SAndroid Build Coastguard Worker '_unsafe_masked_index_put_accumulate': [torch.float16], 150*da0073e9SAndroid Build Coastguard Worker } 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker MACOS_12_3_XFAILLIST_GRAD = { 153*da0073e9SAndroid Build Coastguard Worker # Unsupported Border padding mode, forward pass success as fallback to cpu 154*da0073e9SAndroid Build Coastguard Worker 'grid_sampler_2d': [torch.float32], 155*da0073e9SAndroid Build Coastguard Worker # Unimplemented 156*da0073e9SAndroid Build Coastguard Worker 'logaddexp2': [torch.float32], 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker } 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker MACOS_BEFORE_13_3_XFAILLIST_GRAD = { 161*da0073e9SAndroid Build Coastguard Worker # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ 162*da0073e9SAndroid Build Coastguard Worker 'masked.softmin': [torch.float32, torch.float16], 163*da0073e9SAndroid Build Coastguard Worker 'masked.softmax': [torch.float32, torch.float16], 164*da0073e9SAndroid Build Coastguard Worker 'masked.log_softmax': [torch.float32, torch.float16], 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker # Unsupported Border padding mode, forward pass success as fallback to cpu 167*da0073e9SAndroid Build Coastguard Worker 'grid_sampler_2d': [torch.float32], 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour). 170*da0073e9SAndroid Build Coastguard Worker # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU. 171*da0073e9SAndroid Build Coastguard Worker # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS. 172*da0073e9SAndroid Build Coastguard Worker # Running `msort` with stable `sort` passes. 173*da0073e9SAndroid Build Coastguard Worker 'msort': [torch.float16], 174*da0073e9SAndroid Build Coastguard Worker } 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker SKIPLIST_GRAD = { 177*da0073e9SAndroid Build Coastguard Worker 'nn.functional.pairwise_distance': [torch.float16], 178*da0073e9SAndroid Build Coastguard Worker # failed assertion `destination datatype must be fp32' 179*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv1d': [torch.float16], 180*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv2d': [torch.float16], 181*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv3d': [torch.float16], 182*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv_transpose1d': [torch.float16], 183*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv_transpose2d': [torch.float16], 184*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv_transpose3d': [torch.float16], 185*da0073e9SAndroid Build Coastguard Worker } 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker MACOS_13_3_XFAILLIST_GRAD = { 188*da0073e9SAndroid Build Coastguard Worker # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour). 189*da0073e9SAndroid Build Coastguard Worker # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU. 190*da0073e9SAndroid Build Coastguard Worker # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS. 191*da0073e9SAndroid Build Coastguard Worker # Running `msort` with stable `sort` passes. 192*da0073e9SAndroid Build Coastguard Worker 'msort': [torch.float16], 193*da0073e9SAndroid Build Coastguard Worker } 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker ON_MPS_XFAILLIST = { 196*da0073e9SAndroid Build Coastguard Worker # Failures due to lack of implementation of downstream functions on MPS backend 197*da0073e9SAndroid Build Coastguard Worker # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented 198*da0073e9SAndroid Build Coastguard Worker 'linalg.matrix_rank': None, 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker # Exception: Caused by sample input at index 3 on MPS 201*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv3d': [torch.float32], 202*da0073e9SAndroid Build Coastguard Worker } 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker def addDecorator(op, d) -> None: 205*da0073e9SAndroid Build Coastguard Worker op.decorators = list(op.decorators) if op.decorators is not None else [] 206*da0073e9SAndroid Build Coastguard Worker op.decorators.append(d) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker for op in ops: 209*da0073e9SAndroid Build Coastguard Worker key = op.name + op.variant_test_name 210*da0073e9SAndroid Build Coastguard Worker if key in XFAILLIST_GRAD: 211*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 212*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 213*da0073e9SAndroid Build Coastguard Worker dtypes=XFAILLIST_GRAD[key])) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker if key in SKIPLIST_GRAD: 216*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 217*da0073e9SAndroid Build Coastguard Worker unittest.skip, 218*da0073e9SAndroid Build Coastguard Worker dtypes=SKIPLIST_GRAD[key])) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker if key in ON_MPS_XFAILLIST: 221*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 222*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 223*da0073e9SAndroid Build Coastguard Worker dtypes=ON_MPS_XFAILLIST[key])) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()): 226*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 227*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 228*da0073e9SAndroid Build Coastguard Worker dtypes=MACOS_12_3_XFAILLIST_GRAD[key])) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3): 231*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 232*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 233*da0073e9SAndroid Build Coastguard Worker dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key])) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker if key in MACOS_13_3_XFAILLIST_GRAD and (product_version >= 13.3): 236*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 237*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 238*da0073e9SAndroid Build Coastguard Worker dtypes=MACOS_13_3_XFAILLIST_GRAD[key])) 239*da0073e9SAndroid Build Coastguard Worker yield op 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Workerdef mps_ops_modifier(ops): 242*da0073e9SAndroid Build Coastguard Worker # Supported complex OPS 243*da0073e9SAndroid Build Coastguard Worker SUPPORTED_COMPLEX_OPS = { 244*da0073e9SAndroid Build Coastguard Worker '__radd__', 245*da0073e9SAndroid Build Coastguard Worker '__rmul__', 246*da0073e9SAndroid Build Coastguard Worker '__getitem__', 247*da0073e9SAndroid Build Coastguard Worker 'abs', 248*da0073e9SAndroid Build Coastguard Worker 'add', 249*da0073e9SAndroid Build Coastguard Worker 'alias_copy', 250*da0073e9SAndroid Build Coastguard Worker 'argwhere', 251*da0073e9SAndroid Build Coastguard Worker 'atleast_1d', 252*da0073e9SAndroid Build Coastguard Worker 'atleast_2d', 253*da0073e9SAndroid Build Coastguard Worker 'atleast_3d', 254*da0073e9SAndroid Build Coastguard Worker 'as_strided', 255*da0073e9SAndroid Build Coastguard Worker 'as_strided_copy', 256*da0073e9SAndroid Build Coastguard Worker 'as_strided_scatter', 257*da0073e9SAndroid Build Coastguard Worker 'broadcast_tensors', 258*da0073e9SAndroid Build Coastguard Worker 'broadcast_to', 259*da0073e9SAndroid Build Coastguard Worker 'chalf', 260*da0073e9SAndroid Build Coastguard Worker 'cfloat', 261*da0073e9SAndroid Build Coastguard Worker 'chunk', 262*da0073e9SAndroid Build Coastguard Worker 'clone', 263*da0073e9SAndroid Build Coastguard Worker 'conj', 264*da0073e9SAndroid Build Coastguard Worker 'conj_physical', 265*da0073e9SAndroid Build Coastguard Worker 'contiguous', 266*da0073e9SAndroid Build Coastguard Worker 'diag', 267*da0073e9SAndroid Build Coastguard Worker 'diag_embed', 268*da0073e9SAndroid Build Coastguard Worker 'diagflat', 269*da0073e9SAndroid Build Coastguard Worker 'diagonal', 270*da0073e9SAndroid Build Coastguard Worker 'diagonal_copy', 271*da0073e9SAndroid Build Coastguard Worker 'diagonal_scatter', 272*da0073e9SAndroid Build Coastguard Worker 'dsplit', 273*da0073e9SAndroid Build Coastguard Worker 'empty', 274*da0073e9SAndroid Build Coastguard Worker 'empty_permuted', 275*da0073e9SAndroid Build Coastguard Worker 'empty_strided', 276*da0073e9SAndroid Build Coastguard Worker 'eye', 277*da0073e9SAndroid Build Coastguard Worker 'exp', 278*da0073e9SAndroid Build Coastguard Worker 'expand', 279*da0073e9SAndroid Build Coastguard Worker 'expand_as', 280*da0073e9SAndroid Build Coastguard Worker 'expand_copy', 281*da0073e9SAndroid Build Coastguard Worker 'flatten', 282*da0073e9SAndroid Build Coastguard Worker 'fill', 283*da0073e9SAndroid Build Coastguard Worker 'full', 284*da0073e9SAndroid Build Coastguard Worker 'H', 285*da0073e9SAndroid Build Coastguard Worker 'hsplit', 286*da0073e9SAndroid Build Coastguard Worker 'imag', 287*da0073e9SAndroid Build Coastguard Worker 'index_select', 288*da0073e9SAndroid Build Coastguard Worker 'isfinite', 289*da0073e9SAndroid Build Coastguard Worker 'isinf', 290*da0073e9SAndroid Build Coastguard Worker 'isreal', 291*da0073e9SAndroid Build Coastguard Worker 'item', 292*da0073e9SAndroid Build Coastguard Worker 'kron', 293*da0073e9SAndroid Build Coastguard Worker 'linalg.diagonal', 294*da0073e9SAndroid Build Coastguard Worker 'linalg.svd', 295*da0073e9SAndroid Build Coastguard Worker 'linspace', 296*da0073e9SAndroid Build Coastguard Worker 'logspace', 297*da0073e9SAndroid Build Coastguard Worker 'linspacetensor_overload', 298*da0073e9SAndroid Build Coastguard Worker 'logspacetensor_overload', 299*da0073e9SAndroid Build Coastguard Worker 'mH', 300*da0073e9SAndroid Build Coastguard Worker 'mT', 301*da0073e9SAndroid Build Coastguard Worker 'masked_scatter', 302*da0073e9SAndroid Build Coastguard Worker 'masked_select', 303*da0073e9SAndroid Build Coastguard Worker 'meshgridlist_of_tensors', 304*da0073e9SAndroid Build Coastguard Worker 'meshgridvariadic_tensors', 305*da0073e9SAndroid Build Coastguard Worker 'movedim', 306*da0073e9SAndroid Build Coastguard Worker 'mul', 307*da0073e9SAndroid Build Coastguard Worker 'narrow', 308*da0073e9SAndroid Build Coastguard Worker 'narrow_copy', 309*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv1d', 310*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv2d', 311*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv_transpose1d', 312*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv_transpose2d', 313*da0073e9SAndroid Build Coastguard Worker 'nn.functional.feature_alpha_dropoutwithout_train', 314*da0073e9SAndroid Build Coastguard Worker 'nn.functional.padcircular', 315*da0073e9SAndroid Build Coastguard Worker 'nn.functional.tanhshrink', 316*da0073e9SAndroid Build Coastguard Worker 'nn.functional.unfold', 317*da0073e9SAndroid Build Coastguard Worker 'nonzero', 318*da0073e9SAndroid Build Coastguard Worker 'ones', 319*da0073e9SAndroid Build Coastguard Worker 'outer', 320*da0073e9SAndroid Build Coastguard Worker 'permute', 321*da0073e9SAndroid Build Coastguard Worker 'positive', 322*da0073e9SAndroid Build Coastguard Worker 'randn', 323*da0073e9SAndroid Build Coastguard Worker 'ravel', 324*da0073e9SAndroid Build Coastguard Worker 'real', 325*da0073e9SAndroid Build Coastguard Worker 'repeat_interleave', 326*da0073e9SAndroid Build Coastguard Worker 'reshape_as', 327*da0073e9SAndroid Build Coastguard Worker 'reshape', 328*da0073e9SAndroid Build Coastguard Worker 'resolve_conj', 329*da0073e9SAndroid Build Coastguard Worker 'resolve_neg', 330*da0073e9SAndroid Build Coastguard Worker 'scalar_tensor', 331*da0073e9SAndroid Build Coastguard Worker 'select', 332*da0073e9SAndroid Build Coastguard Worker 'sgn', 333*da0073e9SAndroid Build Coastguard Worker 'slice', 334*da0073e9SAndroid Build Coastguard Worker 'split', 335*da0073e9SAndroid Build Coastguard Worker 'split_with_sizes', 336*da0073e9SAndroid Build Coastguard Worker 'split_with_sizes_copy', 337*da0073e9SAndroid Build Coastguard Worker 'splitlist_args', 338*da0073e9SAndroid Build Coastguard Worker 'squeeze', 339*da0073e9SAndroid Build Coastguard Worker 'squeezemultiple', 340*da0073e9SAndroid Build Coastguard Worker 'sub', 341*da0073e9SAndroid Build Coastguard Worker 'svd', 342*da0073e9SAndroid Build Coastguard Worker 't', 343*da0073e9SAndroid Build Coastguard Worker 't_copy', 344*da0073e9SAndroid Build Coastguard Worker 'tanh', 345*da0073e9SAndroid Build Coastguard Worker 'tensor_split', 346*da0073e9SAndroid Build Coastguard Worker 'transpose', 347*da0073e9SAndroid Build Coastguard Worker 'T', 348*da0073e9SAndroid Build Coastguard Worker 'unbind', 349*da0073e9SAndroid Build Coastguard Worker 'unflatten', 350*da0073e9SAndroid Build Coastguard Worker 'unfold', 351*da0073e9SAndroid Build Coastguard Worker 'unfold_copy', 352*da0073e9SAndroid Build Coastguard Worker 'unsafe_chunk', 353*da0073e9SAndroid Build Coastguard Worker 'unsafe_split', 354*da0073e9SAndroid Build Coastguard Worker 'unsqueeze', 355*da0073e9SAndroid Build Coastguard Worker 'unsqueeze_copy', 356*da0073e9SAndroid Build Coastguard Worker 'view_as', 357*da0073e9SAndroid Build Coastguard Worker 'view_as_real', 358*da0073e9SAndroid Build Coastguard Worker 'view', 359*da0073e9SAndroid Build Coastguard Worker 'view_copy', 360*da0073e9SAndroid Build Coastguard Worker 'vsplit', 361*da0073e9SAndroid Build Coastguard Worker 'zero_', 362*da0073e9SAndroid Build Coastguard Worker 'zeros', 363*da0073e9SAndroid Build Coastguard Worker } 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = { 366*da0073e9SAndroid Build Coastguard Worker '__rdiv__', 367*da0073e9SAndroid Build Coastguard Worker '__rmatmul__', 368*da0073e9SAndroid Build Coastguard Worker '_chunk_cat', 369*da0073e9SAndroid Build Coastguard Worker '_unsafe_masked_index', 370*da0073e9SAndroid Build Coastguard Worker 'acos', 371*da0073e9SAndroid Build Coastguard Worker 'acosh', 372*da0073e9SAndroid Build Coastguard Worker 'all', 373*da0073e9SAndroid Build Coastguard Worker 'allclose', 374*da0073e9SAndroid Build Coastguard Worker 'any', 375*da0073e9SAndroid Build Coastguard Worker 'addcdiv', 376*da0073e9SAndroid Build Coastguard Worker 'addcmul', 377*da0073e9SAndroid Build Coastguard Worker 'addmmdecomposed', 378*da0073e9SAndroid Build Coastguard Worker 'addmv', 379*da0073e9SAndroid Build Coastguard Worker 'asin', 380*da0073e9SAndroid Build Coastguard Worker 'atan', 381*da0073e9SAndroid Build Coastguard Worker 'atanh', 382*da0073e9SAndroid Build Coastguard Worker 'bfloat16', 383*da0073e9SAndroid Build Coastguard Worker 'bmm', 384*da0073e9SAndroid Build Coastguard Worker 'bool', 385*da0073e9SAndroid Build Coastguard Worker 'cartesian_prod', 386*da0073e9SAndroid Build Coastguard Worker 'cat', 387*da0073e9SAndroid Build Coastguard Worker 'char', 388*da0073e9SAndroid Build Coastguard Worker 'column_stack', 389*da0073e9SAndroid Build Coastguard Worker 'combinations', 390*da0073e9SAndroid Build Coastguard Worker 'corrcoef', 391*da0073e9SAndroid Build Coastguard Worker 'constant_pad_nd', 392*da0073e9SAndroid Build Coastguard Worker 'cos', 393*da0073e9SAndroid Build Coastguard Worker 'cosh', 394*da0073e9SAndroid Build Coastguard Worker 'count_nonzero', 395*da0073e9SAndroid Build Coastguard Worker 'diff', 396*da0073e9SAndroid Build Coastguard Worker 'div', 397*da0073e9SAndroid Build Coastguard Worker 'divno_rounding_mode', 398*da0073e9SAndroid Build Coastguard Worker 'dot', 399*da0073e9SAndroid Build Coastguard Worker 'dstack', 400*da0073e9SAndroid Build Coastguard Worker 'einsum', 401*da0073e9SAndroid Build Coastguard Worker 'eq', 402*da0073e9SAndroid Build Coastguard Worker 'equal', 403*da0073e9SAndroid Build Coastguard Worker 'exp2', 404*da0073e9SAndroid Build Coastguard Worker 'expm1', 405*da0073e9SAndroid Build Coastguard Worker 'fft.fft', 406*da0073e9SAndroid Build Coastguard Worker 'fft.fft2', 407*da0073e9SAndroid Build Coastguard Worker 'fft.fftn', 408*da0073e9SAndroid Build Coastguard Worker 'fft.fftshift', 409*da0073e9SAndroid Build Coastguard Worker 'fft.ifft', 410*da0073e9SAndroid Build Coastguard Worker 'fft.ifft2', 411*da0073e9SAndroid Build Coastguard Worker 'fft.ifftn', 412*da0073e9SAndroid Build Coastguard Worker 'fft.ifftshift', 413*da0073e9SAndroid Build Coastguard Worker 'fft.irfftn', 414*da0073e9SAndroid Build Coastguard Worker 'fft.irfft2', 415*da0073e9SAndroid Build Coastguard Worker 'fft.irfft', 416*da0073e9SAndroid Build Coastguard Worker 'fft.hfftn', 417*da0073e9SAndroid Build Coastguard Worker 'fft.hfft2', 418*da0073e9SAndroid Build Coastguard Worker 'fft.hfft', 419*da0073e9SAndroid Build Coastguard Worker 'flip', 420*da0073e9SAndroid Build Coastguard Worker 'fliplr', 421*da0073e9SAndroid Build Coastguard Worker 'flipud', 422*da0073e9SAndroid Build Coastguard Worker 'float', 423*da0073e9SAndroid Build Coastguard Worker 'gradient', 424*da0073e9SAndroid Build Coastguard Worker 'half', 425*da0073e9SAndroid Build Coastguard Worker 'hstack', 426*da0073e9SAndroid Build Coastguard Worker 'inner', 427*da0073e9SAndroid Build Coastguard Worker 'int', 428*da0073e9SAndroid Build Coastguard Worker 'isclose', 429*da0073e9SAndroid Build Coastguard Worker 'isnan', 430*da0073e9SAndroid Build Coastguard Worker 'ldexp', 431*da0073e9SAndroid Build Coastguard Worker 'linalg.multi_dot', 432*da0073e9SAndroid Build Coastguard Worker 'linalg.pinv', 433*da0073e9SAndroid Build Coastguard Worker 'log10', 434*da0073e9SAndroid Build Coastguard Worker 'log1p', 435*da0073e9SAndroid Build Coastguard Worker 'log2', 436*da0073e9SAndroid Build Coastguard Worker 'log', 437*da0073e9SAndroid Build Coastguard Worker 'logical_and', 438*da0073e9SAndroid Build Coastguard Worker 'logical_not', 439*da0073e9SAndroid Build Coastguard Worker 'logical_or', 440*da0073e9SAndroid Build Coastguard Worker 'logical_xor', 441*da0073e9SAndroid Build Coastguard Worker 'logsumexp', 442*da0073e9SAndroid Build Coastguard Worker 'long', 443*da0073e9SAndroid Build Coastguard Worker 'masked_fill', 444*da0073e9SAndroid Build Coastguard Worker 'masked.mean', 445*da0073e9SAndroid Build Coastguard Worker 'masked.prod', 446*da0073e9SAndroid Build Coastguard Worker 'masked.std', 447*da0073e9SAndroid Build Coastguard Worker 'masked.sum', 448*da0073e9SAndroid Build Coastguard Worker 'masked.var', 449*da0073e9SAndroid Build Coastguard Worker 'masked.logsumexp', 450*da0073e9SAndroid Build Coastguard Worker 'matmul', 451*da0073e9SAndroid Build Coastguard Worker 'mean', 452*da0073e9SAndroid Build Coastguard Worker 'mm', 453*da0073e9SAndroid Build Coastguard Worker 'mv', 454*da0073e9SAndroid Build Coastguard Worker 'ne', 455*da0073e9SAndroid Build Coastguard Worker 'neg', 456*da0073e9SAndroid Build Coastguard Worker 'nn.functional.padconstant', 457*da0073e9SAndroid Build Coastguard Worker 'nn.functional.padreflect', 458*da0073e9SAndroid Build Coastguard Worker 'nn.functional.padreplicate', 459*da0073e9SAndroid Build Coastguard Worker 'nn.functional.pixel_shuffle', 460*da0073e9SAndroid Build Coastguard Worker 'nn.functional.pixel_unshuffle', 461*da0073e9SAndroid Build Coastguard Worker 'nn.functional.rms_norm', 462*da0073e9SAndroid Build Coastguard Worker 'nn.functional.softsign', 463*da0073e9SAndroid Build Coastguard Worker 'pinverse', 464*da0073e9SAndroid Build Coastguard Worker 'prod', 465*da0073e9SAndroid Build Coastguard Worker 'reciprocal', 466*da0073e9SAndroid Build Coastguard Worker 'roll', 467*da0073e9SAndroid Build Coastguard Worker 'rot90', 468*da0073e9SAndroid Build Coastguard Worker 'rsqrt', 469*da0073e9SAndroid Build Coastguard Worker 'short', 470*da0073e9SAndroid Build Coastguard Worker 'sigmoid', 471*da0073e9SAndroid Build Coastguard Worker 'sin', 472*da0073e9SAndroid Build Coastguard Worker 'sinh', 473*da0073e9SAndroid Build Coastguard Worker 'sqrt', 474*da0073e9SAndroid Build Coastguard Worker 'square', 475*da0073e9SAndroid Build Coastguard Worker 'stack', 476*da0073e9SAndroid Build Coastguard Worker 'stft', 477*da0073e9SAndroid Build Coastguard Worker 'sum', 478*da0073e9SAndroid Build Coastguard Worker 'sum_to_size', 479*da0073e9SAndroid Build Coastguard Worker 'tan', 480*da0073e9SAndroid Build Coastguard Worker 'tensordot', 481*da0073e9SAndroid Build Coastguard Worker 'trace', 482*da0073e9SAndroid Build Coastguard Worker 'trapz', 483*da0073e9SAndroid Build Coastguard Worker 'trapezoid', 484*da0073e9SAndroid Build Coastguard Worker 'tril', 485*da0073e9SAndroid Build Coastguard Worker 'triu', 486*da0073e9SAndroid Build Coastguard Worker 'true_divide', 487*da0073e9SAndroid Build Coastguard Worker 'vstack', 488*da0073e9SAndroid Build Coastguard Worker 'where', 489*da0073e9SAndroid Build Coastguard Worker 'byte', 490*da0073e9SAndroid Build Coastguard Worker } 491*da0073e9SAndroid Build Coastguard Worker # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758 492*da0073e9SAndroid Build Coastguard Worker MACOS_12_3_XFAILLIST = { 493*da0073e9SAndroid Build Coastguard Worker # Top 60 494*da0073e9SAndroid Build Coastguard Worker # expected failures 495*da0073e9SAndroid Build Coastguard Worker # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721. 496*da0073e9SAndroid Build Coastguard Worker # fixed in macOS 13.3. Currently error is not raised. 497*da0073e9SAndroid Build Coastguard Worker 'pow': [torch.int16, torch.int64, torch.uint8, torch.int8], 498*da0073e9SAndroid Build Coastguard Worker # expected failures 499*da0073e9SAndroid Build Coastguard Worker '__rpow__': [torch.uint8, torch.int8], 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ 502*da0073e9SAndroid Build Coastguard Worker 'cdist': [torch.float32], 503*da0073e9SAndroid Build Coastguard Worker 'tan': [torch.uint8, torch.float32], 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker # Data type support starts from macOS 13 506*da0073e9SAndroid Build Coastguard Worker 'nn.functional.avg_pool1d': [torch.int64], 507*da0073e9SAndroid Build Coastguard Worker 'nn.functional.avg_pool2d': [torch.int64], 508*da0073e9SAndroid Build Coastguard Worker 'nn.functional.local_response_norm': [torch.int64], 509*da0073e9SAndroid Build Coastguard Worker '__radd__': [torch.uint8], 510*da0073e9SAndroid Build Coastguard Worker '__rdiv__': [torch.uint8], 511*da0073e9SAndroid Build Coastguard Worker '__rmul__': [torch.uint8], 512*da0073e9SAndroid Build Coastguard Worker 'abs': [torch.uint8], 513*da0073e9SAndroid Build Coastguard Worker 'acos': [torch.uint8], 514*da0073e9SAndroid Build Coastguard Worker 'acosh': [torch.uint8], 515*da0073e9SAndroid Build Coastguard Worker 'add': [torch.uint8], 516*da0073e9SAndroid Build Coastguard Worker 'asin': [torch.uint8], 517*da0073e9SAndroid Build Coastguard Worker 'asinh': [torch.uint8], 518*da0073e9SAndroid Build Coastguard Worker 'atan': [torch.uint8], 519*da0073e9SAndroid Build Coastguard Worker 'atanh': [torch.uint8], 520*da0073e9SAndroid Build Coastguard Worker 'ceil': [torch.uint8], 521*da0073e9SAndroid Build Coastguard Worker 'corrcoef': [torch.uint8], 522*da0073e9SAndroid Build Coastguard Worker 'cos': [torch.uint8], 523*da0073e9SAndroid Build Coastguard Worker 'cosh': [torch.uint8], 524*da0073e9SAndroid Build Coastguard Worker 'cov': [torch.uint8], 525*da0073e9SAndroid Build Coastguard Worker 'cumulative_trapezoid': [torch.uint8], 526*da0073e9SAndroid Build Coastguard Worker 'deg2rad': [torch.uint8], 527*da0073e9SAndroid Build Coastguard Worker 'diff': [torch.uint8], 528*da0073e9SAndroid Build Coastguard Worker 'eq': [torch.uint8], 529*da0073e9SAndroid Build Coastguard Worker 'equal': [torch.uint8], 530*da0073e9SAndroid Build Coastguard Worker 'erf': [torch.uint8], 531*da0073e9SAndroid Build Coastguard Worker 'exp2': [torch.uint8], 532*da0073e9SAndroid Build Coastguard Worker 'exp': [torch.uint8], 533*da0073e9SAndroid Build Coastguard Worker 'expm1': [torch.uint8], 534*da0073e9SAndroid Build Coastguard Worker 'floor': [torch.uint8], 535*da0073e9SAndroid Build Coastguard Worker 'fmax': [torch.uint8], 536*da0073e9SAndroid Build Coastguard Worker 'fmin': [torch.uint8], 537*da0073e9SAndroid Build Coastguard Worker 'fmod': [torch.uint8], 538*da0073e9SAndroid Build Coastguard Worker 'ge': [torch.uint8], 539*da0073e9SAndroid Build Coastguard Worker 'gt': [torch.uint8], 540*da0073e9SAndroid Build Coastguard Worker 'isclose': [torch.uint8], 541*da0073e9SAndroid Build Coastguard Worker 'isnan': [torch.uint8], 542*da0073e9SAndroid Build Coastguard Worker 'kron': [torch.uint8], 543*da0073e9SAndroid Build Coastguard Worker 'le': [torch.uint8], 544*da0073e9SAndroid Build Coastguard Worker 'log10': [torch.uint8], 545*da0073e9SAndroid Build Coastguard Worker 'log1p': [torch.uint8], 546*da0073e9SAndroid Build Coastguard Worker 'log2': [torch.uint8], 547*da0073e9SAndroid Build Coastguard Worker 'log': [torch.uint8], 548*da0073e9SAndroid Build Coastguard Worker 'logical_and': [torch.uint8], 549*da0073e9SAndroid Build Coastguard Worker 'logical_or': [torch.uint8], 550*da0073e9SAndroid Build Coastguard Worker 'logical_xor': [torch.uint8], 551*da0073e9SAndroid Build Coastguard Worker 'logit': [torch.uint8], 552*da0073e9SAndroid Build Coastguard Worker 'lt': [torch.uint8], 553*da0073e9SAndroid Build Coastguard Worker 'masked.mean': [torch.uint8], 554*da0073e9SAndroid Build Coastguard Worker 'masked.std': [torch.uint8], 555*da0073e9SAndroid Build Coastguard Worker 'masked.var': [torch.uint8], 556*da0073e9SAndroid Build Coastguard Worker 'maximum': [torch.uint8], 557*da0073e9SAndroid Build Coastguard Worker 'minimum': [torch.uint8], 558*da0073e9SAndroid Build Coastguard Worker 'mul': [torch.uint8], 559*da0073e9SAndroid Build Coastguard Worker 'ne': [torch.uint8], 560*da0073e9SAndroid Build Coastguard Worker 'neg': [torch.uint8], 561*da0073e9SAndroid Build Coastguard Worker 'nn.functional.cosine_embedding_loss': [torch.uint8], 562*da0073e9SAndroid Build Coastguard Worker 'nn.functional.margin_ranking_loss': [torch.uint8], 563*da0073e9SAndroid Build Coastguard Worker 'nn.functional.poisson_nll_loss': [torch.uint8], 564*da0073e9SAndroid Build Coastguard Worker 'nn.functional.softsign': [torch.uint8], 565*da0073e9SAndroid Build Coastguard Worker 'nn.functional.tanhshrink': [torch.uint8], 566*da0073e9SAndroid Build Coastguard Worker 'nn.functional.triplet_margin_loss': [torch.uint8], 567*da0073e9SAndroid Build Coastguard Worker 'nn.functional.triplet_margin_with_distance_loss': [torch.uint8], 568*da0073e9SAndroid Build Coastguard Worker 'nn.functional.pairwise_distance': [torch.uint8], 569*da0073e9SAndroid Build Coastguard Worker 'outer': [torch.uint8], 570*da0073e9SAndroid Build Coastguard Worker 'rad2deg': [torch.uint8], 571*da0073e9SAndroid Build Coastguard Worker 'reciprocal': [torch.uint8], 572*da0073e9SAndroid Build Coastguard Worker 'remainder': [torch.uint8], 573*da0073e9SAndroid Build Coastguard Worker 'round': [torch.uint8], 574*da0073e9SAndroid Build Coastguard Worker 'rsqrt': [torch.uint8], 575*da0073e9SAndroid Build Coastguard Worker 'sigmoid': [torch.uint8], 576*da0073e9SAndroid Build Coastguard Worker 'sign': [torch.uint8], 577*da0073e9SAndroid Build Coastguard Worker 'signbit': [torch.uint8], 578*da0073e9SAndroid Build Coastguard Worker 'sin': [torch.uint8], 579*da0073e9SAndroid Build Coastguard Worker 'sinh': [torch.uint8], 580*da0073e9SAndroid Build Coastguard Worker 'special.ndtr': [torch.uint8], 581*da0073e9SAndroid Build Coastguard Worker 'sqrt': [torch.uint8], 582*da0073e9SAndroid Build Coastguard Worker 'sub': [torch.uint8], 583*da0073e9SAndroid Build Coastguard Worker 'trapezoid': [torch.uint8], 584*da0073e9SAndroid Build Coastguard Worker 'trapz': [torch.uint8], 585*da0073e9SAndroid Build Coastguard Worker 'true_divide': [torch.uint8], 586*da0073e9SAndroid Build Coastguard Worker 'trunc': [torch.uint8], 587*da0073e9SAndroid Build Coastguard Worker 'xlogy': [torch.uint8], 588*da0073e9SAndroid Build Coastguard Worker 'minbinary': [torch.uint8], 589*da0073e9SAndroid Build Coastguard Worker 'maxbinary': [torch.uint8], 590*da0073e9SAndroid Build Coastguard Worker 'divtrunc_rounding': [torch.uint8], 591*da0073e9SAndroid Build Coastguard Worker 'divfloor_rounding': [torch.uint8], 592*da0073e9SAndroid Build Coastguard Worker 'divno_rounding_mode': [torch.uint8], 593*da0073e9SAndroid Build Coastguard Worker 'floor_divide': [torch.uint8], 594*da0073e9SAndroid Build Coastguard Worker 'ldexp': [torch.uint8], 595*da0073e9SAndroid Build Coastguard Worker # square internally calls into power, and will type cast to int64, which supports starting from macOS 13 596*da0073e9SAndroid Build Coastguard Worker 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker # cpu not giving nan for x/0.0 599*da0073e9SAndroid Build Coastguard Worker 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker # inconsistency errors between cpu and mps, max seen atol is 2 602*da0073e9SAndroid Build Coastguard Worker 'nn.functional.interpolatebilinear': [torch.uint8], 603*da0073e9SAndroid Build Coastguard Worker } 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker MACOS_BEFORE_13_3_XFAILLIST = { 606*da0073e9SAndroid Build Coastguard Worker # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ 607*da0073e9SAndroid Build Coastguard Worker 'tan': [torch.float32], 608*da0073e9SAndroid Build Coastguard Worker 'cdist': [torch.float32], 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker # CPU Error: cpu not giving nan for x/0.0 611*da0073e9SAndroid Build Coastguard Worker 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker # test blow pass on macOS 12 as it falls back to cpu 614*da0073e9SAndroid Build Coastguard Worker # Argsort case using duplicate indices (undefined behaviour): 615*da0073e9SAndroid Build Coastguard Worker # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu') 616*da0073e9SAndroid Build Coastguard Worker # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') 617*da0073e9SAndroid Build Coastguard Worker # Elements from index 30 and 5133 are both equal. 618*da0073e9SAndroid Build Coastguard Worker # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. 619*da0073e9SAndroid Build Coastguard Worker 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool], 620*da0073e9SAndroid Build Coastguard Worker # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. 621*da0073e9SAndroid Build Coastguard Worker # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour. 622*da0073e9SAndroid Build Coastguard Worker 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16], 623*da0073e9SAndroid Build Coastguard Worker # Unsupported dtypes 624*da0073e9SAndroid Build Coastguard Worker 'cumsum': [torch.int64], 625*da0073e9SAndroid Build Coastguard Worker 'cumprod': [torch.int64], 626*da0073e9SAndroid Build Coastguard Worker 'cumulative_trapezoid': [torch.int64], 627*da0073e9SAndroid Build Coastguard Worker 'masked.cumsum': [torch.int64], 628*da0073e9SAndroid Build Coastguard Worker 'masked.cumprod': [torch.int64], 629*da0073e9SAndroid Build Coastguard Worker 'linalg.vander': [torch.int64], 630*da0073e9SAndroid Build Coastguard Worker } 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker MACOS_AFTER_13_1_XFAILLIST = { 633*da0073e9SAndroid Build Coastguard Worker # before macOS 13.2 it falls back to cpu and pass the forward pass 634*da0073e9SAndroid Build Coastguard Worker 'grid_sampler_2d': [torch.float32], # Unsupported Border padding mode 635*da0073e9SAndroid Build Coastguard Worker # inconsistency errors between cpu and mps, max seen atol is 2 636*da0073e9SAndroid Build Coastguard Worker 'nn.functional.interpolatebilinear': [torch.uint8], 637*da0073e9SAndroid Build Coastguard Worker } 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker MACOS_13_3_XFAILLIST = { 640*da0073e9SAndroid Build Coastguard Worker # Failure due to precision issue for fp16 641*da0073e9SAndroid Build Coastguard Worker # on both cpu and mps there are test cases that might produce inf result 642*da0073e9SAndroid Build Coastguard Worker # 'nn.functional.pairwise_distance': [torch.float16], 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Worker # test blow pass on macOS 12 as it falls back to cpu 645*da0073e9SAndroid Build Coastguard Worker # Argsort case using duplicate indices (undefined behaviour): 646*da0073e9SAndroid Build Coastguard Worker # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu') 647*da0073e9SAndroid Build Coastguard Worker # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') 648*da0073e9SAndroid Build Coastguard Worker # Elements from index 30 and 5133 are both equal. 649*da0073e9SAndroid Build Coastguard Worker # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. 650*da0073e9SAndroid Build Coastguard Worker 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool], 651*da0073e9SAndroid Build Coastguard Worker # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. 652*da0073e9SAndroid Build Coastguard Worker # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour. 653*da0073e9SAndroid Build Coastguard Worker 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16], 654*da0073e9SAndroid Build Coastguard Worker } 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker MACOS_BEFORE_14_4_XFAILLIST = { 657*da0073e9SAndroid Build Coastguard Worker # These ops work fine in 14.4 but fail in 14.2 or 13.x 658*da0073e9SAndroid Build Coastguard Worker 'fft.hfft2': [torch.complex64], 659*da0073e9SAndroid Build Coastguard Worker } 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker # Those ops are not expected to work 662*da0073e9SAndroid Build Coastguard Worker UNIMPLEMENTED_XFAILLIST = { 663*da0073e9SAndroid Build Coastguard Worker # Failures due to lack of op implementation on MPS backend 664*da0073e9SAndroid Build Coastguard Worker 'login': None, 665*da0073e9SAndroid Build Coastguard Worker 'linalg.eig': None, 666*da0073e9SAndroid Build Coastguard Worker 'linalg.eigvals': None, 667*da0073e9SAndroid Build Coastguard Worker 'put': None, 668*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv_transpose3d': None, 669*da0073e9SAndroid Build Coastguard Worker 'rounddecimals_neg_3': None, 670*da0073e9SAndroid Build Coastguard Worker 'rounddecimals_3': None, 671*da0073e9SAndroid Build Coastguard Worker 'rounddecimals_0': None, 672*da0073e9SAndroid Build Coastguard Worker '__rsub__': None, 673*da0073e9SAndroid Build Coastguard Worker 'angle': None, 674*da0073e9SAndroid Build Coastguard Worker 'cauchy_': None, 675*da0073e9SAndroid Build Coastguard Worker 'cauchy': None, 676*da0073e9SAndroid Build Coastguard Worker 'cholesky': None, 677*da0073e9SAndroid Build Coastguard Worker 'cholesky_inverse': None, 678*da0073e9SAndroid Build Coastguard Worker 'cholesky_solve': None, 679*da0073e9SAndroid Build Coastguard Worker 'cummax': None, 680*da0073e9SAndroid Build Coastguard Worker 'cummin': None, 681*da0073e9SAndroid Build Coastguard Worker 'erfc': None, 682*da0073e9SAndroid Build Coastguard Worker 'frexp': None, 683*da0073e9SAndroid Build Coastguard Worker 'gcd': None, 684*da0073e9SAndroid Build Coastguard Worker 'geqrf': None, 685*da0073e9SAndroid Build Coastguard Worker 'nn.functional.grid_sample': None, # Unsupported Border padding mode 686*da0073e9SAndroid Build Coastguard Worker 'heaviside': None, 687*da0073e9SAndroid Build Coastguard Worker 'i0': None, 688*da0073e9SAndroid Build Coastguard Worker 'igamma': None, 689*da0073e9SAndroid Build Coastguard Worker 'igammac': None, 690*da0073e9SAndroid Build Coastguard Worker 'index_copy': None, 691*da0073e9SAndroid Build Coastguard Worker 'index_reduceprod': None, 692*da0073e9SAndroid Build Coastguard Worker 'index_reducemean': None, 693*da0073e9SAndroid Build Coastguard Worker 'index_reduceamax': None, 694*da0073e9SAndroid Build Coastguard Worker 'index_reduceamin': None, 695*da0073e9SAndroid Build Coastguard Worker 'isneginf': None, 696*da0073e9SAndroid Build Coastguard Worker 'isposinf': None, 697*da0073e9SAndroid Build Coastguard Worker 'kthvalue': None, 698*da0073e9SAndroid Build Coastguard Worker 'lcm': None, 699*da0073e9SAndroid Build Coastguard Worker 'linalg.cholesky': None, 700*da0073e9SAndroid Build Coastguard Worker 'linalg.cholesky_ex': None, 701*da0073e9SAndroid Build Coastguard Worker 'linalg.cond': None, 702*da0073e9SAndroid Build Coastguard Worker 'linalg.detsingular': None, 703*da0073e9SAndroid Build Coastguard Worker 'linalg.det': None, 704*da0073e9SAndroid Build Coastguard Worker 'linalg.eigh': None, 705*da0073e9SAndroid Build Coastguard Worker 'linalg.eigvalsh': None, 706*da0073e9SAndroid Build Coastguard Worker 'linalg.householder_product': None, 707*da0073e9SAndroid Build Coastguard Worker 'linalg.ldl_factor': None, 708*da0073e9SAndroid Build Coastguard Worker 'linalg.ldl_factor_ex': None, 709*da0073e9SAndroid Build Coastguard Worker 'linalg.ldl_solve': None, 710*da0073e9SAndroid Build Coastguard Worker 'linalg.lstsq': None, 711*da0073e9SAndroid Build Coastguard Worker 'linalg.lstsqgrad_oriented': None, 712*da0073e9SAndroid Build Coastguard Worker 'linalg.lu': None, 713*da0073e9SAndroid Build Coastguard Worker 'linalg.lu_factor_ex': None, 714*da0073e9SAndroid Build Coastguard Worker 'linalg.lu_solve': None, 715*da0073e9SAndroid Build Coastguard Worker 'linalg.matrix_norm': [torch.float32], 716*da0073e9SAndroid Build Coastguard Worker 'linalg.norm': [torch.float32], 717*da0073e9SAndroid Build Coastguard Worker 'linalg.normsubgradients_at_zero': [torch.float32], 718*da0073e9SAndroid Build Coastguard Worker 'linalg.qr': None, 719*da0073e9SAndroid Build Coastguard Worker 'linalg.slogdet': None, 720*da0073e9SAndroid Build Coastguard Worker 'linalg.solve': None, 721*da0073e9SAndroid Build Coastguard Worker 'linalg.solve_ex': None, 722*da0073e9SAndroid Build Coastguard Worker 'linalg.svdvals': None, 723*da0073e9SAndroid Build Coastguard Worker 'linalg.tensorsolve': None, 724*da0073e9SAndroid Build Coastguard Worker 'linalg.vecdot': None, 725*da0073e9SAndroid Build Coastguard Worker 'logcumsumexp': None, 726*da0073e9SAndroid Build Coastguard Worker 'logdet': None, 727*da0073e9SAndroid Build Coastguard Worker 'lu': None, 728*da0073e9SAndroid Build Coastguard Worker 'lu_solve': None, 729*da0073e9SAndroid Build Coastguard Worker 'lu_unpack': None, 730*da0073e9SAndroid Build Coastguard Worker 'masked.median': None, 731*da0073e9SAndroid Build Coastguard Worker 'matrix_exp': None, 732*da0073e9SAndroid Build Coastguard Worker 'mode': None, 733*da0073e9SAndroid Build Coastguard Worker 'nanmedian': None, 734*da0073e9SAndroid Build Coastguard Worker 'native_dropout_backward': None, 735*da0073e9SAndroid Build Coastguard Worker 'normnuc': None, 736*da0073e9SAndroid Build Coastguard Worker 'nn.functional.fractional_max_pool2d': None, 737*da0073e9SAndroid Build Coastguard Worker 'nn.functional.fractional_max_pool3d': None, 738*da0073e9SAndroid Build Coastguard Worker 'nn.functional.adaptive_avg_pool3d': None, 739*da0073e9SAndroid Build Coastguard Worker 'nn.functional.adaptive_max_pool3d': None, 740*da0073e9SAndroid Build Coastguard Worker 'nn.functional.interpolatearea': None, 741*da0073e9SAndroid Build Coastguard Worker 'nn.functional.interpolatebicubic': None, 742*da0073e9SAndroid Build Coastguard Worker 'nn.functional.interpolatetrilinear': None, 743*da0073e9SAndroid Build Coastguard Worker 'nn.functional.max_unpool1dgrad': None, 744*da0073e9SAndroid Build Coastguard Worker 'nn.functional.max_unpool2dgrad': None, 745*da0073e9SAndroid Build Coastguard Worker 'nn.functional.max_unpool3dgrad': None, 746*da0073e9SAndroid Build Coastguard Worker 'nn.functional.avg_pool3d': None, 747*da0073e9SAndroid Build Coastguard Worker 'nn.functional.ctc_loss': None, 748*da0073e9SAndroid Build Coastguard Worker 'nn.functional.embedding_bag': None, 749*da0073e9SAndroid Build Coastguard Worker 'nn.functional.hardshrink': None, 750*da0073e9SAndroid Build Coastguard Worker 'nn.functional.max_pool3d': None, 751*da0073e9SAndroid Build Coastguard Worker 'nn.functional.max_unpool1d': None, 752*da0073e9SAndroid Build Coastguard Worker 'nn.functional.max_unpool2d': None, 753*da0073e9SAndroid Build Coastguard Worker 'nn.functional.max_unpool3d': None, 754*da0073e9SAndroid Build Coastguard Worker 'nn.functional.multi_margin_loss': None, 755*da0073e9SAndroid Build Coastguard Worker 'nn.functional.multilabel_margin_loss': None, 756*da0073e9SAndroid Build Coastguard Worker 'nn.functional.pdist': None, 757*da0073e9SAndroid Build Coastguard Worker 'nn.functional.rrelu': None, 758*da0073e9SAndroid Build Coastguard Worker 'nn.functional.norm': None, 759*da0073e9SAndroid Build Coastguard Worker 'ormqr': None, 760*da0073e9SAndroid Build Coastguard Worker 'pca_lowrank': None, 761*da0073e9SAndroid Build Coastguard Worker 'qr': None, 762*da0073e9SAndroid Build Coastguard Worker 'rsub': None, 763*da0073e9SAndroid Build Coastguard Worker 'scatter_reduceamax': None, 764*da0073e9SAndroid Build Coastguard Worker 'scatter_reduceamin': None, 765*da0073e9SAndroid Build Coastguard Worker 'scatter_reducemin': None, 766*da0073e9SAndroid Build Coastguard Worker 'scatter_reducemean': None, 767*da0073e9SAndroid Build Coastguard Worker 'scatter_reduceprod': None, 768*da0073e9SAndroid Build Coastguard Worker 'scatter_reducesum': None, 769*da0073e9SAndroid Build Coastguard Worker 'segment_reduce': None, 770*da0073e9SAndroid Build Coastguard Worker '_segment.reduce': None, 771*da0073e9SAndroid Build Coastguard Worker 'segment.reduce': None, 772*da0073e9SAndroid Build Coastguard Worker 'segment_reduce_offsets': None, 773*da0073e9SAndroid Build Coastguard Worker '_segment_reduce_offsets': None, 774*da0073e9SAndroid Build Coastguard Worker '_segment_reduce_lengths': None, 775*da0073e9SAndroid Build Coastguard Worker '_segment_reducelengths': None, 776*da0073e9SAndroid Build Coastguard Worker '_segment_reduceoffsets': None, 777*da0073e9SAndroid Build Coastguard Worker 'sinc': None, 778*da0073e9SAndroid Build Coastguard Worker 'sparse.mm': None, 779*da0073e9SAndroid Build Coastguard Worker 'sparse.mmreduce': None, 780*da0073e9SAndroid Build Coastguard Worker 'special.airy_ai': None, 781*da0073e9SAndroid Build Coastguard Worker 'special.bessel_j0': None, 782*da0073e9SAndroid Build Coastguard Worker 'special.bessel_j1': None, 783*da0073e9SAndroid Build Coastguard Worker 'special.bessel_y0': None, 784*da0073e9SAndroid Build Coastguard Worker 'special.bessel_y1': None, 785*da0073e9SAndroid Build Coastguard Worker 'special.chebyshev_polynomial_t': None, 786*da0073e9SAndroid Build Coastguard Worker 'special.chebyshev_polynomial_u': None, 787*da0073e9SAndroid Build Coastguard Worker 'special.entr': None, 788*da0073e9SAndroid Build Coastguard Worker 'special.erfcx': None, 789*da0073e9SAndroid Build Coastguard Worker 'special.hermite_polynomial_h': None, 790*da0073e9SAndroid Build Coastguard Worker 'special.hermite_polynomial_he': None, 791*da0073e9SAndroid Build Coastguard Worker 'special.i0e': None, 792*da0073e9SAndroid Build Coastguard Worker 'special.i1': None, 793*da0073e9SAndroid Build Coastguard Worker 'special.i1e': None, 794*da0073e9SAndroid Build Coastguard Worker 'special.laguerre_polynomial_l': None, 795*da0073e9SAndroid Build Coastguard Worker 'special.log_ndtr': None, 796*da0073e9SAndroid Build Coastguard Worker 'special.modified_bessel_i0': None, 797*da0073e9SAndroid Build Coastguard Worker 'special.modified_bessel_i1': None, 798*da0073e9SAndroid Build Coastguard Worker 'special.modified_bessel_k0': None, 799*da0073e9SAndroid Build Coastguard Worker 'special.modified_bessel_k1': None, 800*da0073e9SAndroid Build Coastguard Worker 'special.ndtri': None, 801*da0073e9SAndroid Build Coastguard Worker 'special.scaled_modified_bessel_k0': None, 802*da0073e9SAndroid Build Coastguard Worker 'special.scaled_modified_bessel_k1': None, 803*da0073e9SAndroid Build Coastguard Worker 'special.spherical_bessel_j0': None, 804*da0073e9SAndroid Build Coastguard Worker 'special.xlog1py': None, 805*da0073e9SAndroid Build Coastguard Worker 'special.zeta': None, 806*da0073e9SAndroid Build Coastguard Worker 'svd_lowrank': None, 807*da0073e9SAndroid Build Coastguard Worker 'symeig': None, 808*da0073e9SAndroid Build Coastguard Worker 'take': None, 809*da0073e9SAndroid Build Coastguard Worker 'to': None, 810*da0073e9SAndroid Build Coastguard Worker 'to_sparse': None, 811*da0073e9SAndroid Build Coastguard Worker 'unique': None, 812*da0073e9SAndroid Build Coastguard Worker 'vdot': None, 813*da0073e9SAndroid Build Coastguard Worker 'segment_reduce_': None, 814*da0073e9SAndroid Build Coastguard Worker '_upsample_bilinear2d_aa': None, 815*da0073e9SAndroid Build Coastguard Worker 'geometric' : None, 816*da0073e9SAndroid Build Coastguard Worker 'geometric_': None, 817*da0073e9SAndroid Build Coastguard Worker 'log_normal_': None, 818*da0073e9SAndroid Build Coastguard Worker 'log_normal': None, 819*da0073e9SAndroid Build Coastguard Worker 'cdouble': None, 820*da0073e9SAndroid Build Coastguard Worker 'double': None, 821*da0073e9SAndroid Build Coastguard Worker 'nn.functional.softminwith_dtype': None, 822*da0073e9SAndroid Build Coastguard Worker 'log_softmaxwith_dtype': None, 823*da0073e9SAndroid Build Coastguard Worker 'softmaxwith_dtype': None, 824*da0073e9SAndroid Build Coastguard Worker 'float_power': None, 825*da0073e9SAndroid Build Coastguard Worker 'full_like': None, 826*da0073e9SAndroid Build Coastguard Worker 'linalg.matrix_rankhermitian': None, 827*da0073e9SAndroid Build Coastguard Worker 'linalg.pinvhermitian': None, 828*da0073e9SAndroid Build Coastguard Worker 'nonzero_static': None, 829*da0073e9SAndroid Build Coastguard Worker 830*da0073e9SAndroid Build Coastguard Worker # MPS: input sizes must be divisible by output sizes 831*da0073e9SAndroid Build Coastguard Worker 'nn.functional.adaptive_avg_pool1d': None, 832*da0073e9SAndroid Build Coastguard Worker 'nn.functional.adaptive_avg_pool2d': None, 833*da0073e9SAndroid Build Coastguard Worker 834*da0073e9SAndroid Build Coastguard Worker # Unsupported dtypes 835*da0073e9SAndroid Build Coastguard Worker # bmm is not supported for integral types 836*da0073e9SAndroid Build Coastguard Worker 'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 837*da0073e9SAndroid Build Coastguard Worker 'ones_like': None, 838*da0073e9SAndroid Build Coastguard Worker 'zeros_like': None, 839*da0073e9SAndroid Build Coastguard Worker 840*da0073e9SAndroid Build Coastguard Worker # Convolution for integral types is not supported on MPS 841*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv1d': [torch.int64], 842*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv2d': [torch.int64], 843*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv3d': [torch.int64], 844*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv_transpose1d': [torch.int64], 845*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv_transpose2d': [torch.int64], 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker # Unsupported dtypes 848*da0073e9SAndroid Build Coastguard Worker 'dot': [torch.int64], 849*da0073e9SAndroid Build Coastguard Worker 'histc': [torch.float16], 850*da0073e9SAndroid Build Coastguard Worker 'index_add': [torch.int64], 851*da0073e9SAndroid Build Coastguard Worker 'log1p': [torch.int64], 852*da0073e9SAndroid Build Coastguard Worker 'sigmoid': [torch.int64], 853*da0073e9SAndroid Build Coastguard Worker 'atan2': [torch.int64], 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker # GEMM on MPS is not supported for integral types 856*da0073e9SAndroid Build Coastguard Worker 'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 857*da0073e9SAndroid Build Coastguard Worker '__rmatmul__': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 858*da0073e9SAndroid Build Coastguard Worker 'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 859*da0073e9SAndroid Build Coastguard Worker 'addbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 860*da0073e9SAndroid Build Coastguard Worker 'addmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 861*da0073e9SAndroid Build Coastguard Worker 'addmv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 862*da0073e9SAndroid Build Coastguard Worker 'baddbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 863*da0073e9SAndroid Build Coastguard Worker 'mm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 864*da0073e9SAndroid Build Coastguard Worker 'bmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 865*da0073e9SAndroid Build Coastguard Worker 'einsum': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 866*da0073e9SAndroid Build Coastguard Worker 'inner': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 867*da0073e9SAndroid Build Coastguard Worker 'linalg.multi_dot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 868*da0073e9SAndroid Build Coastguard Worker 'matmul': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 869*da0073e9SAndroid Build Coastguard Worker 'mat': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 870*da0073e9SAndroid Build Coastguard Worker 'mv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 871*da0073e9SAndroid Build Coastguard Worker 'tensordot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 872*da0073e9SAndroid Build Coastguard Worker 'unravel_index': [torch.int32, torch.int64], 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker # new_zeros/new_ones: Cannot convert a MPS Tensor to float64 dtype as 875*da0073e9SAndroid Build Coastguard Worker # the MPS framework doesn't support float64 876*da0073e9SAndroid Build Coastguard Worker 'new_zeros': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 877*da0073e9SAndroid Build Coastguard Worker 'new_ones': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 878*da0073e9SAndroid Build Coastguard Worker 'new_full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 879*da0073e9SAndroid Build Coastguard Worker # returned output on CPU is float64 880*da0073e9SAndroid Build Coastguard Worker 'bincount': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker # trunc_tensor not working properly for float16 883*da0073e9SAndroid Build Coastguard Worker 'divtrunc_rounding': [torch.float16], 884*da0073e9SAndroid Build Coastguard Worker 'fmod': [torch.float16], 885*da0073e9SAndroid Build Coastguard Worker 886*da0073e9SAndroid Build Coastguard Worker # round not working properly for float16 887*da0073e9SAndroid Build Coastguard Worker 'round': [torch.float16], 888*da0073e9SAndroid Build Coastguard Worker 889*da0073e9SAndroid Build Coastguard Worker # atomic operations not supported 890*da0073e9SAndroid Build Coastguard Worker '_unsafe_masked_index_put_accumulate': [torch.bool, torch.int8, torch.uint8, torch.float16, torch.int16, torch.int64], 891*da0073e9SAndroid Build Coastguard Worker } 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker if product_version < 14.0: 894*da0073e9SAndroid Build Coastguard Worker # FFT and BFloat16 support was added in MacOS 14 895*da0073e9SAndroid Build Coastguard Worker UNIMPLEMENTED_XFAILLIST.update({ 896*da0073e9SAndroid Build Coastguard Worker 'bfloat16': None, 897*da0073e9SAndroid Build Coastguard Worker 'fft.fft': None, 898*da0073e9SAndroid Build Coastguard Worker 'fft.fft2': None, 899*da0073e9SAndroid Build Coastguard Worker 'fft.fftn': None, 900*da0073e9SAndroid Build Coastguard Worker 'fft.hfft': None, 901*da0073e9SAndroid Build Coastguard Worker 'fft.hfft2': None, 902*da0073e9SAndroid Build Coastguard Worker 'fft.hfftn': None, 903*da0073e9SAndroid Build Coastguard Worker 'fft.ifft': None, 904*da0073e9SAndroid Build Coastguard Worker 'fft.ifft2': None, 905*da0073e9SAndroid Build Coastguard Worker 'fft.ifftn': None, 906*da0073e9SAndroid Build Coastguard Worker 'fft.ihfft': None, 907*da0073e9SAndroid Build Coastguard Worker 'fft.ihfft2': None, 908*da0073e9SAndroid Build Coastguard Worker 'fft.ihfftn': None, 909*da0073e9SAndroid Build Coastguard Worker 'fft.irfft': None, 910*da0073e9SAndroid Build Coastguard Worker 'fft.irfft2': None, 911*da0073e9SAndroid Build Coastguard Worker 'fft.irfftn': None, 912*da0073e9SAndroid Build Coastguard Worker 'fft.rfft': None, 913*da0073e9SAndroid Build Coastguard Worker 'fft.rfft2': None, 914*da0073e9SAndroid Build Coastguard Worker 'fft.rfftn': None, 915*da0073e9SAndroid Build Coastguard Worker 'stft': None, 916*da0073e9SAndroid Build Coastguard Worker # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers, 917*da0073e9SAndroid Build Coastguard Worker # not reproducible in later OS. Added assert to op if used in < 14.0 918*da0073e9SAndroid Build Coastguard Worker 'isin': [torch.int64, torch.int32, torch.int16, torch.uint8, torch.int8], 919*da0073e9SAndroid Build Coastguard Worker 'nn.functional.max_pool2d': [torch.uint8], 920*da0073e9SAndroid Build Coastguard Worker }) 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker if product_version < 15.0: 923*da0073e9SAndroid Build Coastguard Worker UNIMPLEMENTED_XFAILLIST.update({ 924*da0073e9SAndroid Build Coastguard Worker 'quantile': None, 925*da0073e9SAndroid Build Coastguard Worker 'nanquantile': None, 926*da0073e9SAndroid Build Coastguard Worker }) 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker UNDEFINED_XFAILLIST = { 929*da0073e9SAndroid Build Coastguard Worker # Top 60 operators 930*da0073e9SAndroid Build Coastguard Worker # topk fails with duplicate indices 931*da0073e9SAndroid Build Coastguard Worker 'topk': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 932*da0073e9SAndroid Build Coastguard Worker 933*da0073e9SAndroid Build Coastguard Worker # Failures due to random output that they generate using 934*da0073e9SAndroid Build Coastguard Worker # Philox engine causing mismatch with CPU results 935*da0073e9SAndroid Build Coastguard Worker 'multinomial': [torch.float16, torch.float32], # random results 936*da0073e9SAndroid Build Coastguard Worker 'uniform': [torch.float16, torch.float32], 937*da0073e9SAndroid Build Coastguard Worker 'rand_like': [torch.float16, torch.float32], 938*da0073e9SAndroid Build Coastguard Worker 'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 939*da0073e9SAndroid Build Coastguard Worker 'randn_like': [torch.float16, torch.float32], 940*da0073e9SAndroid Build Coastguard Worker 'bernoulli': [torch.float16, torch.float32], 941*da0073e9SAndroid Build Coastguard Worker 'exponential': [torch.float16, torch.float32], 942*da0073e9SAndroid Build Coastguard Worker 'nn.functional.feature_alpha_dropoutwith_train': [torch.float16, torch.float32], 943*da0073e9SAndroid Build Coastguard Worker 'normal': [torch.float16, torch.float32, torch.float16, torch.float32], 944*da0073e9SAndroid Build Coastguard Worker 'normalin_place': [torch.float16, torch.float32], 945*da0073e9SAndroid Build Coastguard Worker 'normalnumber_mean': [torch.float16, torch.float32], 946*da0073e9SAndroid Build Coastguard Worker 'nn.functional.alpha_dropout': [torch.float16, torch.float32], 947*da0073e9SAndroid Build Coastguard Worker 'nn.functional.dropout': [torch.float16, torch.float32], 948*da0073e9SAndroid Build Coastguard Worker 'nn.functional.dropout2d': [torch.float16, torch.float32], 949*da0073e9SAndroid Build Coastguard Worker 'nn.functional.dropout3d': [torch.float16, torch.float32], 950*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/111479 951*da0073e9SAndroid Build Coastguard Worker 'nn.functional.multi_head_attention_forward': [torch.float32, torch.float16], 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker # duplicate indices are used in the testcase - undefined behaviour 954*da0073e9SAndroid Build Coastguard Worker 'index_put': None, 955*da0073e9SAndroid Build Coastguard Worker # zero to negative integer powers are undefined 956*da0073e9SAndroid Build Coastguard Worker '__rpow__': [torch.int8, torch.int16, torch.int32, torch.int64], 957*da0073e9SAndroid Build Coastguard Worker 'resize_': [torch.float16, torch.float32], 958*da0073e9SAndroid Build Coastguard Worker 'resize_as_': [torch.float16, torch.float32], 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker # CPU Errors: 961*da0073e9SAndroid Build Coastguard Worker 'addr': [torch.bool, torch.int16, torch.int32, 962*da0073e9SAndroid Build Coastguard Worker torch.int64, torch.uint8, torch.int8], # "addmv_impl_cpu" not implemented for 'Half' 963*da0073e9SAndroid Build Coastguard Worker 'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16, 964*da0073e9SAndroid Build Coastguard Worker torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values 965*da0073e9SAndroid Build Coastguard Worker 'as_strided_partial_views': [torch.bool, torch.float16, torch.float32, torch.int16, 966*da0073e9SAndroid Build Coastguard Worker torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker # random results 969*da0073e9SAndroid Build Coastguard Worker # mps vs cpu: 970*da0073e9SAndroid Build Coastguard Worker # Mismatched elements: 40 / 96 (41.7%) 971*da0073e9SAndroid Build Coastguard Worker # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed) 972*da0073e9SAndroid Build Coastguard Worker # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed) 973*da0073e9SAndroid Build Coastguard Worker # cuda(2.0.0.dev20230301+cu117) vs cpu: 974*da0073e9SAndroid Build Coastguard Worker # Mismatched elements: 56 / 96 (58.3%) 975*da0073e9SAndroid Build Coastguard Worker # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed) 976*da0073e9SAndroid Build Coastguard Worker # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed) 977*da0073e9SAndroid Build Coastguard Worker 'nn.functional.scaled_dot_product_attention': [torch.float32, torch.float16], 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker # float output for float16 input on MPS 980*da0073e9SAndroid Build Coastguard Worker 'logit': [torch.float16], 981*da0073e9SAndroid Build Coastguard Worker } 982*da0073e9SAndroid Build Coastguard Worker 983*da0073e9SAndroid Build Coastguard Worker ON_MPS_XFAILLIST = { 984*da0073e9SAndroid Build Coastguard Worker # Failures due to lack of implementation of downstream functions on MPS backend 985*da0073e9SAndroid Build Coastguard Worker # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented 986*da0073e9SAndroid Build Coastguard Worker 'linalg.matrix_rank': None, 987*da0073e9SAndroid Build Coastguard Worker } 988*da0073e9SAndroid Build Coastguard Worker 989*da0073e9SAndroid Build Coastguard Worker EMPTY_OPS_SKIPLIST = { 990*da0073e9SAndroid Build Coastguard Worker # Fill tensors with uninitialized data, causing mismatch with CPU. 991*da0073e9SAndroid Build Coastguard Worker # They occasionally match, thus skipping them. 992*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/100175 993*da0073e9SAndroid Build Coastguard Worker 'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 994*da0073e9SAndroid Build Coastguard Worker 'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, 995*da0073e9SAndroid Build Coastguard Worker torch.int32, torch.int64, torch.uint8, torch.int8], 996*da0073e9SAndroid Build Coastguard Worker 'empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 997*da0073e9SAndroid Build Coastguard Worker # CPU: empty is returning all 0's and there is a mismatch with MPS 998*da0073e9SAndroid Build Coastguard Worker # allocation (MacOS 13). According to 999*da0073e9SAndroid Build Coastguard Worker # https://pytorch.org/docs/2.0/generated/torch.empty.html 1000*da0073e9SAndroid Build Coastguard Worker 'empty': [torch.bool, torch.float16, torch.float32, torch.int16, 1001*da0073e9SAndroid Build Coastguard Worker torch.int32, torch.int64, torch.uint8, torch.int8], 1002*da0073e9SAndroid Build Coastguard Worker 'empty_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 1003*da0073e9SAndroid Build Coastguard Worker 'empty_permuted': [torch.bool, torch.float16, torch.float32, torch.int16, 1004*da0073e9SAndroid Build Coastguard Worker torch.int32, torch.int64, torch.uint8, torch.int8], 1005*da0073e9SAndroid Build Coastguard Worker } 1006*da0073e9SAndroid Build Coastguard Worker 1007*da0073e9SAndroid Build Coastguard Worker SKIPLIST = { 1008*da0073e9SAndroid Build Coastguard Worker # Unsupported 1009*da0073e9SAndroid Build Coastguard Worker # input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible 1010*da0073e9SAndroid Build Coastguard Worker 'nn.functional.avg_pool2d': [torch.float16], 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16 1013*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv3d': None, 1014*da0073e9SAndroid Build Coastguard Worker } 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker def addDecorator(op, d) -> None: 1017*da0073e9SAndroid Build Coastguard Worker op.decorators = list(op.decorators) if op.decorators is not None else [] 1018*da0073e9SAndroid Build Coastguard Worker op.decorators.append(d) 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Worker for op in ops: 1021*da0073e9SAndroid Build Coastguard Worker key = op.name + op.variant_test_name 1022*da0073e9SAndroid Build Coastguard Worker if key in EMPTY_OPS_SKIPLIST: 1023*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 1024*da0073e9SAndroid Build Coastguard Worker unittest.skip("Skipping empty ops."), 1025*da0073e9SAndroid Build Coastguard Worker dtypes=EMPTY_OPS_SKIPLIST[key])) 1026*da0073e9SAndroid Build Coastguard Worker if key in SKIPLIST: 1027*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key])) 1028*da0073e9SAndroid Build Coastguard Worker for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST, ON_MPS_XFAILLIST]: 1029*da0073e9SAndroid Build Coastguard Worker if key in xfaillist: 1030*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 1031*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 1032*da0073e9SAndroid Build Coastguard Worker dtypes=xfaillist[key])) 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Worker if key in MACOS_BEFORE_14_4_XFAILLIST and (product_version < 14.4): 1035*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 1036*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 1037*da0073e9SAndroid Build Coastguard Worker dtypes=MACOS_BEFORE_14_4_XFAILLIST[key])) 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3): 1040*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 1041*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 1042*da0073e9SAndroid Build Coastguard Worker dtypes=MACOS_BEFORE_13_3_XFAILLIST[key])) 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker if key in MACOS_AFTER_13_1_XFAILLIST and torch.backends.mps.is_macos13_or_newer(2): 1045*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 1046*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 1047*da0073e9SAndroid Build Coastguard Worker dtypes=MACOS_AFTER_13_1_XFAILLIST[key])) 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker if key in MACOS_13_3_XFAILLIST and (product_version >= 13.3): 1050*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 1051*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 1052*da0073e9SAndroid Build Coastguard Worker dtypes=MACOS_13_3_XFAILLIST[key])) 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Worker if key in MACOS_12_3_XFAILLIST and (not torch.backends.mps.is_macos13_or_newer()): 1055*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo( 1056*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 1057*da0073e9SAndroid Build Coastguard Worker dtypes=MACOS_12_3_XFAILLIST[key])) 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker # If ops is not supported for complex types, expect it to fail 1060*da0073e9SAndroid Build Coastguard Worker if key not in SUPPORTED_COMPLEX_OPS and (key not in AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS or product_version < 14.0): 1061*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo(unittest.expectedFailure, dtypes=[torch.complex32, torch.complex64])) 1062*da0073e9SAndroid Build Coastguard Worker 1063*da0073e9SAndroid Build Coastguard Worker yield op 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Workerdef mps_ops_error_inputs_modifier(ops): 1066*da0073e9SAndroid Build Coastguard Worker # Error input samples do not take a dtype argument. 1067*da0073e9SAndroid Build Coastguard Worker XFAILLIST = { 1068*da0073e9SAndroid Build Coastguard Worker # Exceptions are not raised 1069*da0073e9SAndroid Build Coastguard Worker '__rmod__', 1070*da0073e9SAndroid Build Coastguard Worker '__rsub__', 1071*da0073e9SAndroid Build Coastguard Worker '__rpow__', 1072*da0073e9SAndroid Build Coastguard Worker 'bernoulli', 1073*da0073e9SAndroid Build Coastguard Worker 'clamp_max', 1074*da0073e9SAndroid Build Coastguard Worker 'clamp_min', 1075*da0073e9SAndroid Build Coastguard Worker 'masked_scatter', 1076*da0073e9SAndroid Build Coastguard Worker 1077*da0073e9SAndroid Build Coastguard Worker # unsupported float64 dtype 1078*da0073e9SAndroid Build Coastguard Worker 'cat', 1079*da0073e9SAndroid Build Coastguard Worker 'complex', 1080*da0073e9SAndroid Build Coastguard Worker 'multinomial', 1081*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv1d', 1082*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv2d', 1083*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv3d', 1084*da0073e9SAndroid Build Coastguard Worker 'gather', 1085*da0073e9SAndroid Build Coastguard Worker 'scatter', 1086*da0073e9SAndroid Build Coastguard Worker 'scatter_add', 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker # unsupported complex dtypes 1089*da0073e9SAndroid Build Coastguard Worker 'masked_fill', 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Worker # MPS does not support tensor dimensions > 16 1092*da0073e9SAndroid Build Coastguard Worker 'amax', 1093*da0073e9SAndroid Build Coastguard Worker 'amin', 1094*da0073e9SAndroid Build Coastguard Worker 'aminmax', 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker # memory overlapping checks 1097*da0073e9SAndroid Build Coastguard Worker 'index_select', 1098*da0073e9SAndroid Build Coastguard Worker 1099*da0073e9SAndroid Build Coastguard Worker # unimplemented 1100*da0073e9SAndroid Build Coastguard Worker 'logcumsumexp', 1101*da0073e9SAndroid Build Coastguard Worker } 1102*da0073e9SAndroid Build Coastguard Worker 1103*da0073e9SAndroid Build Coastguard Worker def addDecorator(op, d) -> None: 1104*da0073e9SAndroid Build Coastguard Worker op.decorators = list(op.decorators) if op.decorators is not None else [] 1105*da0073e9SAndroid Build Coastguard Worker op.decorators.append(d) 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Worker for op in ops: 1108*da0073e9SAndroid Build Coastguard Worker if op.error_inputs_func is None: 1109*da0073e9SAndroid Build Coastguard Worker continue 1110*da0073e9SAndroid Build Coastguard Worker key = op.name + op.variant_test_name 1111*da0073e9SAndroid Build Coastguard Worker if key in XFAILLIST: 1112*da0073e9SAndroid Build Coastguard Worker addDecorator(op, DecorateInfo(unittest.expectedFailure)) 1113*da0073e9SAndroid Build Coastguard Worker yield op 1114*da0073e9SAndroid Build Coastguard Worker 1115*da0073e9SAndroid Build Coastguard Worker# Same logic as test_cuda.py 1116*da0073e9SAndroid Build Coastguard Workerif not torch.backends.mps.is_available(): 1117*da0073e9SAndroid Build Coastguard Worker print('MPS not available, skipping tests', file=sys.stderr) 1118*da0073e9SAndroid Build Coastguard Worker TestCase = NoTest # noqa: F811 1119*da0073e9SAndroid Build Coastguard Worker NNTestCase = NoTest # noqa: F811 1120*da0073e9SAndroid Build Coastguard Worker 1121*da0073e9SAndroid Build Coastguard Workerproduct_version = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1) 1122*da0073e9SAndroid Build Coastguard Workertotal_memory = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"])) 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Worker# Determine whether to enable MPS memory leak check (uses same code as CUDA). 1125*da0073e9SAndroid Build Coastguard WorkerTEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1' 1126*da0073e9SAndroid Build Coastguard Worker 1127*da0073e9SAndroid Build Coastguard Workerdef skipMPSMemoryLeakCheckIf(condition): 1128*da0073e9SAndroid Build Coastguard Worker def dec(fn): 1129*da0073e9SAndroid Build Coastguard Worker if getattr(fn, '_do_mps_memory_leak_check', True): 1130*da0073e9SAndroid Build Coastguard Worker fn._do_mps_memory_leak_check = not condition 1131*da0073e9SAndroid Build Coastguard Worker return fn 1132*da0073e9SAndroid Build Coastguard Worker return dec 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Workerclass MpsMemoryLeakCheck: 1135*da0073e9SAndroid Build Coastguard Worker def __init__(self, testcase, name=None): 1136*da0073e9SAndroid Build Coastguard Worker self.name = testcase.id() if name is None else name 1137*da0073e9SAndroid Build Coastguard Worker self.testcase = testcase 1138*da0073e9SAndroid Build Coastguard Worker 1139*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 1140*da0073e9SAndroid Build Coastguard Worker # Performs a gc if required (required if any memory is held) 1141*da0073e9SAndroid Build Coastguard Worker caching_allocator_mem_allocated = torch.mps.current_allocated_memory() 1142*da0073e9SAndroid Build Coastguard Worker if caching_allocator_mem_allocated > 0: 1143*da0073e9SAndroid Build Coastguard Worker gc.collect() 1144*da0073e9SAndroid Build Coastguard Worker torch.mps.empty_cache() 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker # Acquires caching allocator and driver statistics before the test is run 1147*da0073e9SAndroid Build Coastguard Worker self.caching_allocator_before = torch.mps.current_allocated_memory() 1148*da0073e9SAndroid Build Coastguard Worker self.driver_before = torch.mps.driver_allocated_memory() 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exec_type, exec_value, traceback): 1151*da0073e9SAndroid Build Coastguard Worker # Don't check for leaks if an exception was thrown 1152*da0073e9SAndroid Build Coastguard Worker if exec_type is not None: 1153*da0073e9SAndroid Build Coastguard Worker return 1154*da0073e9SAndroid Build Coastguard Worker # Compares caching allocator before/after statistics 1155*da0073e9SAndroid Build Coastguard Worker # An increase in allocated memory is a discrepancy indicating a possible memory leak 1156*da0073e9SAndroid Build Coastguard Worker discrepancy_detected = False 1157*da0073e9SAndroid Build Coastguard Worker caching_allocator_mem_allocated = torch.mps.current_allocated_memory() 1158*da0073e9SAndroid Build Coastguard Worker if caching_allocator_mem_allocated > self.caching_allocator_before: 1159*da0073e9SAndroid Build Coastguard Worker discrepancy_detected = True 1160*da0073e9SAndroid Build Coastguard Worker 1161*da0073e9SAndroid Build Coastguard Worker # Short-circuits if no discrepancy detected 1162*da0073e9SAndroid Build Coastguard Worker if not discrepancy_detected: 1163*da0073e9SAndroid Build Coastguard Worker return 1164*da0073e9SAndroid Build Coastguard Worker # Validates the discrepancy persists after garbage collection and 1165*da0073e9SAndroid Build Coastguard Worker # is confirmed by the driver API 1166*da0073e9SAndroid Build Coastguard Worker gc.collect() 1167*da0073e9SAndroid Build Coastguard Worker torch.mps.empty_cache() 1168*da0073e9SAndroid Build Coastguard Worker 1169*da0073e9SAndroid Build Coastguard Worker discrepancy_detected = True 1170*da0073e9SAndroid Build Coastguard Worker # Query memory multiple items to ensure leak was not transient 1171*da0073e9SAndroid Build Coastguard Worker for n in range(3): 1172*da0073e9SAndroid Build Coastguard Worker caching_allocator_mem_allocated = torch.mps.current_allocated_memory() 1173*da0073e9SAndroid Build Coastguard Worker driver_mem_allocated = torch.mps.driver_allocated_memory() 1174*da0073e9SAndroid Build Coastguard Worker 1175*da0073e9SAndroid Build Coastguard Worker caching_allocator_discrepancy = False 1176*da0073e9SAndroid Build Coastguard Worker driver_discrepancy = False 1177*da0073e9SAndroid Build Coastguard Worker 1178*da0073e9SAndroid Build Coastguard Worker if caching_allocator_mem_allocated > self.caching_allocator_before: 1179*da0073e9SAndroid Build Coastguard Worker caching_allocator_discrepancy = True 1180*da0073e9SAndroid Build Coastguard Worker 1181*da0073e9SAndroid Build Coastguard Worker if driver_mem_allocated > self.driver_before: 1182*da0073e9SAndroid Build Coastguard Worker driver_discrepancy = True 1183*da0073e9SAndroid Build Coastguard Worker 1184*da0073e9SAndroid Build Coastguard Worker if not (caching_allocator_discrepancy or driver_discrepancy): 1185*da0073e9SAndroid Build Coastguard Worker # Leak was false positive, exit loop 1186*da0073e9SAndroid Build Coastguard Worker discrepancy_detected = False 1187*da0073e9SAndroid Build Coastguard Worker break 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Worker if caching_allocator_discrepancy and not driver_discrepancy: 1190*da0073e9SAndroid Build Coastguard Worker # Just raises a warning if the leak is not validated by the driver API 1191*da0073e9SAndroid Build Coastguard Worker msg = ("MPS caching allocator reports a memory leak not " 1192*da0073e9SAndroid Build Coastguard Worker f"verified by the driver API in {self.name}! " 1193*da0073e9SAndroid Build Coastguard Worker f"Caching allocator allocated memory was {self.caching_allocator_before} " 1194*da0073e9SAndroid Build Coastguard Worker f"and is now reported as {caching_allocator_mem_allocated}. " 1195*da0073e9SAndroid Build Coastguard Worker f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.") 1196*da0073e9SAndroid Build Coastguard Worker warnings.warn(msg) 1197*da0073e9SAndroid Build Coastguard Worker elif caching_allocator_discrepancy and driver_discrepancy: 1198*da0073e9SAndroid Build Coastguard Worker # A caching allocator discrepancy validated by the driver API is a failure 1199*da0073e9SAndroid Build Coastguard Worker msg = (f"MPS driver API confirmed a leak in {self.name}! " 1200*da0073e9SAndroid Build Coastguard Worker f"Caching allocator allocated memory was {self.caching_allocator_before} " 1201*da0073e9SAndroid Build Coastguard Worker f"and is now reported as {caching_allocator_mem_allocated}. " 1202*da0073e9SAndroid Build Coastguard Worker f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.") 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(msg) 1205*da0073e9SAndroid Build Coastguard Worker 1206*da0073e9SAndroid Build Coastguard Workerclass TestAutocastMPS(TestCase): 1207*da0073e9SAndroid Build Coastguard Worker 1208*da0073e9SAndroid Build Coastguard Worker def test_matmul_autocast(self): 1209*da0073e9SAndroid Build Coastguard Worker autocast_tensor_A = torch.rand((8, 8), device="mps") 1210*da0073e9SAndroid Build Coastguard Worker autocast_tensor_B = torch.rand((8, 8), device="mps") 1211*da0073e9SAndroid Build Coastguard Worker tensor_A = autocast_tensor_A.clone().detach() 1212*da0073e9SAndroid Build Coastguard Worker tensor_B = autocast_tensor_B.clone().detach() 1213*da0073e9SAndroid Build Coastguard Worker autocast_output_tensor = torch.empty(8, 8) 1214*da0073e9SAndroid Build Coastguard Worker output_tensor = autocast_output_tensor.clone().detach() 1215*da0073e9SAndroid Build Coastguard Worker 1216*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type="mps"): 1217*da0073e9SAndroid Build Coastguard Worker autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B) 1218*da0073e9SAndroid Build Coastguard Worker autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor) 1219*da0073e9SAndroid Build Coastguard Worker 1220*da0073e9SAndroid Build Coastguard Worker output_tensor = torch.mm(tensor_A, tensor_B) 1221*da0073e9SAndroid Build Coastguard Worker output_tensor = torch.mm(tensor_A, output_tensor) 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16") 1224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(autocast_output_tensor, 1225*da0073e9SAndroid Build Coastguard Worker output_tensor.to(torch.float16), 1226*da0073e9SAndroid Build Coastguard Worker f"Autocast & non-autocast tensors did not match, \ 1227*da0073e9SAndroid Build Coastguard Worker got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}") 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker# Expand TestCase class with Memory Leak Detection on MPS device 1230*da0073e9SAndroid Build Coastguard Workerclass TestCaseMPS(TestCase): 1231*da0073e9SAndroid Build Coastguard Worker _do_mps_memory_leak_check = True 1232*da0073e9SAndroid Build Coastguard Worker 1233*da0073e9SAndroid Build Coastguard Worker def __init__(self, method_name='runTest'): 1234*da0073e9SAndroid Build Coastguard Worker super().__init__(method_name) 1235*da0073e9SAndroid Build Coastguard Worker test_method = getattr(self, method_name, None) 1236*da0073e9SAndroid Build Coastguard Worker if test_method is not None: 1237*da0073e9SAndroid Build Coastguard Worker # Wraps the tested method if we should do MPS memory check. 1238*da0073e9SAndroid Build Coastguard Worker if TEST_MPS_MEM_LEAK_CHECK: 1239*da0073e9SAndroid Build Coastguard Worker if self._do_mps_memory_leak_check: 1240*da0073e9SAndroid Build Coastguard Worker self.wrap_with_mps_policy(method_name, self.assertLeaksNoMpsTensors) 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker def assertLeaksNoMpsTensors(self, name=None): 1243*da0073e9SAndroid Build Coastguard Worker name = self.id() if name is None else name 1244*da0073e9SAndroid Build Coastguard Worker return MpsMemoryLeakCheck(self, name) 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker def wrap_with_mps_policy(self, method_name, policy): 1247*da0073e9SAndroid Build Coastguard Worker test_method = getattr(self, method_name) 1248*da0073e9SAndroid Build Coastguard Worker setattr(self, method_name, super().wrap_method_with_policy(test_method, policy)) 1249*da0073e9SAndroid Build Coastguard Worker 1250*da0073e9SAndroid Build Coastguard Worker # checks for leaks even if TEST_MPS_MEM_LEAK_CHECK is 0 1251*da0073e9SAndroid Build Coastguard Worker def wrap_with_mps_memory_check(self, method): 1252*da0073e9SAndroid Build Coastguard Worker return super().wrap_method_with_policy(method, self.assertLeaksNoMpsTensors) 1253*da0073e9SAndroid Build Coastguard Worker 1254*da0073e9SAndroid Build Coastguard Workerclass TestMemoryLeak(TestCaseMPS): 1255*da0073e9SAndroid Build Coastguard Worker def test_mps_memory_leak_detection(self): 1256*da0073e9SAndroid Build Coastguard Worker l = [] 1257*da0073e9SAndroid Build Coastguard Worker 1258*da0073e9SAndroid Build Coastguard Worker @self.wrap_with_mps_memory_check 1259*da0073e9SAndroid Build Coastguard Worker def no_leak(): 1260*da0073e9SAndroid Build Coastguard Worker pass 1261*da0073e9SAndroid Build Coastguard Worker 1262*da0073e9SAndroid Build Coastguard Worker # Trigger an intentional memory leak 1263*da0073e9SAndroid Build Coastguard Worker @self.wrap_with_mps_memory_check 1264*da0073e9SAndroid Build Coastguard Worker def leak_gpu0(): 1265*da0073e9SAndroid Build Coastguard Worker # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms 1266*da0073e9SAndroid Build Coastguard Worker l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps"))) 1267*da0073e9SAndroid Build Coastguard Worker 1268*da0073e9SAndroid Build Coastguard Worker no_leak() 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker # check if a runtime error for memory leak was emitted which would 1271*da0073e9SAndroid Build Coastguard Worker # confirm whether memory leak detection worked successfully or not. 1272*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"): 1273*da0073e9SAndroid Build Coastguard Worker leak_gpu0() 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker def test_copy_cast_no_leak(self): 1276*da0073e9SAndroid Build Coastguard Worker 1277*da0073e9SAndroid Build Coastguard Worker def step(x): 1278*da0073e9SAndroid Build Coastguard Worker x = x.to(device='cpu', dtype=torch.float32) 1279*da0073e9SAndroid Build Coastguard Worker x = x.to(device='mps', dtype=torch.float16) 1280*da0073e9SAndroid Build Coastguard Worker 1281*da0073e9SAndroid Build Coastguard Worker a = torch.randn(128, 128, device='mps', dtype=torch.float16) 1282*da0073e9SAndroid Build Coastguard Worker # Warm up / prebuild MPS shaders (otherwise check fails on 13.2) 1283*da0073e9SAndroid Build Coastguard Worker step(a) 1284*da0073e9SAndroid Build Coastguard Worker torch.mps.empty_cache() 1285*da0073e9SAndroid Build Coastguard Worker driver_before = torch.mps.driver_allocated_memory() 1286*da0073e9SAndroid Build Coastguard Worker step(a) 1287*da0073e9SAndroid Build Coastguard Worker torch.mps.empty_cache() 1288*da0073e9SAndroid Build Coastguard Worker driver_after = torch.mps.driver_allocated_memory() 1289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(driver_before, driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory") 1290*da0073e9SAndroid Build Coastguard Worker 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Workerclass TestPixelShuffle(TestCaseMPS): 1293*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle(self): 1294*da0073e9SAndroid Build Coastguard Worker def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True, 1295*da0073e9SAndroid Build Coastguard Worker upscale_factor=None, is_contiguous=True): 1296*da0073e9SAndroid Build Coastguard Worker 1297*da0073e9SAndroid Build Coastguard Worker def generate_input(): 1298*da0073e9SAndroid Build Coastguard Worker # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2. 1299*da0073e9SAndroid Build Coastguard Worker channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1) 1300*da0073e9SAndroid Build Coastguard Worker height = random.randint(5, 10) 1301*da0073e9SAndroid Build Coastguard Worker width = random.randint(5, 10) 1302*da0073e9SAndroid Build Coastguard Worker 1303*da0073e9SAndroid Build Coastguard Worker if num_input_dims == 1: 1304*da0073e9SAndroid Build Coastguard Worker input = torch.rand(channels, requires_grad=True, device='mps') 1305*da0073e9SAndroid Build Coastguard Worker assert is_contiguous 1306*da0073e9SAndroid Build Coastguard Worker elif num_input_dims == 2: 1307*da0073e9SAndroid Build Coastguard Worker input = torch.rand(width, height, requires_grad=True, device='mps').T 1308*da0073e9SAndroid Build Coastguard Worker if is_contiguous: 1309*da0073e9SAndroid Build Coastguard Worker input = input.contiguous() 1310*da0073e9SAndroid Build Coastguard Worker else: 1311*da0073e9SAndroid Build Coastguard Worker batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] 1312*da0073e9SAndroid Build Coastguard Worker input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps') 1313*da0073e9SAndroid Build Coastguard Worker input = input.transpose(-1, -2) 1314*da0073e9SAndroid Build Coastguard Worker if is_contiguous: 1315*da0073e9SAndroid Build Coastguard Worker input = input.contiguous() 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker if not is_contiguous and len(input.reshape(-1)) > 0: 1318*da0073e9SAndroid Build Coastguard Worker assert not input.is_contiguous() 1319*da0073e9SAndroid Build Coastguard Worker 1320*da0073e9SAndroid Build Coastguard Worker input = input.detach().clone() 1321*da0073e9SAndroid Build Coastguard Worker input.requires_grad = True 1322*da0073e9SAndroid Build Coastguard Worker return input 1323*da0073e9SAndroid Build Coastguard Worker 1324*da0073e9SAndroid Build Coastguard Worker # Function to imperatively ensure pixels are shuffled to the correct locations. 1325*da0073e9SAndroid Build Coastguard Worker # Used to validate the batch operations in pixel_shuffle. 1326*da0073e9SAndroid Build Coastguard Worker def _verify_pixel_shuffle(input, output, upscale_factor): 1327*da0073e9SAndroid Build Coastguard Worker for c in range(output.size(-3)): 1328*da0073e9SAndroid Build Coastguard Worker for h in range(output.size(-2)): 1329*da0073e9SAndroid Build Coastguard Worker for w in range(output.size(-1)): 1330*da0073e9SAndroid Build Coastguard Worker height_idx = h // upscale_factor 1331*da0073e9SAndroid Build Coastguard Worker weight_idx = w // upscale_factor 1332*da0073e9SAndroid Build Coastguard Worker channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \ 1333*da0073e9SAndroid Build Coastguard Worker (c * upscale_factor ** 2) 1334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx]) 1335*da0073e9SAndroid Build Coastguard Worker 1336*da0073e9SAndroid Build Coastguard Worker upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor 1337*da0073e9SAndroid Build Coastguard Worker input = generate_input() 1338*da0073e9SAndroid Build Coastguard Worker 1339*da0073e9SAndroid Build Coastguard Worker ps = nn.PixelShuffle(upscale_factor) 1340*da0073e9SAndroid Build Coastguard Worker pus = nn.PixelUnshuffle(downscale_factor=upscale_factor) 1341*da0073e9SAndroid Build Coastguard Worker 1342*da0073e9SAndroid Build Coastguard Worker if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0: 1343*da0073e9SAndroid Build Coastguard Worker output = ps(input) 1344*da0073e9SAndroid Build Coastguard Worker _verify_pixel_shuffle(input, output, upscale_factor) 1345*da0073e9SAndroid Build Coastguard Worker output.backward(output.data) 1346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.data, input.grad.data) 1347*da0073e9SAndroid Build Coastguard Worker 1348*da0073e9SAndroid Build Coastguard Worker # Ensure unshuffle properly inverts shuffle. 1349*da0073e9SAndroid Build Coastguard Worker unshuffle_output = pus(output) 1350*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, unshuffle_output) 1351*da0073e9SAndroid Build Coastguard Worker else: 1352*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: ps(input)) 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Worker def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True, 1355*da0073e9SAndroid Build Coastguard Worker downscale_factor=None): 1356*da0073e9SAndroid Build Coastguard Worker downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor 1357*da0073e9SAndroid Build Coastguard Worker channels = random.randint(1, 4) 1358*da0073e9SAndroid Build Coastguard Worker # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor. 1359*da0073e9SAndroid Build Coastguard Worker height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1) 1360*da0073e9SAndroid Build Coastguard Worker # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor. 1361*da0073e9SAndroid Build Coastguard Worker width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1) 1362*da0073e9SAndroid Build Coastguard Worker 1363*da0073e9SAndroid Build Coastguard Worker if num_input_dims == 1: 1364*da0073e9SAndroid Build Coastguard Worker input = torch.rand(channels, requires_grad=True, device='mps') 1365*da0073e9SAndroid Build Coastguard Worker elif num_input_dims == 2: 1366*da0073e9SAndroid Build Coastguard Worker input = torch.rand(height, width, requires_grad=True, device='mps') 1367*da0073e9SAndroid Build Coastguard Worker else: 1368*da0073e9SAndroid Build Coastguard Worker batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] 1369*da0073e9SAndroid Build Coastguard Worker input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps') 1370*da0073e9SAndroid Build Coastguard Worker 1371*da0073e9SAndroid Build Coastguard Worker pus = nn.PixelUnshuffle(downscale_factor) 1372*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: pus(input)) 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims): 1375*da0073e9SAndroid Build Coastguard Worker # For 1D - 2D, this is an error case. 1376*da0073e9SAndroid Build Coastguard Worker # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle. 1377*da0073e9SAndroid Build Coastguard Worker is_contiguous_check = [True, False] if num_input_dims > 1 else [True] 1378*da0073e9SAndroid Build Coastguard Worker for is_contiguous in is_contiguous_check: 1379*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_helper( 1380*da0073e9SAndroid Build Coastguard Worker num_input_dims=num_input_dims, is_contiguous=is_contiguous 1381*da0073e9SAndroid Build Coastguard Worker ) 1382*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_helper( 1383*da0073e9SAndroid Build Coastguard Worker num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous 1384*da0073e9SAndroid Build Coastguard Worker ) 1385*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_helper( 1386*da0073e9SAndroid Build Coastguard Worker num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous 1387*da0073e9SAndroid Build Coastguard Worker ) 1388*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_helper( 1389*da0073e9SAndroid Build Coastguard Worker num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous 1390*da0073e9SAndroid Build Coastguard Worker ) 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker # Error cases for pixel_unshuffle. 1393*da0073e9SAndroid Build Coastguard Worker _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False) 1394*da0073e9SAndroid Build Coastguard Worker _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False) 1395*da0073e9SAndroid Build Coastguard Worker _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0) 1396*da0073e9SAndroid Build Coastguard Worker _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2) 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle_1D(): 1399*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1) 1400*da0073e9SAndroid Build Coastguard Worker 1401*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle_2D(): 1402*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2) 1403*da0073e9SAndroid Build Coastguard Worker 1404*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle_3D(): 1405*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3) 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle_4D(): 1408*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4) 1409*da0073e9SAndroid Build Coastguard Worker 1410*da0073e9SAndroid Build Coastguard Worker def test_pixel_shuffle_unshuffle_5D(): 1411*da0073e9SAndroid Build Coastguard Worker _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5) 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker test_pixel_shuffle_unshuffle_1D() 1414*da0073e9SAndroid Build Coastguard Worker test_pixel_shuffle_unshuffle_2D() 1415*da0073e9SAndroid Build Coastguard Worker test_pixel_shuffle_unshuffle_3D() 1416*da0073e9SAndroid Build Coastguard Worker test_pixel_shuffle_unshuffle_4D() 1417*da0073e9SAndroid Build Coastguard Worker test_pixel_shuffle_unshuffle_5D() 1418*da0073e9SAndroid Build Coastguard Worker 1419*da0073e9SAndroid Build Coastguard Workerclass MPSReluTest(TestCaseMPS): 1420*da0073e9SAndroid Build Coastguard Worker def _npRelu(self, np_features): 1421*da0073e9SAndroid Build Coastguard Worker return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype) 1422*da0073e9SAndroid Build Coastguard Worker 1423*da0073e9SAndroid Build Coastguard Worker def testNpRelu(self): 1424*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 1425*da0073e9SAndroid Build Coastguard Worker np.array([[0., 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]), 1426*da0073e9SAndroid Build Coastguard Worker self._npRelu( 1427*da0073e9SAndroid Build Coastguard Worker np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 1428*da0073e9SAndroid Build Coastguard Worker 0.9]]))) 1429*da0073e9SAndroid Build Coastguard Worker 1430*da0073e9SAndroid Build Coastguard Worker def _testRelu(self, np_features, device): 1431*da0073e9SAndroid Build Coastguard Worker np_relu = self._npRelu(np_features) 1432*da0073e9SAndroid Build Coastguard Worker # Convert the numpy array to a PyTorch Tensor, 1433*da0073e9SAndroid Build Coastguard Worker # and move the Tensor to the CPU/GPU based on the "device" parameter 1434*da0073e9SAndroid Build Coastguard Worker py_tensor = torch.from_numpy(np_features).to(device) 1435*da0073e9SAndroid Build Coastguard Worker py_relu = torch.nn.ReLU(inplace=False)(py_tensor) 1436*da0073e9SAndroid Build Coastguard Worker py_relu_cpu = py_relu.to("cpu") 1437*da0073e9SAndroid Build Coastguard Worker 1438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_relu, py_relu_cpu) 1439*da0073e9SAndroid Build Coastguard Worker 1440*da0073e9SAndroid Build Coastguard Worker def _testReluInPlace(self, np_features, device): 1441*da0073e9SAndroid Build Coastguard Worker np_relu = self._npRelu(np_features) 1442*da0073e9SAndroid Build Coastguard Worker # Convert the numpy array to a PyTorch Tensor, 1443*da0073e9SAndroid Build Coastguard Worker # and move the Tensor to the CPU/GPU based on the "device" parameter 1444*da0073e9SAndroid Build Coastguard Worker py_tensor = torch.from_numpy(np_features).to(device) 1445*da0073e9SAndroid Build Coastguard Worker py_relu = torch.nn.ReLU(inplace=True)(py_tensor) 1446*da0073e9SAndroid Build Coastguard Worker py_relu_cpu = py_relu.to("cpu") 1447*da0073e9SAndroid Build Coastguard Worker 1448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_relu, py_relu_cpu) 1449*da0073e9SAndroid Build Coastguard Worker # Inplace Relu modifies the initial input and it should match the output of Relu 1450*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_relu, py_tensor.to("cpu")) 1451*da0073e9SAndroid Build Coastguard Worker 1452*da0073e9SAndroid Build Coastguard Worker def testNumbersCPU(self): 1453*da0073e9SAndroid Build Coastguard Worker for t in [np.int32]: 1454*da0073e9SAndroid Build Coastguard Worker # Force execution on CPU even if a GPU kernel is available for the type. 1455*da0073e9SAndroid Build Coastguard Worker self._testRelu( 1456*da0073e9SAndroid Build Coastguard Worker np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), 1457*da0073e9SAndroid Build Coastguard Worker device="cpu") 1458*da0073e9SAndroid Build Coastguard Worker self._testReluInPlace( 1459*da0073e9SAndroid Build Coastguard Worker np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), 1460*da0073e9SAndroid Build Coastguard Worker device="cpu") 1461*da0073e9SAndroid Build Coastguard Worker 1462*da0073e9SAndroid Build Coastguard Worker def testNumbersGPU(self): 1463*da0073e9SAndroid Build Coastguard Worker for t in [np.float16, np.float32]: 1464*da0073e9SAndroid Build Coastguard Worker self._testRelu( 1465*da0073e9SAndroid Build Coastguard Worker np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), 1466*da0073e9SAndroid Build Coastguard Worker device="mps") 1467*da0073e9SAndroid Build Coastguard Worker self._testReluInPlace( 1468*da0073e9SAndroid Build Coastguard Worker np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), 1469*da0073e9SAndroid Build Coastguard Worker device="mps") 1470*da0073e9SAndroid Build Coastguard Worker self._testRelu(np.array([]).astype(t), device="mps") 1471*da0073e9SAndroid Build Coastguard Worker self._testReluInPlace(np.array([]).astype(t), device="mps") 1472*da0073e9SAndroid Build Coastguard Worker 1473*da0073e9SAndroid Build Coastguard Workerclass MatmulTest(TestCaseMPS): 1474*da0073e9SAndroid Build Coastguard Worker def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None): 1475*da0073e9SAndroid Build Coastguard Worker if expand_tensor_1_shape: 1476*da0073e9SAndroid Build Coastguard Worker tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape) 1477*da0073e9SAndroid Build Coastguard Worker else: 1478*da0073e9SAndroid Build Coastguard Worker tensor1_mps = torch.randn(shape_tensor_1, device="mps") 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker if expand_tensor_2_shape: 1481*da0073e9SAndroid Build Coastguard Worker tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape) 1482*da0073e9SAndroid Build Coastguard Worker else: 1483*da0073e9SAndroid Build Coastguard Worker tensor2_mps = torch.randn(shape_tensor_2, device="mps") 1484*da0073e9SAndroid Build Coastguard Worker 1485*da0073e9SAndroid Build Coastguard Worker tensor1_cpu = tensor1_mps.to("cpu") 1486*da0073e9SAndroid Build Coastguard Worker tensor2_cpu = tensor2_mps.to("cpu") 1487*da0073e9SAndroid Build Coastguard Worker 1488*da0073e9SAndroid Build Coastguard Worker matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu) 1489*da0073e9SAndroid Build Coastguard Worker matmul_mps = torch.matmul(tensor1_mps, tensor2_mps) 1490*da0073e9SAndroid Build Coastguard Worker 1491*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matmul_cpu, matmul_mps.to("cpu")) 1492*da0073e9SAndroid Build Coastguard Worker 1493*da0073e9SAndroid Build Coastguard Worker def test_vector_x_vector(self): 1494*da0073e9SAndroid Build Coastguard Worker # uses `dot` 1495*da0073e9SAndroid Build Coastguard Worker self._helper(3, 3) 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker def test_matrix_x_vector(self): 1498*da0073e9SAndroid Build Coastguard Worker # uses `addmv` 1499*da0073e9SAndroid Build Coastguard Worker self._helper((3, 4), 4) 1500*da0073e9SAndroid Build Coastguard Worker 1501*da0073e9SAndroid Build Coastguard Worker def test_batched_matrix_x_broadcasted_vector(self): 1502*da0073e9SAndroid Build Coastguard Worker self._helper((10, 3, 4), 4) 1503*da0073e9SAndroid Build Coastguard Worker 1504*da0073e9SAndroid Build Coastguard Worker def test_batched_matrix_x_batched_matrix(self): 1505*da0073e9SAndroid Build Coastguard Worker # uses `bmm.out` 1506*da0073e9SAndroid Build Coastguard Worker self._helper((10, 3, 4), (10, 4, 5)) 1507*da0073e9SAndroid Build Coastguard Worker 1508*da0073e9SAndroid Build Coastguard Worker def test_batched_matrix_x_broadcasted_matrix(self): 1509*da0073e9SAndroid Build Coastguard Worker self._helper((10, 3, 4), (4, 5)) 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Workerclass MPSLeakyReluTest(TestCaseMPS): 1513*da0073e9SAndroid Build Coastguard Worker def _npLeakyRelu(self, np_features, negative_slope=0.1): 1514*da0073e9SAndroid Build Coastguard Worker return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype) 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker def testNpLeakyRelu(self): 1517*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 1518*da0073e9SAndroid Build Coastguard Worker np.array([[-0.09, 0.7, -0.05, 0.3, -0.01], 1519*da0073e9SAndroid Build Coastguard Worker [0.1, -0.03, 0.5, -0.07, 0.9]]), 1520*da0073e9SAndroid Build Coastguard Worker self._npLeakyRelu( 1521*da0073e9SAndroid Build Coastguard Worker np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 1522*da0073e9SAndroid Build Coastguard Worker 0.9]]), 1523*da0073e9SAndroid Build Coastguard Worker negative_slope=0.1)) 1524*da0073e9SAndroid Build Coastguard Worker 1525*da0073e9SAndroid Build Coastguard Worker def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous): 1526*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtype) 1527*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 1528*da0073e9SAndroid Build Coastguard Worker 1529*da0073e9SAndroid Build Coastguard Worker if not contiguous and not (0 in shape or len(shape) < 2): 1530*da0073e9SAndroid Build Coastguard Worker # Tranposing will make the tensor non-contiguous 1531*da0073e9SAndroid Build Coastguard Worker cpu_x = cpu_x.transpose(0, 1) 1532*da0073e9SAndroid Build Coastguard Worker mps_x = mps_x.transpose(0, 1) 1533*da0073e9SAndroid Build Coastguard Worker assert not mps_x.is_contiguous() 1534*da0073e9SAndroid Build Coastguard Worker 1535*da0073e9SAndroid Build Coastguard Worker cpu_x.requires_grad_() 1536*da0073e9SAndroid Build Coastguard Worker mps_x.requires_grad_() 1537*da0073e9SAndroid Build Coastguard Worker 1538*da0073e9SAndroid Build Coastguard Worker relu_op = torch.nn.LeakyReLU(negative_slope) 1539*da0073e9SAndroid Build Coastguard Worker 1540*da0073e9SAndroid Build Coastguard Worker cpu_leaky_relu = relu_op(cpu_x) 1541*da0073e9SAndroid Build Coastguard Worker mps_leaky_relu = relu_op(mps_x) 1542*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu')) 1543*da0073e9SAndroid Build Coastguard Worker 1544*da0073e9SAndroid Build Coastguard Worker # test backward pass 1545*da0073e9SAndroid Build Coastguard Worker 1546*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(cpu_leaky_relu) 1547*da0073e9SAndroid Build Coastguard Worker mps_grad = cpu_grad.to('mps') 1548*da0073e9SAndroid Build Coastguard Worker 1549*da0073e9SAndroid Build Coastguard Worker mps_leaky_relu.backward(gradient=mps_grad) 1550*da0073e9SAndroid Build Coastguard Worker cpu_leaky_relu.backward(gradient=cpu_grad) 1551*da0073e9SAndroid Build Coastguard Worker 1552*da0073e9SAndroid Build Coastguard Worker assert cpu_x.grad is not None # Check that the grad is well-populated 1553*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x.grad, mps_x.grad) 1554*da0073e9SAndroid Build Coastguard Worker 1555*da0073e9SAndroid Build Coastguard Worker def testNumbersCPU(self): 1556*da0073e9SAndroid Build Coastguard Worker for t in [torch.float, torch.half]: 1557*da0073e9SAndroid Build Coastguard Worker for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]: 1558*da0073e9SAndroid Build Coastguard Worker for contiguous in [True, False]: 1559*da0073e9SAndroid Build Coastguard Worker self._testLeakyRelu(shape, 1560*da0073e9SAndroid Build Coastguard Worker dtype=t, 1561*da0073e9SAndroid Build Coastguard Worker negative_slope=0.2, 1562*da0073e9SAndroid Build Coastguard Worker contiguous=contiguous) 1563*da0073e9SAndroid Build Coastguard Worker 1564*da0073e9SAndroid Build Coastguard Workerclass TestAvgPool(TestCaseMPS): 1565*da0073e9SAndroid Build Coastguard Worker def _sum_pool2d(self, x, kernel_size): 1566*da0073e9SAndroid Build Coastguard Worker windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size) 1567*da0073e9SAndroid Build Coastguard Worker return torch.sum(windows, dim=1) 1568*da0073e9SAndroid Build Coastguard Worker 1569*da0073e9SAndroid Build Coastguard Worker def _sum_pool3d(self, x, kernel_size): 1570*da0073e9SAndroid Build Coastguard Worker # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum 1571*da0073e9SAndroid Build Coastguard Worker h = kernel_size[0] 1572*da0073e9SAndroid Build Coastguard Worker splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h] 1573*da0073e9SAndroid Build Coastguard Worker # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times 1574*da0073e9SAndroid Build Coastguard Worker splited_x = [self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) for t in splited_x] 1575*da0073e9SAndroid Build Coastguard Worker joined_x = torch.cat(splited_x) 1576*da0073e9SAndroid Build Coastguard Worker return joined_x.view(1, joined_x.numel()) 1577*da0073e9SAndroid Build Coastguard Worker 1578*da0073e9SAndroid Build Coastguard Worker def _avg_pool2d(self, x, kernel_size): 1579*da0073e9SAndroid Build Coastguard Worker size = reduce(operator.mul, kernel_size) # noqa: F821 1580*da0073e9SAndroid Build Coastguard Worker return self._sum_pool2d(x, kernel_size) / size 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker def _avg_pool3d(self, x, kernel_size): 1583*da0073e9SAndroid Build Coastguard Worker size = reduce(operator.mul, kernel_size) # noqa: F821 1584*da0073e9SAndroid Build Coastguard Worker return self._sum_pool3d(x, kernel_size) / size 1585*da0073e9SAndroid Build Coastguard Worker 1586*da0073e9SAndroid Build Coastguard Worker def test_avg_pool2d_with_zero_divisor(self): 1587*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "divisor must be not zero", 1588*da0073e9SAndroid Build Coastguard Worker lambda: F.avg_pool2d(torch.zeros(3, 3, 3), (2, 2), divisor_override=0)) 1589*da0073e9SAndroid Build Coastguard Worker 1590*da0073e9SAndroid Build Coastguard Worker def test_doubletensor_avg_pool2d_with_divisor(self): 1591*da0073e9SAndroid Build Coastguard Worker n, m = 3, 3 1592*da0073e9SAndroid Build Coastguard Worker input = torch.rand(1, 1, n, m) 1593*da0073e9SAndroid Build Coastguard Worker for i in range(1, n + 1): 1594*da0073e9SAndroid Build Coastguard Worker for j in range(1, m + 1): 1595*da0073e9SAndroid Build Coastguard Worker for divisor in [1, 7, i * j]: 1596*da0073e9SAndroid Build Coastguard Worker actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor) 1597*da0073e9SAndroid Build Coastguard Worker actual = actual.view(1, actual.numel()) 1598*da0073e9SAndroid Build Coastguard Worker expected = self._sum_pool2d(input, (i, j)) / divisor 1599*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, rtol=0, atol=1e-5) 1600*da0073e9SAndroid Build Coastguard Worker 1601*da0073e9SAndroid Build Coastguard Worker def test_avg_pool2d_ceil_mode(self): 1602*da0073e9SAndroid Build Coastguard Worker # Regression test for gh-36977 1603*da0073e9SAndroid Build Coastguard Worker x = 10 * torch.randn((1, 16, 4, 4)) 1604*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.avg_pool2d( 1605*da0073e9SAndroid Build Coastguard Worker x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2), 1606*da0073e9SAndroid Build Coastguard Worker padding=(0, 1), stride=2) 1607*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.isnan(y).any()) 1608*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.avg_pool2d( 1609*da0073e9SAndroid Build Coastguard Worker x.to('mps'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2), 1610*da0073e9SAndroid Build Coastguard Worker padding=(0, 1), stride=2) 1611*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.isnan(y).any()) 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker 1614*da0073e9SAndroid Build Coastguard Workerclass TestMPS(TestCaseMPS): 1615*da0073e9SAndroid Build Coastguard Worker def test_exp(self, device="mps", dtype=torch.float): 1616*da0073e9SAndroid Build Coastguard Worker for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()): 1617*da0073e9SAndroid Build Coastguard Worker b = torch.arange(18, dtype=dtype, device=device) / 3 * math.pi 1618*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(v, dtype=dtype, device="mps") * b 1619*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.exp, np.exp, a) 1620*da0073e9SAndroid Build Coastguard Worker 1621*da0073e9SAndroid Build Coastguard Worker def test_conv_raises_error(self, device='mps', dtype=torch.float): 1622*da0073e9SAndroid Build Coastguard Worker conv = nn.Conv1d(1, 65537, 3, padding=1).to('mps') 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker x = torch.ones([1, 1, 3]) 1625*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(NotImplementedError): 1626*da0073e9SAndroid Build Coastguard Worker y = conv(x.to("mps")) 1627*da0073e9SAndroid Build Coastguard Worker 1628*da0073e9SAndroid Build Coastguard Worker def test_triu_inf(self, device="mps", dtype=torch.float): 1629*da0073e9SAndroid Build Coastguard Worker for diag in [-1, 0, 1]: 1630*da0073e9SAndroid Build Coastguard Worker mask = torch.full((3, 6, 6), float("-inf")) 1631*da0073e9SAndroid Build Coastguard Worker mask_mps = mask.clone().detach().to('mps') 1632*da0073e9SAndroid Build Coastguard Worker cpu_ref = torch.triu(mask, diagonal=diag) 1633*da0073e9SAndroid Build Coastguard Worker mps_out = torch.triu(mask_mps, diagonal=diag) 1634*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_ref, mps_out) 1635*da0073e9SAndroid Build Coastguard Worker 1636*da0073e9SAndroid Build Coastguard Worker def test_exp1(self, device="mps", dtype=torch.float): 1637*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype) 1638*da0073e9SAndroid Build Coastguard Worker output = torch.exp(input) 1639*da0073e9SAndroid Build Coastguard Worker output_cpu = torch.exp(input.cpu()) 1640*da0073e9SAndroid Build Coastguard Worker # If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with 1641*da0073e9SAndroid Build Coastguard Worker # Mismatched elements: 3 / 4 (75.0%) 1642*da0073e9SAndroid Build Coastguard Worker # Greatest absolute difference: 1.1920928955078125e-07 at index (3,) (up to 1e-08 allowed) 1643*da0073e9SAndroid Build Coastguard Worker # Greatest relative difference: 1.0786502002702036e-07 at index (3,) (up to 1e-08 allowed) 1644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, output_cpu, atol=1e-8, rtol=1e-8) 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker def test_exp_strided_output(self): 1647*da0073e9SAndroid Build Coastguard Worker x = torch.rand((256, 10), device='mps') 1648*da0073e9SAndroid Build Coastguard Worker x_cpu = x.to("cpu") 1649*da0073e9SAndroid Build Coastguard Worker 1650*da0073e9SAndroid Build Coastguard Worker x = x.permute(1, 0) 1651*da0073e9SAndroid Build Coastguard Worker x_cpu = x_cpu.permute(1, 0) 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker res = x.exp() 1654*da0073e9SAndroid Build Coastguard Worker res_cpu = x_cpu.exp() 1655*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cpu) 1656*da0073e9SAndroid Build Coastguard Worker 1657*da0073e9SAndroid Build Coastguard Worker def _testLeakyRelu(self, np_features, negative_slope, device): 1658*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.from_numpy(np_features).requires_grad_() 1659*da0073e9SAndroid Build Coastguard Worker mps_x = torch.from_numpy(np_features).to('mps').requires_grad_() 1660*da0073e9SAndroid Build Coastguard Worker relu_op = torch.nn.LeakyReLU(negative_slope) 1661*da0073e9SAndroid Build Coastguard Worker 1662*da0073e9SAndroid Build Coastguard Worker cpu_leaky_relu = relu_op(cpu_x) 1663*da0073e9SAndroid Build Coastguard Worker mps_leaky_relu = relu_op(mps_x) 1664*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu')) 1665*da0073e9SAndroid Build Coastguard Worker 1666*da0073e9SAndroid Build Coastguard Worker # test backward pass 1667*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(cpu_leaky_relu) 1668*da0073e9SAndroid Build Coastguard Worker mps_grad = cpu_grad.to('mps') 1669*da0073e9SAndroid Build Coastguard Worker cpu_leaky_relu.backward(gradient=cpu_grad) 1670*da0073e9SAndroid Build Coastguard Worker mps_leaky_relu.backward(gradient=mps_grad) 1671*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu')) 1672*da0073e9SAndroid Build Coastguard Worker 1673*da0073e9SAndroid Build Coastguard Worker def testNumbersGPU(self): 1674*da0073e9SAndroid Build Coastguard Worker for t in [np.float32]: 1675*da0073e9SAndroid Build Coastguard Worker self._testLeakyRelu( 1676*da0073e9SAndroid Build Coastguard Worker np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), 1677*da0073e9SAndroid Build Coastguard Worker negative_slope=0.1, 1678*da0073e9SAndroid Build Coastguard Worker device="mps") 1679*da0073e9SAndroid Build Coastguard Worker 1680*da0073e9SAndroid Build Coastguard Worker def test_fill(self): 1681*da0073e9SAndroid Build Coastguard Worker 1682*da0073e9SAndroid Build Coastguard Worker def helper(val, shape, dtype): 1683*da0073e9SAndroid Build Coastguard Worker tensor = torch.zeros(shape, device='mps', dtype=dtype) 1684*da0073e9SAndroid Build Coastguard Worker tensor_mps = tensor.fill_(val) 1685*da0073e9SAndroid Build Coastguard Worker 1686*da0073e9SAndroid Build Coastguard Worker tensor_0 = torch.zeros(shape, device='cpu', dtype=dtype) 1687*da0073e9SAndroid Build Coastguard Worker tensor_cpu = tensor_0.fill_(val) 1688*da0073e9SAndroid Build Coastguard Worker 1689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_mps, tensor_cpu) 1690*da0073e9SAndroid Build Coastguard Worker 1691*da0073e9SAndroid Build Coastguard Worker helper(0, [1024], torch.float32) 1692*da0073e9SAndroid Build Coastguard Worker helper(0.2, [2, 3], torch.float32) 1693*da0073e9SAndroid Build Coastguard Worker helper(0.2 + 0.5j, [2, 3], torch.complex64) 1694*da0073e9SAndroid Build Coastguard Worker 1695*da0073e9SAndroid Build Coastguard Worker def test_fill_storage_offset(self): 1696*da0073e9SAndroid Build Coastguard Worker shape = [2, 10] 1697*da0073e9SAndroid Build Coastguard Worker val = 0.2 1698*da0073e9SAndroid Build Coastguard Worker tensor = torch.ones(shape, device="mps") 1699*da0073e9SAndroid Build Coastguard Worker tensor_mps = tensor[:][1].fill_(val) 1700*da0073e9SAndroid Build Coastguard Worker tensor_0 = torch.ones(shape, device="cpu") 1701*da0073e9SAndroid Build Coastguard Worker tensor_cpu = tensor_0[:][1].fill_(val) 1702*da0073e9SAndroid Build Coastguard Worker 1703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_mps, tensor_cpu) 1704*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor, tensor_0) 1705*da0073e9SAndroid Build Coastguard Worker 1706*da0073e9SAndroid Build Coastguard Worker shape = [1, 10] 1707*da0073e9SAndroid Build Coastguard Worker val = 0.0 1708*da0073e9SAndroid Build Coastguard Worker tensor = torch.ones(shape, device="mps") 1709*da0073e9SAndroid Build Coastguard Worker val_tensor_mps = torch.tensor(val, device="mps") 1710*da0073e9SAndroid Build Coastguard Worker tensor_mps = tensor[:, 9].fill_(val_tensor_mps) 1711*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/114692 1712*da0073e9SAndroid Build Coastguard Worker tensor[:, 5].fill_(val_tensor_mps) 1713*da0073e9SAndroid Build Coastguard Worker tensor_0 = torch.ones(shape, device="cpu") 1714*da0073e9SAndroid Build Coastguard Worker val_tensor_cpu = torch.tensor(val, device="cpu") 1715*da0073e9SAndroid Build Coastguard Worker tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu) 1716*da0073e9SAndroid Build Coastguard Worker tensor_0[:, 5].fill_(val_tensor_cpu) 1717*da0073e9SAndroid Build Coastguard Worker 1718*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_mps.to(device="cpu"), tensor_cpu) 1719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.to(device="cpu"), tensor_0) 1720*da0073e9SAndroid Build Coastguard Worker 1721*da0073e9SAndroid Build Coastguard Worker def test_cdist_large(self, device="mps"): 1722*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1723*da0073e9SAndroid Build Coastguard Worker x = torch.randn(100, 10, device=device) 1724*da0073e9SAndroid Build Coastguard Worker y = torch.randn(100, 10, device=device) 1725*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 1726*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 1727*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1728*da0073e9SAndroid Build Coastguard Worker 1729*da0073e9SAndroid Build Coastguard Worker def test_cdist_large_batch(self, device="mps"): 1730*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1731*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 3, 100, 10, device=device) 1732*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 3, 100, 10, device=device) 1733*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 1734*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 1735*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1736*da0073e9SAndroid Build Coastguard Worker 1737*da0073e9SAndroid Build Coastguard Worker def test_cdist_non_contiguous(self, device="mps"): 1738*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1739*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 7, device=device).mT 1740*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 3, device=device).mT 1741*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 1742*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 1743*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_contiguous()) 1744*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_contiguous()) 1745*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker x = torch.randn(7, 5, device=device) 1748*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 3, device=device).t() 1749*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 1750*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 1751*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_contiguous()) 1752*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_contiguous()) 1753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 7, device=device).t() 1756*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 5, device=device) 1757*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 1758*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 1759*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_contiguous()) 1760*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.is_contiguous()) 1761*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1762*da0073e9SAndroid Build Coastguard Worker 1763*da0073e9SAndroid Build Coastguard Worker def test_cdist_non_contiguous_batch(self, device="mps"): 1764*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1765*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 3, 2, 5, 7, device=device).mT 1766*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 3, 2, 5, 3, device=device).mT 1767*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 1768*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 1769*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_contiguous()) 1770*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_contiguous()) 1771*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1772*da0073e9SAndroid Build Coastguard Worker 1773*da0073e9SAndroid Build Coastguard Worker x = torch.randn(7, 2, 7, 5, device=device) 1774*da0073e9SAndroid Build Coastguard Worker y = torch.randn(7, 2, 5, 3, device=device).mT 1775*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 1776*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 1777*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_contiguous()) 1778*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_contiguous()) 1779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1780*da0073e9SAndroid Build Coastguard Worker 1781*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 5, 7, device=device).mT 1782*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 3, 5, device=device) 1783*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 1784*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 1785*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_contiguous()) 1786*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.is_contiguous()) 1787*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1788*da0073e9SAndroid Build Coastguard Worker 1789*da0073e9SAndroid Build Coastguard Worker def test_cdist_euclidean_large(self, device="mps"): 1790*da0073e9SAndroid Build Coastguard Worker def _test_euclidean_large_cdist(sizex, sizey=None): 1791*da0073e9SAndroid Build Coastguard Worker if sizey is None: 1792*da0073e9SAndroid Build Coastguard Worker sizey = sizex 1793*da0073e9SAndroid Build Coastguard Worker x = torch.randn(sizex, device=device, dtype=torch.float) 1794*da0073e9SAndroid Build Coastguard Worker y = torch.randn(sizey, device=device, dtype=torch.float) 1795*da0073e9SAndroid Build Coastguard Worker eps = 1e-6 1796*da0073e9SAndroid Build Coastguard Worker # to avoid extremum 1797*da0073e9SAndroid Build Coastguard Worker x = x - (((x - y) < eps).float() * 2 * eps) 1798*da0073e9SAndroid Build Coastguard Worker x.requires_grad = True 1799*da0073e9SAndroid Build Coastguard Worker y.requires_grad = True 1800*da0073e9SAndroid Build Coastguard Worker dist = torch.cdist(x, y, p=2) 1801*da0073e9SAndroid Build Coastguard Worker # Do a backward pass to check that it is valid for large 1802*da0073e9SAndroid Build Coastguard Worker # matrices 1803*da0073e9SAndroid Build Coastguard Worker loss = dist.sum() 1804*da0073e9SAndroid Build Coastguard Worker loss.backward() 1805*da0073e9SAndroid Build Coastguard Worker 1806*da0073e9SAndroid Build Coastguard Worker _test_euclidean_large_cdist((2000, 5)) 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker def test_cdist_same_inputs(self, device="mps"): 1809*da0073e9SAndroid Build Coastguard Worker # Test to detect issues in cdist gradient calculation 1810*da0073e9SAndroid Build Coastguard Worker # When the distances are 0 1811*da0073e9SAndroid Build Coastguard Worker sizex = (1, 27, 32) 1812*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: 1813*da0073e9SAndroid Build Coastguard Worker x = torch.randn(sizex, device=device, dtype=torch.float) 1814*da0073e9SAndroid Build Coastguard Worker dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float) 1815*da0073e9SAndroid Build Coastguard Worker y = x.clone() 1816*da0073e9SAndroid Build Coastguard Worker eps = 1e-6 1817*da0073e9SAndroid Build Coastguard Worker x.requires_grad = True 1818*da0073e9SAndroid Build Coastguard Worker d = torch.cdist(x, y) 1819*da0073e9SAndroid Build Coastguard Worker d.backward(dist_grad) 1820*da0073e9SAndroid Build Coastguard Worker # Check that the backward passs does not contain invalid 1821*da0073e9SAndroid Build Coastguard Worker # values such as nan or inf 1822*da0073e9SAndroid Build Coastguard Worker assert torch.isfinite(x.grad).all() 1823*da0073e9SAndroid Build Coastguard Worker 1824*da0073e9SAndroid Build Coastguard Worker 1825*da0073e9SAndroid Build Coastguard Worker def _brute_cdist(self, x, y, p=2): 1826*da0073e9SAndroid Build Coastguard Worker r1 = x.shape[-2] 1827*da0073e9SAndroid Build Coastguard Worker r2 = y.shape[-2] 1828*da0073e9SAndroid Build Coastguard Worker if r1 == 0 or r2 == 0: 1829*da0073e9SAndroid Build Coastguard Worker return torch.empty(r1, r2, device=x.device) 1830*da0073e9SAndroid Build Coastguard Worker return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1) 1831*da0073e9SAndroid Build Coastguard Worker 1832*da0073e9SAndroid Build Coastguard Worker def test_cdist_norm(self, device="mps"): 1833*da0073e9SAndroid Build Coastguard Worker for r1 in [3, 4]: 1834*da0073e9SAndroid Build Coastguard Worker for m in [2, 3]: 1835*da0073e9SAndroid Build Coastguard Worker for r2 in [4, 6]: 1836*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 1.5, 2.5, float('inf')]: 1837*da0073e9SAndroid Build Coastguard Worker x = torch.randn(r1, m, device=device) 1838*da0073e9SAndroid Build Coastguard Worker y = torch.randn(r2, m, device=device) 1839*da0073e9SAndroid Build Coastguard Worker if p == 2: 1840*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1841*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 1842*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 1843*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual, rtol=0, atol=0.02) 1844*da0073e9SAndroid Build Coastguard Worker else: 1845*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=p) 1846*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=p) 1847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1848*da0073e9SAndroid Build Coastguard Worker 1849*da0073e9SAndroid Build Coastguard Worker def test_cdist_norm_batch(self, device="mps"): 1850*da0073e9SAndroid Build Coastguard Worker for r1 in [3, 4]: 1851*da0073e9SAndroid Build Coastguard Worker for m in [2, 3]: 1852*da0073e9SAndroid Build Coastguard Worker for r2 in [4, 6]: 1853*da0073e9SAndroid Build Coastguard Worker for p in [0, 3, 1.5, 2.5, float('inf')]: 1854*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 6, r1, m, device=device) 1855*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3, 6, r2, m, device=device) 1856*da0073e9SAndroid Build Coastguard Worker if p == 2: 1857*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1858*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 1859*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 1860*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual, rtol=0, atol=0.02) 1861*da0073e9SAndroid Build Coastguard Worker else: 1862*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=p) 1863*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=p) 1864*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1865*da0073e9SAndroid Build Coastguard Worker 1866*da0073e9SAndroid Build Coastguard Worker def test_mm(self): 1867*da0073e9SAndroid Build Coastguard Worker B = torch.ones(5, 6).to("mps") 1868*da0073e9SAndroid Build Coastguard Worker C = torch.ones(6, 5).to("mps") 1869*da0073e9SAndroid Build Coastguard Worker D = torch.mm(B, C).cpu() 1870*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(D, torch.full((5, 5), 6.0)) 1871*da0073e9SAndroid Build Coastguard Worker 1872*da0073e9SAndroid Build Coastguard Worker def test_linalg_cross(self): 1873*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 1874*da0073e9SAndroid Build Coastguard Worker device = "mps" 1875*da0073e9SAndroid Build Coastguard Worker if dtype is torch.int32 or dtype is torch.int64: 1876*da0073e9SAndroid Build Coastguard Worker x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device) 1877*da0073e9SAndroid Build Coastguard Worker y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device) 1878*da0073e9SAndroid Build Coastguard Worker else: 1879*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 3, 100, dtype=dtype, device=device) 1880*da0073e9SAndroid Build Coastguard Worker y = torch.rand(100, 3, 100, dtype=dtype, device=device) 1881*da0073e9SAndroid Build Coastguard Worker x_cpu = x.to("cpu") 1882*da0073e9SAndroid Build Coastguard Worker y_cpu = y.to("cpu") 1883*da0073e9SAndroid Build Coastguard Worker res1 = torch.linalg.cross(x, y, dim=1) 1884*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor((), dtype=dtype, device=device) 1885*da0073e9SAndroid Build Coastguard Worker res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1) 1886*da0073e9SAndroid Build Coastguard Worker res2_cpu = torch.tensor((), dtype=dtype, device="cpu") 1887*da0073e9SAndroid Build Coastguard Worker torch.linalg.cross(x, y, dim=1, out=res2) 1888*da0073e9SAndroid Build Coastguard Worker torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu) 1889*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res1_cpu) 1891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res2, res2_cpu) 1892*da0073e9SAndroid Build Coastguard Worker 1893*da0073e9SAndroid Build Coastguard Worker # test for broadcastable inputs 1894*da0073e9SAndroid Build Coastguard Worker if dtype is torch.int32 or dtype is torch.int64: 1895*da0073e9SAndroid Build Coastguard Worker x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device) 1896*da0073e9SAndroid Build Coastguard Worker y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device) 1897*da0073e9SAndroid Build Coastguard Worker else: 1898*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 3, 2, dtype=dtype, device=device) 1899*da0073e9SAndroid Build Coastguard Worker y = torch.rand(4, 3, 1, dtype=dtype, device=device) 1900*da0073e9SAndroid Build Coastguard Worker x_cpu = x.to("cpu") 1901*da0073e9SAndroid Build Coastguard Worker y_cpu = y.to("cpu") 1902*da0073e9SAndroid Build Coastguard Worker res1 = torch.linalg.cross(x, y, dim=1) 1903*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor((), dtype=dtype, device=device) 1904*da0073e9SAndroid Build Coastguard Worker res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1) 1905*da0073e9SAndroid Build Coastguard Worker res2_cpu = torch.tensor((), dtype=dtype, device="cpu") 1906*da0073e9SAndroid Build Coastguard Worker torch.linalg.cross(x, y, dim=1, out=res2) 1907*da0073e9SAndroid Build Coastguard Worker torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu) 1908*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1909*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res1_cpu) 1910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res2, res2_cpu) 1911*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]] 1912*da0073e9SAndroid Build Coastguard Worker 1913*da0073e9SAndroid Build Coastguard Worker def test_cross(self): 1914*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, 3, device="mps") 1915*da0073e9SAndroid Build Coastguard Worker b = torch.randn(4, 3, device="mps") 1916*da0073e9SAndroid Build Coastguard Worker a_cpu = a.to("cpu") 1917*da0073e9SAndroid Build Coastguard Worker b_cpu = b.to("cpu") 1918*da0073e9SAndroid Build Coastguard Worker res = torch.cross(a, b, dim=1) 1919*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.cross(a_cpu, b_cpu, dim=1) 1920*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cpu) 1921*da0073e9SAndroid Build Coastguard Worker 1922*da0073e9SAndroid Build Coastguard Worker def test_addmm(self): 1923*da0073e9SAndroid Build Coastguard Worker A = torch.ones(5, 5).to("mps") 1924*da0073e9SAndroid Build Coastguard Worker B = torch.ones(5, 6).to("mps") 1925*da0073e9SAndroid Build Coastguard Worker C = torch.ones(6, 5).to("mps") 1926*da0073e9SAndroid Build Coastguard Worker D = torch.addmm(A, B, C).to("cpu") 1927*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(D, torch.full((5, 5), 7.0)) 1928*da0073e9SAndroid Build Coastguard Worker 1929*da0073e9SAndroid Build Coastguard Worker def test_bmm(self): 1930*da0073e9SAndroid Build Coastguard Worker batch1_cpu = torch.randn(10, 3, 4) 1931*da0073e9SAndroid Build Coastguard Worker batch2_cpu = torch.randn(10, 4, 5) 1932*da0073e9SAndroid Build Coastguard Worker 1933*da0073e9SAndroid Build Coastguard Worker batch1_mps = batch1_cpu.detach().clone().to("mps") 1934*da0073e9SAndroid Build Coastguard Worker batch2_mps = batch2_cpu.detach().clone().to("mps") 1935*da0073e9SAndroid Build Coastguard Worker 1936*da0073e9SAndroid Build Coastguard Worker output_cpu = torch.bmm(batch1_cpu, batch2_cpu) 1937*da0073e9SAndroid Build Coastguard Worker output_mps = torch.bmm(batch1_mps, batch2_mps) 1938*da0073e9SAndroid Build Coastguard Worker 1939*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps) 1940*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu.size(), output_mps.size()) 1941*da0073e9SAndroid Build Coastguard Worker 1942*da0073e9SAndroid Build Coastguard Worker @xfailIf(product_version < 15.0) 1943*da0073e9SAndroid Build Coastguard Worker @parametrize("dtype", [torch.float16, torch.bfloat16]) 1944*da0073e9SAndroid Build Coastguard Worker def test_large_bmm(self, dtype): 1945*da0073e9SAndroid Build Coastguard Worker batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps') 1946*da0073e9SAndroid Build Coastguard Worker batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps') 1947*da0073e9SAndroid Build Coastguard Worker output_cpu = torch.bmm(batch1.cpu(), batch2.cpu()) 1948*da0073e9SAndroid Build Coastguard Worker output_mps = torch.bmm(batch1, batch2) 1949*da0073e9SAndroid Build Coastguard Worker 1950*da0073e9SAndroid Build Coastguard Worker # Using the low precision comparison for FP16 1951*da0073e9SAndroid Build Coastguard Worker tol = 1e-2 if dtype == torch.float16 else None 1952*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol) 1953*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu.size(), output_mps.size()) 1954*da0073e9SAndroid Build Coastguard Worker 1955*da0073e9SAndroid Build Coastguard Worker 1956*da0073e9SAndroid Build Coastguard Worker def test_addr(self): 1957*da0073e9SAndroid Build Coastguard Worker A = torch.ones(5, 10).to("mps") 1958*da0073e9SAndroid Build Coastguard Worker B = torch.ones(5).to("mps") 1959*da0073e9SAndroid Build Coastguard Worker C = torch.ones(10).to("mps") 1960*da0073e9SAndroid Build Coastguard Worker D = torch.addr(A, B, C).to("cpu") 1961*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(D, torch.full((5, 10), 2.0)) 1962*da0073e9SAndroid Build Coastguard Worker 1963*da0073e9SAndroid Build Coastguard Worker def test_trace(self): 1964*da0073e9SAndroid Build Coastguard Worker M_cpu = torch.randn(3, 3) 1965*da0073e9SAndroid Build Coastguard Worker M_mps = M_cpu.detach().clone().to("mps") 1966*da0073e9SAndroid Build Coastguard Worker 1967*da0073e9SAndroid Build Coastguard Worker output_cpu = torch.trace(M_cpu) 1968*da0073e9SAndroid Build Coastguard Worker output_mps = torch.trace(M_mps) 1969*da0073e9SAndroid Build Coastguard Worker 1970*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps) 1971*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu.size(), output_mps.size()) 1972*da0073e9SAndroid Build Coastguard Worker 1973*da0073e9SAndroid Build Coastguard Worker def test_addbmm(self): 1974*da0073e9SAndroid Build Coastguard Worker M_cpu = torch.randn(3, 5) 1975*da0073e9SAndroid Build Coastguard Worker batch1_cpu = torch.randn(10, 3, 4) 1976*da0073e9SAndroid Build Coastguard Worker batch2_cpu = torch.randn(10, 4, 5) 1977*da0073e9SAndroid Build Coastguard Worker 1978*da0073e9SAndroid Build Coastguard Worker M_mps = M_cpu.detach().clone().to("mps") 1979*da0073e9SAndroid Build Coastguard Worker batch1_mps = batch1_cpu.detach().clone().to("mps") 1980*da0073e9SAndroid Build Coastguard Worker batch2_mps = batch2_cpu.detach().clone().to("mps") 1981*da0073e9SAndroid Build Coastguard Worker 1982*da0073e9SAndroid Build Coastguard Worker output_cpu = torch.addbmm(M_cpu, batch1_cpu, batch2_cpu) 1983*da0073e9SAndroid Build Coastguard Worker output_mps = torch.addbmm(M_mps, batch1_mps, batch2_mps) 1984*da0073e9SAndroid Build Coastguard Worker 1985*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps) 1986*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu.size(), output_mps.size()) 1987*da0073e9SAndroid Build Coastguard Worker 1988*da0073e9SAndroid Build Coastguard Worker def test_baddbmm(self): 1989*da0073e9SAndroid Build Coastguard Worker def helper(input_shape, batch1_shape, batch2_shape): 1990*da0073e9SAndroid Build Coastguard Worker M_cpu = torch.randn(input_shape) 1991*da0073e9SAndroid Build Coastguard Worker batch1_cpu = torch.randn(batch1_shape) 1992*da0073e9SAndroid Build Coastguard Worker batch2_cpu = torch.randn(batch2_shape) 1993*da0073e9SAndroid Build Coastguard Worker alpha = 1.2 1994*da0073e9SAndroid Build Coastguard Worker beta = 0.8 1995*da0073e9SAndroid Build Coastguard Worker 1996*da0073e9SAndroid Build Coastguard Worker M_mps = M_cpu.detach().clone().to("mps") 1997*da0073e9SAndroid Build Coastguard Worker batch1_mps = batch1_cpu.detach().clone().to("mps") 1998*da0073e9SAndroid Build Coastguard Worker batch2_mps = batch2_cpu.detach().clone().to("mps") 1999*da0073e9SAndroid Build Coastguard Worker 2000*da0073e9SAndroid Build Coastguard Worker output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha) 2001*da0073e9SAndroid Build Coastguard Worker output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha) 2002*da0073e9SAndroid Build Coastguard Worker 2003*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps) 2004*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu.size(), output_mps.size()) 2005*da0073e9SAndroid Build Coastguard Worker 2006*da0073e9SAndroid Build Coastguard Worker helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5)) 2007*da0073e9SAndroid Build Coastguard Worker helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5)) 2008*da0073e9SAndroid Build Coastguard Worker helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77)) 2009*da0073e9SAndroid Build Coastguard Worker 2010*da0073e9SAndroid Build Coastguard Worker def test_local_scalar_dense_mps(self): 2011*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.randn(1) 2012*da0073e9SAndroid Build Coastguard Worker y_mps = x_cpu.to("mps") 2013*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(x_cpu.item(), y_mps.item()) 2014*da0073e9SAndroid Build Coastguard Worker 2015*da0073e9SAndroid Build Coastguard Worker def test_linear_1d_weight(self): 2016*da0073e9SAndroid Build Coastguard Worker device = 'cpu' 2017*da0073e9SAndroid Build Coastguard Worker projected = torch.rand([8]).to(device) 2018*da0073e9SAndroid Build Coastguard Worker x = torch.rand([1, 2, 2, 8]).to(device) 2019*da0073e9SAndroid Build Coastguard Worker x_mps = x.to('mps') 2020*da0073e9SAndroid Build Coastguard Worker projected_mps = projected.to('mps') 2021*da0073e9SAndroid Build Coastguard Worker linear = F.linear(x, projected) 2022*da0073e9SAndroid Build Coastguard Worker linear_mps = F.linear(x_mps, projected_mps) 2023*da0073e9SAndroid Build Coastguard Worker 2024*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linear, linear_mps) 2025*da0073e9SAndroid Build Coastguard Worker 2026*da0073e9SAndroid Build Coastguard Worker projected = torch.rand([1, 8]).to(device) 2027*da0073e9SAndroid Build Coastguard Worker x = torch.rand([1, 2, 2, 8]).to(device) 2028*da0073e9SAndroid Build Coastguard Worker x_mps = x.to('mps') 2029*da0073e9SAndroid Build Coastguard Worker projected_mps = projected.to('mps') 2030*da0073e9SAndroid Build Coastguard Worker linear = F.linear(x, projected) 2031*da0073e9SAndroid Build Coastguard Worker linear_mps = F.linear(x_mps, projected_mps) 2032*da0073e9SAndroid Build Coastguard Worker 2033*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linear, linear_mps) 2034*da0073e9SAndroid Build Coastguard Worker 2035*da0073e9SAndroid Build Coastguard Worker def test_linear_bias(self): 2036*da0073e9SAndroid Build Coastguard Worker def helper(bias_shape): 2037*da0073e9SAndroid Build Coastguard Worker device = "cpu" 2038*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, 2, 64, device=device) 2039*da0073e9SAndroid Build Coastguard Worker linear = torch.nn.Linear(64, 4, device=device) 2040*da0073e9SAndroid Build Coastguard Worker linear.bias = torch.nn.Parameter(torch.randn(bias_shape, dtype=torch.float32, device=device)) 2041*da0073e9SAndroid Build Coastguard Worker y = linear(x) 2042*da0073e9SAndroid Build Coastguard Worker device = "mps" 2043*da0073e9SAndroid Build Coastguard Worker x_mps = x.to(device) 2044*da0073e9SAndroid Build Coastguard Worker linear.to(device) 2045*da0073e9SAndroid Build Coastguard Worker y_mps = linear(x_mps) 2046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_mps) 2047*da0073e9SAndroid Build Coastguard Worker 2048*da0073e9SAndroid Build Coastguard Worker helper(()) 2049*da0073e9SAndroid Build Coastguard Worker helper((2, 4)) 2050*da0073e9SAndroid Build Coastguard Worker 2051*da0073e9SAndroid Build Coastguard Worker def test_linear_errors(self): 2052*da0073e9SAndroid Build Coastguard Worker # Mixed CPU<->MPS tensors 2053*da0073e9SAndroid Build Coastguard Worker size = (3, 3) 2054*da0073e9SAndroid Build Coastguard Worker 2055*da0073e9SAndroid Build Coastguard Worker # Unsupported dtypes 2056*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not support linear for non-float weights"): 2057*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.linear(torch.rand(size, device='mps'), 2058*da0073e9SAndroid Build Coastguard Worker torch.randint(-10, 10, size, dtype=torch.int8, device='mps')) 2059*da0073e9SAndroid Build Coastguard Worker 2060*da0073e9SAndroid Build Coastguard Worker # Weigths on wrong device 2061*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"): 2062*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.linear(torch.rand(size, device='mps'), 2063*da0073e9SAndroid Build Coastguard Worker torch.rand(size, device='cpu')) 2064*da0073e9SAndroid Build Coastguard Worker 2065*da0073e9SAndroid Build Coastguard Worker # Input on wrong device 2066*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"): 2067*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.linear(torch.rand(size, device='cpu'), 2068*da0073e9SAndroid Build Coastguard Worker torch.rand(size, device='mps')) 2069*da0073e9SAndroid Build Coastguard Worker 2070*da0073e9SAndroid Build Coastguard Worker def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False): 2071*da0073e9SAndroid Build Coastguard Worker cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias) 2072*da0073e9SAndroid Build Coastguard Worker mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias) 2073*da0073e9SAndroid Build Coastguard Worker 2074*da0073e9SAndroid Build Coastguard Worker # Use the same weights and bias as the ones from the cpu 2075*da0073e9SAndroid Build Coastguard Worker mps_linear.weight.data = cpu_linear.weight.data.detach().clone().to("mps") 2076*da0073e9SAndroid Build Coastguard Worker 2077*da0073e9SAndroid Build Coastguard Worker if bias: 2078*da0073e9SAndroid Build Coastguard Worker mps_linear.bias.data = cpu_linear.bias.data.detach().clone().to("mps") 2079*da0073e9SAndroid Build Coastguard Worker 2080*da0073e9SAndroid Build Coastguard Worker linear_mps_input = torch.randn(shape).to('mps') 2081*da0073e9SAndroid Build Coastguard Worker linear_cpu_input = linear_mps_input.detach().clone().to('cpu') 2082*da0073e9SAndroid Build Coastguard Worker 2083*da0073e9SAndroid Build Coastguard Worker if backward_pass: 2084*da0073e9SAndroid Build Coastguard Worker linear_mps_input = linear_mps_input.requires_grad_() 2085*da0073e9SAndroid Build Coastguard Worker linear_cpu_input = linear_cpu_input.requires_grad_() 2086*da0073e9SAndroid Build Coastguard Worker 2087*da0073e9SAndroid Build Coastguard Worker linear_cpu_output = cpu_linear(linear_cpu_input) 2088*da0073e9SAndroid Build Coastguard Worker linear_mps_output = mps_linear(linear_mps_input) 2089*da0073e9SAndroid Build Coastguard Worker 2090*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linear_cpu_output, linear_mps_output.to('cpu')) 2091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linear_cpu_output.size(), linear_mps_output.size()) 2092*da0073e9SAndroid Build Coastguard Worker 2093*da0073e9SAndroid Build Coastguard Worker if backward_pass: 2094*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.rand_like(linear_cpu_output, requires_grad=True) 2095*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.detach().to('mps').requires_grad_() 2096*da0073e9SAndroid Build Coastguard Worker 2097*da0073e9SAndroid Build Coastguard Worker linear_cpu_output.backward(gradient=cpu_grad, create_graph=True) 2098*da0073e9SAndroid Build Coastguard Worker linear_mps_output.backward(gradient=grad, create_graph=True) 2099*da0073e9SAndroid Build Coastguard Worker 2100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linear_cpu_input.grad.size(), linear_mps_input.grad.size()) 2101*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad.to("cpu"), atol=8e-04, rtol=10.4e-05) 2102*da0073e9SAndroid Build Coastguard Worker 2103*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_linear.weight.grad.size(), mps_linear.weight.grad.size()) 2104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad.to("cpu"), atol=8e-04, rtol=10.4e-05) 2105*da0073e9SAndroid Build Coastguard Worker if bias: 2106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_linear.bias.grad.size(), mps_linear.bias.grad.size()) 2107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad.to("cpu"), atol=8e-04, rtol=10.4e-05) 2108*da0073e9SAndroid Build Coastguard Worker 2109*da0073e9SAndroid Build Coastguard Worker # test gradgrad 2110*da0073e9SAndroid Build Coastguard Worker x_grad_out = torch.rand_like(linear_cpu_input) 2111*da0073e9SAndroid Build Coastguard Worker x_grad_out_mps = x_grad_out.to("mps") 2112*da0073e9SAndroid Build Coastguard Worker w_grad_out = torch.rand_like(cpu_linear.weight) 2113*da0073e9SAndroid Build Coastguard Worker w_grad_out_mps = w_grad_out.to("mps") 2114*da0073e9SAndroid Build Coastguard Worker 2115*da0073e9SAndroid Build Coastguard Worker linear_cpu_input.grad.detach().zero_() 2116*da0073e9SAndroid Build Coastguard Worker linear_mps_input.grad.detach().zero_() 2117*da0073e9SAndroid Build Coastguard Worker cpu_linear.weight.grad.detach().zero_() 2118*da0073e9SAndroid Build Coastguard Worker mps_linear.weight.grad.detach().zero_() 2119*da0073e9SAndroid Build Coastguard Worker if bias: 2120*da0073e9SAndroid Build Coastguard Worker b_grad_out = torch.rand_like(cpu_linear.bias) 2121*da0073e9SAndroid Build Coastguard Worker b_grad_out_mps = b_grad_out.to("mps") 2122*da0073e9SAndroid Build Coastguard Worker cpu_linear.bias.grad.detach().zero_() 2123*da0073e9SAndroid Build Coastguard Worker mps_linear.bias.grad.detach().zero_() 2124*da0073e9SAndroid Build Coastguard Worker 2125*da0073e9SAndroid Build Coastguard Worker linear_cpu_input.grad.backward(x_grad_out, retain_graph=True) 2126*da0073e9SAndroid Build Coastguard Worker linear_mps_input.grad.backward(x_grad_out_mps, retain_graph=True) 2127*da0073e9SAndroid Build Coastguard Worker cpu_linear.weight.grad.backward(w_grad_out, retain_graph=True) 2128*da0073e9SAndroid Build Coastguard Worker mps_linear.weight.grad.backward(w_grad_out_mps, retain_graph=True) 2129*da0073e9SAndroid Build Coastguard Worker if bias: 2130*da0073e9SAndroid Build Coastguard Worker cpu_linear.bias.grad.backward(b_grad_out, retain_graph=True) 2131*da0073e9SAndroid Build Coastguard Worker mps_linear.bias.grad.backward(b_grad_out_mps, retain_graph=True) 2132*da0073e9SAndroid Build Coastguard Worker 2133*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_grad.grad, grad.grad) 2134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad) 2135*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad) 2136*da0073e9SAndroid Build Coastguard Worker if bias: 2137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad) 2138*da0073e9SAndroid Build Coastguard Worker 2139*da0073e9SAndroid Build Coastguard Worker def test_linear1D(self): 2140*da0073e9SAndroid Build Coastguard Worker self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=False) 2141*da0073e9SAndroid Build Coastguard Worker 2142*da0073e9SAndroid Build Coastguard Worker def test_linear1D_backward(self): 2143*da0073e9SAndroid Build Coastguard Worker self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=True) 2144*da0073e9SAndroid Build Coastguard Worker 2145*da0073e9SAndroid Build Coastguard Worker def test_linear2D(self): 2146*da0073e9SAndroid Build Coastguard Worker self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=False) 2147*da0073e9SAndroid Build Coastguard Worker 2148*da0073e9SAndroid Build Coastguard Worker def test_linear2D_backward(self): 2149*da0073e9SAndroid Build Coastguard Worker self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=True) 2150*da0073e9SAndroid Build Coastguard Worker 2151*da0073e9SAndroid Build Coastguard Worker def test_linear2D_no_bias(self): 2152*da0073e9SAndroid Build Coastguard Worker self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=False) 2153*da0073e9SAndroid Build Coastguard Worker 2154*da0073e9SAndroid Build Coastguard Worker def test_linear2D_no_bias_backward(self): 2155*da0073e9SAndroid Build Coastguard Worker self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=True) 2156*da0073e9SAndroid Build Coastguard Worker 2157*da0073e9SAndroid Build Coastguard Worker def test_linear3D(self): 2158*da0073e9SAndroid Build Coastguard Worker self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False) 2159*da0073e9SAndroid Build Coastguard Worker 2160*da0073e9SAndroid Build Coastguard Worker def test_linear3D_backward(self): 2161*da0073e9SAndroid Build Coastguard Worker self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True) 2162*da0073e9SAndroid Build Coastguard Worker 2163*da0073e9SAndroid Build Coastguard Worker def test_linear3D_no_bias(self): 2164*da0073e9SAndroid Build Coastguard Worker self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False) 2165*da0073e9SAndroid Build Coastguard Worker 2166*da0073e9SAndroid Build Coastguard Worker def test_linear3D_no_bias_backward(self): 2167*da0073e9SAndroid Build Coastguard Worker self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True) 2168*da0073e9SAndroid Build Coastguard Worker 2169*da0073e9SAndroid Build Coastguard Worker def test_uniform(self): 2170*da0073e9SAndroid Build Coastguard Worker low = torch.zeros(5, 5, requires_grad=True) 2171*da0073e9SAndroid Build Coastguard Worker high = (torch.ones(5, 5) * 3).requires_grad_() 2172*da0073e9SAndroid Build Coastguard Worker low_1d = torch.zeros(1, requires_grad=True) 2173*da0073e9SAndroid Build Coastguard Worker high_1d = (torch.ones(1) * 3).requires_grad_() 2174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Uniform(low, high).sample().size(), (5, 5)) 2175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5)) 2176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,)) 2177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1)) 2178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,)) 2179*da0073e9SAndroid Build Coastguard Worker 2180*da0073e9SAndroid Build Coastguard Worker # Check log_prob computation when value outside range 2181*da0073e9SAndroid Build Coastguard Worker uniform = Uniform(low_1d, high_1d, validate_args=False) 2182*da0073e9SAndroid Build Coastguard Worker above_high = torch.tensor([4.0]) 2183*da0073e9SAndroid Build Coastguard Worker below_low = torch.tensor([-1.0]) 2184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.log_prob(above_high).item(), -inf) 2185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.log_prob(below_low).item(), -inf) 2186*da0073e9SAndroid Build Coastguard Worker 2187*da0073e9SAndroid Build Coastguard Worker # check cdf computation when value outside range 2188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.cdf(below_low).item(), 0) 2189*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniform.cdf(above_high).item(), 1) 2190*da0073e9SAndroid Build Coastguard Worker 2191*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 2192*da0073e9SAndroid Build Coastguard Worker rand = low.new(low.size()).uniform_() 2193*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 2194*da0073e9SAndroid Build Coastguard Worker u = Uniform(low, high).rsample() 2195*da0073e9SAndroid Build Coastguard Worker u.backward(torch.ones_like(u)) 2196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(low.grad, 1 - rand) 2197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(high.grad, rand) 2198*da0073e9SAndroid Build Coastguard Worker low.grad.zero_() 2199*da0073e9SAndroid Build Coastguard Worker high.grad.zero_() 2200*da0073e9SAndroid Build Coastguard Worker 2201*da0073e9SAndroid Build Coastguard Worker def test_randperm(self, device="mps"): 2202*da0073e9SAndroid Build Coastguard Worker rng_device = None 2203*da0073e9SAndroid Build Coastguard Worker for n in (5, 100, 50000, 100000): 2204*da0073e9SAndroid Build Coastguard Worker for dtype in (torch.long, torch.half, torch.float): 2205*da0073e9SAndroid Build Coastguard Worker if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here. 2206*da0073e9SAndroid Build Coastguard Worker continue 2207*da0073e9SAndroid Build Coastguard Worker if n > 256 and dtype == torch.bfloat16: 2208*da0073e9SAndroid Build Coastguard Worker continue 2209*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(devices=rng_device): 2210*da0073e9SAndroid Build Coastguard Worker res1 = torch.randperm(n, dtype=dtype, device=device) 2211*da0073e9SAndroid Build Coastguard Worker res2 = torch.empty(0, dtype=dtype, device=device) 2212*da0073e9SAndroid Build Coastguard Worker torch.randperm(n, out=res2, dtype=dtype, device=device) 2213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device)) 2214*da0073e9SAndroid Build Coastguard Worker 2215*da0073e9SAndroid Build Coastguard Worker # Default type is long 2216*da0073e9SAndroid Build Coastguard Worker for n in (100, 10000): 2217*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.randperm(n, device=device).dtype, torch.long) 2218*da0073e9SAndroid Build Coastguard Worker 2219*da0073e9SAndroid Build Coastguard Worker # randperm of 0 elements is an empty tensor 2220*da0073e9SAndroid Build Coastguard Worker res1 = torch.randperm(0) 2221*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor(5, dtype=dtype, device=device) 2222*da0073e9SAndroid Build Coastguard Worker torch.randperm(0, out=res2) 2223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1.numel(), 0) 2224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res2.numel(), 0) 2225*da0073e9SAndroid Build Coastguard Worker 2226*da0073e9SAndroid Build Coastguard Worker # Test non-contiguous tensors 2227*da0073e9SAndroid Build Coastguard Worker for n in (4, 5, 6, 10, 20): 2228*da0073e9SAndroid Build Coastguard Worker non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t() 2229*da0073e9SAndroid Build Coastguard Worker self.assertFalse(non_contiguous_tensor.is_contiguous()) 2230*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(devices=rng_device): 2231*da0073e9SAndroid Build Coastguard Worker res = torch.randperm(n, dtype=torch.long, device=device) 2232*da0073e9SAndroid Build Coastguard Worker torch.randperm(n, out=non_contiguous_tensor) 2233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device)) 2234*da0073e9SAndroid Build Coastguard Worker 2235*da0073e9SAndroid Build Coastguard Worker # Test forward maxpool2d 2236*da0073e9SAndroid Build Coastguard Worker def test_max_pool2d(self): 2237*da0073e9SAndroid Build Coastguard Worker def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False): 2238*da0073e9SAndroid Build Coastguard Worker 2239*da0073e9SAndroid Build Coastguard Worker cpu_x = None 2240*da0073e9SAndroid Build Coastguard Worker if (test_ties): 2241*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True) 2242*da0073e9SAndroid Build Coastguard Worker else: 2243*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 2244*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 2245*da0073e9SAndroid Build Coastguard Worker 2246*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation, 2247*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode, return_indices=return_indices) 2248*da0073e9SAndroid Build Coastguard Worker 2249*da0073e9SAndroid Build Coastguard Worker if (return_indices is False): 2250*da0073e9SAndroid Build Coastguard Worker y = pool(x) 2251*da0073e9SAndroid Build Coastguard Worker ref_y = pool(cpu_x) 2252*da0073e9SAndroid Build Coastguard Worker 2253*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(ref_y) 2254*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 2255*da0073e9SAndroid Build Coastguard Worker 2256*da0073e9SAndroid Build Coastguard Worker y.backward(gradient=grad) 2257*da0073e9SAndroid Build Coastguard Worker ref_y.backward(gradient=cpu_grad) 2258*da0073e9SAndroid Build Coastguard Worker 2259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 2260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 2261*da0073e9SAndroid Build Coastguard Worker else: 2262*da0073e9SAndroid Build Coastguard Worker y, idx = pool(x) 2263*da0073e9SAndroid Build Coastguard Worker ref_y, ref_idx = pool(cpu_x) 2264*da0073e9SAndroid Build Coastguard Worker 2265*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(ref_y) 2266*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 2267*da0073e9SAndroid Build Coastguard Worker 2268*da0073e9SAndroid Build Coastguard Worker y.backward(gradient=grad) 2269*da0073e9SAndroid Build Coastguard Worker ref_y.backward(gradient=cpu_grad) 2270*da0073e9SAndroid Build Coastguard Worker 2271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 2272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx, ref_idx) 2273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 2274*da0073e9SAndroid Build Coastguard Worker 2275*da0073e9SAndroid Build Coastguard Worker # Test with no batch dimension 2276*da0073e9SAndroid Build Coastguard Worker helper((8, 4, 4), ks=2) 2277*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 4), ks=2) 2278*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 32, 32), ks=4) 2279*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 1, 4), ks=(1, 4)) # test for max_pool1d 2280*da0073e9SAndroid Build Coastguard Worker # Test padding 2281*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 32, 32), ks=4, padding=1) 2282*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1)) # test for max_pool1d 2283*da0073e9SAndroid Build Coastguard Worker # Test dilation 2284*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 32, 32), ks=4, dilation=2) 2285*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2)) # test for max_pool1d 2286*da0073e9SAndroid Build Coastguard Worker # Test ceil mode 2287*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 32, 32), ks=4, ceil_mode=True) 2288*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True) # test for max_pool1d 2289*da0073e9SAndroid Build Coastguard Worker 2290*da0073e9SAndroid Build Coastguard Worker # Test return indices 2291*da0073e9SAndroid Build Coastguard Worker for test_ties in [False, True]: 2292*da0073e9SAndroid Build Coastguard Worker # Test with no batch dimension 2293*da0073e9SAndroid Build Coastguard Worker helper((8, 4, 4), ks=2, return_indices=True, test_ties=test_ties) 2294*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 4), ks=2, return_indices=True, test_ties=test_ties) 2295*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 32, 32), ks=4, return_indices=True, test_ties=test_ties) 2296*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 1, 4), ks=(1, 4), return_indices=True, test_ties=test_ties) # test for max_pool1d 2297*da0073e9SAndroid Build Coastguard Worker # Test padding 2298*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 32, 32), ks=4, padding=1, return_indices=True, test_ties=test_ties) 2299*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1), 2300*da0073e9SAndroid Build Coastguard Worker return_indices=True, test_ties=test_ties) # test for max_pool1d 2301*da0073e9SAndroid Build Coastguard Worker # Test dilation 2302*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 32, 32), ks=4, dilation=2, return_indices=True, test_ties=test_ties) 2303*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2), 2304*da0073e9SAndroid Build Coastguard Worker return_indices=True, test_ties=test_ties) # test for max_pool1d 2305*da0073e9SAndroid Build Coastguard Worker # Test ceil mode 2306*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 32, 32), ks=4, ceil_mode=True, return_indices=True, test_ties=test_ties) 2307*da0073e9SAndroid Build Coastguard Worker helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True, 2308*da0073e9SAndroid Build Coastguard Worker return_indices=True, test_ties=test_ties) # test for max_pool1d 2309*da0073e9SAndroid Build Coastguard Worker 2310*da0073e9SAndroid Build Coastguard Worker def test_adaptive_avg_pool2d_output_size_one(self): 2311*da0073e9SAndroid Build Coastguard Worker def helper(size, memory_format): 2312*da0073e9SAndroid Build Coastguard Worker x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True) 2313*da0073e9SAndroid Build Coastguard Worker if memory_format == 'non_contiguous': 2314*da0073e9SAndroid Build Coastguard Worker x = x[::2, ::2, ::2, ::2] 2315*da0073e9SAndroid Build Coastguard Worker else: 2316*da0073e9SAndroid Build Coastguard Worker x = x.to(memory_format=memory_format) 2317*da0073e9SAndroid Build Coastguard Worker 2318*da0073e9SAndroid Build Coastguard Worker net = torch.nn.AdaptiveAvgPool2d((1, 1)) 2319*da0073e9SAndroid Build Coastguard Worker out = net(x) 2320*da0073e9SAndroid Build Coastguard Worker ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1)) 2321*da0073e9SAndroid Build Coastguard Worker 2322*da0073e9SAndroid Build Coastguard Worker out.sum().backward() # make sure it doesn't crash 2323*da0073e9SAndroid Build Coastguard Worker 2324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 2325*da0073e9SAndroid Build Coastguard Worker if memory_format == torch.channels_last: 2326*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 2327*da0073e9SAndroid Build Coastguard Worker c = out.size(1) 2328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.stride(), [c, 1, c, c]) 2329*da0073e9SAndroid Build Coastguard Worker else: 2330*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous()) 2331*da0073e9SAndroid Build Coastguard Worker c = out.size(1) 2332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.stride(), [c, 1, 1, 1]) 2333*da0073e9SAndroid Build Coastguard Worker 2334*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 6, 6), torch.contiguous_format) 2335*da0073e9SAndroid Build Coastguard Worker 2336*da0073e9SAndroid Build Coastguard Worker def test_masked_scatter(self): 2337*da0073e9SAndroid Build Coastguard Worker def helper(shape): 2338*da0073e9SAndroid Build Coastguard Worker x_mps = torch.randn(shape, device="mps") 2339*da0073e9SAndroid Build Coastguard Worker x_cpu = x_mps.detach().clone().cpu() 2340*da0073e9SAndroid Build Coastguard Worker 2341*da0073e9SAndroid Build Coastguard Worker mask_mps = torch.rand(shape, device="mps") < 0.6 2342*da0073e9SAndroid Build Coastguard Worker mask_cpu = mask_mps.detach().clone().cpu() 2343*da0073e9SAndroid Build Coastguard Worker 2344*da0073e9SAndroid Build Coastguard Worker y_mps = torch.randn(shape, device="mps") 2345*da0073e9SAndroid Build Coastguard Worker y_cpu = y_mps.detach().clone().cpu() 2346*da0073e9SAndroid Build Coastguard Worker 2347*da0073e9SAndroid Build Coastguard Worker y_mps.masked_scatter_(mask_mps, x_mps) 2348*da0073e9SAndroid Build Coastguard Worker y_cpu.masked_scatter_(mask_cpu, x_cpu) 2349*da0073e9SAndroid Build Coastguard Worker 2350*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_mps, y_cpu) 2351*da0073e9SAndroid Build Coastguard Worker helper([2, 5]) 2352*da0073e9SAndroid Build Coastguard Worker helper([10, 10]) 2353*da0073e9SAndroid Build Coastguard Worker helper([5, 10, 3]) 2354*da0073e9SAndroid Build Coastguard Worker helper([10, 5, 10, 3]) 2355*da0073e9SAndroid Build Coastguard Worker helper([10, 5, 10, 3, 20]) 2356*da0073e9SAndroid Build Coastguard Worker 2357*da0073e9SAndroid Build Coastguard Worker def test_masked_fill(self): 2358*da0073e9SAndroid Build Coastguard Worker device = "mps" 2359*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 2360*da0073e9SAndroid Build Coastguard Worker mask_dtype = torch.bool 2361*da0073e9SAndroid Build Coastguard Worker num_dest = 10 2362*da0073e9SAndroid Build Coastguard Worker 2363*da0073e9SAndroid Build Coastguard Worker dst = torch.zeros(num_dest, dtype=dtype, device=device) 2364*da0073e9SAndroid Build Coastguard Worker mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device) 2365*da0073e9SAndroid Build Coastguard Worker val = random.random() 2366*da0073e9SAndroid Build Coastguard Worker dst2 = torch.zeros(num_dest, dtype=dtype) 2367*da0073e9SAndroid Build Coastguard Worker mask_cpu = mask.to("cpu") 2368*da0073e9SAndroid Build Coastguard Worker 2369*da0073e9SAndroid Build Coastguard Worker dst.masked_fill_(mask, val) 2370*da0073e9SAndroid Build Coastguard Worker for i in range(num_dest): 2371*da0073e9SAndroid Build Coastguard Worker if mask_cpu[i]: 2372*da0073e9SAndroid Build Coastguard Worker dst2[i] = val 2373*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0) 2374*da0073e9SAndroid Build Coastguard Worker 2375*da0073e9SAndroid Build Coastguard Worker def test_masked_fill__non_contiguous(self): 2376*da0073e9SAndroid Build Coastguard Worker shape = (3, 5) 2377*da0073e9SAndroid Build Coastguard Worker 2378*da0073e9SAndroid Build Coastguard Worker x_mps = torch.randn(shape, device="mps") 2379*da0073e9SAndroid Build Coastguard Worker x_cpu = x_mps.detach().clone().cpu() 2380*da0073e9SAndroid Build Coastguard Worker mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool) 2381*da0073e9SAndroid Build Coastguard Worker mask_cpu = mask_mps.detach().clone().cpu() 2382*da0073e9SAndroid Build Coastguard Worker 2383*da0073e9SAndroid Build Coastguard Worker x_mps_strided = x_mps.T 2384*da0073e9SAndroid Build Coastguard Worker x_cpu_strided = x_cpu.T 2385*da0073e9SAndroid Build Coastguard Worker 2386*da0073e9SAndroid Build Coastguard Worker x_mps_strided.masked_fill_(mask_mps.T, float("-inf")) 2387*da0073e9SAndroid Build Coastguard Worker x_cpu_strided.masked_fill_(mask_cpu.T, float("-inf")) 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_mps_strided, x_cpu_strided) 2390*da0073e9SAndroid Build Coastguard Worker self.assertFalse((x_mps_strided == float("-inf")).any()) 2391*da0073e9SAndroid Build Coastguard Worker 2392*da0073e9SAndroid Build Coastguard Worker def test_nhwc_operation(self): 2393*da0073e9SAndroid Build Coastguard Worker def helper(shape, channels_last=False): 2394*da0073e9SAndroid Build Coastguard Worker import numpy as np 2395*da0073e9SAndroid Build Coastguard Worker np.random.seed(332) 2396*da0073e9SAndroid Build Coastguard Worker arr = (256 - 128) * np.random.random_sample(size=shape) + 128 2397*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True) 2398*da0073e9SAndroid Build Coastguard Worker if (channels_last): 2399*da0073e9SAndroid Build Coastguard Worker cpu_x = cpu_x.to(memory_format=torch.channels_last) 2400*da0073e9SAndroid Build Coastguard Worker cpu_x.retain_grad() 2401*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 2402*da0073e9SAndroid Build Coastguard Worker 2403*da0073e9SAndroid Build Coastguard Worker # This passes 2404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, cpu_x) 2405*da0073e9SAndroid Build Coastguard Worker 2406*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 2, 2), True) 2407*da0073e9SAndroid Build Coastguard Worker 2408*da0073e9SAndroid Build Coastguard Worker # Test forward batch norm 2409*da0073e9SAndroid Build Coastguard Worker def test_batch_norm(self): 2410*da0073e9SAndroid Build Coastguard Worker def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last=False, 2411*da0073e9SAndroid Build Coastguard Worker track_running_stats=True, test_module=False): 2412*da0073e9SAndroid Build Coastguard Worker 2413*da0073e9SAndroid Build Coastguard Worker import numpy as np 2414*da0073e9SAndroid Build Coastguard Worker np.random.seed(332) 2415*da0073e9SAndroid Build Coastguard Worker arr = (256 - 128) * np.random.random_sample(size=shape) + 128 2416*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True) 2417*da0073e9SAndroid Build Coastguard Worker if (channels_last): 2418*da0073e9SAndroid Build Coastguard Worker cpu_x = cpu_x.to(memory_format=torch.channels_last) 2419*da0073e9SAndroid Build Coastguard Worker cpu_x.retain_grad() 2420*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 2421*da0073e9SAndroid Build Coastguard Worker 2422*da0073e9SAndroid Build Coastguard Worker mean_shape = [shape[1]] 2423*da0073e9SAndroid Build Coastguard Worker cpu_running_mean = None 2424*da0073e9SAndroid Build Coastguard Worker cpu_running_var = None 2425*da0073e9SAndroid Build Coastguard Worker running_mean = None 2426*da0073e9SAndroid Build Coastguard Worker running_var = None 2427*da0073e9SAndroid Build Coastguard Worker if (track_running_stats): 2428*da0073e9SAndroid Build Coastguard Worker mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140 2429*da0073e9SAndroid Build Coastguard Worker cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float) 2430*da0073e9SAndroid Build Coastguard Worker var_arr = 32 * np.random.random_sample(size=mean_shape) 2431*da0073e9SAndroid Build Coastguard Worker cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float) 2432*da0073e9SAndroid Build Coastguard Worker running_mean = cpu_running_mean.detach().clone().to('mps') 2433*da0073e9SAndroid Build Coastguard Worker running_var = cpu_running_var.detach().clone().to('mps') 2434*da0073e9SAndroid Build Coastguard Worker 2435*da0073e9SAndroid Build Coastguard Worker weight = None 2436*da0073e9SAndroid Build Coastguard Worker cpu_weight = None 2437*da0073e9SAndroid Build Coastguard Worker bias = None 2438*da0073e9SAndroid Build Coastguard Worker cpu_bias = None 2439*da0073e9SAndroid Build Coastguard Worker if (wts): 2440*da0073e9SAndroid Build Coastguard Worker cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) 2441*da0073e9SAndroid Build Coastguard Worker weight = cpu_weight.detach().clone().to('mps').requires_grad_() 2442*da0073e9SAndroid Build Coastguard Worker cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) 2443*da0073e9SAndroid Build Coastguard Worker bias = cpu_bias.detach().clone().to('mps').requires_grad_() 2444*da0073e9SAndroid Build Coastguard Worker 2445*da0073e9SAndroid Build Coastguard Worker y = None 2446*da0073e9SAndroid Build Coastguard Worker ref_y = None 2447*da0073e9SAndroid Build Coastguard Worker 2448*da0073e9SAndroid Build Coastguard Worker if (not test_module): 2449*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.batch_norm(x, running_mean, running_var, 2450*da0073e9SAndroid Build Coastguard Worker weight=weight, 2451*da0073e9SAndroid Build Coastguard Worker bias=bias, 2452*da0073e9SAndroid Build Coastguard Worker training=training, 2453*da0073e9SAndroid Build Coastguard Worker momentum=momentum, eps=eps) 2454*da0073e9SAndroid Build Coastguard Worker ref_y = torch.nn.functional.batch_norm(cpu_x, cpu_running_mean, cpu_running_var, 2455*da0073e9SAndroid Build Coastguard Worker weight=cpu_weight, 2456*da0073e9SAndroid Build Coastguard Worker bias=cpu_bias, 2457*da0073e9SAndroid Build Coastguard Worker training=training, 2458*da0073e9SAndroid Build Coastguard Worker momentum=momentum, eps=eps) 2459*da0073e9SAndroid Build Coastguard Worker 2460*da0073e9SAndroid Build Coastguard Worker else: 2461*da0073e9SAndroid Build Coastguard Worker 2462*da0073e9SAndroid Build Coastguard Worker batchnorm_op = None 2463*da0073e9SAndroid Build Coastguard Worker mps_batchnorm_op = None 2464*da0073e9SAndroid Build Coastguard Worker 2465*da0073e9SAndroid Build Coastguard Worker if (len(shape) == 3): 2466*da0073e9SAndroid Build Coastguard Worker batchnorm_op = torch.nn.BatchNorm1d(shape[1], 2467*da0073e9SAndroid Build Coastguard Worker eps=eps, 2468*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2469*da0073e9SAndroid Build Coastguard Worker affine=wts, 2470*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2471*da0073e9SAndroid Build Coastguard Worker device='cpu') 2472*da0073e9SAndroid Build Coastguard Worker mps_batchnorm_op = torch.nn.BatchNorm1d(shape[1], 2473*da0073e9SAndroid Build Coastguard Worker eps=eps, 2474*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2475*da0073e9SAndroid Build Coastguard Worker affine=wts, 2476*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2477*da0073e9SAndroid Build Coastguard Worker device='mps') 2478*da0073e9SAndroid Build Coastguard Worker elif (len(shape) == 4): 2479*da0073e9SAndroid Build Coastguard Worker batchnorm_op = torch.nn.BatchNorm2d(shape[1], 2480*da0073e9SAndroid Build Coastguard Worker eps=eps, 2481*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2482*da0073e9SAndroid Build Coastguard Worker affine=wts, 2483*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2484*da0073e9SAndroid Build Coastguard Worker device='cpu') 2485*da0073e9SAndroid Build Coastguard Worker mps_batchnorm_op = torch.nn.BatchNorm2d(shape[1], 2486*da0073e9SAndroid Build Coastguard Worker eps=eps, 2487*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2488*da0073e9SAndroid Build Coastguard Worker affine=wts, 2489*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2490*da0073e9SAndroid Build Coastguard Worker device='mps') 2491*da0073e9SAndroid Build Coastguard Worker elif (len(shape) == 5): 2492*da0073e9SAndroid Build Coastguard Worker batchnorm_op = torch.nn.BatchNorm3d(shape[1], 2493*da0073e9SAndroid Build Coastguard Worker eps=eps, 2494*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2495*da0073e9SAndroid Build Coastguard Worker affine=wts, 2496*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2497*da0073e9SAndroid Build Coastguard Worker device='cpu') 2498*da0073e9SAndroid Build Coastguard Worker mps_batchnorm_op = torch.nn.BatchNorm3d(shape[1], 2499*da0073e9SAndroid Build Coastguard Worker eps=eps, 2500*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2501*da0073e9SAndroid Build Coastguard Worker affine=wts, 2502*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2503*da0073e9SAndroid Build Coastguard Worker device='mps') 2504*da0073e9SAndroid Build Coastguard Worker 2505*da0073e9SAndroid Build Coastguard Worker if (track_running_stats): 2506*da0073e9SAndroid Build Coastguard Worker batchnorm_op.running_mean = cpu_running_mean 2507*da0073e9SAndroid Build Coastguard Worker batchnorm_op.running_var = cpu_running_var 2508*da0073e9SAndroid Build Coastguard Worker mps_batchnorm_op.running_mean = running_mean 2509*da0073e9SAndroid Build Coastguard Worker mps_batchnorm_op.running_var = running_var 2510*da0073e9SAndroid Build Coastguard Worker if (wts): 2511*da0073e9SAndroid Build Coastguard Worker batchnorm_op.weight = torch.nn.Parameter(cpu_weight) 2512*da0073e9SAndroid Build Coastguard Worker batchnorm_op.bias = torch.nn.Parameter(cpu_bias) 2513*da0073e9SAndroid Build Coastguard Worker mps_batchnorm_op.weight = torch.nn.Parameter(weight) 2514*da0073e9SAndroid Build Coastguard Worker mps_batchnorm_op.bias = torch.nn.Parameter(bias) 2515*da0073e9SAndroid Build Coastguard Worker 2516*da0073e9SAndroid Build Coastguard Worker ref_y = batchnorm_op(cpu_x) 2517*da0073e9SAndroid Build Coastguard Worker y = mps_batchnorm_op(x) 2518*da0073e9SAndroid Build Coastguard Worker 2519*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 2520*da0073e9SAndroid Build Coastguard Worker if (not test_module): 2521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(running_mean, cpu_running_mean) 2522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(running_var, cpu_running_var) 2523*da0073e9SAndroid Build Coastguard Worker else: 2524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_batchnorm_op.running_mean, batchnorm_op.running_mean) 2525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_batchnorm_op.running_var, batchnorm_op.running_var) 2526*da0073e9SAndroid Build Coastguard Worker 2527*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(ref_y.shape) 2528*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 2529*da0073e9SAndroid Build Coastguard Worker ref_y.backward(gradient=cpu_grad) 2530*da0073e9SAndroid Build Coastguard Worker y.backward(gradient=grad) 2531*da0073e9SAndroid Build Coastguard Worker 2532*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 2533*da0073e9SAndroid Build Coastguard Worker if (wts): 2534*da0073e9SAndroid Build Coastguard Worker if (not test_module): 2535*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight.grad, cpu_weight.grad) 2536*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bias.grad, cpu_bias.grad) 2537*da0073e9SAndroid Build Coastguard Worker else: 2538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_batchnorm_op.weight.grad, batchnorm_op.weight.grad) 2539*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_batchnorm_op.bias.grad, batchnorm_op.bias.grad) 2540*da0073e9SAndroid Build Coastguard Worker 2541*da0073e9SAndroid Build Coastguard Worker for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]: 2542*da0073e9SAndroid Build Coastguard Worker for test_module in [False, True]: 2543*da0073e9SAndroid Build Coastguard Worker for track_running_stats in [True, False]: 2544*da0073e9SAndroid Build Coastguard Worker for channels_last in [False]: 2545*da0073e9SAndroid Build Coastguard Worker if (channels_last and len(shape) != 4): 2546*da0073e9SAndroid Build Coastguard Worker continue 2547*da0073e9SAndroid Build Coastguard Worker # Running stats must be tracked in eval mode 2548*da0073e9SAndroid Build Coastguard Worker if (track_running_stats): 2549*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=0, momentum=1, channels_last=channels_last, 2550*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2551*da0073e9SAndroid Build Coastguard Worker helper(shape, channels_last=channels_last, 2552*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2553*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last, 2554*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2555*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last, 2556*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2557*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last, 2558*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2559*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=3, momentum=0.67, wts=True, training=False, channels_last=channels_last, 2560*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2561*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last, 2562*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2563*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last, 2564*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2565*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last, 2566*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2567*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last, 2568*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2569*da0073e9SAndroid Build Coastguard Worker 2570*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_backward(self): 2571*da0073e9SAndroid Build Coastguard Worker inputs = torch.rand(1, 8, 4, 4, device="mps", requires_grad=True) 2572*da0073e9SAndroid Build Coastguard Worker x = torch.nn.BatchNorm2d(8).to("mps") 2573*da0073e9SAndroid Build Coastguard Worker y = torch.nn.BatchNorm2d(8).to("mps") 2574*da0073e9SAndroid Build Coastguard Worker y.weight.requires_grad = False 2575*da0073e9SAndroid Build Coastguard Worker y.bias.requires_grad = False 2576*da0073e9SAndroid Build Coastguard Worker outputs = y(x(inputs)) 2577*da0073e9SAndroid Build Coastguard Worker # This used to crash, see https://github.com/pytorch/pytorch/issues/98602 2578*da0073e9SAndroid Build Coastguard Worker outputs.sum().backward() 2579*da0073e9SAndroid Build Coastguard Worker 2580*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/133520 2581*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_slices(self): 2582*da0073e9SAndroid Build Coastguard Worker bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu') 2583*da0073e9SAndroid Build Coastguard Worker bn_mps = nn.BatchNorm2d(100, affine=False, device='mps') 2584*da0073e9SAndroid Build Coastguard Worker 2585*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.randn(100, 100, 35, 45).to('cpu') 2586*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.to('mps') 2587*da0073e9SAndroid Build Coastguard Worker 2588*da0073e9SAndroid Build Coastguard Worker res_cpu = bn_cpu(x_cpu[5:]) 2589*da0073e9SAndroid Build Coastguard Worker res_mps = bn_mps(x_mps[5:]) 2590*da0073e9SAndroid Build Coastguard Worker 2591*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps) 2592*da0073e9SAndroid Build Coastguard Worker 2593*da0073e9SAndroid Build Coastguard Worker def test_layer_norm_backward(self): 2594*da0073e9SAndroid Build Coastguard Worker inputs = torch.rand(4, 4, device="mps", requires_grad=True) 2595*da0073e9SAndroid Build Coastguard Worker x = torch.nn.LayerNorm(4).to("mps") 2596*da0073e9SAndroid Build Coastguard Worker y = torch.nn.LayerNorm(4).to("mps") 2597*da0073e9SAndroid Build Coastguard Worker y.weight.requires_grad = False 2598*da0073e9SAndroid Build Coastguard Worker y.bias.requires_grad = False 2599*da0073e9SAndroid Build Coastguard Worker outputs = y(x(inputs)) 2600*da0073e9SAndroid Build Coastguard Worker # This used to crash, see https://github.com/pytorch/pytorch/issues/98602 2601*da0073e9SAndroid Build Coastguard Worker outputs.sum().backward() 2602*da0073e9SAndroid Build Coastguard Worker 2603*da0073e9SAndroid Build Coastguard Worker def test_norm(self): 2604*da0073e9SAndroid Build Coastguard Worker a = torch.arange(9, dtype=torch.float, device="mps") - 4 2605*da0073e9SAndroid Build Coastguard Worker b = a.reshape((3, 3)) 2606*da0073e9SAndroid Build Coastguard Worker 2607*da0073e9SAndroid Build Coastguard Worker a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4 2608*da0073e9SAndroid Build Coastguard Worker b_cpu = a_cpu.reshape((3, 3)) 2609*da0073e9SAndroid Build Coastguard Worker 2610*da0073e9SAndroid Build Coastguard Worker res = torch.norm(a) 2611*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.norm(a_cpu) 2612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cpu) 2613*da0073e9SAndroid Build Coastguard Worker 2614*da0073e9SAndroid Build Coastguard Worker res = torch.norm(b) 2615*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.norm(b_cpu) 2616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cpu) 2617*da0073e9SAndroid Build Coastguard Worker 2618*da0073e9SAndroid Build Coastguard Worker res = torch.norm(a, float('inf')) 2619*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.norm(a_cpu, float('inf')) 2620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cpu) 2621*da0073e9SAndroid Build Coastguard Worker 2622*da0073e9SAndroid Build Coastguard Worker res = torch.norm(b, float('inf')) 2623*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.norm(b_cpu, float('inf')) 2624*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cpu) 2625*da0073e9SAndroid Build Coastguard Worker 2626*da0073e9SAndroid Build Coastguard Worker c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps") 2627*da0073e9SAndroid Build Coastguard Worker c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu") 2628*da0073e9SAndroid Build Coastguard Worker 2629*da0073e9SAndroid Build Coastguard Worker res = torch.norm(c, dim=0) 2630*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.norm(c_cpu, dim=0) 2631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cpu) 2632*da0073e9SAndroid Build Coastguard Worker 2633*da0073e9SAndroid Build Coastguard Worker res = torch.norm(c, dim=1) 2634*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.norm(c_cpu, dim=1) 2635*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cpu) 2636*da0073e9SAndroid Build Coastguard Worker 2637*da0073e9SAndroid Build Coastguard Worker res = torch.norm(c, p=1, dim=1) 2638*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.norm(c_cpu, p=1, dim=1) 2639*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cpu) 2640*da0073e9SAndroid Build Coastguard Worker 2641*da0073e9SAndroid Build Coastguard Worker d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2) 2642*da0073e9SAndroid Build Coastguard Worker d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2) 2643*da0073e9SAndroid Build Coastguard Worker 2644*da0073e9SAndroid Build Coastguard Worker res = torch.norm(d, dim=(1, 2)) 2645*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.norm(d_cpu, dim=(1, 2)) 2646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cpu) 2647*da0073e9SAndroid Build Coastguard Worker 2648*da0073e9SAndroid Build Coastguard Worker res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :]) 2649*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :]) 2650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cpu) 2651*da0073e9SAndroid Build Coastguard Worker 2652*da0073e9SAndroid Build Coastguard Worker def test_linalg_vector_norm(self): 2653*da0073e9SAndroid Build Coastguard Worker x_mps = torch.tensor([0, 0, 0, 2, 3], dtype=torch.float, device="mps") 2654*da0073e9SAndroid Build Coastguard Worker x_cpu = x_mps.detach().clone().cpu() 2655*da0073e9SAndroid Build Coastguard Worker 2656*da0073e9SAndroid Build Coastguard Worker res_mps = torch.linalg.vector_norm(x_mps, ord=0) 2657*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.linalg.vector_norm(x_cpu, ord=0) 2658*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_mps, res_cpu) 2659*da0073e9SAndroid Build Coastguard Worker 2660*da0073e9SAndroid Build Coastguard Worker a_mps = torch.arange(27, dtype=torch.float, device="mps") - 4 2661*da0073e9SAndroid Build Coastguard Worker a_cpu = torch.arange(27, dtype=torch.float, device="cpu") - 4 2662*da0073e9SAndroid Build Coastguard Worker 2663*da0073e9SAndroid Build Coastguard Worker B_mps = a_mps.reshape(3, 3, 3) 2664*da0073e9SAndroid Build Coastguard Worker B_cpu = a_cpu.reshape(3, 3, 3) 2665*da0073e9SAndroid Build Coastguard Worker 2666*da0073e9SAndroid Build Coastguard Worker res_mps = torch.linalg.vector_norm(a_mps, ord=3.5) 2667*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.linalg.vector_norm(a_cpu, ord=3.5) 2668*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_mps, res_cpu) 2669*da0073e9SAndroid Build Coastguard Worker 2670*da0073e9SAndroid Build Coastguard Worker res_mps = torch.linalg.vector_norm(B_mps, ord=3.5) 2671*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5) 2672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_mps, res_cpu) 2673*da0073e9SAndroid Build Coastguard Worker 2674*da0073e9SAndroid Build Coastguard Worker for dim in range(0, B_mps.dim()): 2675*da0073e9SAndroid Build Coastguard Worker res_mps = torch.linalg.vector_norm(B_mps, ord=3.5, dim=dim) 2676*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5, dim=dim) 2677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_mps, res_cpu) 2678*da0073e9SAndroid Build Coastguard Worker 2679*da0073e9SAndroid Build Coastguard Worker 2680*da0073e9SAndroid Build Coastguard Worker def test_layer_norm(self): 2681*da0073e9SAndroid Build Coastguard Worker # TODO: Test non-contiguous 2682*da0073e9SAndroid Build Coastguard Worker def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32): 2683*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True) 2684*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 2685*da0073e9SAndroid Build Coastguard Worker 2686*da0073e9SAndroid Build Coastguard Worker cpu_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='cpu', dtype=dtype) 2687*da0073e9SAndroid Build Coastguard Worker mps_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='mps', dtype=dtype) 2688*da0073e9SAndroid Build Coastguard Worker cpu_wt = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True) 2689*da0073e9SAndroid Build Coastguard Worker wt = cpu_wt.detach().clone().to('mps').requires_grad_() 2690*da0073e9SAndroid Build Coastguard Worker cpu_bias = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True) 2691*da0073e9SAndroid Build Coastguard Worker bias = cpu_bias.detach().clone().to('mps').requires_grad_() 2692*da0073e9SAndroid Build Coastguard Worker 2693*da0073e9SAndroid Build Coastguard Worker if (elementwise_affine): 2694*da0073e9SAndroid Build Coastguard Worker cpu_op.weight = torch.nn.Parameter(cpu_wt) 2695*da0073e9SAndroid Build Coastguard Worker mps_op.weight = torch.nn.Parameter(wt) 2696*da0073e9SAndroid Build Coastguard Worker cpu_op.bias = torch.nn.Parameter(cpu_bias) 2697*da0073e9SAndroid Build Coastguard Worker mps_op.bias = torch.nn.Parameter(bias) 2698*da0073e9SAndroid Build Coastguard Worker 2699*da0073e9SAndroid Build Coastguard Worker cpu_result = cpu_op(cpu_x) 2700*da0073e9SAndroid Build Coastguard Worker result = mps_op(x) 2701*da0073e9SAndroid Build Coastguard Worker 2702*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(cpu_result.shape) 2703*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 2704*da0073e9SAndroid Build Coastguard Worker 2705*da0073e9SAndroid Build Coastguard Worker cpu_result.backward(cpu_grad) 2706*da0073e9SAndroid Build Coastguard Worker result.backward(grad) 2707*da0073e9SAndroid Build Coastguard Worker 2708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, cpu_result) 2709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 2710*da0073e9SAndroid Build Coastguard Worker if (elementwise_affine): 2711*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_op.weight.grad, cpu_op.weight.grad) 2712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_op.bias.grad, cpu_op.bias.grad) 2713*da0073e9SAndroid Build Coastguard Worker 2714*da0073e9SAndroid Build Coastguard Worker for elementwise_affine in [True, False]: 2715*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 2, 2), (2, 2), elementwise_affine=elementwise_affine) 2716*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine) 2717*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine) 2718*da0073e9SAndroid Build Coastguard Worker 2719*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/96113 2720*da0073e9SAndroid Build Coastguard Worker torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16)) 2721*da0073e9SAndroid Build Coastguard Worker 2722*da0073e9SAndroid Build Coastguard Worker @xfailIf(product_version < 14.0) 2723*da0073e9SAndroid Build Coastguard Worker def test_ifft(self): 2724*da0073e9SAndroid Build Coastguard Worker # See: https://github.com/pytorch/pytorch/issues/124096 2725*da0073e9SAndroid Build Coastguard Worker device = torch.device("mps") 2726*da0073e9SAndroid Build Coastguard Worker 2727*da0073e9SAndroid Build Coastguard Worker N = 64 2728*da0073e9SAndroid Build Coastguard Worker signal = torch.rand(N, device=device) 2729*da0073e9SAndroid Build Coastguard Worker fft_result = torch.fft.rfft(signal) 2730*da0073e9SAndroid Build Coastguard Worker ifft_result = torch.fft.irfft(fft_result, n=signal.shape[0]) 2731*da0073e9SAndroid Build Coastguard Worker 2732*da0073e9SAndroid Build Coastguard Worker # Expecting the inverted to yield the original signal 2733*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ifft_result, signal) 2734*da0073e9SAndroid Build Coastguard Worker 2735*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/135223 2736*da0073e9SAndroid Build Coastguard Worker def test_fftfreq(self): 2737*da0073e9SAndroid Build Coastguard Worker freq_cpu = torch.fft.fftfreq(10**4, device='cpu') 2738*da0073e9SAndroid Build Coastguard Worker freq_mps = torch.fft.fftfreq(10**4, device='mps') 2739*da0073e9SAndroid Build Coastguard Worker self.assertEqual(freq_cpu, freq_mps) 2740*da0073e9SAndroid Build Coastguard Worker 2741*da0073e9SAndroid Build Coastguard Worker def test_instance_norm(self): 2742*da0073e9SAndroid Build Coastguard Worker def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False): 2743*da0073e9SAndroid Build Coastguard Worker 2744*da0073e9SAndroid Build Coastguard Worker import numpy as np 2745*da0073e9SAndroid Build Coastguard Worker np.random.seed(332) 2746*da0073e9SAndroid Build Coastguard Worker arr = (256 - 128) * np.random.random_sample(size=shape) + 128 2747*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True) 2748*da0073e9SAndroid Build Coastguard Worker if (channels_last): 2749*da0073e9SAndroid Build Coastguard Worker cpu_x = cpu_x.to(memory_format=torch.channels_last) 2750*da0073e9SAndroid Build Coastguard Worker cpu_x.retain_grad() 2751*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 2752*da0073e9SAndroid Build Coastguard Worker 2753*da0073e9SAndroid Build Coastguard Worker mean_shape = [shape[1]] 2754*da0073e9SAndroid Build Coastguard Worker cpu_running_mean = None 2755*da0073e9SAndroid Build Coastguard Worker cpu_running_var = None 2756*da0073e9SAndroid Build Coastguard Worker running_mean = None 2757*da0073e9SAndroid Build Coastguard Worker running_var = None 2758*da0073e9SAndroid Build Coastguard Worker if (track_running_stats): 2759*da0073e9SAndroid Build Coastguard Worker mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140 2760*da0073e9SAndroid Build Coastguard Worker cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float) 2761*da0073e9SAndroid Build Coastguard Worker var_arr = 32 * np.random.random_sample(size=mean_shape) 2762*da0073e9SAndroid Build Coastguard Worker cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float) 2763*da0073e9SAndroid Build Coastguard Worker running_mean = cpu_running_mean.detach().clone().to('mps') 2764*da0073e9SAndroid Build Coastguard Worker running_var = cpu_running_var.detach().clone().to('mps') 2765*da0073e9SAndroid Build Coastguard Worker 2766*da0073e9SAndroid Build Coastguard Worker weight = None 2767*da0073e9SAndroid Build Coastguard Worker cpu_weight = None 2768*da0073e9SAndroid Build Coastguard Worker bias = None 2769*da0073e9SAndroid Build Coastguard Worker cpu_bias = None 2770*da0073e9SAndroid Build Coastguard Worker if (wts): 2771*da0073e9SAndroid Build Coastguard Worker cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) 2772*da0073e9SAndroid Build Coastguard Worker weight = cpu_weight.detach().clone().to('mps').requires_grad_() 2773*da0073e9SAndroid Build Coastguard Worker cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) 2774*da0073e9SAndroid Build Coastguard Worker bias = cpu_bias.detach().clone().to('mps').requires_grad_() 2775*da0073e9SAndroid Build Coastguard Worker 2776*da0073e9SAndroid Build Coastguard Worker y = None 2777*da0073e9SAndroid Build Coastguard Worker ref_y = None 2778*da0073e9SAndroid Build Coastguard Worker 2779*da0073e9SAndroid Build Coastguard Worker if (not test_module): 2780*da0073e9SAndroid Build Coastguard Worker ref_y = torch.nn.functional.instance_norm(cpu_x, cpu_running_mean, cpu_running_var, 2781*da0073e9SAndroid Build Coastguard Worker weight=cpu_weight, 2782*da0073e9SAndroid Build Coastguard Worker bias=cpu_bias, 2783*da0073e9SAndroid Build Coastguard Worker momentum=momentum, eps=eps) 2784*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.instance_norm(x, running_mean, running_var, 2785*da0073e9SAndroid Build Coastguard Worker weight=weight, 2786*da0073e9SAndroid Build Coastguard Worker bias=bias, 2787*da0073e9SAndroid Build Coastguard Worker momentum=momentum, eps=eps) 2788*da0073e9SAndroid Build Coastguard Worker 2789*da0073e9SAndroid Build Coastguard Worker else: 2790*da0073e9SAndroid Build Coastguard Worker 2791*da0073e9SAndroid Build Coastguard Worker instancenorm_op = None 2792*da0073e9SAndroid Build Coastguard Worker mps_instancenorm_op = None 2793*da0073e9SAndroid Build Coastguard Worker 2794*da0073e9SAndroid Build Coastguard Worker if (len(shape) == 3): 2795*da0073e9SAndroid Build Coastguard Worker instancenorm_op = torch.nn.InstanceNorm1d(shape[1], 2796*da0073e9SAndroid Build Coastguard Worker eps=eps, 2797*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2798*da0073e9SAndroid Build Coastguard Worker affine=wts, 2799*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2800*da0073e9SAndroid Build Coastguard Worker device='cpu') 2801*da0073e9SAndroid Build Coastguard Worker mps_instancenorm_op = torch.nn.InstanceNorm1d(shape[1], 2802*da0073e9SAndroid Build Coastguard Worker eps=eps, 2803*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2804*da0073e9SAndroid Build Coastguard Worker affine=wts, 2805*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2806*da0073e9SAndroid Build Coastguard Worker device='mps') 2807*da0073e9SAndroid Build Coastguard Worker elif (len(shape) == 4): 2808*da0073e9SAndroid Build Coastguard Worker instancenorm_op = torch.nn.InstanceNorm2d(shape[1], 2809*da0073e9SAndroid Build Coastguard Worker eps=eps, 2810*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2811*da0073e9SAndroid Build Coastguard Worker affine=wts, 2812*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2813*da0073e9SAndroid Build Coastguard Worker device='cpu') 2814*da0073e9SAndroid Build Coastguard Worker mps_instancenorm_op = torch.nn.InstanceNorm2d(shape[1], 2815*da0073e9SAndroid Build Coastguard Worker eps=eps, 2816*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2817*da0073e9SAndroid Build Coastguard Worker affine=wts, 2818*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2819*da0073e9SAndroid Build Coastguard Worker device='mps') 2820*da0073e9SAndroid Build Coastguard Worker elif (len(shape) == 5): 2821*da0073e9SAndroid Build Coastguard Worker instancenorm_op = torch.nn.InstanceNorm3d(shape[1], 2822*da0073e9SAndroid Build Coastguard Worker eps=eps, 2823*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2824*da0073e9SAndroid Build Coastguard Worker affine=wts, 2825*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2826*da0073e9SAndroid Build Coastguard Worker device='cpu') 2827*da0073e9SAndroid Build Coastguard Worker mps_instancenorm_op = torch.nn.InstanceNorm3d(shape[1], 2828*da0073e9SAndroid Build Coastguard Worker eps=eps, 2829*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 2830*da0073e9SAndroid Build Coastguard Worker affine=wts, 2831*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 2832*da0073e9SAndroid Build Coastguard Worker device='mps') 2833*da0073e9SAndroid Build Coastguard Worker 2834*da0073e9SAndroid Build Coastguard Worker if (track_running_stats): 2835*da0073e9SAndroid Build Coastguard Worker instancenorm_op.running_mean = cpu_running_mean 2836*da0073e9SAndroid Build Coastguard Worker instancenorm_op.running_var = cpu_running_var 2837*da0073e9SAndroid Build Coastguard Worker mps_instancenorm_op.running_mean = running_mean 2838*da0073e9SAndroid Build Coastguard Worker mps_instancenorm_op.running_var = running_var 2839*da0073e9SAndroid Build Coastguard Worker if (wts): 2840*da0073e9SAndroid Build Coastguard Worker instancenorm_op.weight = torch.nn.Parameter(cpu_weight) 2841*da0073e9SAndroid Build Coastguard Worker instancenorm_op.bias = torch.nn.Parameter(cpu_bias) 2842*da0073e9SAndroid Build Coastguard Worker mps_instancenorm_op.weight = torch.nn.Parameter(weight) 2843*da0073e9SAndroid Build Coastguard Worker mps_instancenorm_op.bias = torch.nn.Parameter(bias) 2844*da0073e9SAndroid Build Coastguard Worker 2845*da0073e9SAndroid Build Coastguard Worker ref_y = instancenorm_op(cpu_x) 2846*da0073e9SAndroid Build Coastguard Worker y = mps_instancenorm_op(x) 2847*da0073e9SAndroid Build Coastguard Worker 2848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 2849*da0073e9SAndroid Build Coastguard Worker if (not test_module): 2850*da0073e9SAndroid Build Coastguard Worker self.assertEqual(running_mean, cpu_running_mean) 2851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(running_var, cpu_running_var) 2852*da0073e9SAndroid Build Coastguard Worker else: 2853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_instancenorm_op.running_mean, instancenorm_op.running_mean) 2854*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_instancenorm_op.running_var, instancenorm_op.running_var) 2855*da0073e9SAndroid Build Coastguard Worker 2856*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(ref_y.shape) 2857*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 2858*da0073e9SAndroid Build Coastguard Worker ref_y.backward(gradient=cpu_grad) 2859*da0073e9SAndroid Build Coastguard Worker y.backward(gradient=grad) 2860*da0073e9SAndroid Build Coastguard Worker 2861*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 2862*da0073e9SAndroid Build Coastguard Worker if (wts): 2863*da0073e9SAndroid Build Coastguard Worker if (not test_module): 2864*da0073e9SAndroid Build Coastguard Worker self.assertEqual(weight.grad, cpu_weight.grad) 2865*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bias.grad, cpu_bias.grad) 2866*da0073e9SAndroid Build Coastguard Worker else: 2867*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_instancenorm_op.weight.grad, instancenorm_op.weight.grad) 2868*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_instancenorm_op.bias.grad, instancenorm_op.bias.grad) 2869*da0073e9SAndroid Build Coastguard Worker 2870*da0073e9SAndroid Build Coastguard Worker for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]: 2871*da0073e9SAndroid Build Coastguard Worker for test_module in [False, True]: 2872*da0073e9SAndroid Build Coastguard Worker for track_running_stats in [True, False]: 2873*da0073e9SAndroid Build Coastguard Worker for channels_last in [False]: 2874*da0073e9SAndroid Build Coastguard Worker if (channels_last and len(shape) != 4): 2875*da0073e9SAndroid Build Coastguard Worker continue 2876*da0073e9SAndroid Build Coastguard Worker # Running stats must be tracked in eval mode 2877*da0073e9SAndroid Build Coastguard Worker if (track_running_stats): 2878*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=0, momentum=1, channels_last=channels_last, 2879*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2880*da0073e9SAndroid Build Coastguard Worker helper(shape, channels_last=channels_last, 2881*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2882*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last, 2883*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2884*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last, 2885*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2886*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last, 2887*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2888*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last, 2889*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2890*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last, 2891*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2892*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last, 2893*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2894*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last, 2895*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2896*da0073e9SAndroid Build Coastguard Worker helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last, 2897*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, test_module=test_module) 2898*da0073e9SAndroid Build Coastguard Worker 2899*da0073e9SAndroid Build Coastguard Worker def test_weight_norm(self): 2900*da0073e9SAndroid Build Coastguard Worker def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim): 2901*da0073e9SAndroid Build Coastguard Worker cpu_norm = torch.nn.utils.parametrizations.weight_norm(cpu_model, dim=dim) 2902*da0073e9SAndroid Build Coastguard Worker norm = torch.nn.utils.parametrizations.weight_norm(model, dim=dim) 2903*da0073e9SAndroid Build Coastguard Worker 2904*da0073e9SAndroid Build Coastguard Worker cpu_out = cpu_norm(cpu_x) 2905*da0073e9SAndroid Build Coastguard Worker out = norm(x) 2906*da0073e9SAndroid Build Coastguard Worker 2907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_out, out) 2908*da0073e9SAndroid Build Coastguard Worker 2909*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(cpu_out.shape) 2910*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 2911*da0073e9SAndroid Build Coastguard Worker cpu_out.backward(gradient=cpu_grad) 2912*da0073e9SAndroid Build Coastguard Worker out.backward(gradient=grad) 2913*da0073e9SAndroid Build Coastguard Worker 2914*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_model.parametrizations.weight.original0.grad, model.parametrizations.weight.original0.grad) 2915*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_model.parametrizations.weight.original1.grad, model.parametrizations.weight.original1.grad) 2916*da0073e9SAndroid Build Coastguard Worker 2917*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 2918*da0073e9SAndroid Build Coastguard Worker 2919*da0073e9SAndroid Build Coastguard Worker def helper(dim, layer='linear', dtype=torch.float32): 2920*da0073e9SAndroid Build Coastguard Worker # linear layer 2921*da0073e9SAndroid Build Coastguard Worker if layer == 'linear': 2922*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn((2, 5), device='cpu', dtype=dtype, requires_grad=True) 2923*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 2924*da0073e9SAndroid Build Coastguard Worker 2925*da0073e9SAndroid Build Coastguard Worker cpu_weight = torch.randn(10, 5, device='cpu', dtype=dtype, requires_grad=True) 2926*da0073e9SAndroid Build Coastguard Worker weight = cpu_weight.detach().clone().to('mps').requires_grad_() 2927*da0073e9SAndroid Build Coastguard Worker 2928*da0073e9SAndroid Build Coastguard Worker cpu_bias = torch.randn(10, device='cpu', dtype=dtype, requires_grad=True) 2929*da0073e9SAndroid Build Coastguard Worker bias = cpu_bias.detach().clone().to('mps').requires_grad_() 2930*da0073e9SAndroid Build Coastguard Worker 2931*da0073e9SAndroid Build Coastguard Worker cpu_linear = torch.nn.Linear(5, 10, device='cpu') 2932*da0073e9SAndroid Build Coastguard Worker linear = torch.nn.Linear(5, 10, device='mps') 2933*da0073e9SAndroid Build Coastguard Worker 2934*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2935*da0073e9SAndroid Build Coastguard Worker cpu_linear.weight.copy_(cpu_weight) 2936*da0073e9SAndroid Build Coastguard Worker cpu_linear.bias.copy_(cpu_bias) 2937*da0073e9SAndroid Build Coastguard Worker linear.weight.copy_(weight) 2938*da0073e9SAndroid Build Coastguard Worker linear.bias.copy_(bias) 2939*da0073e9SAndroid Build Coastguard Worker validate_weight_norm_equality(linear, cpu_linear, x, cpu_x, dim) 2940*da0073e9SAndroid Build Coastguard Worker 2941*da0073e9SAndroid Build Coastguard Worker # conv layer 2942*da0073e9SAndroid Build Coastguard Worker if layer == 'conv': 2943*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn((3, 5, 5), device='cpu', dtype=dtype, requires_grad=True) 2944*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 2945*da0073e9SAndroid Build Coastguard Worker 2946*da0073e9SAndroid Build Coastguard Worker cpu_conv = torch.nn.Conv2d(3, 3, 3, device='cpu') 2947*da0073e9SAndroid Build Coastguard Worker conv = torch.nn.Conv2d(3, 3, 3, device='mps') 2948*da0073e9SAndroid Build Coastguard Worker 2949*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2950*da0073e9SAndroid Build Coastguard Worker conv.weight.copy_(cpu_conv.weight) 2951*da0073e9SAndroid Build Coastguard Worker conv.bias.copy_(cpu_conv.bias) 2952*da0073e9SAndroid Build Coastguard Worker 2953*da0073e9SAndroid Build Coastguard Worker validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim) 2954*da0073e9SAndroid Build Coastguard Worker 2955*da0073e9SAndroid Build Coastguard Worker # conv3d layer 2956*da0073e9SAndroid Build Coastguard Worker if layer == 'conv3d': 2957*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn((3, 5, 5, 4), device='cpu', dtype=dtype, requires_grad=True) 2958*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 2959*da0073e9SAndroid Build Coastguard Worker 2960*da0073e9SAndroid Build Coastguard Worker cpu_conv = torch.nn.Conv3d(3, 3, 3, device='cpu') 2961*da0073e9SAndroid Build Coastguard Worker conv = torch.nn.Conv3d(3, 3, 3, device='mps') 2962*da0073e9SAndroid Build Coastguard Worker 2963*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2964*da0073e9SAndroid Build Coastguard Worker conv.weight.copy_(cpu_conv.weight) 2965*da0073e9SAndroid Build Coastguard Worker conv.bias.copy_(cpu_conv.bias) 2966*da0073e9SAndroid Build Coastguard Worker 2967*da0073e9SAndroid Build Coastguard Worker validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim) 2968*da0073e9SAndroid Build Coastguard Worker 2969*da0073e9SAndroid Build Coastguard Worker helper(0, layer='linear') 2970*da0073e9SAndroid Build Coastguard Worker helper(1, layer='linear') 2971*da0073e9SAndroid Build Coastguard Worker helper(-1, layer='linear') 2972*da0073e9SAndroid Build Coastguard Worker 2973*da0073e9SAndroid Build Coastguard Worker helper(0, layer='conv') 2974*da0073e9SAndroid Build Coastguard Worker helper(1, layer='conv') 2975*da0073e9SAndroid Build Coastguard Worker helper(2, layer='conv') 2976*da0073e9SAndroid Build Coastguard Worker helper(3, layer='conv') 2977*da0073e9SAndroid Build Coastguard Worker helper(-1, layer='conv') 2978*da0073e9SAndroid Build Coastguard Worker 2979*da0073e9SAndroid Build Coastguard Worker if product_version >= 13.2: 2980*da0073e9SAndroid Build Coastguard Worker # Conv3d is only available from MacOS 13 onwards 2981*da0073e9SAndroid Build Coastguard Worker helper(0, layer='conv3d') 2982*da0073e9SAndroid Build Coastguard Worker helper(1, layer='conv3d') 2983*da0073e9SAndroid Build Coastguard Worker helper(2, layer='conv3d') 2984*da0073e9SAndroid Build Coastguard Worker helper(3, layer='conv3d') 2985*da0073e9SAndroid Build Coastguard Worker helper(4, layer='conv3d') 2986*da0073e9SAndroid Build Coastguard Worker helper(-1, layer='conv3d') 2987*da0073e9SAndroid Build Coastguard Worker 2988*da0073e9SAndroid Build Coastguard Worker # Test conv2d 2989*da0073e9SAndroid Build Coastguard Worker def test_conv2d_unit(self): 2990*da0073e9SAndroid Build Coastguard Worker def helper(input_shape, wt_shape, 2991*da0073e9SAndroid Build Coastguard Worker stride=1, padding=0, 2992*da0073e9SAndroid Build Coastguard Worker dilation=1, groups=1, 2993*da0073e9SAndroid Build Coastguard Worker bias_shape=None): 2994*da0073e9SAndroid Build Coastguard Worker 2995*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True) 2996*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 2997*da0073e9SAndroid Build Coastguard Worker 2998*da0073e9SAndroid Build Coastguard Worker cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True) 2999*da0073e9SAndroid Build Coastguard Worker wt = cpu_wt.detach().clone().to('mps').requires_grad_() 3000*da0073e9SAndroid Build Coastguard Worker 3001*da0073e9SAndroid Build Coastguard Worker cpu_bias = None 3002*da0073e9SAndroid Build Coastguard Worker bias = None 3003*da0073e9SAndroid Build Coastguard Worker 3004*da0073e9SAndroid Build Coastguard Worker if (bias_shape is not None): 3005*da0073e9SAndroid Build Coastguard Worker cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True) 3006*da0073e9SAndroid Build Coastguard Worker bias = cpu_bias.detach().clone().to('mps').requires_grad_() 3007*da0073e9SAndroid Build Coastguard Worker 3008*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.conv2d(x, wt, bias=bias, stride=stride, 3009*da0073e9SAndroid Build Coastguard Worker padding=padding, dilation=dilation, groups=groups) 3010*da0073e9SAndroid Build Coastguard Worker ref_y = torch.nn.functional.conv2d(cpu_x, cpu_wt, bias=cpu_bias, stride=stride, 3011*da0073e9SAndroid Build Coastguard Worker padding=padding, dilation=dilation, groups=groups) 3012*da0073e9SAndroid Build Coastguard Worker 3013*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(ref_y) 3014*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 3015*da0073e9SAndroid Build Coastguard Worker 3016*da0073e9SAndroid Build Coastguard Worker y.backward(gradient=grad) 3017*da0073e9SAndroid Build Coastguard Worker ref_y.backward(gradient=cpu_grad) 3018*da0073e9SAndroid Build Coastguard Worker 3019*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04) 3020*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05) 3021*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05) 3022*da0073e9SAndroid Build Coastguard Worker if (bias_shape is not None): 3023*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bias.grad, cpu_bias.grad, atol=8e-04, rtol=10.4e-05) 3024*da0073e9SAndroid Build Coastguard Worker 3025*da0073e9SAndroid Build Coastguard Worker N = 1 3026*da0073e9SAndroid Build Coastguard Worker C_in = 3 3027*da0073e9SAndroid Build Coastguard Worker C_out = 64 3028*da0073e9SAndroid Build Coastguard Worker H = 64 3029*da0073e9SAndroid Build Coastguard Worker W = 64 3030*da0073e9SAndroid Build Coastguard Worker kH = 4 3031*da0073e9SAndroid Build Coastguard Worker kW = 4 3032*da0073e9SAndroid Build Coastguard Worker stride = 2 3033*da0073e9SAndroid Build Coastguard Worker padding = 1 3034*da0073e9SAndroid Build Coastguard Worker 3035*da0073e9SAndroid Build Coastguard Worker helper((N, C_in, H, W), (C_out, C_in, kH, kW), stride=stride, padding=padding) 3036*da0073e9SAndroid Build Coastguard Worker 3037*da0073e9SAndroid Build Coastguard Worker N = 4 3038*da0073e9SAndroid Build Coastguard Worker C_in = 16 3039*da0073e9SAndroid Build Coastguard Worker H = 32 3040*da0073e9SAndroid Build Coastguard Worker W = 32 3041*da0073e9SAndroid Build Coastguard Worker 3042*da0073e9SAndroid Build Coastguard Worker C_out = 8 3043*da0073e9SAndroid Build Coastguard Worker kH = 3 3044*da0073e9SAndroid Build Coastguard Worker kW = 3 3045*da0073e9SAndroid Build Coastguard Worker 3046*da0073e9SAndroid Build Coastguard Worker for groups in [1, 2, 4]: 3047*da0073e9SAndroid Build Coastguard Worker helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups) 3048*da0073e9SAndroid Build Coastguard Worker helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups) 3049*da0073e9SAndroid Build Coastguard Worker 3050*da0073e9SAndroid Build Coastguard Worker helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups) 3051*da0073e9SAndroid Build Coastguard Worker helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups) 3052*da0073e9SAndroid Build Coastguard Worker 3053*da0073e9SAndroid Build Coastguard Worker helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups) 3054*da0073e9SAndroid Build Coastguard Worker helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups) 3055*da0073e9SAndroid Build Coastguard Worker 3056*da0073e9SAndroid Build Coastguard Worker helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, 3057*da0073e9SAndroid Build Coastguard Worker kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups) 3058*da0073e9SAndroid Build Coastguard Worker helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, 3059*da0073e9SAndroid Build Coastguard Worker kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups) 3060*da0073e9SAndroid Build Coastguard Worker 3061*da0073e9SAndroid Build Coastguard Worker # Test conv transpose 2d 3062*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose2d(self): 3063*da0073e9SAndroid Build Coastguard Worker def helper(input_shape, wt_shape, 3064*da0073e9SAndroid Build Coastguard Worker stride=1, padding=0, 3065*da0073e9SAndroid Build Coastguard Worker output_padding=0, 3066*da0073e9SAndroid Build Coastguard Worker dilation=1, groups=1, 3067*da0073e9SAndroid Build Coastguard Worker bias_shape=None): 3068*da0073e9SAndroid Build Coastguard Worker 3069*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True) 3070*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 3071*da0073e9SAndroid Build Coastguard Worker 3072*da0073e9SAndroid Build Coastguard Worker cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True) 3073*da0073e9SAndroid Build Coastguard Worker wt = cpu_wt.detach().clone().to('mps').requires_grad_() 3074*da0073e9SAndroid Build Coastguard Worker 3075*da0073e9SAndroid Build Coastguard Worker cpu_bias = None 3076*da0073e9SAndroid Build Coastguard Worker bias = None 3077*da0073e9SAndroid Build Coastguard Worker 3078*da0073e9SAndroid Build Coastguard Worker if (bias_shape is not None): 3079*da0073e9SAndroid Build Coastguard Worker cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True) 3080*da0073e9SAndroid Build Coastguard Worker bias = cpu_bias.detach().clone().to('mps').requires_grad_() 3081*da0073e9SAndroid Build Coastguard Worker 3082*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.conv_transpose2d( 3083*da0073e9SAndroid Build Coastguard Worker x, wt, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 3084*da0073e9SAndroid Build Coastguard Worker ref_y = torch.nn.functional.conv_transpose2d( 3085*da0073e9SAndroid Build Coastguard Worker cpu_x, cpu_wt, bias=cpu_bias, stride=stride, padding=padding, 3086*da0073e9SAndroid Build Coastguard Worker output_padding=output_padding, groups=groups, dilation=dilation) 3087*da0073e9SAndroid Build Coastguard Worker 3088*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(ref_y.shape) 3089*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 3090*da0073e9SAndroid Build Coastguard Worker 3091*da0073e9SAndroid Build Coastguard Worker y.backward(gradient=grad) 3092*da0073e9SAndroid Build Coastguard Worker ref_y.backward(gradient=cpu_grad) 3093*da0073e9SAndroid Build Coastguard Worker 3094*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04) 3095*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05) 3096*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05) 3097*da0073e9SAndroid Build Coastguard Worker 3098*da0073e9SAndroid Build Coastguard Worker # if (bias_shape is not None): 3099*da0073e9SAndroid Build Coastguard Worker # print(cpu_bias.grad) 3100*da0073e9SAndroid Build Coastguard Worker # print(bias.grad.to('cpu')) 3101*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(bias.grad, cpu_bias.grad) 3102*da0073e9SAndroid Build Coastguard Worker 3103*da0073e9SAndroid Build Coastguard Worker N = 4 3104*da0073e9SAndroid Build Coastguard Worker C_in = 2 3105*da0073e9SAndroid Build Coastguard Worker H = 32 3106*da0073e9SAndroid Build Coastguard Worker W = 32 3107*da0073e9SAndroid Build Coastguard Worker 3108*da0073e9SAndroid Build Coastguard Worker C_out = 8 3109*da0073e9SAndroid Build Coastguard Worker groups = 1 3110*da0073e9SAndroid Build Coastguard Worker kH = 3 3111*da0073e9SAndroid Build Coastguard Worker kW = 3 3112*da0073e9SAndroid Build Coastguard Worker 3113*da0073e9SAndroid Build Coastguard Worker for stride in [1, 2, 3]: 3114*da0073e9SAndroid Build Coastguard Worker for padding in [0, 1, 2]: 3115*da0073e9SAndroid Build Coastguard Worker for output_padding in [0, 1, 2]: 3116*da0073e9SAndroid Build Coastguard Worker for dilation in [1, 2]: 3117*da0073e9SAndroid Build Coastguard Worker if (output_padding >= stride or output_padding >= dilation): 3118*da0073e9SAndroid Build Coastguard Worker continue 3119*da0073e9SAndroid Build Coastguard Worker helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride, 3120*da0073e9SAndroid Build Coastguard Worker padding=padding, output_padding=output_padding, dilation=dilation) 3121*da0073e9SAndroid Build Coastguard Worker helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride, 3122*da0073e9SAndroid Build Coastguard Worker padding=padding, output_padding=output_padding, dilation=dilation) 3123*da0073e9SAndroid Build Coastguard Worker 3124*da0073e9SAndroid Build Coastguard Worker helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride, 3125*da0073e9SAndroid Build Coastguard Worker padding=padding, output_padding=output_padding, dilation=dilation) 3126*da0073e9SAndroid Build Coastguard Worker helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride, 3127*da0073e9SAndroid Build Coastguard Worker padding=padding, output_padding=output_padding, dilation=dilation) 3128*da0073e9SAndroid Build Coastguard Worker 3129*da0073e9SAndroid Build Coastguard Worker # Test sigmoid 3130*da0073e9SAndroid Build Coastguard Worker def test_sigmoid(self): 3131*da0073e9SAndroid Build Coastguard Worker def helper(shape): 3132*da0073e9SAndroid Build Coastguard Worker 3133*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 3134*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 3135*da0073e9SAndroid Build Coastguard Worker 3136*da0073e9SAndroid Build Coastguard Worker sigmoid_op = torch.nn.Sigmoid() 3137*da0073e9SAndroid Build Coastguard Worker 3138*da0073e9SAndroid Build Coastguard Worker y = sigmoid_op(x) 3139*da0073e9SAndroid Build Coastguard Worker ref_y = sigmoid_op(cpu_x) 3140*da0073e9SAndroid Build Coastguard Worker 3141*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(ref_y) 3142*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 3143*da0073e9SAndroid Build Coastguard Worker 3144*da0073e9SAndroid Build Coastguard Worker y.backward(gradient=grad) 3145*da0073e9SAndroid Build Coastguard Worker ref_y.backward(gradient=cpu_grad) 3146*da0073e9SAndroid Build Coastguard Worker 3147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 3148*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 3149*da0073e9SAndroid Build Coastguard Worker 3150*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 3151*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4)) 3152*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 3153*da0073e9SAndroid Build Coastguard Worker 3154*da0073e9SAndroid Build Coastguard Worker # Test tanh 3155*da0073e9SAndroid Build Coastguard Worker def test_tanh(self): 3156*da0073e9SAndroid Build Coastguard Worker def helper(shape): 3157*da0073e9SAndroid Build Coastguard Worker 3158*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 3159*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 3160*da0073e9SAndroid Build Coastguard Worker 3161*da0073e9SAndroid Build Coastguard Worker tanh_op = torch.nn.Tanh() 3162*da0073e9SAndroid Build Coastguard Worker 3163*da0073e9SAndroid Build Coastguard Worker y = tanh_op(x) 3164*da0073e9SAndroid Build Coastguard Worker ref_y = tanh_op(cpu_x) 3165*da0073e9SAndroid Build Coastguard Worker 3166*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(ref_y) 3167*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 3168*da0073e9SAndroid Build Coastguard Worker 3169*da0073e9SAndroid Build Coastguard Worker y.backward(gradient=grad) 3170*da0073e9SAndroid Build Coastguard Worker ref_y.backward(gradient=cpu_grad) 3171*da0073e9SAndroid Build Coastguard Worker 3172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 3173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 3174*da0073e9SAndroid Build Coastguard Worker 3175*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 3176*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4)) 3177*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 3178*da0073e9SAndroid Build Coastguard Worker 3179*da0073e9SAndroid Build Coastguard Worker def test_threshold(self): 3180*da0073e9SAndroid Build Coastguard Worker def helper(threshold, value, num_elems, inplace=False, requires_grad=True): 3181*da0073e9SAndroid Build Coastguard Worker m = nn.Threshold(threshold=threshold, value=value, inplace=inplace) 3182*da0073e9SAndroid Build Coastguard Worker 3183*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(num_elems, requires_grad=requires_grad, dtype=torch.float) 3184*da0073e9SAndroid Build Coastguard Worker input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad) 3185*da0073e9SAndroid Build Coastguard Worker 3186*da0073e9SAndroid Build Coastguard Worker output_cpu = m(input_cpu) 3187*da0073e9SAndroid Build Coastguard Worker output_mps = m(input_mps) 3188*da0073e9SAndroid Build Coastguard Worker 3189*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(output_cpu) 3190*da0073e9SAndroid Build Coastguard Worker mps_grad = cpu_grad.to('mps') 3191*da0073e9SAndroid Build Coastguard Worker 3192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps) 3193*da0073e9SAndroid Build Coastguard Worker 3194*da0073e9SAndroid Build Coastguard Worker if requires_grad: 3195*da0073e9SAndroid Build Coastguard Worker output_cpu.backward(gradient=cpu_grad) 3196*da0073e9SAndroid Build Coastguard Worker output_mps.backward(gradient=mps_grad) 3197*da0073e9SAndroid Build Coastguard Worker 3198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_cpu.grad, input_mps.grad) 3199*da0073e9SAndroid Build Coastguard Worker 3200*da0073e9SAndroid Build Coastguard Worker helper(threshold=0.1, value=20, num_elems=2) 3201*da0073e9SAndroid Build Coastguard Worker helper(threshold=-0.1, value=10, num_elems=10) 3202*da0073e9SAndroid Build Coastguard Worker helper(threshold=0.5, value=-15, num_elems=100) 3203*da0073e9SAndroid Build Coastguard Worker helper(threshold=1, value=10, num_elems=100, inplace=True, requires_grad=False) 3204*da0073e9SAndroid Build Coastguard Worker 3205*da0073e9SAndroid Build Coastguard Worker # Test pow 3206*da0073e9SAndroid Build Coastguard Worker def test_pow(self): 3207*da0073e9SAndroid Build Coastguard Worker def helper(shape): 3208*da0073e9SAndroid Build Coastguard Worker # aten::pow.Tensor_Tensor 3209*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3210*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 3211*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3212*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps') 3213*da0073e9SAndroid Build Coastguard Worker z = torch.pow(x, y) 3214*da0073e9SAndroid Build Coastguard Worker ref_z = torch.pow(cpu_x, cpu_y) 3215*da0073e9SAndroid Build Coastguard Worker 3216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, ref_z) 3217*da0073e9SAndroid Build Coastguard Worker 3218*da0073e9SAndroid Build Coastguard Worker # aten::pow.Tensor_Scalar 3219*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3220*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 3221*da0073e9SAndroid Build Coastguard Worker exp = random.random() 3222*da0073e9SAndroid Build Coastguard Worker z = torch.pow(x, exp) 3223*da0073e9SAndroid Build Coastguard Worker ref_z = torch.pow(cpu_x, exp) 3224*da0073e9SAndroid Build Coastguard Worker 3225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, ref_z) 3226*da0073e9SAndroid Build Coastguard Worker 3227*da0073e9SAndroid Build Coastguard Worker # aten::pow.Scalar 3228*da0073e9SAndroid Build Coastguard Worker x = random.random() 3229*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3230*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps') 3231*da0073e9SAndroid Build Coastguard Worker z = torch.pow(x, y) 3232*da0073e9SAndroid Build Coastguard Worker ref_z = torch.pow(x, cpu_y) 3233*da0073e9SAndroid Build Coastguard Worker 3234*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, ref_z) 3235*da0073e9SAndroid Build Coastguard Worker 3236*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 3237*da0073e9SAndroid Build Coastguard Worker 3238*da0073e9SAndroid Build Coastguard Worker # Test addcmul 3239*da0073e9SAndroid Build Coastguard Worker def test_addcmul(self): 3240*da0073e9SAndroid Build Coastguard Worker def helper(shape, value, xtype=torch.float32, ytype=None, ztype=None): 3241*da0073e9SAndroid Build Coastguard Worker def rand_helper(dtype): 3242*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 3243*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) 3244*da0073e9SAndroid Build Coastguard Worker return torch.randint(10, shape, dtype=dtype, device='cpu', requires_grad=False) 3245*da0073e9SAndroid Build Coastguard Worker 3246*da0073e9SAndroid Build Coastguard Worker cpu_x = rand_helper(xtype) 3247*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 3248*da0073e9SAndroid Build Coastguard Worker 3249*da0073e9SAndroid Build Coastguard Worker cpu_y = rand_helper(ytype if ytype is not None else xtype) 3250*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps') 3251*da0073e9SAndroid Build Coastguard Worker 3252*da0073e9SAndroid Build Coastguard Worker cpu_z = rand_helper(ztype if ztype is not None else xtype) 3253*da0073e9SAndroid Build Coastguard Worker z = cpu_z.detach().clone().to('mps') 3254*da0073e9SAndroid Build Coastguard Worker 3255*da0073e9SAndroid Build Coastguard Worker y = torch.addcmul(x, y, z, value=value) 3256*da0073e9SAndroid Build Coastguard Worker ref_y = torch.addcmul(cpu_x, cpu_y, cpu_z, value=value) 3257*da0073e9SAndroid Build Coastguard Worker 3258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 3259*da0073e9SAndroid Build Coastguard Worker 3260*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5), 0.1) 3261*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 0.1) 3262*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5), 0.2) 3263*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 0.2) 3264*da0073e9SAndroid Build Coastguard Worker # Integral types 3265*da0073e9SAndroid Build Coastguard Worker helper((2, 2), 1.0, xtype=torch.int32) 3266*da0073e9SAndroid Build Coastguard Worker helper((2, 2), 2.0, xtype=torch.int16) 3267*da0073e9SAndroid Build Coastguard Worker 3268*da0073e9SAndroid Build Coastguard Worker # Mixed types 3269*da0073e9SAndroid Build Coastguard Worker helper((2, 2), 1.0, xtype=torch.float16, ytype=torch.float32) 3270*da0073e9SAndroid Build Coastguard Worker helper((3, 2), 1.0, ytype=torch.float16) 3271*da0073e9SAndroid Build Coastguard Worker helper((2, 3), 1.0, ztype=torch.float16) 3272*da0073e9SAndroid Build Coastguard Worker helper((2, 2), 1.0, xtype=torch.int32, ytype=torch.int16, ztype=torch.uint8) 3273*da0073e9SAndroid Build Coastguard Worker helper((2, 2), 1.0, ytype=torch.int16, ztype=torch.uint8) 3274*da0073e9SAndroid Build Coastguard Worker 3275*da0073e9SAndroid Build Coastguard Worker # Test addcdiv 3276*da0073e9SAndroid Build Coastguard Worker def test_addcdiv(self): 3277*da0073e9SAndroid Build Coastguard Worker def helper(shape, value): 3278*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3279*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3280*da0073e9SAndroid Build Coastguard Worker # clamp to avoid division by 0 3281*da0073e9SAndroid Build Coastguard Worker cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False).clamp_min_(0.1) 3282*da0073e9SAndroid Build Coastguard Worker cpu_out = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3283*da0073e9SAndroid Build Coastguard Worker 3284*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 3285*da0073e9SAndroid Build Coastguard Worker mps_y = cpu_y.detach().clone().to('mps') 3286*da0073e9SAndroid Build Coastguard Worker mps_z = cpu_z.detach().clone().to('mps') 3287*da0073e9SAndroid Build Coastguard Worker mps_out = cpu_out.detach().clone().to('mps') 3288*da0073e9SAndroid Build Coastguard Worker 3289*da0073e9SAndroid Build Coastguard Worker result_div_mps = torch.addcdiv(mps_x, mps_y, mps_z, value=value) 3290*da0073e9SAndroid Build Coastguard Worker result_div_cpu = torch.addcdiv(cpu_x, cpu_y, cpu_z, value=value) 3291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_div_mps, result_div_cpu) 3292*da0073e9SAndroid Build Coastguard Worker # test .out variant 3293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.addcdiv(mps_x, mps_y, mps_z, out=mps_out, value=value), result_div_cpu) 3294*da0073e9SAndroid Build Coastguard Worker 3295*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5), 0.1) 3296*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 0.2) 3297*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5), 1.0) # value of 1 should be ignored internally 3298*da0073e9SAndroid Build Coastguard Worker 3299*da0073e9SAndroid Build Coastguard Worker def test_addcdiv_transpose(self): 3300*da0073e9SAndroid Build Coastguard Worker # Regression test for issue https://github.com/pytorch/pytorch/issues/118115 3301*da0073e9SAndroid Build Coastguard Worker # Testing continuity of all input tensors 3302*da0073e9SAndroid Build Coastguard Worker 3303*da0073e9SAndroid Build Coastguard Worker def helper(shape, value): 3304*da0073e9SAndroid Build Coastguard Worker shape_t = shape[::-1] 3305*da0073e9SAndroid Build Coastguard Worker for i in range(2): 3306*da0073e9SAndroid Build Coastguard Worker for j in range(2): 3307*da0073e9SAndroid Build Coastguard Worker for k in range(2): 3308*da0073e9SAndroid Build Coastguard Worker x = torch.rand(shape, device="cpu") if i == 0 else torch.rand(shape_t, device="cpu").t() 3309*da0073e9SAndroid Build Coastguard Worker y = torch.rand(shape, device="cpu") if j == 0 else torch.rand(shape_t, device="cpu").t() 3310*da0073e9SAndroid Build Coastguard Worker z = torch.rand(shape, device="cpu") if k == 0 else torch.rand(shape_t, device="cpu").t() 3311*da0073e9SAndroid Build Coastguard Worker 3312*da0073e9SAndroid Build Coastguard Worker x_mps = x.detach().clone().to(device="mps") 3313*da0073e9SAndroid Build Coastguard Worker y_mps = y.detach().clone().to(device="mps") 3314*da0073e9SAndroid Build Coastguard Worker z_mps = z.detach().clone().to(device="mps") 3315*da0073e9SAndroid Build Coastguard Worker 3316*da0073e9SAndroid Build Coastguard Worker result_cpu = x.addcdiv_(y, z, value=value) 3317*da0073e9SAndroid Build Coastguard Worker result_mps = x_mps.addcdiv(y_mps, z_mps, value=value) 3318*da0073e9SAndroid Build Coastguard Worker result_mps_out = result_cpu.detach().clone().to('mps') 3319*da0073e9SAndroid Build Coastguard Worker torch.addcdiv(x_mps, y_mps, z_mps, out=result_mps_out, value=value) 3320*da0073e9SAndroid Build Coastguard Worker 3321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps) 3322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps_out) 3323*da0073e9SAndroid Build Coastguard Worker 3324*da0073e9SAndroid Build Coastguard Worker helper((2, 3), 1.0) 3325*da0073e9SAndroid Build Coastguard Worker helper((2, 3), 0.2) 3326*da0073e9SAndroid Build Coastguard Worker helper((100, 300), 1.0) 3327*da0073e9SAndroid Build Coastguard Worker helper((100, 300), 0.2) 3328*da0073e9SAndroid Build Coastguard Worker 3329*da0073e9SAndroid Build Coastguard Worker def test_buffer_size_match(self): 3330*da0073e9SAndroid Build Coastguard Worker # this test shouldn't cause any crash 3331*da0073e9SAndroid Build Coastguard Worker size = 16 3332*da0073e9SAndroid Build Coastguard Worker cpu_A = torch.rand(size, device='cpu') 3333*da0073e9SAndroid Build Coastguard Worker cpu_F = torch.rand(size, size, size, device='cpu') 3334*da0073e9SAndroid Build Coastguard Worker 3335*da0073e9SAndroid Build Coastguard Worker mps_A = cpu_A.to('mps') 3336*da0073e9SAndroid Build Coastguard Worker mps_F = cpu_F.to('mps') 3337*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_A @ cpu_F, mps_A @ mps_F) 3338*da0073e9SAndroid Build Coastguard Worker 3339*da0073e9SAndroid Build Coastguard Worker def test_transpose_inplace(self): 3340*da0073e9SAndroid Build Coastguard Worker values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] 3341*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu') 3342*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor(values, device='mps') 3343*da0073e9SAndroid Build Coastguard Worker 3344*da0073e9SAndroid Build Coastguard Worker cpu_x.transpose_(0, 1) 3345*da0073e9SAndroid Build Coastguard Worker mps_x.transpose_(0, 1) 3346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x, mps_x.to('cpu')) 3347*da0073e9SAndroid Build Coastguard Worker 3348*da0073e9SAndroid Build Coastguard Worker def test_expand_cpu_to_mps_copy(self): 3349*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/78642 3350*da0073e9SAndroid Build Coastguard Worker 3351*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(1).expand([10]).to("mps") 3352*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.tensor(1).expand([10]) 3353*da0073e9SAndroid Build Coastguard Worker 3354*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x.cpu()) 3355*da0073e9SAndroid Build Coastguard Worker 3356*da0073e9SAndroid Build Coastguard Worker def test_cpu_to_strided_mps_copy(self): 3357*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/86975 3358*da0073e9SAndroid Build Coastguard Worker 3359*da0073e9SAndroid Build Coastguard Worker a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps")) 3360*da0073e9SAndroid Build Coastguard Worker b1 = torch.Tensor([-1, -1]) 3361*da0073e9SAndroid Build Coastguard Worker a1[1:, 1] = b1 3362*da0073e9SAndroid Build Coastguard Worker 3363*da0073e9SAndroid Build Coastguard Worker a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps")) 3364*da0073e9SAndroid Build Coastguard Worker b2 = torch.Tensor([-1, -1]).to(torch.device("mps")) 3365*da0073e9SAndroid Build Coastguard Worker a2[1:, 1] = b2 3366*da0073e9SAndroid Build Coastguard Worker 3367*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1, a2) 3368*da0073e9SAndroid Build Coastguard Worker 3369*da0073e9SAndroid Build Coastguard Worker def test_view_slice_reshape(self): 3370*da0073e9SAndroid Build Coastguard Worker x = torch.randn([1, 4, 4], device="mps") 3371*da0073e9SAndroid Build Coastguard Worker y = x[0, :1, 1:] 3372*da0073e9SAndroid Build Coastguard Worker 3373*da0073e9SAndroid Build Coastguard Worker x_cpu = x.to("cpu") 3374*da0073e9SAndroid Build Coastguard Worker y_cpu = x_cpu[0, :1, 1:] 3375*da0073e9SAndroid Build Coastguard Worker 3376*da0073e9SAndroid Build Coastguard Worker r = y + 1 3377*da0073e9SAndroid Build Coastguard Worker r_cpu = y_cpu + 1 3378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, r_cpu) 3379*da0073e9SAndroid Build Coastguard Worker 3380*da0073e9SAndroid Build Coastguard Worker def test_slice_reshape(self): 3381*da0073e9SAndroid Build Coastguard Worker x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps") 3382*da0073e9SAndroid Build Coastguard Worker x_cpu = x.detach().clone().to("cpu") 3383*da0073e9SAndroid Build Coastguard Worker 3384*da0073e9SAndroid Build Coastguard Worker x = x[:, 3:].view(2, 3, 4, 1) 3385*da0073e9SAndroid Build Coastguard Worker x_cpu = x_cpu[:, 3:].view(2, 3, 4, 1) 3386*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_cpu) 3387*da0073e9SAndroid Build Coastguard Worker 3388*da0073e9SAndroid Build Coastguard Worker x = x + 2 3389*da0073e9SAndroid Build Coastguard Worker x_cpu = x_cpu + 2 3390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_cpu) 3391*da0073e9SAndroid Build Coastguard Worker 3392*da0073e9SAndroid Build Coastguard Worker def test_reshape_storage_offset(self): 3393*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/95883 3394*da0073e9SAndroid Build Coastguard Worker B = 4 3395*da0073e9SAndroid Build Coastguard Worker T = 1 3396*da0073e9SAndroid Build Coastguard Worker 3397*da0073e9SAndroid Build Coastguard Worker lin_cpu = nn.Linear(10, 256) 3398*da0073e9SAndroid Build Coastguard Worker lin_mps = nn.Linear(10, 256, device="mps") 3399*da0073e9SAndroid Build Coastguard Worker 3400*da0073e9SAndroid Build Coastguard Worker # Use the same weights and bias as the ones from the cpu 3401*da0073e9SAndroid Build Coastguard Worker lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_() 3402*da0073e9SAndroid Build Coastguard Worker lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_() 3403*da0073e9SAndroid Build Coastguard Worker 3404*da0073e9SAndroid Build Coastguard Worker x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True) 3405*da0073e9SAndroid Build Coastguard Worker x_cpu = x_mps.detach().clone().cpu().requires_grad_() 3406*da0073e9SAndroid Build Coastguard Worker x_mps = lin_mps(x_mps) 3407*da0073e9SAndroid Build Coastguard Worker x_cpu = lin_cpu(x_cpu) 3408*da0073e9SAndroid Build Coastguard Worker 3409*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_mps.shape, (B, T, 256)) 3410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu.shape, (B, T, 256)) 3411*da0073e9SAndroid Build Coastguard Worker 3412*da0073e9SAndroid Build Coastguard Worker cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1) 3413*da0073e9SAndroid Build Coastguard Worker cls_token_cpu = cls_token_mps.detach().clone().cpu() 3414*da0073e9SAndroid Build Coastguard Worker x_mps = torch.cat([cls_token_mps, x_mps], dim=1) 3415*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1) 3416*da0073e9SAndroid Build Coastguard Worker 3417*da0073e9SAndroid Build Coastguard Worker x_mps = x_mps.transpose(0, 1) 3418*da0073e9SAndroid Build Coastguard Worker x_cpu = x_cpu.transpose(0, 1) 3419*da0073e9SAndroid Build Coastguard Worker 3420*da0073e9SAndroid Build Coastguard Worker target_mps = torch.rand_like(x_mps) 3421*da0073e9SAndroid Build Coastguard Worker target_cpu = target_mps.detach().clone().cpu() 3422*da0073e9SAndroid Build Coastguard Worker loss_mps = F.mse_loss(x_mps, target_mps) 3423*da0073e9SAndroid Build Coastguard Worker loss_cpu = F.mse_loss(x_cpu, target_cpu) 3424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loss_mps, loss_cpu) 3425*da0073e9SAndroid Build Coastguard Worker 3426*da0073e9SAndroid Build Coastguard Worker loss_mps.backward() 3427*da0073e9SAndroid Build Coastguard Worker loss_cpu.backward() 3428*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_mps.grad, x_cpu.grad) 3429*da0073e9SAndroid Build Coastguard Worker 3430*da0073e9SAndroid Build Coastguard Worker def test_stack_storage_offset(self): 3431*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/87856 3432*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.tensor([[1, 2]]) 3433*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.detach().clone().to("mps") 3434*da0073e9SAndroid Build Coastguard Worker 3435*da0073e9SAndroid Build Coastguard Worker y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1) 3436*da0073e9SAndroid Build Coastguard Worker y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1) 3437*da0073e9SAndroid Build Coastguard Worker 3438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_cpu, y_mps) 3439*da0073e9SAndroid Build Coastguard Worker 3440*da0073e9SAndroid Build Coastguard Worker t_mps = torch.tensor([1, 2, 3, 4], device="mps") 3441*da0073e9SAndroid Build Coastguard Worker t_cpu = t_mps.detach().cpu().detach() 3442*da0073e9SAndroid Build Coastguard Worker 3443*da0073e9SAndroid Build Coastguard Worker x_mps = t_mps[2:] 3444*da0073e9SAndroid Build Coastguard Worker y_mps = t_mps[:2] 3445*da0073e9SAndroid Build Coastguard Worker 3446*da0073e9SAndroid Build Coastguard Worker x_cpu = t_cpu[2:] 3447*da0073e9SAndroid Build Coastguard Worker y_cpu = t_cpu[:2] 3448*da0073e9SAndroid Build Coastguard Worker 3449*da0073e9SAndroid Build Coastguard Worker res_mps = torch.stack((y_mps, x_mps), dim=-1) 3450*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.stack((y_cpu, x_cpu), dim=-1) 3451*da0073e9SAndroid Build Coastguard Worker 3452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_mps, res_cpu) 3453*da0073e9SAndroid Build Coastguard Worker 3454*da0073e9SAndroid Build Coastguard Worker def test_unsafe_chunk(self): 3455*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/91065 3456*da0073e9SAndroid Build Coastguard Worker a = torch.rand(5, dtype=torch.float32, device="cpu") 3457*da0073e9SAndroid Build Coastguard Worker ret = a.unsafe_chunk(4, 0) 3458*da0073e9SAndroid Build Coastguard Worker y = ret[0] * ret[2] 3459*da0073e9SAndroid Build Coastguard Worker a_mps = a.to("mps") 3460*da0073e9SAndroid Build Coastguard Worker ret_mps = a_mps.unsafe_chunk(4, 0) 3461*da0073e9SAndroid Build Coastguard Worker y_mps = ret_mps[0] * ret_mps[2] 3462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_mps) 3463*da0073e9SAndroid Build Coastguard Worker 3464*da0073e9SAndroid Build Coastguard Worker def test_slice_casting(self): 3465*da0073e9SAndroid Build Coastguard Worker # generate random binary numbers 3466*da0073e9SAndroid Build Coastguard Worker cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8) 3467*da0073e9SAndroid Build Coastguard Worker mps_in = cpu_in.detach().clone().to("mps") 3468*da0073e9SAndroid Build Coastguard Worker # check copy_cast(unit8 -> bool) on tensors with storage offset 3469*da0073e9SAndroid Build Coastguard Worker cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool) 3470*da0073e9SAndroid Build Coastguard Worker mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool) 3471*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_out, mps_out) 3472*da0073e9SAndroid Build Coastguard Worker 3473*da0073e9SAndroid Build Coastguard Worker def test_slice_reshape_contg_view(self): 3474*da0073e9SAndroid Build Coastguard Worker import torch 3475*da0073e9SAndroid Build Coastguard Worker 3476*da0073e9SAndroid Build Coastguard Worker x_mps = torch.randn(1, 4800, 2, device="mps") 3477*da0073e9SAndroid Build Coastguard Worker x_cpu = x_mps.detach().clone().cpu() 3478*da0073e9SAndroid Build Coastguard Worker 3479*da0073e9SAndroid Build Coastguard Worker r_mps = x_mps + 2 3480*da0073e9SAndroid Build Coastguard Worker r_cpu = x_cpu + 2 3481*da0073e9SAndroid Build Coastguard Worker 3482*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r_mps, r_cpu) 3483*da0073e9SAndroid Build Coastguard Worker 3484*da0073e9SAndroid Build Coastguard Worker def test_contiguous_slice_2d(self): 3485*da0073e9SAndroid Build Coastguard Worker def helper(shape): 3486*da0073e9SAndroid Build Coastguard Worker for i in range(0, shape[0]): 3487*da0073e9SAndroid Build Coastguard Worker for j in range(0, shape[1]): 3488*da0073e9SAndroid Build Coastguard Worker t_mps = torch.randn(shape, device="mps") 3489*da0073e9SAndroid Build Coastguard Worker t_cpu = t_mps.detach().clone().cpu() 3490*da0073e9SAndroid Build Coastguard Worker 3491*da0073e9SAndroid Build Coastguard Worker y_mps = t_mps[i:, :j] 3492*da0073e9SAndroid Build Coastguard Worker y_cpu = t_cpu[i:, :j] 3493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_mps + 1, y_cpu + 1) 3494*da0073e9SAndroid Build Coastguard Worker 3495*da0073e9SAndroid Build Coastguard Worker y_mps = t_mps[i:, j] 3496*da0073e9SAndroid Build Coastguard Worker y_cpu = t_cpu[i:, j] 3497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_mps + 1, y_cpu + 1) 3498*da0073e9SAndroid Build Coastguard Worker 3499*da0073e9SAndroid Build Coastguard Worker y_mps = t_mps[i, :j] 3500*da0073e9SAndroid Build Coastguard Worker y_cpu = t_cpu[i, :j] 3501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_mps + 1, y_cpu + 1) 3502*da0073e9SAndroid Build Coastguard Worker 3503*da0073e9SAndroid Build Coastguard Worker y_mps = t_mps[:i, :j] 3504*da0073e9SAndroid Build Coastguard Worker y_cpu = t_cpu[:i, :j] 3505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_mps + 1, y_cpu + 1) 3506*da0073e9SAndroid Build Coastguard Worker 3507*da0073e9SAndroid Build Coastguard Worker y_mps = t_mps[:i, j] 3508*da0073e9SAndroid Build Coastguard Worker y_cpu = t_cpu[:i, j] 3509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_mps + 1, y_cpu + 1) 3510*da0073e9SAndroid Build Coastguard Worker 3511*da0073e9SAndroid Build Coastguard Worker y_mps = t_mps[:i, j:] 3512*da0073e9SAndroid Build Coastguard Worker y_cpu = t_cpu[:i, j:] 3513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_mps + 1, y_cpu + 1) 3514*da0073e9SAndroid Build Coastguard Worker 3515*da0073e9SAndroid Build Coastguard Worker l = [] 3516*da0073e9SAndroid Build Coastguard Worker for N in range(1, 3): 3517*da0073e9SAndroid Build Coastguard Worker l.append(N) 3518*da0073e9SAndroid Build Coastguard Worker for C in range(1, 3): 3519*da0073e9SAndroid Build Coastguard Worker l.append(C) 3520*da0073e9SAndroid Build Coastguard Worker helper(l) 3521*da0073e9SAndroid Build Coastguard Worker for D in range(1, 3): 3522*da0073e9SAndroid Build Coastguard Worker l.append(D) 3523*da0073e9SAndroid Build Coastguard Worker helper(l) 3524*da0073e9SAndroid Build Coastguard Worker for H in range(1, 3): 3525*da0073e9SAndroid Build Coastguard Worker l.append(H) 3526*da0073e9SAndroid Build Coastguard Worker helper(l) 3527*da0073e9SAndroid Build Coastguard Worker for W in range(1, 3): 3528*da0073e9SAndroid Build Coastguard Worker l.append(W) 3529*da0073e9SAndroid Build Coastguard Worker helper(l) 3530*da0073e9SAndroid Build Coastguard Worker l.pop() 3531*da0073e9SAndroid Build Coastguard Worker l.pop() 3532*da0073e9SAndroid Build Coastguard Worker l.pop() 3533*da0073e9SAndroid Build Coastguard Worker l.pop() 3534*da0073e9SAndroid Build Coastguard Worker l.pop() 3535*da0073e9SAndroid Build Coastguard Worker 3536*da0073e9SAndroid Build Coastguard Worker helper([9, 15, 4]) 3537*da0073e9SAndroid Build Coastguard Worker helper([9, 3, 2]) 3538*da0073e9SAndroid Build Coastguard Worker helper([3, 4, 18, 22]) 3539*da0073e9SAndroid Build Coastguard Worker helper([3, 4, 18, 22, 150]) 3540*da0073e9SAndroid Build Coastguard Worker 3541*da0073e9SAndroid Build Coastguard Worker def test_contiguous_slice_3d(self): 3542*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 3, device="mps") 3543*da0073e9SAndroid Build Coastguard Worker x_cpu = x.detach().clone().cpu() 3544*da0073e9SAndroid Build Coastguard Worker x = x[:1] 3545*da0073e9SAndroid Build Coastguard Worker x_cpu = x_cpu[:1] 3546*da0073e9SAndroid Build Coastguard Worker out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2] 3547*da0073e9SAndroid Build Coastguard Worker out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2] 3548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_cpu) 3549*da0073e9SAndroid Build Coastguard Worker 3550*da0073e9SAndroid Build Coastguard Worker def test_view_slice(self): 3551*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/83995 3552*da0073e9SAndroid Build Coastguard Worker NUM_SAMPLES = 60 3553*da0073e9SAndroid Build Coastguard Worker s = (0, 1) 3554*da0073e9SAndroid Build Coastguard Worker 3555*da0073e9SAndroid Build Coastguard Worker X = torch.rand(8000, 3, dtype=torch.float32, device='cpu') 3556*da0073e9SAndroid Build Coastguard Worker X_mps = X.detach().clone().to("cpu") 3557*da0073e9SAndroid Build Coastguard Worker 3558*da0073e9SAndroid Build Coastguard Worker idx = torch.randint(0, X.shape[0], (1,)).repeat(len(s)) 3559*da0073e9SAndroid Build Coastguard Worker pts = torch.randint(0, X.shape[0], (NUM_SAMPLES, X.shape[1])) 3560*da0073e9SAndroid Build Coastguard Worker idx_mps = idx.to("mps") 3561*da0073e9SAndroid Build Coastguard Worker pts_mps = pts.to("mps") 3562*da0073e9SAndroid Build Coastguard Worker pts[:, s] = idx 3563*da0073e9SAndroid Build Coastguard Worker pts_mps[:, s] = idx_mps 3564*da0073e9SAndroid Build Coastguard Worker 3565*da0073e9SAndroid Build Coastguard Worker actual_pts = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float) 3566*da0073e9SAndroid Build Coastguard Worker actual_pts_mps = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float, device="mps") 3567*da0073e9SAndroid Build Coastguard Worker 3568*da0073e9SAndroid Build Coastguard Worker for i in range(NUM_SAMPLES): 3569*da0073e9SAndroid Build Coastguard Worker for j in range(X.shape[1]): 3570*da0073e9SAndroid Build Coastguard Worker actual_pts_mps[i, j] = X_mps[pts_mps[i, j], j] 3571*da0073e9SAndroid Build Coastguard Worker actual_pts[i, j] = X[pts[i, j], j] 3572*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_pts[i, j], actual_pts_mps[i, j]) 3573*da0073e9SAndroid Build Coastguard Worker 3574*da0073e9SAndroid Build Coastguard Worker def test_slice_scatter(self): 3575*da0073e9SAndroid Build Coastguard Worker shape = (4, 4) 3576*da0073e9SAndroid Build Coastguard Worker tensor = torch.randint(10, shape, device="mps") 3577*da0073e9SAndroid Build Coastguard Worker tensor_before = tensor.clone() 3578*da0073e9SAndroid Build Coastguard Worker torch.empty(shape[0], shape[1] * 2, device="mps")[:, ::2].copy_(tensor) 3579*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(tensor, tensor_before) 3580*da0073e9SAndroid Build Coastguard Worker 3581*da0073e9SAndroid Build Coastguard Worker def test_slice(self): 3582*da0073e9SAndroid Build Coastguard Worker values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] 3583*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu') 3584*da0073e9SAndroid Build Coastguard Worker mps_x = (torch.tensor(values, device='mps', dtype=torch.float)) 3585*da0073e9SAndroid Build Coastguard Worker 3586*da0073e9SAndroid Build Coastguard Worker cpu_slice1 = cpu_x[:2, :] 3587*da0073e9SAndroid Build Coastguard Worker mps_slice1 = mps_x[:2, :] 3588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_slice1, mps_slice1) 3589*da0073e9SAndroid Build Coastguard Worker 3590*da0073e9SAndroid Build Coastguard Worker cpu_slice2 = cpu_x[:, :1] 3591*da0073e9SAndroid Build Coastguard Worker mps_slice2 = mps_x[:, :1] 3592*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_slice2, mps_slice2) 3593*da0073e9SAndroid Build Coastguard Worker 3594*da0073e9SAndroid Build Coastguard Worker cpu_slice3 = cpu_x[1:2, :] 3595*da0073e9SAndroid Build Coastguard Worker mps_slice3 = mps_x[1:2, :] 3596*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_slice3, mps_slice3.to('cpu')) 3597*da0073e9SAndroid Build Coastguard Worker 3598*da0073e9SAndroid Build Coastguard Worker cpu_slice4 = cpu_x[1, :] 3599*da0073e9SAndroid Build Coastguard Worker mps_slice4 = mps_x[1, :].to('cpu') 3600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_slice4, mps_slice4) 3601*da0073e9SAndroid Build Coastguard Worker 3602*da0073e9SAndroid Build Coastguard Worker @parametrize("torch_type", arg_values=[torch.float16, torch.float32, torch.bfloat16]) 3603*da0073e9SAndroid Build Coastguard Worker def test_slice_view_api(self, torch_type: torch.dtype): 3604*da0073e9SAndroid Build Coastguard Worker 3605*da0073e9SAndroid Build Coastguard Worker def helper(x_tensor, y_func, z_func, r_func=None): 3606*da0073e9SAndroid Build Coastguard Worker x_mps = x_tensor.detach().clone().to("mps") 3607*da0073e9SAndroid Build Coastguard Worker 3608*da0073e9SAndroid Build Coastguard Worker y = y_func(x_tensor) 3609*da0073e9SAndroid Build Coastguard Worker y_mps = y_func(x_mps) 3610*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_mps) 3611*da0073e9SAndroid Build Coastguard Worker 3612*da0073e9SAndroid Build Coastguard Worker z = z_func(y) 3613*da0073e9SAndroid Build Coastguard Worker z_mps = z_func(y_mps) 3614*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, z_mps) 3615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.storage_offset(), z_mps.storage_offset()) 3616*da0073e9SAndroid Build Coastguard Worker 3617*da0073e9SAndroid Build Coastguard Worker if r_func: 3618*da0073e9SAndroid Build Coastguard Worker r = r_func(z) 3619*da0073e9SAndroid Build Coastguard Worker r_mps = r_func(z_mps) 3620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, r_mps) 3621*da0073e9SAndroid Build Coastguard Worker 3622*da0073e9SAndroid Build Coastguard Worker # Skip bfloat16 before MacOS15 3623*da0073e9SAndroid Build Coastguard Worker if not (product_version < 15.0 and torch_type == torch.bfloat16): 3624*da0073e9SAndroid Build Coastguard Worker # Tests for previously encountered MPS bugs 3625*da0073e9SAndroid Build Coastguard Worker helper( 3626*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, dtype=torch_type), 3627*da0073e9SAndroid Build Coastguard Worker lambda x: x[1], 3628*da0073e9SAndroid Build Coastguard Worker lambda y: y.reshape(2, 2), 3629*da0073e9SAndroid Build Coastguard Worker lambda z: z + 1 3630*da0073e9SAndroid Build Coastguard Worker ) 3631*da0073e9SAndroid Build Coastguard Worker helper( 3632*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 4, dtype=torch_type), 3633*da0073e9SAndroid Build Coastguard Worker lambda x: x[1], 3634*da0073e9SAndroid Build Coastguard Worker lambda y: y + torch.ones(4, device=y.device) 3635*da0073e9SAndroid Build Coastguard Worker ) 3636*da0073e9SAndroid Build Coastguard Worker helper( 3637*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 6, dtype=torch_type), 3638*da0073e9SAndroid Build Coastguard Worker lambda x: x[1], 3639*da0073e9SAndroid Build Coastguard Worker lambda y: y.reshape(3, 2).t(), 3640*da0073e9SAndroid Build Coastguard Worker lambda z: z + 1 3641*da0073e9SAndroid Build Coastguard Worker ) 3642*da0073e9SAndroid Build Coastguard Worker helper( 3643*da0073e9SAndroid Build Coastguard Worker torch.arange(4, dtype=torch_type).resize(1, 2, 2), 3644*da0073e9SAndroid Build Coastguard Worker lambda x: x.permute(2, 0, 1), 3645*da0073e9SAndroid Build Coastguard Worker lambda y: y + 1 3646*da0073e9SAndroid Build Coastguard Worker ) 3647*da0073e9SAndroid Build Coastguard Worker helper( 3648*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 8, dtype=torch_type), 3649*da0073e9SAndroid Build Coastguard Worker lambda x: x.transpose(0, 1).reshape(-1), 3650*da0073e9SAndroid Build Coastguard Worker lambda y: y[:2], 3651*da0073e9SAndroid Build Coastguard Worker lambda z: z + 1 3652*da0073e9SAndroid Build Coastguard Worker ) 3653*da0073e9SAndroid Build Coastguard Worker helper( 3654*da0073e9SAndroid Build Coastguard Worker torch.randn(1, dtype=torch_type), 3655*da0073e9SAndroid Build Coastguard Worker lambda x: x.expand(2, 3), 3656*da0073e9SAndroid Build Coastguard Worker lambda y: y + torch.ones(2, 3, device=y.device) 3657*da0073e9SAndroid Build Coastguard Worker ) 3658*da0073e9SAndroid Build Coastguard Worker 3659*da0073e9SAndroid Build Coastguard Worker def test_slice_reshape_contiguous(self): 3660*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4) 3661*da0073e9SAndroid Build Coastguard Worker x_mps = x.detach().clone().to("mps") 3662*da0073e9SAndroid Build Coastguard Worker 3663*da0073e9SAndroid Build Coastguard Worker y = x[1] 3664*da0073e9SAndroid Build Coastguard Worker y_mps = x_mps[1] 3665*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_mps) 3666*da0073e9SAndroid Build Coastguard Worker 3667*da0073e9SAndroid Build Coastguard Worker z = y.reshape(2, 2) 3668*da0073e9SAndroid Build Coastguard Worker z_mps = y_mps.reshape(2, 2) 3669*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, z_mps) 3670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.storage_offset(), z_mps.storage_offset()) 3671*da0073e9SAndroid Build Coastguard Worker 3672*da0073e9SAndroid Build Coastguard Worker def test_scalar_from_slice_unary(self): 3673*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/82543 3674*da0073e9SAndroid Build Coastguard Worker tensor_list = torch.tensor([1.0, 1.2], device="mps") 3675*da0073e9SAndroid Build Coastguard Worker 3676*da0073e9SAndroid Build Coastguard Worker for scalar in tensor_list: 3677*da0073e9SAndroid Build Coastguard Worker r_mps = torch.ceil(scalar) 3678*da0073e9SAndroid Build Coastguard Worker r_cpu = torch.ceil(scalar.to("cpu")) 3679*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r_mps.cpu(), r_cpu) 3680*da0073e9SAndroid Build Coastguard Worker 3681*da0073e9SAndroid Build Coastguard Worker def test_scalar_from_slice_binary(self): 3682*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/82543 3683*da0073e9SAndroid Build Coastguard Worker def helper(binary_op): 3684*da0073e9SAndroid Build Coastguard Worker tensor_list = torch.tensor([1.0, 1.2, 2.5, 1.0], device="mps") 3685*da0073e9SAndroid Build Coastguard Worker 3686*da0073e9SAndroid Build Coastguard Worker for scalar in tensor_list: 3687*da0073e9SAndroid Build Coastguard Worker r_mps = binary_op(scalar, 1.0) 3688*da0073e9SAndroid Build Coastguard Worker r_cpu = binary_op(scalar.cpu(), 1.0) 3689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r_mps.cpu(), r_cpu) 3690*da0073e9SAndroid Build Coastguard Worker helper(torch.sub) 3691*da0073e9SAndroid Build Coastguard Worker helper(torch.add) 3692*da0073e9SAndroid Build Coastguard Worker helper(torch.not_equal) 3693*da0073e9SAndroid Build Coastguard Worker helper(torch.eq) 3694*da0073e9SAndroid Build Coastguard Worker 3695*da0073e9SAndroid Build Coastguard Worker def test_slice_contiguous_view(self): 3696*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/77750 3697*da0073e9SAndroid Build Coastguard Worker 3698*da0073e9SAndroid Build Coastguard Worker def helper(operator): 3699*da0073e9SAndroid Build Coastguard Worker t_mps = torch.tensor([1, 2, 3, 4], device="mps") 3700*da0073e9SAndroid Build Coastguard Worker t_cpu = torch.tensor([1, 2, 3, 4], device="cpu") 3701*da0073e9SAndroid Build Coastguard Worker 3702*da0073e9SAndroid Build Coastguard Worker # contiguous view 3703*da0073e9SAndroid Build Coastguard Worker x_mps = t_mps[2:] # 3, 4 3704*da0073e9SAndroid Build Coastguard Worker y_mps = t_mps[:2] # 1, 2 3705*da0073e9SAndroid Build Coastguard Worker 3706*da0073e9SAndroid Build Coastguard Worker x_cpu = t_cpu[2:] 3707*da0073e9SAndroid Build Coastguard Worker y_cpu = t_cpu[:2] 3708*da0073e9SAndroid Build Coastguard Worker 3709*da0073e9SAndroid Build Coastguard Worker res_mps = res_cpu = None 3710*da0073e9SAndroid Build Coastguard Worker if operator == "<=": 3711*da0073e9SAndroid Build Coastguard Worker res_mps = x_mps <= y_mps 3712*da0073e9SAndroid Build Coastguard Worker res_cpu = x_cpu <= y_cpu 3713*da0073e9SAndroid Build Coastguard Worker elif operator == "<": 3714*da0073e9SAndroid Build Coastguard Worker res_mps = x_mps < y_mps 3715*da0073e9SAndroid Build Coastguard Worker res_cpu = x_cpu < y_cpu 3716*da0073e9SAndroid Build Coastguard Worker elif operator == ">=": 3717*da0073e9SAndroid Build Coastguard Worker res_mps = x_mps >= y_mps 3718*da0073e9SAndroid Build Coastguard Worker res_cpu = x_cpu >= y_cpu 3719*da0073e9SAndroid Build Coastguard Worker elif operator == ">": 3720*da0073e9SAndroid Build Coastguard Worker res_mps = x_mps >= y_mps 3721*da0073e9SAndroid Build Coastguard Worker res_cpu = x_cpu >= y_cpu 3722*da0073e9SAndroid Build Coastguard Worker elif operator == "==": 3723*da0073e9SAndroid Build Coastguard Worker res_mps = x_mps == y_mps 3724*da0073e9SAndroid Build Coastguard Worker res_cpu = x_cpu == y_cpu 3725*da0073e9SAndroid Build Coastguard Worker elif operator == "!=": 3726*da0073e9SAndroid Build Coastguard Worker res_mps = x_mps != y_mps 3727*da0073e9SAndroid Build Coastguard Worker res_cpu = x_cpu != y_cpu 3728*da0073e9SAndroid Build Coastguard Worker elif operator == "stack": 3729*da0073e9SAndroid Build Coastguard Worker res_mps = torch.stack((y_mps, x_mps), dim=-1) 3730*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.stack((y_cpu, x_cpu), dim=-1) 3731*da0073e9SAndroid Build Coastguard Worker 3732*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_mps, res_cpu) 3733*da0073e9SAndroid Build Coastguard Worker 3734*da0073e9SAndroid Build Coastguard Worker for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]: 3735*da0073e9SAndroid Build Coastguard Worker helper(op) 3736*da0073e9SAndroid Build Coastguard Worker 3737*da0073e9SAndroid Build Coastguard Worker def test_slice_of_slice(self): 3738*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.5, 0.5], device="cpu") 3739*da0073e9SAndroid Build Coastguard Worker x_mps = torch.tensor([0.5, 0.5], device="mps") 3740*da0073e9SAndroid Build Coastguard Worker 3741*da0073e9SAndroid Build Coastguard Worker tensor = x[1][None] 3742*da0073e9SAndroid Build Coastguard Worker tensor_mps = x_mps[1][None] 3743*da0073e9SAndroid Build Coastguard Worker 3744*da0073e9SAndroid Build Coastguard Worker res = tensor.ne(0) 3745*da0073e9SAndroid Build Coastguard Worker res_mps = tensor_mps.ne(0) 3746*da0073e9SAndroid Build Coastguard Worker 3747*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_mps) 3748*da0073e9SAndroid Build Coastguard Worker 3749*da0073e9SAndroid Build Coastguard Worker def test_index_storage_offset(self): 3750*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/78107 3751*da0073e9SAndroid Build Coastguard Worker 3752*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([8.2670e-01, -1.0293e+00]) 3753*da0073e9SAndroid Build Coastguard Worker b_cpu = a[0] 3754*da0073e9SAndroid Build Coastguard Worker c_cpu = a[1] 3755*da0073e9SAndroid Build Coastguard Worker 3756*da0073e9SAndroid Build Coastguard Worker # both 'b' and 'c' are views of 'a' 3757*da0073e9SAndroid Build Coastguard Worker # 'b' has a storage offset of 0, while 'c' has a storage offset of 1 3758*da0073e9SAndroid Build Coastguard Worker # when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account, 3759*da0073e9SAndroid Build Coastguard Worker # otherwise it ends with same value as 'b' 3760*da0073e9SAndroid Build Coastguard Worker b = b_cpu.to('mps') 3761*da0073e9SAndroid Build Coastguard Worker c = c_cpu.to('mps') 3762*da0073e9SAndroid Build Coastguard Worker 3763*da0073e9SAndroid Build Coastguard Worker res_mps = b > c 3764*da0073e9SAndroid Build Coastguard Worker res_cpu = b_cpu > c_cpu 3765*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_mps, res_cpu) 3766*da0073e9SAndroid Build Coastguard Worker 3767*da0073e9SAndroid Build Coastguard Worker res_mps = c > b 3768*da0073e9SAndroid Build Coastguard Worker res_cpu = c_cpu > b_cpu 3769*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_mps, res_cpu) 3770*da0073e9SAndroid Build Coastguard Worker 3771*da0073e9SAndroid Build Coastguard Worker def test_flatten(self): 3772*da0073e9SAndroid Build Coastguard Worker values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] 3773*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu') 3774*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor(values, device='mps') 3775*da0073e9SAndroid Build Coastguard Worker 3776*da0073e9SAndroid Build Coastguard Worker cpu_flatten1 = cpu_x.flatten() 3777*da0073e9SAndroid Build Coastguard Worker mps_flatten1 = mps_x.flatten().to('cpu') 3778*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_flatten1, mps_flatten1) 3779*da0073e9SAndroid Build Coastguard Worker 3780*da0073e9SAndroid Build Coastguard Worker cpu_flatten2 = cpu_x.flatten(start_dim=1) 3781*da0073e9SAndroid Build Coastguard Worker mps_flatten2 = mps_x.flatten(start_dim=1).to('cpu') 3782*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_flatten2, mps_flatten2) 3783*da0073e9SAndroid Build Coastguard Worker 3784*da0073e9SAndroid Build Coastguard Worker cpu_flatten3 = cpu_x.flatten(end_dim=1) 3785*da0073e9SAndroid Build Coastguard Worker mps_flatten3 = mps_x.flatten(end_dim=1).to('cpu') 3786*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_flatten3, mps_flatten3) 3787*da0073e9SAndroid Build Coastguard Worker 3788*da0073e9SAndroid Build Coastguard Worker # Test repeat 3789*da0073e9SAndroid Build Coastguard Worker def test_repeat(self): 3790*da0073e9SAndroid Build Coastguard Worker def helper(shape, repeats): 3791*da0073e9SAndroid Build Coastguard Worker 3792*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 3793*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 3794*da0073e9SAndroid Build Coastguard Worker 3795*da0073e9SAndroid Build Coastguard Worker y = x.repeat(repeats) 3796*da0073e9SAndroid Build Coastguard Worker ref_y = cpu_x.repeat(repeats) 3797*da0073e9SAndroid Build Coastguard Worker 3798*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(ref_y.shape) 3799*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 3800*da0073e9SAndroid Build Coastguard Worker 3801*da0073e9SAndroid Build Coastguard Worker y.backward(gradient=grad) 3802*da0073e9SAndroid Build Coastguard Worker ref_y.backward(gradient=cpu_grad) 3803*da0073e9SAndroid Build Coastguard Worker 3804*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 3805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 3806*da0073e9SAndroid Build Coastguard Worker 3807*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5), (2, 3, 4, 5)) 3808*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4), (4, 3, 2, 5, 7, 2)) 3809*da0073e9SAndroid Build Coastguard Worker helper((3, 4, 5), (2, 3, 4, 5)) 3810*da0073e9SAndroid Build Coastguard Worker helper((3, 4, 5), (2, 2, 2)) 3811*da0073e9SAndroid Build Coastguard Worker 3812*da0073e9SAndroid Build Coastguard Worker def test_torch_repeat_interleave(self, device="mps"): 3813*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([[1, 2], [3, 4]], device=device) 3814*da0073e9SAndroid Build Coastguard Worker # exercise single argument function signature 3815*da0073e9SAndroid Build Coastguard Worker temp = y.repeat_interleave(2) 3816*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([8]), temp.size()) 3817*da0073e9SAndroid Build Coastguard Worker 3818*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.int, torch.long]: 3819*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([1, 2], dtype=dtype, device="mps") 3820*da0073e9SAndroid Build Coastguard Worker output_size = torch.sum(lengths) 3821*da0073e9SAndroid Build Coastguard Worker a = torch.repeat_interleave( 3822*da0073e9SAndroid Build Coastguard Worker y, 3823*da0073e9SAndroid Build Coastguard Worker lengths, 3824*da0073e9SAndroid Build Coastguard Worker dim=0, 3825*da0073e9SAndroid Build Coastguard Worker ) 3826*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, y.dtype) 3827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(), torch.Size([3, 2])) 3828*da0073e9SAndroid Build Coastguard Worker 3829*da0073e9SAndroid Build Coastguard Worker a_with_output = torch.repeat_interleave( 3830*da0073e9SAndroid Build Coastguard Worker y, 3831*da0073e9SAndroid Build Coastguard Worker lengths, 3832*da0073e9SAndroid Build Coastguard Worker dim=0, 3833*da0073e9SAndroid Build Coastguard Worker output_size=output_size, 3834*da0073e9SAndroid Build Coastguard Worker ) 3835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_with_output.dtype, y.dtype) 3836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_with_output.size(), torch.Size([3, 2])) 3837*da0073e9SAndroid Build Coastguard Worker 3838*da0073e9SAndroid Build Coastguard Worker def test_repeat_interleave(self, device="mps"): 3839*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0, 1, 2, 3], device=device) 3840*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([1, 2, 2, 3, 3, 3], device=device) 3841*da0073e9SAndroid Build Coastguard Worker # Prior to macos 13.3, input of dtype=torch.int64 returns dtype=torch.int32 3842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.repeat_interleave(x), expected, exact_dtype=product_version >= 13.3) 3843*da0073e9SAndroid Build Coastguard Worker 3844*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3845*da0073e9SAndroid Build Coastguard Worker torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2)) 3846*da0073e9SAndroid Build Coastguard Worker 3847*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3848*da0073e9SAndroid Build Coastguard Worker torch.repeat_interleave(torch.arange(4.0, device=device)) 3849*da0073e9SAndroid Build Coastguard Worker 3850*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3851*da0073e9SAndroid Build Coastguard Worker torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4], device=device)) 3852*da0073e9SAndroid Build Coastguard Worker 3853*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([[1, 2], [3, 4]], device=device) 3854*da0073e9SAndroid Build Coastguard Worker 3855*da0073e9SAndroid Build Coastguard Worker y1_v1 = torch.repeat_interleave(y, 2) 3856*da0073e9SAndroid Build Coastguard Worker y1_v2 = torch.repeat_interleave(y, torch.tensor(2, device=device)) 3857*da0073e9SAndroid Build Coastguard Worker y1_v3 = torch.repeat_interleave(y, torch.tensor([2], device=device)) 3858*da0073e9SAndroid Build Coastguard Worker y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4], device=device) 3859*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1_v1, y1_expect) 3860*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1_v2, y1_expect) 3861*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1_v3, y1_expect) 3862*da0073e9SAndroid Build Coastguard Worker 3863*da0073e9SAndroid Build Coastguard Worker y2 = torch.repeat_interleave(y, 3, dim=1) 3864*da0073e9SAndroid Build Coastguard Worker y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2], 3865*da0073e9SAndroid Build Coastguard Worker [3, 3, 3, 4, 4, 4]], device=device) 3866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y2, y2_expect) 3867*da0073e9SAndroid Build Coastguard Worker 3868*da0073e9SAndroid Build Coastguard Worker y3 = torch.repeat_interleave(y, torch.tensor([1, 2], device=device), dim=0) 3869*da0073e9SAndroid Build Coastguard Worker y3_expect = torch.tensor([[1, 2], 3870*da0073e9SAndroid Build Coastguard Worker [3, 4], 3871*da0073e9SAndroid Build Coastguard Worker [3, 4]], device=device) 3872*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y3, y3_expect) 3873*da0073e9SAndroid Build Coastguard Worker 3874*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3875*da0073e9SAndroid Build Coastguard Worker torch.repeat_interleave(y, torch.tensor([1, 2, 3], device=device), dim=0) 3876*da0073e9SAndroid Build Coastguard Worker 3877*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3878*da0073e9SAndroid Build Coastguard Worker torch.repeat_interleave(y, torch.arange(9, device=device).reshape(3, 3), dim=0) 3879*da0073e9SAndroid Build Coastguard Worker 3880*da0073e9SAndroid Build Coastguard Worker # test zero sized dimension 3881*da0073e9SAndroid Build Coastguard Worker x = torch.zeros((5, 0), device=device) 3882*da0073e9SAndroid Build Coastguard Worker y = torch.repeat_interleave(x, repeats=3, dim=1) 3883*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.new_zeros(5, 0, device=device)) 3884*da0073e9SAndroid Build Coastguard Worker 3885*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([], dtype=torch.int64, device=device) 3886*da0073e9SAndroid Build Coastguard Worker y = torch.repeat_interleave(x, x) 3887*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x) 3888*da0073e9SAndroid Build Coastguard Worker 3889*da0073e9SAndroid Build Coastguard Worker def test_repeat_interleave_simple(self): 3890*da0073e9SAndroid Build Coastguard Worker def helper(shape, dtype=torch.float32, num_repeats=torch.Tensor(), dim=None): 3891*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, dtype=dtype, device="mps") 3892*da0073e9SAndroid Build Coastguard Worker x_cpu = x.detach().clone().cpu() 3893*da0073e9SAndroid Build Coastguard Worker 3894*da0073e9SAndroid Build Coastguard Worker num_repeats_cpu = num_repeats.detach().clone().cpu() 3895*da0073e9SAndroid Build Coastguard Worker 3896*da0073e9SAndroid Build Coastguard Worker repeats = torch.repeat_interleave(x, num_repeats, dim) 3897*da0073e9SAndroid Build Coastguard Worker repeats_cpu = torch.repeat_interleave(x_cpu, num_repeats_cpu, dim) 3898*da0073e9SAndroid Build Coastguard Worker 3899*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repeats, repeats_cpu) 3900*da0073e9SAndroid Build Coastguard Worker helper(shape=3, num_repeats=torch.tensor([100], device="mps")) 3901*da0073e9SAndroid Build Coastguard Worker helper(shape=(2, 2), num_repeats=torch.tensor([3, 3], device="mps"), dim=0) 3902*da0073e9SAndroid Build Coastguard Worker helper(shape=(10, 15, 8), num_repeats=torch.arange(10, device="mps"), dim=0) 3903*da0073e9SAndroid Build Coastguard Worker helper(shape=(10, 15, 8), num_repeats=torch.randint(0, 100, (15, ), device="mps"), dim=1) 3904*da0073e9SAndroid Build Coastguard Worker helper(shape=(10, 15, 30), num_repeats=torch.randint(0, 100, (30, ), device="mps"), dim=2) 3905*da0073e9SAndroid Build Coastguard Worker 3906*da0073e9SAndroid Build Coastguard Worker def test_count_nonzero(self): 3907*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 3908*da0073e9SAndroid Build Coastguard Worker n = [ 3909*da0073e9SAndroid Build Coastguard Worker [[1, 0, 2], [3, 0, 2], [7, 9, -4]], 3910*da0073e9SAndroid Build Coastguard Worker [[0, 2, 3], [3, 2, 1], [2, 0, 0]], 3911*da0073e9SAndroid Build Coastguard Worker ] 3912*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(n, dtype=dtype) 3913*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor(n, dtype=dtype).to('mps') 3914*da0073e9SAndroid Build Coastguard Worker 3915*da0073e9SAndroid Build Coastguard Worker # All non-zeros 3916*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3917*da0073e9SAndroid Build Coastguard Worker torch.count_nonzero(cpu_x), 3918*da0073e9SAndroid Build Coastguard Worker torch.count_nonzero(mps_x) 3919*da0073e9SAndroid Build Coastguard Worker ) 3920*da0073e9SAndroid Build Coastguard Worker 3921*da0073e9SAndroid Build Coastguard Worker # dim=1 3922*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3923*da0073e9SAndroid Build Coastguard Worker torch.count_nonzero(cpu_x, dim=1), 3924*da0073e9SAndroid Build Coastguard Worker torch.count_nonzero(mps_x, dim=1) 3925*da0073e9SAndroid Build Coastguard Worker ) 3926*da0073e9SAndroid Build Coastguard Worker 3927*da0073e9SAndroid Build Coastguard Worker # dim=(0, 1) 3928*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3929*da0073e9SAndroid Build Coastguard Worker torch.count_nonzero(cpu_x, dim=(0, 1)), 3930*da0073e9SAndroid Build Coastguard Worker torch.count_nonzero(mps_x, dim=(0, 1)) 3931*da0073e9SAndroid Build Coastguard Worker ) 3932*da0073e9SAndroid Build Coastguard Worker helper(torch.int32) 3933*da0073e9SAndroid Build Coastguard Worker helper(torch.int64) 3934*da0073e9SAndroid Build Coastguard Worker helper(torch.float16) 3935*da0073e9SAndroid Build Coastguard Worker helper(torch.float32) 3936*da0073e9SAndroid Build Coastguard Worker 3937*da0073e9SAndroid Build Coastguard Worker def _test_module_empty_input(self, module, inp, check_size=True): 3938*da0073e9SAndroid Build Coastguard Worker inp.requires_grad_(True) 3939*da0073e9SAndroid Build Coastguard Worker out = module(inp) 3940*da0073e9SAndroid Build Coastguard Worker gO = torch.rand_like(out) 3941*da0073e9SAndroid Build Coastguard Worker out.backward(gO) 3942*da0073e9SAndroid Build Coastguard Worker if check_size: 3943*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.size(), inp.size()) 3944*da0073e9SAndroid Build Coastguard Worker for p in module.parameters(): 3945*da0073e9SAndroid Build Coastguard Worker if p.requires_grad: 3946*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p.grad, torch.zeros_like(p.grad)) 3947*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp.grad, torch.zeros_like(inp)) 3948*da0073e9SAndroid Build Coastguard Worker 3949*da0073e9SAndroid Build Coastguard Worker # Test dtype casting, with and without simultaneous device change 3950*da0073e9SAndroid Build Coastguard Worker def test_to(self): 3951*da0073e9SAndroid Build Coastguard Worker values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] 3952*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu') 3953*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor(values, device='mps') 3954*da0073e9SAndroid Build Coastguard Worker 3955*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x.int(), mps_x.int().cpu()) 3956*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x.bool(), mps_x.bool().cpu()) 3957*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x.float(), mps_x.float().cpu()) 3958*da0073e9SAndroid Build Coastguard Worker 3959*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(), 3960*da0073e9SAndroid Build Coastguard Worker torch.tensor(1, dtype=torch.int32)) 3961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False)) 3962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True)) 3963*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(), 3964*da0073e9SAndroid Build Coastguard Worker torch.tensor(1, dtype=torch.int32)) 3965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(), 3966*da0073e9SAndroid Build Coastguard Worker torch.tensor(1.0)) 3967*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int), 3968*da0073e9SAndroid Build Coastguard Worker torch.tensor(4, dtype=torch.int32)) 3969*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(), 3970*da0073e9SAndroid Build Coastguard Worker torch.tensor(4, dtype=torch.int32)) 3971*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int), 3972*da0073e9SAndroid Build Coastguard Worker torch.tensor(-8.34, device='cpu').to('mps').to(torch.int)) 3973*da0073e9SAndroid Build Coastguard Worker # Cast int8 and uint8 to float and compare results 3974*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/80009 for more details 3975*da0073e9SAndroid Build Coastguard Worker cpu_byte = torch.tensor([60, 160, 20, 220], dtype=torch.uint8) 3976*da0073e9SAndroid Build Coastguard Worker cpu_char = torch.tensor([60, -60, 20, -120], dtype=torch.uint8) 3977*da0073e9SAndroid Build Coastguard Worker for x_cpu in [cpu_byte, cpu_char]: 3978*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.to('mps') 3979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_mps.to(torch.float32), x_cpu.to(torch.float32)) 3980*da0073e9SAndroid Build Coastguard Worker 3981*da0073e9SAndroid Build Coastguard Worker 3982*da0073e9SAndroid Build Coastguard Worker def test_setitem_scalar(self) -> None: 3983*da0073e9SAndroid Build Coastguard Worker device = 'mps' 3984*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.int32, torch.float32, torch.int64]: 3985*da0073e9SAndroid Build Coastguard Worker for i in range(3, 6): 3986*da0073e9SAndroid Build Coastguard Worker for j in range(3, 6): 3987*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(i, j, dtype=dtype, device=device) 3988*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.sum(), 0) 3989*da0073e9SAndroid Build Coastguard Worker t[1, 1] = 1 3990*da0073e9SAndroid Build Coastguard Worker t[2, 1] = j 3991*da0073e9SAndroid Build Coastguard Worker t[1, 2] = i 3992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], 1) 3993*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 2], i) 3994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[2, 1], j) 3995*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.sum(), 1 + i + j) 3996*da0073e9SAndroid Build Coastguard Worker 3997*da0073e9SAndroid Build Coastguard Worker def test_stride_of_strides(self) -> None: 3998*da0073e9SAndroid Build Coastguard Worker x = torch.rand(32, 1, device='mps') 3999*da0073e9SAndroid Build Coastguard Worker y = x.as_strided(size=(32, 2), stride=(1, 0)) 4000*da0073e9SAndroid Build Coastguard Worker # Casting stride of strided tensor to CPU use to crash with "buffer is not large enough." assert 4001*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/79181#issuecomment-1154683435 4002*da0073e9SAndroid Build Coastguard Worker z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu") 4003*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z) 4004*da0073e9SAndroid Build Coastguard Worker 4005*da0073e9SAndroid Build Coastguard Worker def test_type_casting(self): 4006*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/81567 4007*da0073e9SAndroid Build Coastguard Worker def helper(data, to_dtype): 4008*da0073e9SAndroid Build Coastguard Worker a_cpu = torch.tensor(data) 4009*da0073e9SAndroid Build Coastguard Worker a_mps = a_cpu.to(torch.device('mps')) 4010*da0073e9SAndroid Build Coastguard Worker 4011*da0073e9SAndroid Build Coastguard Worker res_cpu = a_cpu.type(to_dtype) 4012*da0073e9SAndroid Build Coastguard Worker res_mps = a_mps.type(to_dtype) 4013*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps) 4014*da0073e9SAndroid Build Coastguard Worker 4015*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor) 4016*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor) 4017*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor) 4018*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor) 4019*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor) 4020*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor) 4021*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor) 4022*da0073e9SAndroid Build Coastguard Worker 4023*da0073e9SAndroid Build Coastguard Worker def test_to_casting(self): 4024*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/81567 4025*da0073e9SAndroid Build Coastguard Worker def helper(data, to_dtype): 4026*da0073e9SAndroid Build Coastguard Worker a_cpu = torch.tensor(data) 4027*da0073e9SAndroid Build Coastguard Worker a_mps = a_cpu.to(torch.device('mps')) 4028*da0073e9SAndroid Build Coastguard Worker 4029*da0073e9SAndroid Build Coastguard Worker res_cpu = a_cpu.to(to_dtype) 4030*da0073e9SAndroid Build Coastguard Worker res_mps = a_mps.to(to_dtype) 4031*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps) 4032*da0073e9SAndroid Build Coastguard Worker 4033*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.int64) 4034*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.float) 4035*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.int32) 4036*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.short) 4037*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.half) 4038*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.int8) 4039*da0073e9SAndroid Build Coastguard Worker helper([9.0, 3.0, 5.0, 4.0], torch.uint8) 4040*da0073e9SAndroid Build Coastguard Worker 4041*da0073e9SAndroid Build Coastguard Worker def test_storage_offset_greater_than_src_nbytes(self): 4042*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/80844 4043*da0073e9SAndroid Build Coastguard Worker n_tensors = 100 4044*da0073e9SAndroid Build Coastguard Worker n_tensor_elems = 784 4045*da0073e9SAndroid Build Coastguard Worker elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32) 4046*da0073e9SAndroid Build Coastguard Worker 4047*da0073e9SAndroid Build Coastguard Worker tensor_list = [] 4048*da0073e9SAndroid Build Coastguard Worker for i in range(0, n_tensors - 1): 4049*da0073e9SAndroid Build Coastguard Worker # create a list of contiguous view tensors (view tensor created by the slice op) 4050*da0073e9SAndroid Build Coastguard Worker t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)] 4051*da0073e9SAndroid Build Coastguard Worker tensor_list.append(t) 4052*da0073e9SAndroid Build Coastguard Worker 4053*da0073e9SAndroid Build Coastguard Worker for i in range(0, n_tensors - 1): 4054*da0073e9SAndroid Build Coastguard Worker t = tensor_list[i].view(1, n_tensor_elems) 4055*da0073e9SAndroid Build Coastguard Worker t_mps = t.to("mps") 4056*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, t_mps.cpu(), f"i={i}") 4057*da0073e9SAndroid Build Coastguard Worker 4058*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/82427 4059*da0073e9SAndroid Build Coastguard Worker # and https://github.com/pytorch/pytorch/issues/83692 4060*da0073e9SAndroid Build Coastguard Worker def test_full_bugs(self): 4061*da0073e9SAndroid Build Coastguard Worker # Test should not crash 4062*da0073e9SAndroid Build Coastguard Worker x = torch.full((3, 3), True, device='mps') 4063*da0073e9SAndroid Build Coastguard Worker # torch.full should work for uint8 4064*da0073e9SAndroid Build Coastguard Worker y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8) 4065*da0073e9SAndroid Build Coastguard Worker y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8) 4066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_mps, y_cpu) 4067*da0073e9SAndroid Build Coastguard Worker 4068*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12") 4069*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/84995 4070*da0073e9SAndroid Build Coastguard Worker def test_div_bugs(self): 4071*da0073e9SAndroid Build Coastguard Worker for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']): 4072*da0073e9SAndroid Build Coastguard Worker if dtype != torch.int64: 4073*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype) 4074*da0073e9SAndroid Build Coastguard Worker y = torch.div(x, 101, rounding_mode=mode) 4075*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.sum(), 0) 4076*da0073e9SAndroid Build Coastguard Worker 4077*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/82663 4078*da0073e9SAndroid Build Coastguard Worker def test_bool_expand(self): 4079*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps') 4080*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0, 1], dtype=torch.bool, device='mps') 4081*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2))) 4082*da0073e9SAndroid Build Coastguard Worker 4083*da0073e9SAndroid Build Coastguard Worker def test_int_expand(self): 4084*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1], [0]], dtype=torch.int8, device='mps') 4085*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0, 1], dtype=torch.int8, device='mps') 4086*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2))) 4087*da0073e9SAndroid Build Coastguard Worker 4088*da0073e9SAndroid Build Coastguard Worker # Empty unary op should return tensor of the same size 4089*da0073e9SAndroid Build Coastguard Worker def test_empty_neg(self): 4090*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[]], device='mps') 4091*da0073e9SAndroid Build Coastguard Worker y = -x 4092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 4093*da0073e9SAndroid Build Coastguard Worker 4094*da0073e9SAndroid Build Coastguard Worker def _test_unique_scalar_empty(self, dtype, device, f): 4095*da0073e9SAndroid Build Coastguard Worker # test scalar 4096*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(0, dtype=dtype, device=device) 4097*da0073e9SAndroid Build Coastguard Worker unique, inverse, counts = f(x, return_inverse=True, return_counts=True) 4098*da0073e9SAndroid Build Coastguard Worker expected_unique = torch.tensor([0], dtype=dtype, device=device) 4099*da0073e9SAndroid Build Coastguard Worker expected_inverse = torch.tensor(0, device=device) 4100*da0073e9SAndroid Build Coastguard Worker expected_counts = torch.tensor([1], device=device) 4101*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unique, expected_unique) 4102*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inverse, expected_inverse) 4103*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counts, expected_counts) 4104*da0073e9SAndroid Build Coastguard Worker 4105*da0073e9SAndroid Build Coastguard Worker # test zero sized tensor 4106*da0073e9SAndroid Build Coastguard Worker x = torch.zeros((0, 0, 3), dtype=dtype, device=device) 4107*da0073e9SAndroid Build Coastguard Worker unique, inverse, counts = f(x, return_inverse=True, return_counts=True) 4108*da0073e9SAndroid Build Coastguard Worker expected_unique = torch.tensor([], dtype=dtype, device=device) 4109*da0073e9SAndroid Build Coastguard Worker expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device) 4110*da0073e9SAndroid Build Coastguard Worker expected_counts = torch.tensor([], dtype=torch.long, device=device) 4111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unique, expected_unique) 4112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inverse, expected_inverse) 4113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counts, expected_counts) 4114*da0073e9SAndroid Build Coastguard Worker 4115*da0073e9SAndroid Build Coastguard Worker def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape): 4116*da0073e9SAndroid Build Coastguard Worker def ensure_tuple(x): 4117*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor): 4118*da0073e9SAndroid Build Coastguard Worker return (x,) 4119*da0073e9SAndroid Build Coastguard Worker return x 4120*da0073e9SAndroid Build Coastguard Worker 4121*da0073e9SAndroid Build Coastguard Worker for return_inverse in [True, False]: 4122*da0073e9SAndroid Build Coastguard Worker for return_counts in [True, False]: 4123*da0073e9SAndroid Build Coastguard Worker # test with expected 4124*da0073e9SAndroid Build Coastguard Worker ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) 4125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) 4126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_unique, ret[0]) 4127*da0073e9SAndroid Build Coastguard Worker if return_inverse: 4128*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_inverse, ret[1]) 4129*da0073e9SAndroid Build Coastguard Worker if return_counts: 4130*da0073e9SAndroid Build Coastguard Worker count_index = 1 + int(return_inverse) 4131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_counts, ret[count_index]) 4132*da0073e9SAndroid Build Coastguard Worker 4133*da0073e9SAndroid Build Coastguard Worker # tests per-element unique on a higher rank tensor. 4134*da0073e9SAndroid Build Coastguard Worker y = x.view(additional_shape) 4135*da0073e9SAndroid Build Coastguard Worker y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True) 4136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_unique, y_unique) 4137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_inverse.view(additional_shape), y_inverse) 4138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_counts, y_counts) 4139*da0073e9SAndroid Build Coastguard Worker 4140*da0073e9SAndroid Build Coastguard Worker def test_unique_all_dtypes(self, device="mps"): 4141*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 4142*da0073e9SAndroid Build Coastguard Worker def ensure_tuple(x): 4143*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor): 4144*da0073e9SAndroid Build Coastguard Worker return (x,) 4145*da0073e9SAndroid Build Coastguard Worker return x 4146*da0073e9SAndroid Build Coastguard Worker 4147*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bool: 4148*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device) 4149*da0073e9SAndroid Build Coastguard Worker expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device) 4150*da0073e9SAndroid Build Coastguard Worker expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device) 4151*da0073e9SAndroid Build Coastguard Worker expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device) 4152*da0073e9SAndroid Build Coastguard Worker else: 4153*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device) 4154*da0073e9SAndroid Build Coastguard Worker expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device) 4155*da0073e9SAndroid Build Coastguard Worker expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device) 4156*da0073e9SAndroid Build Coastguard Worker expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device) 4157*da0073e9SAndroid Build Coastguard Worker 4158*da0073e9SAndroid Build Coastguard Worker # test sorted unique 4159*da0073e9SAndroid Build Coastguard Worker fs = ( 4160*da0073e9SAndroid Build Coastguard Worker lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs), 4161*da0073e9SAndroid Build Coastguard Worker lambda x, **kwargs: x.unique(sorted=True, **kwargs), 4162*da0073e9SAndroid Build Coastguard Worker ) 4163*da0073e9SAndroid Build Coastguard Worker x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x) 4164*da0073e9SAndroid Build Coastguard Worker xs = (x, x_sliced) 4165*da0073e9SAndroid Build Coastguard Worker for f, x in product(fs, xs): 4166*da0073e9SAndroid Build Coastguard Worker self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2)) 4167*da0073e9SAndroid Build Coastguard Worker self._test_unique_scalar_empty(dtype, device, f) 4168*da0073e9SAndroid Build Coastguard Worker 4169*da0073e9SAndroid Build Coastguard Worker # test unsorted unique 4170*da0073e9SAndroid Build Coastguard Worker fs = ( 4171*da0073e9SAndroid Build Coastguard Worker lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs), 4172*da0073e9SAndroid Build Coastguard Worker lambda x, **kwargs: x.unique(sorted=False, **kwargs) 4173*da0073e9SAndroid Build Coastguard Worker ) 4174*da0073e9SAndroid Build Coastguard Worker for f, x in product(fs, xs): 4175*da0073e9SAndroid Build Coastguard Worker self._test_unique_scalar_empty(dtype, device, f) 4176*da0073e9SAndroid Build Coastguard Worker for return_inverse, return_counts in product((True, False), repeat=2): 4177*da0073e9SAndroid Build Coastguard Worker ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) 4178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) 4179*da0073e9SAndroid Build Coastguard Worker x_list = x.tolist() 4180*da0073e9SAndroid Build Coastguard Worker x_unique_list = ret[0].tolist() 4181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_unique.tolist(), sorted(x_unique_list)) 4182*da0073e9SAndroid Build Coastguard Worker if return_inverse: 4183*da0073e9SAndroid Build Coastguard Worker x_inverse_list = ret[1].tolist() 4184*da0073e9SAndroid Build Coastguard Worker for i, j in enumerate(x_inverse_list): 4185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_list[i], x_unique_list[j]) 4186*da0073e9SAndroid Build Coastguard Worker if return_counts: 4187*da0073e9SAndroid Build Coastguard Worker count_index = 1 + int(return_inverse) 4188*da0073e9SAndroid Build Coastguard Worker x_counts_list = ret[count_index].tolist() 4189*da0073e9SAndroid Build Coastguard Worker for i, j in zip(x_unique_list, x_counts_list): 4190*da0073e9SAndroid Build Coastguard Worker count = 0 4191*da0073e9SAndroid Build Coastguard Worker for k in x_list: 4192*da0073e9SAndroid Build Coastguard Worker if k == i: 4193*da0073e9SAndroid Build Coastguard Worker count += 1 4194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(j, count) 4195*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in [torch.float32, torch.int64, torch.int32, torch.int16, torch.uint8]] 4196*da0073e9SAndroid Build Coastguard Worker 4197*da0073e9SAndroid Build Coastguard Worker def test_unique(self): 4198*da0073e9SAndroid Build Coastguard Worker def helper(x, return_inverse, return_counts): 4199*da0073e9SAndroid Build Coastguard Worker cpu_x = x 4200*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 4201*da0073e9SAndroid Build Coastguard Worker 4202*da0073e9SAndroid Build Coastguard Worker result = torch.unique(x, return_inverse=return_inverse, return_counts=return_counts) 4203*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.unique(cpu_x, return_inverse=return_inverse, return_counts=return_counts) 4204*da0073e9SAndroid Build Coastguard Worker 4205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_cpu) 4206*da0073e9SAndroid Build Coastguard Worker helper(torch.tensor([1, 2, 4, 2, 1]), False, False) 4207*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (10, )), False, False) 4208*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (10, )), True, False) 4209*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (10, )), False, True) 4210*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (10, )), True, True) 4211*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (1, )), True, True) 4212*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (0, )), True, True) 4213*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/104879 4214*da0073e9SAndroid Build Coastguard Worker x = torch.arange(2, device="mps") 4215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.reshape(1, 1, 2).unique(), x) 4216*da0073e9SAndroid Build Coastguard Worker 4217*da0073e9SAndroid Build Coastguard Worker def test_unique_consecutive(self): 4218*da0073e9SAndroid Build Coastguard Worker def helper(x, dim, return_inverse, return_counts): 4219*da0073e9SAndroid Build Coastguard Worker cpu_x = x 4220*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 4221*da0073e9SAndroid Build Coastguard Worker 4222*da0073e9SAndroid Build Coastguard Worker result = torch.unique_consecutive(x, dim=dim, return_inverse=return_inverse, return_counts=return_counts) 4223*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.unique_consecutive(cpu_x, dim=dim, return_inverse=return_inverse, return_counts=return_counts) 4224*da0073e9SAndroid Build Coastguard Worker 4225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_cpu) 4226*da0073e9SAndroid Build Coastguard Worker helper(torch.tensor([1, 2, 4, 2, 1]), 0, False, False) 4227*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (10, )), 0, False, False) 4228*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (10, )), 0, True, False) 4229*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (10, )), 0, False, True) 4230*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (10, )), 0, True, True) 4231*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (10, )), 0, True, True) 4232*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (1, )), 0, True, True) 4233*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(3, (0, )), 0, True, True) 4234*da0073e9SAndroid Build Coastguard Worker 4235*da0073e9SAndroid Build Coastguard Worker helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, False, False) 4236*da0073e9SAndroid Build Coastguard Worker helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, True, True) 4237*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(2, (20, 2)), 0, True, True) 4238*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(2, (1, 2)), 0, True, True) 4239*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(2, (0, 2)), 0, True, True) 4240*da0073e9SAndroid Build Coastguard Worker 4241*da0073e9SAndroid Build Coastguard Worker helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, False, False) 4242*da0073e9SAndroid Build Coastguard Worker helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, True, True) 4243*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(2, (2, 20)), 1, True, True) 4244*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(2, (2, 1)), 1, True, True) 4245*da0073e9SAndroid Build Coastguard Worker helper(torch.randint(2, (2, 0)), 1, True, True) 4246*da0073e9SAndroid Build Coastguard Worker 4247*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/85675 4248*da0073e9SAndroid Build Coastguard Worker def test_cat_non_contiguous(self): 4249*da0073e9SAndroid Build Coastguard Worker def rotate_subset(data, dim): 4250*da0073e9SAndroid Build Coastguard Worker x1 = data[:, :, :2, :] 4251*da0073e9SAndroid Build Coastguard Worker x2 = data[:, :, 2:, :] 4252*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x1.is_contiguous()) 4253*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x2.is_contiguous()) 4254*da0073e9SAndroid Build Coastguard Worker return torch.concat((x1, x2), dim=dim) 4255*da0073e9SAndroid Build Coastguard Worker for dtype in MPS_DTYPES: 4256*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bool: 4257*da0073e9SAndroid Build Coastguard Worker continue 4258*da0073e9SAndroid Build Coastguard Worker data = torch.arange(48, dtype=dtype).reshape(1, 2, 4, 6) 4259*da0073e9SAndroid Build Coastguard Worker data = data.to(memory_format=torch.channels_last) 4260*da0073e9SAndroid Build Coastguard Worker mps_data = data.to("mps") 4261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data, mps_data) 4262*da0073e9SAndroid Build Coastguard Worker for dim in range(data.dim()): 4263*da0073e9SAndroid Build Coastguard Worker cpu_result = rotate_subset(data, dim) 4264*da0073e9SAndroid Build Coastguard Worker mps_result = rotate_subset(mps_data, dim) 4265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_result, mps_result.to("cpu")) 4266*da0073e9SAndroid Build Coastguard Worker # TODO: enable memory format test 4267*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous()) 4268*da0073e9SAndroid Build Coastguard Worker 4269*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/85967 4270*da0073e9SAndroid Build Coastguard Worker def test_from_numpy_non_contiguous(self): 4271*da0073e9SAndroid Build Coastguard Worker a = np.arange(9).reshape(3, 3)[:, :2] 4272*da0073e9SAndroid Build Coastguard Worker t_cpu = torch.tensor(a, device="cpu") 4273*da0073e9SAndroid Build Coastguard Worker t_mps = torch.tensor(a, device="mps") 4274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_cpu, t_mps.to("cpu")) 4275*da0073e9SAndroid Build Coastguard Worker 4276*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/86954 4277*da0073e9SAndroid Build Coastguard Worker def test_copy_non_contiguous(self): 4278*da0073e9SAndroid Build Coastguard Worker x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1) 4279*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_contiguous()) 4280*da0073e9SAndroid Build Coastguard Worker y = x.to('mps') 4281*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_contiguous()) 4282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y.to('cpu')) 4283*da0073e9SAndroid Build Coastguard Worker 4284*da0073e9SAndroid Build Coastguard Worker x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2] 4285*da0073e9SAndroid Build Coastguard Worker y = x.to('mps') 4286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y.to('cpu')) 4287*da0073e9SAndroid Build Coastguard Worker 4288*da0073e9SAndroid Build Coastguard Worker x = torch.full((4, 4, 4, 4), 13, device="cpu") 4289*da0073e9SAndroid Build Coastguard Worker y = torch.full((4, 4, 4, 4), 13, device="mps") 4290*da0073e9SAndroid Build Coastguard Worker z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2] 4291*da0073e9SAndroid Build Coastguard Worker x.permute(3, 2, 1, 0)[1::, ::2] = z 4292*da0073e9SAndroid Build Coastguard Worker # As y is on MPS and z on CPU, this dispatches to a copy operator 4293*da0073e9SAndroid Build Coastguard Worker y.permute(3, 2, 1, 0)[1::, ::2] = z 4294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y.to('cpu')) 4295*da0073e9SAndroid Build Coastguard Worker 4296*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/95417 4297*da0073e9SAndroid Build Coastguard Worker def test_copy_storage_offset(self): 4298*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.zeros(5, device="cpu", dtype=torch.float32) 4299*da0073e9SAndroid Build Coastguard Worker x_mps = torch.zeros(5, device="mps", dtype=torch.float32) 4300*da0073e9SAndroid Build Coastguard Worker update_cpu = torch.tensor([1, 1], device="cpu", dtype=torch.int64) 4301*da0073e9SAndroid Build Coastguard Worker update_mps = torch.tensor([1, 1], device="mps", dtype=torch.int64) 4302*da0073e9SAndroid Build Coastguard Worker x_cpu[2:4] = update_cpu 4303*da0073e9SAndroid Build Coastguard Worker x_mps[2:4] = update_mps # implicit type casting and copy 4304*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_mps) 4305*da0073e9SAndroid Build Coastguard Worker 4306*da0073e9SAndroid Build Coastguard Worker x_cpu[2:4] = update_mps # implicit device moving and copy 4307*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_mps) 4308*da0073e9SAndroid Build Coastguard Worker 4309*da0073e9SAndroid Build Coastguard Worker def test_copy_broadcasting(self): 4310*da0073e9SAndroid Build Coastguard Worker def helper(src_shape, dst_shape, src_dtype, dst_dtype): 4311*da0073e9SAndroid Build Coastguard Worker cpu_src = torch.randint(0, 127, src_shape).to(src_dtype) 4312*da0073e9SAndroid Build Coastguard Worker cpu_dst = torch.randint(0, 127, dst_shape).to(dst_dtype) 4313*da0073e9SAndroid Build Coastguard Worker cpu_result = cpu_dst.copy_(cpu_src) 4314*da0073e9SAndroid Build Coastguard Worker mps_src = cpu_src.to("mps") 4315*da0073e9SAndroid Build Coastguard Worker mps_dst = cpu_dst.to("mps") 4316*da0073e9SAndroid Build Coastguard Worker mps_result = mps_dst.copy_(mps_src) 4317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_result, mps_result) 4318*da0073e9SAndroid Build Coastguard Worker 4319*da0073e9SAndroid Build Coastguard Worker test_dtypes = [torch.float32, torch.int32, torch.int16, torch.int8] 4320*da0073e9SAndroid Build Coastguard Worker 4321*da0073e9SAndroid Build Coastguard Worker for (src_dtype, dst_dtype) in itertools.product(test_dtypes, test_dtypes): 4322*da0073e9SAndroid Build Coastguard Worker helper((2, 1), (2, 3), src_dtype, dst_dtype) 4323*da0073e9SAndroid Build Coastguard Worker helper((2, 1), (2, 2), src_dtype, dst_dtype) 4324*da0073e9SAndroid Build Coastguard Worker helper((3, 1, 4, 1), (3, 4, 4, 5), src_dtype, dst_dtype) 4325*da0073e9SAndroid Build Coastguard Worker helper((3,), (2, 3), src_dtype, dst_dtype) 4326*da0073e9SAndroid Build Coastguard Worker helper((2,), (2, 2), src_dtype, dst_dtype) 4327*da0073e9SAndroid Build Coastguard Worker helper((4, 1, 5), (3, 4, 4, 5), src_dtype, dst_dtype) 4328*da0073e9SAndroid Build Coastguard Worker helper((4, 1, 5), (4, 0, 5), src_dtype, dst_dtype) 4329*da0073e9SAndroid Build Coastguard Worker helper((1, 5), (4, 0, 5), src_dtype, dst_dtype) 4330*da0073e9SAndroid Build Coastguard Worker helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype) 4331*da0073e9SAndroid Build Coastguard Worker helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype) 4332*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/107867 4333*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([[1]], device='mps').item(), 1.0) 4334*da0073e9SAndroid Build Coastguard Worker 4335*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/pull/84742 4336*da0073e9SAndroid Build Coastguard Worker # and https://github.com/pytorch/pytorch/pull/78319 4337*da0073e9SAndroid Build Coastguard Worker def test_binops_dtype_precedence(self): 4338*da0073e9SAndroid Build Coastguard Worker # Test dtype precedence (casting order) in binary operations by comparing to CPU result 4339*da0073e9SAndroid Build Coastguard Worker # Example values for all dtypes supported on the MPS backend 4340*da0073e9SAndroid Build Coastguard Worker sample_vals = { 4341*da0073e9SAndroid Build Coastguard Worker torch.bool: [False, True], 4342*da0073e9SAndroid Build Coastguard Worker torch.int16: [-15, 0, 1, 10], 4343*da0073e9SAndroid Build Coastguard Worker torch.int32: [-376, 0, 1, 13], 4344*da0073e9SAndroid Build Coastguard Worker torch.int64: [-8, 0, 1, 77], 4345*da0073e9SAndroid Build Coastguard Worker torch.float16: [-234.5, 0.0, 1.0, 2.0], 4346*da0073e9SAndroid Build Coastguard Worker torch.float32: [-1.0, 0.0, 0.1, 111.99], 4347*da0073e9SAndroid Build Coastguard Worker } 4348*da0073e9SAndroid Build Coastguard Worker # Test all combinations of dtypes, operations, dimensionality 4349*da0073e9SAndroid Build Coastguard Worker for dtype1, dtype2, binop in itertools.product( 4350*da0073e9SAndroid Build Coastguard Worker sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul', 'div']): 4351*da0073e9SAndroid Build Coastguard Worker # bool minus bool is generally unsupported, so skip 4352*da0073e9SAndroid Build Coastguard Worker if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool): 4353*da0073e9SAndroid Build Coastguard Worker continue 4354*da0073e9SAndroid Build Coastguard Worker full_shape = (10,) 4355*da0073e9SAndroid Build Coastguard Worker for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]): 4356*da0073e9SAndroid Build Coastguard Worker # print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})') 4357*da0073e9SAndroid Build Coastguard Worker # print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) 4358*da0073e9SAndroid Build Coastguard Worker # (torch.tensor(val2, dtype=dtype2, device='mps'))) 4359*da0073e9SAndroid Build Coastguard Worker # print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) 4360*da0073e9SAndroid Build Coastguard Worker # (torch.tensor(val2, dtype=dtype2, device='cpu'))) 4361*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4362*da0073e9SAndroid Build Coastguard Worker getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) 4363*da0073e9SAndroid Build Coastguard Worker (torch.tensor(val2, dtype=dtype2, device='mps')), 4364*da0073e9SAndroid Build Coastguard Worker getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) 4365*da0073e9SAndroid Build Coastguard Worker (torch.tensor(val2, dtype=dtype2, device='cpu'))) 4366*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4367*da0073e9SAndroid Build Coastguard Worker getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop) 4368*da0073e9SAndroid Build Coastguard Worker (torch.tensor([val2], dtype=dtype2, device='mps')), 4369*da0073e9SAndroid Build Coastguard Worker getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop) 4370*da0073e9SAndroid Build Coastguard Worker (torch.tensor([val2], dtype=dtype2, device='cpu'))) 4371*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4372*da0073e9SAndroid Build Coastguard Worker getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) 4373*da0073e9SAndroid Build Coastguard Worker (torch.tensor([val2], dtype=dtype2, device='mps')), 4374*da0073e9SAndroid Build Coastguard Worker getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) 4375*da0073e9SAndroid Build Coastguard Worker (torch.tensor([val2], dtype=dtype2, device='cpu'))) 4376*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4377*da0073e9SAndroid Build Coastguard Worker getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop) 4378*da0073e9SAndroid Build Coastguard Worker (torch.tensor(val2, dtype=dtype2, device='mps')), 4379*da0073e9SAndroid Build Coastguard Worker getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop) 4380*da0073e9SAndroid Build Coastguard Worker (torch.tensor(val2, dtype=dtype2, device='cpu'))) 4381*da0073e9SAndroid Build Coastguard Worker # Test tensors created with torch.full 4382*da0073e9SAndroid Build Coastguard Worker x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps') 4383*da0073e9SAndroid Build Coastguard Worker y1 = torch.tensor(val2, dtype=dtype2, device='mps') 4384*da0073e9SAndroid Build Coastguard Worker x2 = torch.full(full_shape, val1, dtype=dtype1, device='cpu') 4385*da0073e9SAndroid Build Coastguard Worker y2 = torch.tensor(val2, dtype=dtype2, device='cpu') 4386*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2)) 4387*da0073e9SAndroid Build Coastguard Worker x3 = torch.tensor(val1, dtype=dtype1, device='mps') 4388*da0073e9SAndroid Build Coastguard Worker y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps') 4389*da0073e9SAndroid Build Coastguard Worker x4 = torch.tensor(val1, dtype=dtype1, device='cpu') 4390*da0073e9SAndroid Build Coastguard Worker y4 = torch.full(full_shape, val2, dtype=dtype2, device='cpu') 4391*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4)) 4392*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4393*da0073e9SAndroid Build Coastguard Worker getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) 4394*da0073e9SAndroid Build Coastguard Worker (torch.full(full_shape, val2, dtype=dtype2, device='mps')), 4395*da0073e9SAndroid Build Coastguard Worker getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) 4396*da0073e9SAndroid Build Coastguard Worker (torch.full(full_shape, val2, dtype=dtype2, device='cpu'))) 4397*da0073e9SAndroid Build Coastguard Worker 4398*da0073e9SAndroid Build Coastguard Worker def test_nansum(self): 4399*da0073e9SAndroid Build Coastguard Worker def helper(dtype, noncontiguous, dim): 4400*da0073e9SAndroid Build Coastguard Worker zero_cpu = torch.zeros((), dtype=dtype) 4401*da0073e9SAndroid Build Coastguard Worker 4402*da0073e9SAndroid Build Coastguard Worker # Randomly scale the values 4403*da0073e9SAndroid Build Coastguard Worker scale = random.randint(10, 100) 4404*da0073e9SAndroid Build Coastguard Worker x_cpu: torch.Tensor = make_tensor( 4405*da0073e9SAndroid Build Coastguard Worker (5, 5), dtype=dtype, device='cpu', 4406*da0073e9SAndroid Build Coastguard Worker low=-scale, high=scale, noncontiguous=noncontiguous) 4407*da0073e9SAndroid Build Coastguard Worker 4408*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 4409*da0073e9SAndroid Build Coastguard Worker nan_mask_cpu = x_cpu < (0.2 * scale) 4410*da0073e9SAndroid Build Coastguard Worker x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu) 4411*da0073e9SAndroid Build Coastguard Worker x_cpu[nan_mask_cpu] = np.nan 4412*da0073e9SAndroid Build Coastguard Worker else: 4413*da0073e9SAndroid Build Coastguard Worker x_no_nan_cpu = x_cpu 4414*da0073e9SAndroid Build Coastguard Worker 4415*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.to('mps') 4416*da0073e9SAndroid Build Coastguard Worker actual_out_mps = torch.empty(0, dtype=dtype, device='mps') 4417*da0073e9SAndroid Build Coastguard Worker expect_out_cpu = torch.empty(0, dtype=dtype) 4418*da0073e9SAndroid Build Coastguard Worker dim_kwargs = {"dim": dim} if dim is not None else {} 4419*da0073e9SAndroid Build Coastguard Worker expect = torch.sum(x_no_nan_cpu, **dim_kwargs) 4420*da0073e9SAndroid Build Coastguard Worker 4421*da0073e9SAndroid Build Coastguard Worker actual_cpu = torch.nansum(x_cpu, **dim_kwargs) 4422*da0073e9SAndroid Build Coastguard Worker # Sanity check on CPU 4423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual_cpu) 4424*da0073e9SAndroid Build Coastguard Worker 4425*da0073e9SAndroid Build Coastguard Worker # Test MPS 4426*da0073e9SAndroid Build Coastguard Worker actual_mps = torch.nansum(x_mps, **dim_kwargs) 4427*da0073e9SAndroid Build Coastguard Worker # Test out= variant 4428*da0073e9SAndroid Build Coastguard Worker torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs) 4429*da0073e9SAndroid Build Coastguard Worker torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs) 4430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual_mps) 4431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect_out_cpu, actual_out_mps) 4432*da0073e9SAndroid Build Coastguard Worker 4433*da0073e9SAndroid Build Coastguard Worker args = itertools.product( 4434*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.float32, torch.int32, torch.int64), # dtype 4435*da0073e9SAndroid Build Coastguard Worker (True, False), # noncontiguous 4436*da0073e9SAndroid Build Coastguard Worker (0, 1, None), # dim 4437*da0073e9SAndroid Build Coastguard Worker ) 4438*da0073e9SAndroid Build Coastguard Worker 4439*da0073e9SAndroid Build Coastguard Worker for dtype, noncontiguous, dim in args: 4440*da0073e9SAndroid Build Coastguard Worker with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim): 4441*da0073e9SAndroid Build Coastguard Worker helper(dtype, noncontiguous, dim) 4442*da0073e9SAndroid Build Coastguard Worker 4443*da0073e9SAndroid Build Coastguard Worker def test_cumsum_all_dtypes(self): 4444*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 4445*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype) 4446*da0073e9SAndroid Build Coastguard Worker t_cpu = torch.tensor([1, 1, 1, 1], device="cpu") 4447*da0073e9SAndroid Build Coastguard Worker 4448*da0073e9SAndroid Build Coastguard Worker a = t.cumsum(0, dtype=dtype) 4449*da0073e9SAndroid Build Coastguard Worker a_cpu = t_cpu.cumsum(0, dtype=dtype) 4450*da0073e9SAndroid Build Coastguard Worker 4451*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.cpu(), a_cpu) 4452*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]] 4453*da0073e9SAndroid Build Coastguard Worker 4454*da0073e9SAndroid Build Coastguard Worker try: 4455*da0073e9SAndroid Build Coastguard Worker helper(torch.int64) 4456*da0073e9SAndroid Build Coastguard Worker except Exception as e: 4457*da0073e9SAndroid Build Coastguard Worker e_string = str(e) 4458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." + 4459*da0073e9SAndroid Build Coastguard Worker " Support has been added in macOS 13.3") 4460*da0073e9SAndroid Build Coastguard Worker 4461*da0073e9SAndroid Build Coastguard Worker def test_cumsum_bool(self): 4462*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2**16, dtype=torch.bool) 4463*da0073e9SAndroid Build Coastguard Worker t_cpu = a.cumsum(0) 4464*da0073e9SAndroid Build Coastguard Worker t_mps = a.to("mps").cumsum(0) 4465*da0073e9SAndroid Build Coastguard Worker 4466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_cpu, t_mps) 4467*da0073e9SAndroid Build Coastguard Worker 4468*da0073e9SAndroid Build Coastguard Worker def test_cumsum_minus_one_axis(self): 4469*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 4470*da0073e9SAndroid Build Coastguard Worker # Test with axis -1 4471*da0073e9SAndroid Build Coastguard Worker cpu_x = None 4472*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float32: 4473*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32) 4474*da0073e9SAndroid Build Coastguard Worker else: 4475*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32) 4476*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 4477*da0073e9SAndroid Build Coastguard Worker 4478*da0073e9SAndroid Build Coastguard Worker cpu_y = cpu_x.cumsum(-1) 4479*da0073e9SAndroid Build Coastguard Worker y = x.cumsum(-1) 4480*da0073e9SAndroid Build Coastguard Worker 4481*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, cpu_y) 4482*da0073e9SAndroid Build Coastguard Worker 4483*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]] 4484*da0073e9SAndroid Build Coastguard Worker 4485*da0073e9SAndroid Build Coastguard Worker def test_cumprod_all_dtypes(self): 4486*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 4487*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype) 4488*da0073e9SAndroid Build Coastguard Worker t_cpu = torch.tensor([1, 1, 1, 1], device="cpu") 4489*da0073e9SAndroid Build Coastguard Worker 4490*da0073e9SAndroid Build Coastguard Worker a = t.cumprod(0, dtype=dtype) 4491*da0073e9SAndroid Build Coastguard Worker a_cpu = t_cpu.cumprod(0, dtype=dtype) 4492*da0073e9SAndroid Build Coastguard Worker 4493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.cpu(), a_cpu) 4494*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]] 4495*da0073e9SAndroid Build Coastguard Worker 4496*da0073e9SAndroid Build Coastguard Worker try: 4497*da0073e9SAndroid Build Coastguard Worker helper(torch.int64) 4498*da0073e9SAndroid Build Coastguard Worker except Exception as e: 4499*da0073e9SAndroid Build Coastguard Worker e_string = str(e) 4500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e_string, "MPS does not support cumprod_out_mps op with int64 input." 4501*da0073e9SAndroid Build Coastguard Worker + " Support has been added in macOS 13.3") 4502*da0073e9SAndroid Build Coastguard Worker 4503*da0073e9SAndroid Build Coastguard Worker def test_cumprod_minus_one_axis(self): 4504*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 4505*da0073e9SAndroid Build Coastguard Worker # Test with axis -1 4506*da0073e9SAndroid Build Coastguard Worker cpu_x = None 4507*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float32: 4508*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32) 4509*da0073e9SAndroid Build Coastguard Worker else: 4510*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32) 4511*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 4512*da0073e9SAndroid Build Coastguard Worker 4513*da0073e9SAndroid Build Coastguard Worker cpu_y = cpu_x.cumprod(-1) 4514*da0073e9SAndroid Build Coastguard Worker y = x.cumprod(-1) 4515*da0073e9SAndroid Build Coastguard Worker 4516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, cpu_y) 4517*da0073e9SAndroid Build Coastguard Worker 4518*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]] 4519*da0073e9SAndroid Build Coastguard Worker 4520*da0073e9SAndroid Build Coastguard Worker def test_median_int16(self): 4521*da0073e9SAndroid Build Coastguard Worker def helper(shape, dtype): 4522*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype) 4523*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 4524*da0073e9SAndroid Build Coastguard Worker 4525*da0073e9SAndroid Build Coastguard Worker median_result = torch.median(x) 4526*da0073e9SAndroid Build Coastguard Worker median_result_cpu = torch.median(cpu_x) 4527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(median_result, median_result_cpu) 4528*da0073e9SAndroid Build Coastguard Worker 4529*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), torch.int16) 4530*da0073e9SAndroid Build Coastguard Worker 4531*da0073e9SAndroid Build Coastguard Worker def test_activation_checkpoint_does_not_error(self): 4532*da0073e9SAndroid Build Coastguard Worker from torch.utils.checkpoint import checkpoint 4533*da0073e9SAndroid Build Coastguard Worker 4534*da0073e9SAndroid Build Coastguard Worker for use_reentrant in (True, False): 4535*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1., device="mps", requires_grad=True) 4536*da0073e9SAndroid Build Coastguard Worker 4537*da0073e9SAndroid Build Coastguard Worker def fn(x): 4538*da0073e9SAndroid Build Coastguard Worker return x.sin().cos().exp() 4539*da0073e9SAndroid Build Coastguard Worker 4540*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=use_reentrant) 4541*da0073e9SAndroid Build Coastguard Worker out.backward() 4542*da0073e9SAndroid Build Coastguard Worker 4543*da0073e9SAndroid Build Coastguard Worker def test_as_strided(self): 4544*da0073e9SAndroid Build Coastguard Worker values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] 4545*da0073e9SAndroid Build Coastguard Worker values_1 = [[1.0, 1.0], [1.0, 1.0]] 4546*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu') 4547*da0073e9SAndroid Build Coastguard Worker ones1 = torch.tensor(values_1, device='mps') 4548*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 4549*da0073e9SAndroid Build Coastguard Worker strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2)) 4550*da0073e9SAndroid Build Coastguard Worker strided_mps = torch.as_strided(x, (2, 2), (1, 2)) 4551*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strided_mps, strided_cpu) 4552*da0073e9SAndroid Build Coastguard Worker strided_cpu_out = strided_cpu + ones1.to('cpu') 4553*da0073e9SAndroid Build Coastguard Worker strided_mps_out = strided_mps + ones1 4554*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strided_cpu_out, strided_mps_out) 4555*da0073e9SAndroid Build Coastguard Worker 4556*da0073e9SAndroid Build Coastguard Worker # test with storage offsets 4557*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.rand(3, 3, device='cpu') 4558*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.to('mps') 4559*da0073e9SAndroid Build Coastguard Worker strided_cpu1 = torch.as_strided(cpu_x, (2, 2), (1, 2), 0) 4560*da0073e9SAndroid Build Coastguard Worker strided_mps1 = torch.as_strided(mps_x, (2, 2), (1, 2), 0) 4561*da0073e9SAndroid Build Coastguard Worker strided_cpu2 = torch.as_strided(cpu_x, (2, 2), (1, 2), 1) 4562*da0073e9SAndroid Build Coastguard Worker strided_mps2 = torch.as_strided(mps_x, (2, 2), (1, 2), 1) 4563*da0073e9SAndroid Build Coastguard Worker strided_cpu_out = strided_cpu1 - strided_cpu2 4564*da0073e9SAndroid Build Coastguard Worker strided_mps_out = strided_mps1 - strided_mps2 4565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strided_cpu_out, strided_mps_out) 4566*da0073e9SAndroid Build Coastguard Worker 4567*da0073e9SAndroid Build Coastguard Worker def test_unfold(self): 4568*da0073e9SAndroid Build Coastguard Worker x = torch.arange(1., 8) 4569*da0073e9SAndroid Build Coastguard Worker x_mps = torch.arange(1., 8, device="mps") 4570*da0073e9SAndroid Build Coastguard Worker 4571*da0073e9SAndroid Build Coastguard Worker y = x.unfold(0, 2, 1) 4572*da0073e9SAndroid Build Coastguard Worker y_mps = x_mps.unfold(0, 2, 1) 4573*da0073e9SAndroid Build Coastguard Worker 4574*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_mps) 4575*da0073e9SAndroid Build Coastguard Worker 4576*da0073e9SAndroid Build Coastguard Worker def test_unfold_all_devices_and_dtypes(self): 4577*da0073e9SAndroid Build Coastguard Worker supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8] 4578*da0073e9SAndroid Build Coastguard Worker for dt in supported_dtypes: 4579*da0073e9SAndroid Build Coastguard Worker x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps") 4580*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape) 4581*da0073e9SAndroid Build Coastguard Worker 4582*da0073e9SAndroid Build Coastguard Worker def test_unfold_scalars(self): 4583*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(0.5, device="mps") 4584*da0073e9SAndroid Build Coastguard Worker # unfold on a 0-dimensional tensor should always return a 1-d dimensional 4585*da0073e9SAndroid Build Coastguard Worker # tensor of shape [size] (i.e., the second parameter to unfold) 4586*da0073e9SAndroid Build Coastguard Worker 4587*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1)) 4588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2)) 4589*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1)) 4590*da0073e9SAndroid Build Coastguard Worker 4591*da0073e9SAndroid Build Coastguard Worker def test_bincount_simple(self): 4592*da0073e9SAndroid Build Coastguard Worker input = torch.randint(0, 8, (5,), dtype=torch.int32, device="mps") 4593*da0073e9SAndroid Build Coastguard Worker input_cpu = input.to("cpu") 4594*da0073e9SAndroid Build Coastguard Worker weights = torch.linspace(0, 1, steps=5, device="mps", dtype=torch.float32) 4595*da0073e9SAndroid Build Coastguard Worker weights_cpu = weights.to("cpu") 4596*da0073e9SAndroid Build Coastguard Worker 4597*da0073e9SAndroid Build Coastguard Worker x = torch.bincount(input) 4598*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.bincount(input_cpu) 4599*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_cpu) 4600*da0073e9SAndroid Build Coastguard Worker 4601*da0073e9SAndroid Build Coastguard Worker y = input.bincount(weights) 4602*da0073e9SAndroid Build Coastguard Worker y_cpu = input_cpu.bincount(weights_cpu) 4603*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_cpu) 4604*da0073e9SAndroid Build Coastguard Worker 4605*da0073e9SAndroid Build Coastguard Worker def test_bincount_reduction(self): 4606*da0073e9SAndroid Build Coastguard Worker device = "mps" 4607*da0073e9SAndroid Build Coastguard Worker # negative input throws 4608*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): 4609*da0073e9SAndroid Build Coastguard Worker torch.bincount(torch.tensor([1, -1], device=device, dtype=torch.int32)) 4610*da0073e9SAndroid Build Coastguard Worker # n-d input, with n > 1 throws 4611*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): 4612*da0073e9SAndroid Build Coastguard Worker torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device)) 4613*da0073e9SAndroid Build Coastguard Worker # minlength < 0 throws 4614*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'): 4615*da0073e9SAndroid Build Coastguard Worker torch.bincount(torch.tensor([1, 3], device=device), 4616*da0073e9SAndroid Build Coastguard Worker torch.tensor([.2, .2], device=device), 4617*da0073e9SAndroid Build Coastguard Worker minlength=-1) 4618*da0073e9SAndroid Build Coastguard Worker # n-d weights, with n > 1 throws 4619*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, '1-d'): 4620*da0073e9SAndroid Build Coastguard Worker torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32), 4621*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1., 0.3], [1., 0.3]], device=device, dtype=torch.float)) 4622*da0073e9SAndroid Build Coastguard Worker # input and weights dim mismatch 4623*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'same length'): 4624*da0073e9SAndroid Build Coastguard Worker torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32), 4625*da0073e9SAndroid Build Coastguard Worker torch.tensor([1., 0.3, 0.5], device=device, dtype=torch.float)) 4626*da0073e9SAndroid Build Coastguard Worker # 1-d input with no elements and default minlength 4627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)), 4628*da0073e9SAndroid Build Coastguard Worker torch.zeros(0, dtype=torch.long, device=device)) 4629*da0073e9SAndroid Build Coastguard Worker # 1-d input with no elements and specified minlength 4630*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10), 4631*da0073e9SAndroid Build Coastguard Worker torch.zeros(10, dtype=torch.long, device=device)) 4632*da0073e9SAndroid Build Coastguard Worker 4633*da0073e9SAndroid Build Coastguard Worker # test tensor method without weights 4634*da0073e9SAndroid Build Coastguard Worker long_counts = torch.tensor( 4635*da0073e9SAndroid Build Coastguard Worker [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount() 4636*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4637*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device), 4638*da0073e9SAndroid Build Coastguard Worker long_counts) 4639*da0073e9SAndroid Build Coastguard Worker # test avoiding overflow for uint8 (#76979) 4640*da0073e9SAndroid Build Coastguard Worker count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount() 4641*da0073e9SAndroid Build Coastguard Worker count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount() 4642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count_uint8, count_int16) 4643*da0073e9SAndroid Build Coastguard Worker # test minlength functionality 4644*da0073e9SAndroid Build Coastguard Worker int_counts = torch.bincount( 4645*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 1, 1, 1], device=device, dtype=torch.int32), minlength=5) 4646*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4647*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device), 4648*da0073e9SAndroid Build Coastguard Worker int_counts) 4649*da0073e9SAndroid Build Coastguard Worker # test weights 4650*da0073e9SAndroid Build Coastguard Worker byte_counts = torch.bincount( 4651*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32), 4652*da0073e9SAndroid Build Coastguard Worker torch.tensor([.1, .2, .3, .4, .5], device=device)) 4653*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4654*da0073e9SAndroid Build Coastguard Worker torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts) 4655*da0073e9SAndroid Build Coastguard Worker byte_counts = torch.bincount( 4656*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32), 4657*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device)) 4658*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4659*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.int32), byte_counts) 4660*da0073e9SAndroid Build Coastguard Worker # test non-contiguous inputs and weights 4661*da0073e9SAndroid Build Coastguard Worker inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device, dtype=torch.int32) 4662*da0073e9SAndroid Build Coastguard Worker weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device) 4663*da0073e9SAndroid Build Coastguard Worker for i in [0, 1]: 4664*da0073e9SAndroid Build Coastguard Worker assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous" 4665*da0073e9SAndroid Build Coastguard Worker assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous" 4666*da0073e9SAndroid Build Coastguard Worker # inputs are non-contiguous but weights are contiguous 4667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2])) 4668*da0073e9SAndroid Build Coastguard Worker # inputs and weights are non-contiguous 4669*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4670*da0073e9SAndroid Build Coastguard Worker inputs[:, 1].bincount(weights[:, 1]), 4671*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) 4672*da0073e9SAndroid Build Coastguard Worker # weights are non-contiguous but inputs are contiguous 4673*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]), 4674*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) 4675*da0073e9SAndroid Build Coastguard Worker 4676*da0073e9SAndroid Build Coastguard Worker # test bincount on non-contiguous slices 4677*da0073e9SAndroid Build Coastguard Worker all0s = torch.zeros((32, 2), dtype=torch.int32, device=device) 4678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32])) 4679*da0073e9SAndroid Build Coastguard Worker 4680*da0073e9SAndroid Build Coastguard Worker all1s = torch.ones((32, 2), dtype=torch.int32, device=device) 4681*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32])) 4682*da0073e9SAndroid Build Coastguard Worker 4683*da0073e9SAndroid Build Coastguard Worker # test large number of bins - global memory use 4684*da0073e9SAndroid Build Coastguard Worker big_exp = torch.zeros(100, device=device) 4685*da0073e9SAndroid Build Coastguard Worker big_exp[-1] = 50.0 4686*da0073e9SAndroid Build Coastguard Worker big_w = torch.tensor([.5] * 100, device=device) 4687*da0073e9SAndroid Build Coastguard Worker big_out = torch.tensor([99] * 100, device=device, dtype=torch.int32).bincount(big_w) 4688*da0073e9SAndroid Build Coastguard Worker self.assertEqual(big_exp, big_out) 4689*da0073e9SAndroid Build Coastguard Worker # test large input size 4690*da0073e9SAndroid Build Coastguard Worker big_exp = torch.zeros(2, device=device, dtype=torch.int64) 4691*da0073e9SAndroid Build Coastguard Worker big_exp[1] = 10 4692*da0073e9SAndroid Build Coastguard Worker big_out = torch.ones(10, dtype=torch.int8, device=device).bincount() 4693*da0073e9SAndroid Build Coastguard Worker self.assertEqual(big_exp, big_out) 4694*da0073e9SAndroid Build Coastguard Worker 4695*da0073e9SAndroid Build Coastguard Worker def test_bincount(self): 4696*da0073e9SAndroid Build Coastguard Worker device = "mps" 4697*da0073e9SAndroid Build Coastguard Worker input_size = (5000,) 4698*da0073e9SAndroid Build Coastguard Worker w = torch.randn(input_size, dtype=torch.float, device=device) 4699*da0073e9SAndroid Build Coastguard Worker w_cpu = w.cpu() 4700*da0073e9SAndroid Build Coastguard Worker 4701*da0073e9SAndroid Build Coastguard Worker t = torch.randint(50, input_size, dtype=torch.int8, device=device) 4702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.cpu().bincount(), t.bincount()) 4703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) 4704*da0073e9SAndroid Build Coastguard Worker 4705*da0073e9SAndroid Build Coastguard Worker t = torch.randint(500, input_size, dtype=torch.int32, device=device) 4706*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.cpu().bincount(), t.bincount()) 4707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) 4708*da0073e9SAndroid Build Coastguard Worker 4709*da0073e9SAndroid Build Coastguard Worker t = torch.randint(2000, input_size, dtype=torch.int32, device=device) 4710*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.cpu().bincount(), t.bincount()) 4711*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) 4712*da0073e9SAndroid Build Coastguard Worker 4713*da0073e9SAndroid Build Coastguard Worker t = torch.zeros([10], dtype=torch.int32, device=device) 4714*da0073e9SAndroid Build Coastguard Worker t[0] = 35488 4715*da0073e9SAndroid Build Coastguard Worker counted = t.bincount(minlength=65536) 4716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sum(counted), 10) 4717*da0073e9SAndroid Build Coastguard Worker 4718*da0073e9SAndroid Build Coastguard Worker def test_sum_backward(self): 4719*da0073e9SAndroid Build Coastguard Worker def helper(n, c): 4720*da0073e9SAndroid Build Coastguard Worker values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] 4721*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu', requires_grad=True) 4722*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 4723*da0073e9SAndroid Build Coastguard Worker 4724*da0073e9SAndroid Build Coastguard Worker all_sum = torch.sum(x) 4725*da0073e9SAndroid Build Coastguard Worker all_sum_cpu = torch.sum(cpu_x) 4726*da0073e9SAndroid Build Coastguard Worker 4727*da0073e9SAndroid Build Coastguard Worker all_sum.backward() 4728*da0073e9SAndroid Build Coastguard Worker all_sum_cpu.backward() 4729*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_sum, all_sum_cpu) 4730*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 4731*da0073e9SAndroid Build Coastguard Worker 4732*da0073e9SAndroid Build Coastguard Worker helper(3, 3) 4733*da0073e9SAndroid Build Coastguard Worker 4734*da0073e9SAndroid Build Coastguard Worker # L1 loss 4735*da0073e9SAndroid Build Coastguard Worker def test_l1_loss(self): 4736*da0073e9SAndroid Build Coastguard Worker def helper(shape, reduction): 4737*da0073e9SAndroid Build Coastguard Worker # create the criterion 4738*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.L1Loss(reduction=reduction) 4739*da0073e9SAndroid Build Coastguard Worker 4740*da0073e9SAndroid Build Coastguard Worker inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 4741*da0073e9SAndroid Build Coastguard Worker targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 4742*da0073e9SAndroid Build Coastguard Worker inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 4743*da0073e9SAndroid Build Coastguard Worker targetMPS = targetCPU.detach().clone().to('mps') 4744*da0073e9SAndroid Build Coastguard Worker 4745*da0073e9SAndroid Build Coastguard Worker # forward pass 4746*da0073e9SAndroid Build Coastguard Worker outputCPU = loss(inputCPU, targetCPU) 4747*da0073e9SAndroid Build Coastguard Worker outputMPS = loss(inputMPS, targetMPS) 4748*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputCPU, outputMPS) 4749*da0073e9SAndroid Build Coastguard Worker 4750*da0073e9SAndroid Build Coastguard Worker # backward pass 4751*da0073e9SAndroid Build Coastguard Worker if reduction != 'none': 4752*da0073e9SAndroid Build Coastguard Worker # chose 2 just to make the grad_output > 1 in backward pass 4753*da0073e9SAndroid Build Coastguard Worker outputCPU.backward(gradient=torch.full_like(outputCPU, 2)) 4754*da0073e9SAndroid Build Coastguard Worker outputMPS.backward(gradient=torch.full_like(outputMPS, 2)) 4755*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputCPU.grad, inputMPS.grad) 4756*da0073e9SAndroid Build Coastguard Worker 4757*da0073e9SAndroid Build Coastguard Worker helper([8, 5, 4], 'none') 4758*da0073e9SAndroid Build Coastguard Worker helper([7, 5, 2, 4], 'sum') 4759*da0073e9SAndroid Build Coastguard Worker # verify if changes in shape would cause cached graph lookup problems 4760*da0073e9SAndroid Build Coastguard Worker helper([7, 5, 2, 4, 6], 'sum') 4761*da0073e9SAndroid Build Coastguard Worker helper([8, 4, 5, 7, 6], 'mean') 4762*da0073e9SAndroid Build Coastguard Worker 4763*da0073e9SAndroid Build Coastguard Worker # Mean Squared Error 4764*da0073e9SAndroid Build Coastguard Worker def test_mse_loss(self): 4765*da0073e9SAndroid Build Coastguard Worker def helper(shape, reduction): 4766*da0073e9SAndroid Build Coastguard Worker # create the criterion 4767*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.MSELoss(reduction=reduction) 4768*da0073e9SAndroid Build Coastguard Worker 4769*da0073e9SAndroid Build Coastguard Worker inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 4770*da0073e9SAndroid Build Coastguard Worker targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 4771*da0073e9SAndroid Build Coastguard Worker inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 4772*da0073e9SAndroid Build Coastguard Worker targetMPS = targetCPU.detach().clone().to('mps') 4773*da0073e9SAndroid Build Coastguard Worker 4774*da0073e9SAndroid Build Coastguard Worker # forward pass 4775*da0073e9SAndroid Build Coastguard Worker outputCPU = loss(inputCPU, targetCPU) 4776*da0073e9SAndroid Build Coastguard Worker outputMPS = loss(inputMPS, targetMPS) 4777*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputCPU, outputMPS) 4778*da0073e9SAndroid Build Coastguard Worker 4779*da0073e9SAndroid Build Coastguard Worker # backward pass 4780*da0073e9SAndroid Build Coastguard Worker if reduction != 'none': 4781*da0073e9SAndroid Build Coastguard Worker # chose 2 just to make the grad_output > 1 in backward pass 4782*da0073e9SAndroid Build Coastguard Worker outputCPU.backward(gradient=torch.full_like(outputCPU, 2)) 4783*da0073e9SAndroid Build Coastguard Worker outputMPS.backward(gradient=torch.full_like(outputMPS, 2)) 4784*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputCPU.grad, inputMPS.grad) 4785*da0073e9SAndroid Build Coastguard Worker 4786*da0073e9SAndroid Build Coastguard Worker helper([8, 5, 4], 'none') 4787*da0073e9SAndroid Build Coastguard Worker helper([7, 5, 2, 4], 'sum') 4788*da0073e9SAndroid Build Coastguard Worker # verify if changes in shape would cause cached graph lookup problems 4789*da0073e9SAndroid Build Coastguard Worker helper([7, 5, 2, 4, 6], 'sum') 4790*da0073e9SAndroid Build Coastguard Worker helper([8, 4, 5, 7, 6], 'mean') 4791*da0073e9SAndroid Build Coastguard Worker 4792*da0073e9SAndroid Build Coastguard Worker def test_mse_loss_strided_output(self): 4793*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/124621 4794*da0073e9SAndroid Build Coastguard Worker lf = nn.MSELoss(reduction='none') 4795*da0073e9SAndroid Build Coastguard Worker model_cpu = nn.Sequential( 4796*da0073e9SAndroid Build Coastguard Worker nn.Conv1d(3, 3, 1), 4797*da0073e9SAndroid Build Coastguard Worker ) 4798*da0073e9SAndroid Build Coastguard Worker model_mps = copy.deepcopy(model_cpu).to("mps") 4799*da0073e9SAndroid Build Coastguard Worker 4800*da0073e9SAndroid Build Coastguard Worker x = torch.randn(128, 10, 3) 4801*da0073e9SAndroid Build Coastguard Worker x = x.permute(0, 2, 1) 4802*da0073e9SAndroid Build Coastguard Worker 4803*da0073e9SAndroid Build Coastguard Worker x_mps = x.detach().clone().to("mps").permute(0, 2, 1) 4804*da0073e9SAndroid Build Coastguard Worker x_mps = x_mps.permute(0, 2, 1) 4805*da0073e9SAndroid Build Coastguard Worker 4806*da0073e9SAndroid Build Coastguard Worker y = model_cpu(x) 4807*da0073e9SAndroid Build Coastguard Worker y_mps = model_mps(x_mps) 4808*da0073e9SAndroid Build Coastguard Worker 4809*da0073e9SAndroid Build Coastguard Worker y = y.permute(0, 2, 1)[:, :5, :] 4810*da0073e9SAndroid Build Coastguard Worker y_mps = y_mps.permute(0, 2, 1)[:, :5, :] 4811*da0073e9SAndroid Build Coastguard Worker 4812*da0073e9SAndroid Build Coastguard Worker y_hat = torch.randn(128, 5, 3) 4813*da0073e9SAndroid Build Coastguard Worker y_hat_mps = y_hat.detach().clone().to("mps") 4814*da0073e9SAndroid Build Coastguard Worker 4815*da0073e9SAndroid Build Coastguard Worker loss = lf(y, y_hat) 4816*da0073e9SAndroid Build Coastguard Worker loss_mps = lf(y_mps, y_hat_mps) 4817*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loss, loss_mps) 4818*da0073e9SAndroid Build Coastguard Worker 4819*da0073e9SAndroid Build Coastguard Worker # Binary Cross Enropy 4820*da0073e9SAndroid Build Coastguard Worker def test_bce_loss_simple(self): 4821*da0073e9SAndroid Build Coastguard Worker def helper(shape, reduction): 4822*da0073e9SAndroid Build Coastguard Worker # create the criterion 4823*da0073e9SAndroid Build Coastguard Worker loss = torch.nn.BCELoss(reduction=reduction) 4824*da0073e9SAndroid Build Coastguard Worker 4825*da0073e9SAndroid Build Coastguard Worker # input and target must be within [0..1] 4826*da0073e9SAndroid Build Coastguard Worker input_t = np.random.random_sample(size=shape).astype(np.float32) 4827*da0073e9SAndroid Build Coastguard Worker target_t = np.random.random_sample(size=shape).astype(np.float32) 4828*da0073e9SAndroid Build Coastguard Worker inputCPU = torch.tensor(input_t, device='cpu', dtype=torch.float, requires_grad=True) 4829*da0073e9SAndroid Build Coastguard Worker targetCPU = torch.tensor(target_t, device='cpu', dtype=torch.float, requires_grad=False) 4830*da0073e9SAndroid Build Coastguard Worker inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 4831*da0073e9SAndroid Build Coastguard Worker targetMPS = targetCPU.detach().clone().to('mps') 4832*da0073e9SAndroid Build Coastguard Worker 4833*da0073e9SAndroid Build Coastguard Worker # forward pass 4834*da0073e9SAndroid Build Coastguard Worker outputCPU = loss(inputCPU, targetCPU) 4835*da0073e9SAndroid Build Coastguard Worker outputMPS = loss(inputMPS, targetMPS) 4836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputCPU, outputMPS) 4837*da0073e9SAndroid Build Coastguard Worker 4838*da0073e9SAndroid Build Coastguard Worker # backward pass 4839*da0073e9SAndroid Build Coastguard Worker if reduction != 'none': 4840*da0073e9SAndroid Build Coastguard Worker # chose 0.6 just to have the grad_output != 1 4841*da0073e9SAndroid Build Coastguard Worker outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6)) 4842*da0073e9SAndroid Build Coastguard Worker outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6)) 4843*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputCPU.grad, inputMPS.grad) 4844*da0073e9SAndroid Build Coastguard Worker 4845*da0073e9SAndroid Build Coastguard Worker helper([8, 5, 4], 'none') 4846*da0073e9SAndroid Build Coastguard Worker helper([7, 5, 2, 4], 'sum') 4847*da0073e9SAndroid Build Coastguard Worker # verify if changes in shape would cause cached graph lookup problems 4848*da0073e9SAndroid Build Coastguard Worker helper([7, 5, 2, 4, 6], 'sum') 4849*da0073e9SAndroid Build Coastguard Worker helper([8, 4, 5, 7, 6], 'mean') 4850*da0073e9SAndroid Build Coastguard Worker helper([1, 1, 32, 32], 'mean') 4851*da0073e9SAndroid Build Coastguard Worker 4852*da0073e9SAndroid Build Coastguard Worker def test_bce_loss_always_nonnegative(self): 4853*da0073e9SAndroid Build Coastguard Worker target = torch.ones(5, device='mps') 4854*da0073e9SAndroid Build Coastguard Worker input = torch.ones(5, device='mps') 4855*da0073e9SAndroid Build Coastguard Worker self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0) 4856*da0073e9SAndroid Build Coastguard Worker 4857*da0073e9SAndroid Build Coastguard Worker target = torch.zeros(5, device='mps') 4858*da0073e9SAndroid Build Coastguard Worker input = torch.zeros(5, device='mps') 4859*da0073e9SAndroid Build Coastguard Worker self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0) 4860*da0073e9SAndroid Build Coastguard Worker 4861*da0073e9SAndroid Build Coastguard Worker def test_bce_loss_size_mismatch(self): 4862*da0073e9SAndroid Build Coastguard Worker bceloss = nn.BCELoss() 4863*da0073e9SAndroid Build Coastguard Worker a = torch.rand(25, device='mps') 4864*da0073e9SAndroid Build Coastguard Worker b = torch.rand(25, 1, device='mps') 4865*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, r'Using a target size \('): 4866*da0073e9SAndroid Build Coastguard Worker bceloss(a, b) 4867*da0073e9SAndroid Build Coastguard Worker 4868*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self): 4869*da0073e9SAndroid Build Coastguard Worker x_size = 1024 4870*da0073e9SAndroid Build Coastguard Worker y_size = 256 4871*da0073e9SAndroid Build Coastguard Worker target = torch.rand(x_size, y_size, device='mps') 4872*da0073e9SAndroid Build Coastguard Worker 4873*da0073e9SAndroid Build Coastguard Worker for reduction in ['none', 'mean', 'sum']: 4874*da0073e9SAndroid Build Coastguard Worker output_sig = torch.rand(x_size, y_size, device='mps') - 0.5 4875*da0073e9SAndroid Build Coastguard Worker output_logits = output_sig.clone().detach() 4876*da0073e9SAndroid Build Coastguard Worker 4877*da0073e9SAndroid Build Coastguard Worker output_sig.requires_grad = True 4878*da0073e9SAndroid Build Coastguard Worker output_logits.requires_grad = True 4879*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(y_size, device='mps') 4880*da0073e9SAndroid Build Coastguard Worker 4881*da0073e9SAndroid Build Coastguard Worker loss_sig = nn.BCELoss(weight, reduction=reduction)( 4882*da0073e9SAndroid Build Coastguard Worker torch.sigmoid(output_sig), target 4883*da0073e9SAndroid Build Coastguard Worker ) 4884*da0073e9SAndroid Build Coastguard Worker loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)( 4885*da0073e9SAndroid Build Coastguard Worker output_logits, target 4886*da0073e9SAndroid Build Coastguard Worker ) 4887*da0073e9SAndroid Build Coastguard Worker 4888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loss_logits, loss_sig) 4889*da0073e9SAndroid Build Coastguard Worker 4890*da0073e9SAndroid Build Coastguard Worker if reduction == 'none': 4891*da0073e9SAndroid Build Coastguard Worker grad = torch.rand(x_size, y_size, device='mps') 4892*da0073e9SAndroid Build Coastguard Worker loss_sig.backward(grad) 4893*da0073e9SAndroid Build Coastguard Worker loss_logits.backward(grad) 4894*da0073e9SAndroid Build Coastguard Worker else: 4895*da0073e9SAndroid Build Coastguard Worker loss_sig.backward() 4896*da0073e9SAndroid Build Coastguard Worker loss_logits.backward() 4897*da0073e9SAndroid Build Coastguard Worker 4898*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_sig.grad, output_logits.grad) 4899*da0073e9SAndroid Build Coastguard Worker 4900*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_has_correct_grad_at_zero(self): 4901*da0073e9SAndroid Build Coastguard Worker output = torch.zeros(3, 1, requires_grad=True, device='mps') 4902*da0073e9SAndroid Build Coastguard Worker target = torch.zeros(3, 1, device='mps') 4903*da0073e9SAndroid Build Coastguard Worker nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward() 4904*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.empty(3, 1, device='mps').fill_(0.5) 4905*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.grad, expected_grad) 4906*da0073e9SAndroid Build Coastguard Worker 4907*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_broadcasts_weights(self): 4908*da0073e9SAndroid Build Coastguard Worker target = torch.rand(16, 4, device='mps') 4909*da0073e9SAndroid Build Coastguard Worker output = torch.rand(16, 4, device='mps') - 0.5 4910*da0073e9SAndroid Build Coastguard Worker 4911*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(4, device='mps') 4912*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCEWithLogitsLoss(weight)(output, target) 4913*da0073e9SAndroid Build Coastguard Worker 4914*da0073e9SAndroid Build Coastguard Worker weight = weight.expand(16, 4).contiguous() 4915*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCEWithLogitsLoss(weight)(output, target) 4916*da0073e9SAndroid Build Coastguard Worker 4917*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 4918*da0073e9SAndroid Build Coastguard Worker 4919*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(16, 1, device='mps') 4920*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCEWithLogitsLoss(weight)(output, target) 4921*da0073e9SAndroid Build Coastguard Worker 4922*da0073e9SAndroid Build Coastguard Worker weight = weight.expand(16, 4).contiguous() 4923*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCEWithLogitsLoss(weight)(output, target) 4924*da0073e9SAndroid Build Coastguard Worker 4925*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 4926*da0073e9SAndroid Build Coastguard Worker 4927*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self): 4928*da0073e9SAndroid Build Coastguard Worker target = torch.rand(64, 4, device='mps') 4929*da0073e9SAndroid Build Coastguard Worker output = torch.rand(64, 4, device='mps') - 0.5 4930*da0073e9SAndroid Build Coastguard Worker pos_weight = torch.ones(64, 4, device='mps') 4931*da0073e9SAndroid Build Coastguard Worker 4932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nn.BCEWithLogitsLoss()(output, target), 4933*da0073e9SAndroid Build Coastguard Worker nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)) 4934*da0073e9SAndroid Build Coastguard Worker 4935*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_broadcasts_pos_weights(self): 4936*da0073e9SAndroid Build Coastguard Worker target = torch.rand(64, 4, device='mps') 4937*da0073e9SAndroid Build Coastguard Worker output = torch.rand(64, 4, device='mps') - 0.5 4938*da0073e9SAndroid Build Coastguard Worker pos_weight = torch.rand(4, device='mps') 4939*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target) 4940*da0073e9SAndroid Build Coastguard Worker 4941*da0073e9SAndroid Build Coastguard Worker pos_weight1 = pos_weight.expand(1, 4) 4942*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target) 4943*da0073e9SAndroid Build Coastguard Worker 4944*da0073e9SAndroid Build Coastguard Worker pos_weight2 = pos_weight.expand(64, 4) 4945*da0073e9SAndroid Build Coastguard Worker out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target) 4946*da0073e9SAndroid Build Coastguard Worker 4947*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 4948*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out3) 4949*da0073e9SAndroid Build Coastguard Worker 4950*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self): 4951*da0073e9SAndroid Build Coastguard Worker output = torch.zeros(3, 1, requires_grad=True, device='mps') 4952*da0073e9SAndroid Build Coastguard Worker target = torch.zeros(3, 1, device='mps') 4953*da0073e9SAndroid Build Coastguard Worker pos_weight = torch.ones(3, 1, device='mps') 4954*da0073e9SAndroid Build Coastguard Worker nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward() 4955*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.empty(3, 1, device='mps').fill_(0.5) 4956*da0073e9SAndroid Build Coastguard Worker grad = output.grad 4957*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, expected_grad) 4958*da0073e9SAndroid Build Coastguard Worker 4959*da0073e9SAndroid Build Coastguard Worker def test_bce_with_logits_stability(self): 4960*da0073e9SAndroid Build Coastguard Worker output = torch.tensor([0., -120.], device='mps') 4961*da0073e9SAndroid Build Coastguard Worker target = torch.tensor([0., 1.], device='mps') 4962*da0073e9SAndroid Build Coastguard Worker pos_weight = torch.tensor([1., 1.], device='mps') 4963*da0073e9SAndroid Build Coastguard Worker 4964*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCEWithLogitsLoss()(output, target) 4965*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isfinite(out1).all().item()) 4966*da0073e9SAndroid Build Coastguard Worker 4967*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target) 4968*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isfinite(out2).all().item()) 4969*da0073e9SAndroid Build Coastguard Worker 4970*da0073e9SAndroid Build Coastguard Worker def test_bce_loss_broadcasts_weights(self): 4971*da0073e9SAndroid Build Coastguard Worker sigmoid = nn.Sigmoid() 4972*da0073e9SAndroid Build Coastguard Worker target = torch.rand(16, 4, device='mps') 4973*da0073e9SAndroid Build Coastguard Worker output = torch.rand(16, 4, device='mps') - 0.5 4974*da0073e9SAndroid Build Coastguard Worker 4975*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(4, device='mps') 4976*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCELoss(weight)(sigmoid(output), target) 4977*da0073e9SAndroid Build Coastguard Worker 4978*da0073e9SAndroid Build Coastguard Worker weight = weight.expand(16, 4).contiguous() 4979*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCELoss(weight)(sigmoid(output), target) 4980*da0073e9SAndroid Build Coastguard Worker 4981*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 4982*da0073e9SAndroid Build Coastguard Worker 4983*da0073e9SAndroid Build Coastguard Worker weight = torch.rand(16, 1, device='mps') 4984*da0073e9SAndroid Build Coastguard Worker out1 = nn.BCELoss(weight)(sigmoid(output), target) 4985*da0073e9SAndroid Build Coastguard Worker 4986*da0073e9SAndroid Build Coastguard Worker weight = weight.expand(16, 4).contiguous() 4987*da0073e9SAndroid Build Coastguard Worker out2 = nn.BCELoss(weight)(sigmoid(output), target) 4988*da0073e9SAndroid Build Coastguard Worker 4989*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 4990*da0073e9SAndroid Build Coastguard Worker 4991*da0073e9SAndroid Build Coastguard Worker def test_cross_entropy_loss(self): 4992*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/116095 4993*da0073e9SAndroid Build Coastguard Worker loss = nn.CrossEntropyLoss() 4994*da0073e9SAndroid Build Coastguard Worker pred = torch.randn(3, 5, requires_grad=True, dtype=torch.float16, device='mps') 4995*da0073e9SAndroid Build Coastguard Worker target = torch.ones(3, dtype=torch.long, device='mps') 4996*da0073e9SAndroid Build Coastguard Worker output = loss(pred, target) 4997*da0073e9SAndroid Build Coastguard Worker output.backward() 4998*da0073e9SAndroid Build Coastguard Worker 4999*da0073e9SAndroid Build Coastguard Worker def test_log_softmax(self): 5000*da0073e9SAndroid Build Coastguard Worker values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] 5001*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu', requires_grad=True) 5002*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor(values, device='mps', requires_grad=True) 5003*da0073e9SAndroid Build Coastguard Worker 5004*da0073e9SAndroid Build Coastguard Worker cpu_log_softmax = F.log_softmax(cpu_x, dim=0) 5005*da0073e9SAndroid Build Coastguard Worker mps_log_softmax = F.log_softmax(mps_x, dim=0) 5006*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu')) 5007*da0073e9SAndroid Build Coastguard Worker 5008*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(cpu_log_softmax) 5009*da0073e9SAndroid Build Coastguard Worker mps_grad = torch.ones_like(cpu_log_softmax).to('mps') 5010*da0073e9SAndroid Build Coastguard Worker 5011*da0073e9SAndroid Build Coastguard Worker cpu_log_softmax.backward(gradient=cpu_grad) 5012*da0073e9SAndroid Build Coastguard Worker mps_log_softmax.backward(gradient=mps_grad) 5013*da0073e9SAndroid Build Coastguard Worker 5014*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu')) 5015*da0073e9SAndroid Build Coastguard Worker 5016*da0073e9SAndroid Build Coastguard Worker def test_log_softmax_large_numbers(self): 5017*da0073e9SAndroid Build Coastguard Worker values = [ 5018*da0073e9SAndroid Build Coastguard Worker [10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0], 5019*da0073e9SAndroid Build Coastguard Worker [-10.0, -100.0, -1000.0, -10000.0, -100000.0, -1000000.0] 5020*da0073e9SAndroid Build Coastguard Worker ] 5021*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu', requires_grad=True) 5022*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor(values, device='mps', requires_grad=True) 5023*da0073e9SAndroid Build Coastguard Worker 5024*da0073e9SAndroid Build Coastguard Worker cpu_log_softmax = F.log_softmax(cpu_x, dim=-1) 5025*da0073e9SAndroid Build Coastguard Worker mps_log_softmax = F.log_softmax(mps_x, dim=-1) 5026*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu')) 5027*da0073e9SAndroid Build Coastguard Worker 5028*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(cpu_log_softmax) 5029*da0073e9SAndroid Build Coastguard Worker mps_grad = torch.ones_like(cpu_log_softmax).to('mps') 5030*da0073e9SAndroid Build Coastguard Worker 5031*da0073e9SAndroid Build Coastguard Worker cpu_log_softmax.backward(gradient=cpu_grad) 5032*da0073e9SAndroid Build Coastguard Worker mps_log_softmax.backward(gradient=mps_grad) 5033*da0073e9SAndroid Build Coastguard Worker 5034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu')) 5035*da0073e9SAndroid Build Coastguard Worker 5036*da0073e9SAndroid Build Coastguard Worker def test_eq(self): 5037*da0073e9SAndroid Build Coastguard Worker values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] 5038*da0073e9SAndroid Build Coastguard Worker values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]] 5039*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor(values1, device='mps') 5040*da0073e9SAndroid Build Coastguard Worker mps_y = torch.tensor(values2, device='mps') 5041*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values1, device='cpu') 5042*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.tensor(values2, device='cpu') 5043*da0073e9SAndroid Build Coastguard Worker result_mps = torch.eq(mps_x, mps_y) 5044*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.eq(cpu_x, cpu_y) 5045*da0073e9SAndroid Build Coastguard Worker 5046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5047*da0073e9SAndroid Build Coastguard Worker 5048*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12") 5049*da0073e9SAndroid Build Coastguard Worker def test_signed_vs_unsigned_comparison(self): 5050*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8) 5051*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8) 5052*da0073e9SAndroid Build Coastguard Worker # in the comparison of signed vs. unsigned we should always cast to unsigned 5053*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x == -1, mps_x == -1) 5054*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x > -1, mps_x > -1) 5055*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x < -1, mps_x < -1) 5056*da0073e9SAndroid Build Coastguard Worker 5057*da0073e9SAndroid Build Coastguard Worker def test_eq_int64(self): 5058*da0073e9SAndroid Build Coastguard Worker values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] 5059*da0073e9SAndroid Build Coastguard Worker values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]] 5060*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor(values1, device='mps') 5061*da0073e9SAndroid Build Coastguard Worker mps_y = torch.tensor(values2, device='mps') 5062*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values1, device='cpu') 5063*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.tensor(values2, device='cpu') 5064*da0073e9SAndroid Build Coastguard Worker result_mps = torch.eq(mps_x, mps_y) 5065*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.eq(cpu_x, cpu_y) 5066*da0073e9SAndroid Build Coastguard Worker 5067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5068*da0073e9SAndroid Build Coastguard Worker 5069*da0073e9SAndroid Build Coastguard Worker def test_ne(self): 5070*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5071*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5072*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=torch.float) 5073*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5074*da0073e9SAndroid Build Coastguard Worker mps_y = cpu_y.detach().clone().to('mps') 5075*da0073e9SAndroid Build Coastguard Worker result_mps = torch.ne(mps_x, mps_y) 5076*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.ne(cpu_x, cpu_y) 5077*da0073e9SAndroid Build Coastguard Worker 5078*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5079*da0073e9SAndroid Build Coastguard Worker 5080*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 5081*da0073e9SAndroid Build Coastguard Worker 5082*da0073e9SAndroid Build Coastguard Worker def test_ne_scalar(self): 5083*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5084*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5085*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5086*da0073e9SAndroid Build Coastguard Worker result_mps = torch.ne(mps_x, 0.0) 5087*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.ne(cpu_x, 0.0) 5088*da0073e9SAndroid Build Coastguard Worker 5089*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5090*da0073e9SAndroid Build Coastguard Worker 5091*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 5092*da0073e9SAndroid Build Coastguard Worker 5093*da0073e9SAndroid Build Coastguard Worker def test_lt(self): 5094*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5095*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5096*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=torch.float) 5097*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5098*da0073e9SAndroid Build Coastguard Worker mps_y = cpu_y.detach().clone().to('mps') 5099*da0073e9SAndroid Build Coastguard Worker result_mps = torch.lt(mps_x, mps_y) 5100*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.lt(cpu_x, cpu_y) 5101*da0073e9SAndroid Build Coastguard Worker 5102*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5103*da0073e9SAndroid Build Coastguard Worker 5104*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 5105*da0073e9SAndroid Build Coastguard Worker 5106*da0073e9SAndroid Build Coastguard Worker def test_lt_scalar(self): 5107*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5108*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5109*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5110*da0073e9SAndroid Build Coastguard Worker result_mps = torch.lt(mps_x, 0.0) 5111*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.lt(cpu_x, 0.0) 5112*da0073e9SAndroid Build Coastguard Worker 5113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5114*da0073e9SAndroid Build Coastguard Worker 5115*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 5116*da0073e9SAndroid Build Coastguard Worker 5117*da0073e9SAndroid Build Coastguard Worker def test_le(self): 5118*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5119*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5120*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=torch.float) 5121*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5122*da0073e9SAndroid Build Coastguard Worker mps_y = cpu_y.detach().clone().to('mps') 5123*da0073e9SAndroid Build Coastguard Worker result_mps = torch.le(mps_x, mps_y) 5124*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.le(cpu_x, cpu_y) 5125*da0073e9SAndroid Build Coastguard Worker 5126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5127*da0073e9SAndroid Build Coastguard Worker 5128*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 5129*da0073e9SAndroid Build Coastguard Worker 5130*da0073e9SAndroid Build Coastguard Worker def test_le_scalar(self): 5131*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5132*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5133*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5134*da0073e9SAndroid Build Coastguard Worker result_mps = torch.le(mps_x, 0.0) 5135*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.le(cpu_x, 0.0) 5136*da0073e9SAndroid Build Coastguard Worker 5137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5138*da0073e9SAndroid Build Coastguard Worker 5139*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 5140*da0073e9SAndroid Build Coastguard Worker 5141*da0073e9SAndroid Build Coastguard Worker def test_ge(self): 5142*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5143*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5144*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=torch.float) 5145*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5146*da0073e9SAndroid Build Coastguard Worker mps_y = cpu_y.detach().clone().to('mps') 5147*da0073e9SAndroid Build Coastguard Worker result_mps = torch.ge(mps_x, mps_y) 5148*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.ge(cpu_x, cpu_y) 5149*da0073e9SAndroid Build Coastguard Worker 5150*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5151*da0073e9SAndroid Build Coastguard Worker 5152*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 5153*da0073e9SAndroid Build Coastguard Worker 5154*da0073e9SAndroid Build Coastguard Worker def test_ge_scalar(self): 5155*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5156*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5157*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5158*da0073e9SAndroid Build Coastguard Worker result_mps = torch.ge(mps_x, 0.0) 5159*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.ge(cpu_x, 0.0) 5160*da0073e9SAndroid Build Coastguard Worker 5161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5162*da0073e9SAndroid Build Coastguard Worker 5163*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 5164*da0073e9SAndroid Build Coastguard Worker 5165*da0073e9SAndroid Build Coastguard Worker def test_gt(self): 5166*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5167*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5168*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=torch.float) 5169*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5170*da0073e9SAndroid Build Coastguard Worker mps_y = cpu_y.detach().clone().to('mps') 5171*da0073e9SAndroid Build Coastguard Worker result_mps = torch.gt(mps_x, mps_y) 5172*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.gt(cpu_x, cpu_y) 5173*da0073e9SAndroid Build Coastguard Worker 5174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5175*da0073e9SAndroid Build Coastguard Worker 5176*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 5177*da0073e9SAndroid Build Coastguard Worker 5178*da0073e9SAndroid Build Coastguard Worker def test_gt_scalar(self): 5179*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5180*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5181*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5182*da0073e9SAndroid Build Coastguard Worker result_mps = torch.gt(mps_x, 0.0) 5183*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.gt(cpu_x, 0.0) 5184*da0073e9SAndroid Build Coastguard Worker 5185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps.to('cpu')) 5186*da0073e9SAndroid Build Coastguard Worker 5187*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5)) 5188*da0073e9SAndroid Build Coastguard Worker 5189*da0073e9SAndroid Build Coastguard Worker def test_argmax(self): 5190*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/98191 5191*da0073e9SAndroid Build Coastguard Worker cpu_tensor = torch.tensor([[0, 1], [2, 1], [1, 0]]) 5192*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.argmax(cpu_tensor, dim=1) 5193*da0073e9SAndroid Build Coastguard Worker 5194*da0073e9SAndroid Build Coastguard Worker mps_tensor = cpu_tensor.to(torch.device('mps')) 5195*da0073e9SAndroid Build Coastguard Worker res_mps = torch.argmax(mps_tensor, dim=1) 5196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps) 5197*da0073e9SAndroid Build Coastguard Worker 5198*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/92311 5199*da0073e9SAndroid Build Coastguard Worker mps_tensor = torch.randn(10, 2, device='mps', dtype=torch.float32) 5200*da0073e9SAndroid Build Coastguard Worker cpu_tensor = mps_tensor.detach().clone().cpu() 5201*da0073e9SAndroid Build Coastguard Worker 5202*da0073e9SAndroid Build Coastguard Worker res_mps = torch.argmax(mps_tensor, dim=1) 5203*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.argmax(cpu_tensor, dim=1) 5204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps) 5205*da0073e9SAndroid Build Coastguard Worker 5206*da0073e9SAndroid Build Coastguard Worker # Test forward argmin argmax 5207*da0073e9SAndroid Build Coastguard Worker def test_argmin_argmax(self): 5208*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w, reduction_type, dtype=torch.float32): 5209*da0073e9SAndroid Build Coastguard Worker if reduction_type == "max": 5210*da0073e9SAndroid Build Coastguard Worker arg_reduction_fn = torch.argmax 5211*da0073e9SAndroid Build Coastguard Worker else: 5212*da0073e9SAndroid Build Coastguard Worker arg_reduction_fn = torch.argmin 5213*da0073e9SAndroid Build Coastguard Worker 5214*da0073e9SAndroid Build Coastguard Worker cpu_x = None 5215*da0073e9SAndroid Build Coastguard Worker x = None 5216*da0073e9SAndroid Build Coastguard Worker if (dtype not in [torch.float32, torch.bool]): 5217*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5218*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5219*da0073e9SAndroid Build Coastguard Worker elif (dtype == torch.bool): 5220*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5221*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5222*da0073e9SAndroid Build Coastguard Worker else: 5223*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True) 5224*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 5225*da0073e9SAndroid Build Coastguard Worker 5226*da0073e9SAndroid Build Coastguard Worker y = arg_reduction_fn(x) 5227*da0073e9SAndroid Build Coastguard Worker ref_y = arg_reduction_fn(cpu_x) 5228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 5229*da0073e9SAndroid Build Coastguard Worker 5230*da0073e9SAndroid Build Coastguard Worker y_0 = arg_reduction_fn(x, dim=0) 5231*da0073e9SAndroid Build Coastguard Worker refy_0 = arg_reduction_fn(cpu_x, dim=0) 5232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0, refy_0) 5233*da0073e9SAndroid Build Coastguard Worker 5234*da0073e9SAndroid Build Coastguard Worker y_0dim = arg_reduction_fn(x, dim=0, keepdim=True) 5235*da0073e9SAndroid Build Coastguard Worker refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True) 5236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0dim, refy_0dim) 5237*da0073e9SAndroid Build Coastguard Worker 5238*da0073e9SAndroid Build Coastguard Worker y_1 = arg_reduction_fn(x, dim=1) 5239*da0073e9SAndroid Build Coastguard Worker refy_1 = arg_reduction_fn(cpu_x, dim=1) 5240*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1, refy_1) 5241*da0073e9SAndroid Build Coastguard Worker 5242*da0073e9SAndroid Build Coastguard Worker y_1dim = arg_reduction_fn(x, dim=1, keepdim=True) 5243*da0073e9SAndroid Build Coastguard Worker refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True) 5244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1dim, refy_1dim) 5245*da0073e9SAndroid Build Coastguard Worker 5246*da0073e9SAndroid Build Coastguard Worker y_2 = arg_reduction_fn(x, dim=2) 5247*da0073e9SAndroid Build Coastguard Worker refy_2 = arg_reduction_fn(cpu_x, dim=2) 5248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2, refy_2) 5249*da0073e9SAndroid Build Coastguard Worker 5250*da0073e9SAndroid Build Coastguard Worker y_2dim = arg_reduction_fn(x, dim=2, keepdim=True) 5251*da0073e9SAndroid Build Coastguard Worker refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True) 5252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2dim, refy_2dim) 5253*da0073e9SAndroid Build Coastguard Worker 5254*da0073e9SAndroid Build Coastguard Worker y_3 = arg_reduction_fn(x, dim=3) 5255*da0073e9SAndroid Build Coastguard Worker refy_3 = arg_reduction_fn(cpu_x, dim=3) 5256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3, refy_3) 5257*da0073e9SAndroid Build Coastguard Worker 5258*da0073e9SAndroid Build Coastguard Worker y_3dim = arg_reduction_fn(x, dim=3, keepdim=True) 5259*da0073e9SAndroid Build Coastguard Worker refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True) 5260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3dim, refy_3dim) 5261*da0073e9SAndroid Build Coastguard Worker 5262*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, "max", torch.float32) 5263*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, "max", torch.int32) 5264*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, "max", torch.float16) 5265*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, "max", torch.int64) 5266*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, "min", torch.float32) 5267*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, "min", torch.int32) 5268*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, "min", torch.float16) 5269*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, "min", torch.int64) 5270*da0073e9SAndroid Build Coastguard Worker 5271*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(product_version < 13.3, "Long data type supported from macOS 13.3 and above") 5272*da0073e9SAndroid Build Coastguard Worker def test_reduction_sum_max_long_val(self): 5273*da0073e9SAndroid Build Coastguard Worker x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps") 5274*da0073e9SAndroid Build Coastguard Worker x_cpu = x_mps.detach().clone().cpu() 5275*da0073e9SAndroid Build Coastguard Worker 5276*da0073e9SAndroid Build Coastguard Worker res_mps = torch.sum(x_mps) 5277*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.sum(x_cpu) 5278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_mps, res_cpu) 5279*da0073e9SAndroid Build Coastguard Worker 5280*da0073e9SAndroid Build Coastguard Worker # Test forward max 5281*da0073e9SAndroid Build Coastguard Worker # Note - don't test grad now 5282*da0073e9SAndroid Build Coastguard Worker def test_max_el(self): 5283*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w, dtype=torch.float32): 5284*da0073e9SAndroid Build Coastguard Worker 5285*da0073e9SAndroid Build Coastguard Worker if (dtype not in [torch.float32, torch.bool]): 5286*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5287*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5288*da0073e9SAndroid Build Coastguard Worker elif (dtype == torch.bool): 5289*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5290*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5291*da0073e9SAndroid Build Coastguard Worker else: 5292*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True) 5293*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5294*da0073e9SAndroid Build Coastguard Worker 5295*da0073e9SAndroid Build Coastguard Worker ref_y = torch.max(cpu_x) 5296*da0073e9SAndroid Build Coastguard Worker y = torch.max(x) 5297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 5298*da0073e9SAndroid Build Coastguard Worker 5299*da0073e9SAndroid Build Coastguard Worker for dim in [0, 1, 2, 3]: 5300*da0073e9SAndroid Build Coastguard Worker for keepdim in [True, False]: 5301*da0073e9SAndroid Build Coastguard Worker y, idx = torch.max(x, dim=dim, keepdim=keepdim) 5302*da0073e9SAndroid Build Coastguard Worker refy, refidx = torch.max(cpu_x, dim=dim, keepdim=keepdim) 5303*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, refy) 5304*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx, refidx) 5305*da0073e9SAndroid Build Coastguard Worker 5306*da0073e9SAndroid Build Coastguard Worker y_0 = torch.ones(c, h, w, device='mps', dtype=dtype) 5307*da0073e9SAndroid Build Coastguard Worker idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64) 5308*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=0, out=(y_0, idx_0)) 5309*da0073e9SAndroid Build Coastguard Worker refy_0, refidx_0 = torch.max(cpu_x, dim=0) 5310*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0, refy_0) 5311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_0, refidx_0) 5312*da0073e9SAndroid Build Coastguard Worker 5313*da0073e9SAndroid Build Coastguard Worker y_0dim = torch.ones(1, c, h, w, device='mps', dtype=dtype) 5314*da0073e9SAndroid Build Coastguard Worker idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64) 5315*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim)) 5316*da0073e9SAndroid Build Coastguard Worker refy_0dim, refidx_0dim = torch.max(cpu_x, dim=0, keepdim=True) 5317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0dim, refy_0dim) 5318*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_0dim, refidx_0dim) 5319*da0073e9SAndroid Build Coastguard Worker 5320*da0073e9SAndroid Build Coastguard Worker y_1 = torch.ones(n, h, w, device='mps', dtype=dtype) 5321*da0073e9SAndroid Build Coastguard Worker idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64) 5322*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=1, out=(y_1, idx_1)) 5323*da0073e9SAndroid Build Coastguard Worker refy_1, refidx_1 = torch.max(cpu_x, dim=1) 5324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1, refy_1) 5325*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_1, refidx_1) 5326*da0073e9SAndroid Build Coastguard Worker 5327*da0073e9SAndroid Build Coastguard Worker y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=dtype) 5328*da0073e9SAndroid Build Coastguard Worker idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64) 5329*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim)) 5330*da0073e9SAndroid Build Coastguard Worker refy_1dim, refidx_1dim = torch.max(cpu_x, keepdim=True, dim=1) 5331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1dim, refy_1dim) 5332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_1dim, refidx_1dim) 5333*da0073e9SAndroid Build Coastguard Worker 5334*da0073e9SAndroid Build Coastguard Worker y_2 = torch.ones(n, c, w, device='mps', dtype=dtype) 5335*da0073e9SAndroid Build Coastguard Worker idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64) 5336*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=2, out=(y_2, idx_2)) 5337*da0073e9SAndroid Build Coastguard Worker refy_2, refidx_2 = torch.max(cpu_x, dim=2) 5338*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2, refy_2) 5339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_2, refidx_2) 5340*da0073e9SAndroid Build Coastguard Worker 5341*da0073e9SAndroid Build Coastguard Worker y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=dtype) 5342*da0073e9SAndroid Build Coastguard Worker idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64) 5343*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim)) 5344*da0073e9SAndroid Build Coastguard Worker refy_2dim, refidx_2dim = torch.max(cpu_x, dim=2, keepdim=True,) 5345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2dim, refy_2dim) 5346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_2dim, refidx_2dim) 5347*da0073e9SAndroid Build Coastguard Worker 5348*da0073e9SAndroid Build Coastguard Worker y_3 = torch.ones(n, c, h, device='mps', dtype=dtype) 5349*da0073e9SAndroid Build Coastguard Worker idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64) 5350*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=3, out=(y_3, idx_3)) 5351*da0073e9SAndroid Build Coastguard Worker refy_3, refidx_3 = torch.max(cpu_x, dim=3) 5352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3, refy_3) 5353*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_3, refidx_3) 5354*da0073e9SAndroid Build Coastguard Worker 5355*da0073e9SAndroid Build Coastguard Worker y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=dtype) 5356*da0073e9SAndroid Build Coastguard Worker idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64) 5357*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim)) 5358*da0073e9SAndroid Build Coastguard Worker refy_3dim, refidx_3dim = torch.max(cpu_x, dim=3, keepdim=True,) 5359*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3dim, refy_3dim) 5360*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_3dim, refidx_3dim) 5361*da0073e9SAndroid Build Coastguard Worker 5362*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 5, torch.float32) 5363*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 5, torch.int32) 5364*da0073e9SAndroid Build Coastguard Worker # helper(2, 8, 4, 5, torch.int64) 5365*da0073e9SAndroid Build Coastguard Worker 5366*da0073e9SAndroid Build Coastguard Worker def test_median(self): 5367*da0073e9SAndroid Build Coastguard Worker def helper_dtype_int32(n1, n2, n3): 5368*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32) 5369*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5370*da0073e9SAndroid Build Coastguard Worker 5371*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.median(cpu_x) 5372*da0073e9SAndroid Build Coastguard Worker result_mps = torch.median(mps_x) 5373*da0073e9SAndroid Build Coastguard Worker 5374*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps) 5375*da0073e9SAndroid Build Coastguard Worker 5376*da0073e9SAndroid Build Coastguard Worker for dim in [0, 1, 2]: 5377*da0073e9SAndroid Build Coastguard Worker for keepdim in [True, False]: 5378*da0073e9SAndroid Build Coastguard Worker y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim) 5379*da0073e9SAndroid Build Coastguard Worker refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim) 5380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, refy) 5381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx, refidx) 5382*da0073e9SAndroid Build Coastguard Worker 5383*da0073e9SAndroid Build Coastguard Worker def helper_dtype_float32(n1, n2, n3): 5384*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32) 5385*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 5386*da0073e9SAndroid Build Coastguard Worker 5387*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.median(cpu_x) 5388*da0073e9SAndroid Build Coastguard Worker result_mps = torch.median(mps_x) 5389*da0073e9SAndroid Build Coastguard Worker 5390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_cpu, result_mps) 5391*da0073e9SAndroid Build Coastguard Worker 5392*da0073e9SAndroid Build Coastguard Worker for dim in [0, 1, 2]: 5393*da0073e9SAndroid Build Coastguard Worker for keepdim in [True, False]: 5394*da0073e9SAndroid Build Coastguard Worker y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim) 5395*da0073e9SAndroid Build Coastguard Worker refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim) 5396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, refy) 5397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx, refidx) 5398*da0073e9SAndroid Build Coastguard Worker 5399*da0073e9SAndroid Build Coastguard Worker helper_dtype_int32(10, 10, 10) # median at even place 5400*da0073e9SAndroid Build Coastguard Worker helper_dtype_int32(3, 3, 3) # median at odd place 5401*da0073e9SAndroid Build Coastguard Worker helper_dtype_int32(1, 1, 1) 5402*da0073e9SAndroid Build Coastguard Worker helper_dtype_int32(1, 2, 3) 5403*da0073e9SAndroid Build Coastguard Worker helper_dtype_float32(10, 10, 10) 5404*da0073e9SAndroid Build Coastguard Worker helper_dtype_float32(3, 3, 3) 5405*da0073e9SAndroid Build Coastguard Worker helper_dtype_float32(1, 1, 1) 5406*da0073e9SAndroid Build Coastguard Worker 5407*da0073e9SAndroid Build Coastguard Worker def test_any(self): 5408*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5409*da0073e9SAndroid Build Coastguard Worker input_xs = [] 5410*da0073e9SAndroid Build Coastguard Worker prod = 1 5411*da0073e9SAndroid Build Coastguard Worker 5412*da0073e9SAndroid Build Coastguard Worker for i in range(len(shape)): 5413*da0073e9SAndroid Build Coastguard Worker prod *= shape[i] 5414*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape)) 5415*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape)) 5416*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape)) 5417*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape)) 5418*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape)) 5419*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape)) 5420*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape)) 5421*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool()) 5422*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool()) 5423*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool()) 5424*da0073e9SAndroid Build Coastguard Worker 5425*da0073e9SAndroid Build Coastguard Worker for i, cpu_x in enumerate(input_xs): 5426*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5427*da0073e9SAndroid Build Coastguard Worker y = torch.any(x) 5428*da0073e9SAndroid Build Coastguard Worker ref_y = torch.any(cpu_x) 5429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 5430*da0073e9SAndroid Build Coastguard Worker 5431*da0073e9SAndroid Build Coastguard Worker y_0 = torch.any(x, dim=0) 5432*da0073e9SAndroid Build Coastguard Worker refy_0 = torch.any(cpu_x, dim=0) 5433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0, refy_0) 5434*da0073e9SAndroid Build Coastguard Worker 5435*da0073e9SAndroid Build Coastguard Worker y_0dim = torch.any(x, dim=0, keepdim=True) 5436*da0073e9SAndroid Build Coastguard Worker refy_0dim = torch.any(cpu_x, dim=0, keepdim=True) 5437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0dim, refy_0dim) 5438*da0073e9SAndroid Build Coastguard Worker 5439*da0073e9SAndroid Build Coastguard Worker y_0dim = torch.any(x, dim=0, keepdim=True) 5440*da0073e9SAndroid Build Coastguard Worker refy_0dim = torch.any(cpu_x, dim=0, keepdim=True) 5441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0dim, refy_0dim) 5442*da0073e9SAndroid Build Coastguard Worker 5443*da0073e9SAndroid Build Coastguard Worker y_1 = torch.any(x, dim=1) 5444*da0073e9SAndroid Build Coastguard Worker refy_1 = torch.any(cpu_x, dim=1) 5445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1, refy_1) 5446*da0073e9SAndroid Build Coastguard Worker 5447*da0073e9SAndroid Build Coastguard Worker y_1dim = torch.any(x, dim=1, keepdim=True) 5448*da0073e9SAndroid Build Coastguard Worker refy_1dim = torch.any(cpu_x, dim=1, keepdim=True) 5449*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1dim, refy_1dim) 5450*da0073e9SAndroid Build Coastguard Worker 5451*da0073e9SAndroid Build Coastguard Worker if (len(shape) > 2): 5452*da0073e9SAndroid Build Coastguard Worker y_2 = torch.any(x, dim=2) 5453*da0073e9SAndroid Build Coastguard Worker refy_2 = torch.any(cpu_x, dim=2) 5454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2, refy_2) 5455*da0073e9SAndroid Build Coastguard Worker 5456*da0073e9SAndroid Build Coastguard Worker y_2dim = torch.any(x, dim=2, keepdim=True) 5457*da0073e9SAndroid Build Coastguard Worker refy_2dim = torch.any(cpu_x, dim=2, keepdim=True) 5458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2dim, refy_2dim) 5459*da0073e9SAndroid Build Coastguard Worker 5460*da0073e9SAndroid Build Coastguard Worker y_3 = torch.any(x, dim=3) 5461*da0073e9SAndroid Build Coastguard Worker refy_3 = torch.any(cpu_x, dim=3) 5462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3, refy_3) 5463*da0073e9SAndroid Build Coastguard Worker 5464*da0073e9SAndroid Build Coastguard Worker y_3dim = torch.any(x, dim=3, keepdim=True) 5465*da0073e9SAndroid Build Coastguard Worker refy_3dim = torch.any(cpu_x, dim=3, keepdim=True) 5466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3dim, refy_3dim) 5467*da0073e9SAndroid Build Coastguard Worker helper((1, 1, 1, 1)) 5468*da0073e9SAndroid Build Coastguard Worker helper((1, 1, 3, 3)) 5469*da0073e9SAndroid Build Coastguard Worker helper((7, 13)) 5470*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 5471*da0073e9SAndroid Build Coastguard Worker 5472*da0073e9SAndroid Build Coastguard Worker def test_reduction_ops_5D(self): 5473*da0073e9SAndroid Build Coastguard Worker def helper(fn, dim): 5474*da0073e9SAndroid Build Coastguard Worker shape = (1, 1, 2, 1, 1) 5475*da0073e9SAndroid Build Coastguard Worker x_cpu = fn(torch.zeros(shape), dim=dim) 5476*da0073e9SAndroid Build Coastguard Worker x_mps = fn(torch.zeros(shape, device="mps"), dim=dim) 5477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_mps.cpu()) 5478*da0073e9SAndroid Build Coastguard Worker for fn in [torch.any, torch.all]: 5479*da0073e9SAndroid Build Coastguard Worker for dim in range(0, 4): 5480*da0073e9SAndroid Build Coastguard Worker helper(fn, dim) 5481*da0073e9SAndroid Build Coastguard Worker 5482*da0073e9SAndroid Build Coastguard Worker # 6D tensor reductions 5483*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/95538 5484*da0073e9SAndroid Build Coastguard Worker x = (torch.rand(2, 3, 4, 3, 4, 2, device="mps") - .5).relu() 5485*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.all(), x.cpu().all()) 5486*da0073e9SAndroid Build Coastguard Worker for i in range(-5, 6): 5487*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.all(dim=i), x.cpu().all(dim=i)) 5488*da0073e9SAndroid Build Coastguard Worker 5489*da0073e9SAndroid Build Coastguard Worker def test_all(self): 5490*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5491*da0073e9SAndroid Build Coastguard Worker input_xs = [] 5492*da0073e9SAndroid Build Coastguard Worker prod = 1 5493*da0073e9SAndroid Build Coastguard Worker 5494*da0073e9SAndroid Build Coastguard Worker for i in range(len(shape)): 5495*da0073e9SAndroid Build Coastguard Worker prod *= shape[i] 5496*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape)) 5497*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape)) 5498*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape)) 5499*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape)) 5500*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape)) 5501*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape)) 5502*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape)) 5503*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool()) 5504*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool()) 5505*da0073e9SAndroid Build Coastguard Worker input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool()) 5506*da0073e9SAndroid Build Coastguard Worker 5507*da0073e9SAndroid Build Coastguard Worker for i, cpu_x in enumerate(input_xs): 5508*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5509*da0073e9SAndroid Build Coastguard Worker y = torch.all(x) 5510*da0073e9SAndroid Build Coastguard Worker ref_y = torch.all(cpu_x) 5511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 5512*da0073e9SAndroid Build Coastguard Worker 5513*da0073e9SAndroid Build Coastguard Worker y_0 = torch.all(x, dim=0) 5514*da0073e9SAndroid Build Coastguard Worker refy_0 = torch.all(cpu_x, dim=0) 5515*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0, refy_0) 5516*da0073e9SAndroid Build Coastguard Worker 5517*da0073e9SAndroid Build Coastguard Worker y_0dim = torch.all(x, dim=0, keepdim=True) 5518*da0073e9SAndroid Build Coastguard Worker refy_0dim = torch.all(cpu_x, dim=0, keepdim=True) 5519*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0dim, refy_0dim) 5520*da0073e9SAndroid Build Coastguard Worker 5521*da0073e9SAndroid Build Coastguard Worker y_0dim = torch.all(x, dim=0, keepdim=True) 5522*da0073e9SAndroid Build Coastguard Worker refy_0dim = torch.all(cpu_x, dim=0, keepdim=True) 5523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0dim, refy_0dim) 5524*da0073e9SAndroid Build Coastguard Worker 5525*da0073e9SAndroid Build Coastguard Worker y_1 = torch.all(x, dim=1) 5526*da0073e9SAndroid Build Coastguard Worker refy_1 = torch.all(cpu_x, dim=1) 5527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1, refy_1) 5528*da0073e9SAndroid Build Coastguard Worker 5529*da0073e9SAndroid Build Coastguard Worker y_1dim = torch.all(x, dim=1, keepdim=True) 5530*da0073e9SAndroid Build Coastguard Worker refy_1dim = torch.all(cpu_x, dim=1, keepdim=True) 5531*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1dim, refy_1dim) 5532*da0073e9SAndroid Build Coastguard Worker if (len(shape) > 2): 5533*da0073e9SAndroid Build Coastguard Worker y_2 = torch.all(x, dim=2) 5534*da0073e9SAndroid Build Coastguard Worker refy_2 = torch.all(cpu_x, dim=2) 5535*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2, refy_2) 5536*da0073e9SAndroid Build Coastguard Worker 5537*da0073e9SAndroid Build Coastguard Worker y_2dim = torch.all(x, dim=2, keepdim=True) 5538*da0073e9SAndroid Build Coastguard Worker refy_2dim = torch.all(cpu_x, dim=2, keepdim=True) 5539*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2dim, refy_2dim) 5540*da0073e9SAndroid Build Coastguard Worker 5541*da0073e9SAndroid Build Coastguard Worker y_3 = torch.all(x, dim=3) 5542*da0073e9SAndroid Build Coastguard Worker refy_3 = torch.all(cpu_x, dim=3) 5543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3, refy_3) 5544*da0073e9SAndroid Build Coastguard Worker 5545*da0073e9SAndroid Build Coastguard Worker y_3dim = torch.all(x, dim=3, keepdim=True) 5546*da0073e9SAndroid Build Coastguard Worker refy_3dim = torch.all(cpu_x, dim=3, keepdim=True) 5547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3dim, refy_3dim) 5548*da0073e9SAndroid Build Coastguard Worker 5549*da0073e9SAndroid Build Coastguard Worker helper((1, 1, 1, 1)) 5550*da0073e9SAndroid Build Coastguard Worker helper((1, 1, 3, 3)) 5551*da0073e9SAndroid Build Coastguard Worker helper((7, 13)) 5552*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 5553*da0073e9SAndroid Build Coastguard Worker # Empty tensor 5554*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.tensor([], dtype=torch.bool) 5555*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.to("mps") 5556*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu.all(), x_mps.all().cpu()) 5557*da0073e9SAndroid Build Coastguard Worker 5558*da0073e9SAndroid Build Coastguard Worker # Test forward min 5559*da0073e9SAndroid Build Coastguard Worker def test_min_el(self): 5560*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w): 5561*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 5562*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5563*da0073e9SAndroid Build Coastguard Worker 5564*da0073e9SAndroid Build Coastguard Worker y = torch.min(x) 5565*da0073e9SAndroid Build Coastguard Worker ref_y = torch.min(cpu_x) 5566*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 5567*da0073e9SAndroid Build Coastguard Worker 5568*da0073e9SAndroid Build Coastguard Worker y_0, idx_0 = torch.min(x, dim=0) 5569*da0073e9SAndroid Build Coastguard Worker refy_0, refidx_0 = torch.min(cpu_x, dim=0) 5570*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0, refy_0) 5571*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_0, refidx_0) 5572*da0073e9SAndroid Build Coastguard Worker 5573*da0073e9SAndroid Build Coastguard Worker y_0 = torch.ones(c, h, w, device='mps', dtype=torch.float) 5574*da0073e9SAndroid Build Coastguard Worker idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64) 5575*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=0, out=(y_0, idx_0)) 5576*da0073e9SAndroid Build Coastguard Worker refy_0, refidx_0 = torch.min(cpu_x, dim=0) 5577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0, refy_0) 5578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_0, refidx_0) 5579*da0073e9SAndroid Build Coastguard Worker 5580*da0073e9SAndroid Build Coastguard Worker y_0dim, idx_0dim = torch.min(x, dim=0, keepdim=True) 5581*da0073e9SAndroid Build Coastguard Worker refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True) 5582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0dim, refy_0dim) 5583*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_0dim, refidx_0dim) 5584*da0073e9SAndroid Build Coastguard Worker 5585*da0073e9SAndroid Build Coastguard Worker y_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.float) 5586*da0073e9SAndroid Build Coastguard Worker idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64) 5587*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim)) 5588*da0073e9SAndroid Build Coastguard Worker refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True) 5589*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_0dim, refy_0dim) 5590*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_0dim, refidx_0dim) 5591*da0073e9SAndroid Build Coastguard Worker 5592*da0073e9SAndroid Build Coastguard Worker y_1, idx_1 = torch.min(x, dim=1) 5593*da0073e9SAndroid Build Coastguard Worker refy_1, refidx_1 = torch.min(cpu_x, dim=1) 5594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1, refy_1) 5595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_1, refidx_1) 5596*da0073e9SAndroid Build Coastguard Worker 5597*da0073e9SAndroid Build Coastguard Worker y_1 = torch.ones(n, h, w, device='mps', dtype=torch.float) 5598*da0073e9SAndroid Build Coastguard Worker idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64) 5599*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=1, out=(y_1, idx_1)) 5600*da0073e9SAndroid Build Coastguard Worker refy_1, refidx_1 = torch.min(cpu_x, dim=1) 5601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1, refy_1) 5602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_1, refidx_1) 5603*da0073e9SAndroid Build Coastguard Worker 5604*da0073e9SAndroid Build Coastguard Worker y_1dim, idx_1dim = torch.min(x, dim=1, keepdim=True) 5605*da0073e9SAndroid Build Coastguard Worker refy_1dim, refidx_1dim = torch.min(cpu_x, dim=1, keepdim=True) 5606*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1dim, refy_1dim) 5607*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_1dim, refidx_1dim) 5608*da0073e9SAndroid Build Coastguard Worker 5609*da0073e9SAndroid Build Coastguard Worker y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.float) 5610*da0073e9SAndroid Build Coastguard Worker idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64) 5611*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim)) 5612*da0073e9SAndroid Build Coastguard Worker refy_1dim, refidx_1dim = torch.min(cpu_x, keepdim=True, dim=1) 5613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_1dim, refy_1dim) 5614*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_1dim, refidx_1dim) 5615*da0073e9SAndroid Build Coastguard Worker 5616*da0073e9SAndroid Build Coastguard Worker y_2, idx_2 = torch.min(x, dim=2) 5617*da0073e9SAndroid Build Coastguard Worker refy_2, refidx_2 = torch.min(cpu_x, dim=2) 5618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2, refy_2) 5619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_2, refidx_2) 5620*da0073e9SAndroid Build Coastguard Worker 5621*da0073e9SAndroid Build Coastguard Worker y_2 = torch.ones(n, c, w, device='mps', dtype=torch.float) 5622*da0073e9SAndroid Build Coastguard Worker idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64) 5623*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=2, out=(y_2, idx_2)) 5624*da0073e9SAndroid Build Coastguard Worker refy_2, refidx_2 = torch.min(cpu_x, dim=2) 5625*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2, refy_2) 5626*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_2, refidx_2) 5627*da0073e9SAndroid Build Coastguard Worker 5628*da0073e9SAndroid Build Coastguard Worker y_2dim, idx_2dim = torch.min(x, dim=2, keepdim=True) 5629*da0073e9SAndroid Build Coastguard Worker refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True) 5630*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2dim, refy_2dim) 5631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_2dim, refidx_2dim) 5632*da0073e9SAndroid Build Coastguard Worker 5633*da0073e9SAndroid Build Coastguard Worker y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.float) 5634*da0073e9SAndroid Build Coastguard Worker idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64) 5635*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim)) 5636*da0073e9SAndroid Build Coastguard Worker refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True,) 5637*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_2dim, refy_2dim) 5638*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_2dim, refidx_2dim) 5639*da0073e9SAndroid Build Coastguard Worker 5640*da0073e9SAndroid Build Coastguard Worker y_3, idx_3 = torch.min(x, dim=3) 5641*da0073e9SAndroid Build Coastguard Worker refy_3, refidx_3 = torch.min(cpu_x, dim=3) 5642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3, refy_3) 5643*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_3, refidx_3) 5644*da0073e9SAndroid Build Coastguard Worker 5645*da0073e9SAndroid Build Coastguard Worker y_3 = torch.ones(n, c, h, device='mps', dtype=torch.float) 5646*da0073e9SAndroid Build Coastguard Worker idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64) 5647*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=3, out=(y_3, idx_3)) 5648*da0073e9SAndroid Build Coastguard Worker refy_3, refidx_3 = torch.min(cpu_x, dim=3) 5649*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3, refy_3) 5650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_3, refidx_3) 5651*da0073e9SAndroid Build Coastguard Worker 5652*da0073e9SAndroid Build Coastguard Worker y_3dim, idx_3dim = torch.min(x, dim=3, keepdim=True) 5653*da0073e9SAndroid Build Coastguard Worker refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True) 5654*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3dim, refy_3dim) 5655*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_3dim, refidx_3dim) 5656*da0073e9SAndroid Build Coastguard Worker 5657*da0073e9SAndroid Build Coastguard Worker y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.float) 5658*da0073e9SAndroid Build Coastguard Worker idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64) 5659*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim)) 5660*da0073e9SAndroid Build Coastguard Worker refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True,) 5661*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_3dim, refy_3dim) 5662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_3dim, refidx_3dim) 5663*da0073e9SAndroid Build Coastguard Worker 5664*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 5) 5665*da0073e9SAndroid Build Coastguard Worker 5666*da0073e9SAndroid Build Coastguard Worker # Test forward sum 5667*da0073e9SAndroid Build Coastguard Worker def test_sum(self): 5668*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w, dtype=torch.float32): 5669*da0073e9SAndroid Build Coastguard Worker cpu_x = None 5670*da0073e9SAndroid Build Coastguard Worker x = None 5671*da0073e9SAndroid Build Coastguard Worker if (dtype not in [torch.float32, torch.bool]): 5672*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5673*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5674*da0073e9SAndroid Build Coastguard Worker elif (dtype == torch.bool): 5675*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5676*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5677*da0073e9SAndroid Build Coastguard Worker else: 5678*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True) 5679*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 5680*da0073e9SAndroid Build Coastguard Worker 5681*da0073e9SAndroid Build Coastguard Worker all_sum = torch.sum(x) 5682*da0073e9SAndroid Build Coastguard Worker all_sum_cpu = torch.sum(cpu_x) 5683*da0073e9SAndroid Build Coastguard Worker 5684*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_sum, all_sum_cpu) 5685*da0073e9SAndroid Build Coastguard Worker 5686*da0073e9SAndroid Build Coastguard Worker nil_dim_sum = torch.sum(x, dim=[]) 5687*da0073e9SAndroid Build Coastguard Worker nil_dim_sum_cpu = torch.sum(cpu_x, dim=[]) 5688*da0073e9SAndroid Build Coastguard Worker 5689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nil_dim_sum, nil_dim_sum_cpu) 5690*da0073e9SAndroid Build Coastguard Worker 5691*da0073e9SAndroid Build Coastguard Worker nil_dim_sum_keepdim = torch.sum(x, dim=[], keepdim=True) 5692*da0073e9SAndroid Build Coastguard Worker nil_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[], keepdim=True) 5693*da0073e9SAndroid Build Coastguard Worker 5694*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nil_dim_sum_keepdim, nil_dim_sum_cpu_keepdim) 5695*da0073e9SAndroid Build Coastguard Worker 5696*da0073e9SAndroid Build Coastguard Worker zero_dim_sum = torch.sum(x, dim=[0]) 5697*da0073e9SAndroid Build Coastguard Worker zero_dim_sum_cpu = torch.sum(cpu_x, dim=[0]) 5698*da0073e9SAndroid Build Coastguard Worker 5699*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_sum, zero_dim_sum_cpu) 5700*da0073e9SAndroid Build Coastguard Worker 5701*da0073e9SAndroid Build Coastguard Worker zero_dim_sum_keepdim = torch.sum(x, dim=[0], keepdim=True) 5702*da0073e9SAndroid Build Coastguard Worker zero_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0], keepdim=True) 5703*da0073e9SAndroid Build Coastguard Worker 5704*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_sum_keepdim, zero_dim_sum_cpu_keepdim) 5705*da0073e9SAndroid Build Coastguard Worker 5706*da0073e9SAndroid Build Coastguard Worker zero_one_dim_sum = torch.sum(x, dim=[0, 1]) 5707*da0073e9SAndroid Build Coastguard Worker zero_one_dim_sum_cpu = torch.sum(cpu_x, dim=[0, 1]) 5708*da0073e9SAndroid Build Coastguard Worker 5709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_one_dim_sum, zero_one_dim_sum_cpu) 5710*da0073e9SAndroid Build Coastguard Worker 5711*da0073e9SAndroid Build Coastguard Worker zero_one_dim_sum_keepdim = torch.sum(x, dim=[0, 1], keepdim=True) 5712*da0073e9SAndroid Build Coastguard Worker zero_one_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0, 1], keepdim=True) 5713*da0073e9SAndroid Build Coastguard Worker 5714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_one_dim_sum_keepdim, zero_one_dim_sum_cpu_keepdim) 5715*da0073e9SAndroid Build Coastguard Worker 5716*da0073e9SAndroid Build Coastguard Worker two_three_dim_sum = torch.sum(x, dim=[2, 3]) 5717*da0073e9SAndroid Build Coastguard Worker two_three_dim_sum_cpu = torch.sum(cpu_x, dim=[2, 3]) 5718*da0073e9SAndroid Build Coastguard Worker 5719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(two_three_dim_sum, two_three_dim_sum_cpu) 5720*da0073e9SAndroid Build Coastguard Worker 5721*da0073e9SAndroid Build Coastguard Worker two_three_keepdim_sum = torch.sum(x, dim=[2, 3], keepdim=True) 5722*da0073e9SAndroid Build Coastguard Worker two_three_dim_keepsum_cpu = torch.sum(cpu_x, dim=[2, 3], keepdim=True) 5723*da0073e9SAndroid Build Coastguard Worker 5724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(two_three_keepdim_sum, two_three_dim_keepsum_cpu) 5725*da0073e9SAndroid Build Coastguard Worker 5726*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 5) 5727*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 5, dtype=torch.int32) 5728*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 5, dtype=torch.int64) 5729*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 5, dtype=torch.bool) 5730*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/136132 5731*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 4, 1, 30, 1, device='mps').sum(dim=-2) 5732*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.numel(), 8) 5733*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.max().item(), 30.0) 5734*da0073e9SAndroid Build Coastguard Worker 5735*da0073e9SAndroid Build Coastguard Worker # Test forward prod 5736*da0073e9SAndroid Build Coastguard Worker def test_prod(self): 5737*da0073e9SAndroid Build Coastguard Worker def helper(shape, dtype=torch.float32): 5738*da0073e9SAndroid Build Coastguard Worker cpu_x = None 5739*da0073e9SAndroid Build Coastguard Worker x = None 5740*da0073e9SAndroid Build Coastguard Worker if (dtype not in [torch.float32, torch.bool]): 5741*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(1, 6, shape, device='cpu', dtype=dtype, requires_grad=False) 5742*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5743*da0073e9SAndroid Build Coastguard Worker elif (dtype == torch.bool): 5744*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False) 5745*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5746*da0073e9SAndroid Build Coastguard Worker else: 5747*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) 5748*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 5749*da0073e9SAndroid Build Coastguard Worker 5750*da0073e9SAndroid Build Coastguard Worker all_prod = torch.prod(x) 5751*da0073e9SAndroid Build Coastguard Worker all_prod_cpu = torch.prod(cpu_x) 5752*da0073e9SAndroid Build Coastguard Worker 5753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_prod, all_prod_cpu) 5754*da0073e9SAndroid Build Coastguard Worker 5755*da0073e9SAndroid Build Coastguard Worker for dim in range(len(shape)): 5756*da0073e9SAndroid Build Coastguard Worker dim_prod = torch.prod(x, dim=dim) 5757*da0073e9SAndroid Build Coastguard Worker dim_prod_cpu = torch.prod(cpu_x, dim=dim) 5758*da0073e9SAndroid Build Coastguard Worker 5759*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dim_prod, dim_prod_cpu) 5760*da0073e9SAndroid Build Coastguard Worker 5761*da0073e9SAndroid Build Coastguard Worker dim_prod_keepdim = torch.prod(x, dim=dim, keepdim=True) 5762*da0073e9SAndroid Build Coastguard Worker dim_prod_cpu_keepdim = torch.prod(cpu_x, dim=dim, keepdim=True) 5763*da0073e9SAndroid Build Coastguard Worker 5764*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dim_prod_keepdim, dim_prod_cpu_keepdim) 5765*da0073e9SAndroid Build Coastguard Worker 5766*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float32, torch.int32, torch.int64, torch.bool]: 5767*da0073e9SAndroid Build Coastguard Worker helper((2, 3), dtype) 5768*da0073e9SAndroid Build Coastguard Worker 5769*da0073e9SAndroid Build Coastguard Worker # Test forward mean 5770*da0073e9SAndroid Build Coastguard Worker def test_mean(self): 5771*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w): 5772*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=True) 5773*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 5774*da0073e9SAndroid Build Coastguard Worker 5775*da0073e9SAndroid Build Coastguard Worker all_mean = torch.mean(x) 5776*da0073e9SAndroid Build Coastguard Worker all_mean_cpu = torch.mean(cpu_x) 5777*da0073e9SAndroid Build Coastguard Worker 5778*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_mean, all_mean_cpu) 5779*da0073e9SAndroid Build Coastguard Worker 5780*da0073e9SAndroid Build Coastguard Worker nil_dim_mean = torch.mean(x, dim=[]) 5781*da0073e9SAndroid Build Coastguard Worker nil_dim_mean_cpu = torch.mean(cpu_x, dim=[]) 5782*da0073e9SAndroid Build Coastguard Worker 5783*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nil_dim_mean, nil_dim_mean_cpu) 5784*da0073e9SAndroid Build Coastguard Worker 5785*da0073e9SAndroid Build Coastguard Worker nil_dim_mean_keepdim = torch.mean(x, dim=[], keepdim=True) 5786*da0073e9SAndroid Build Coastguard Worker nil_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[], keepdim=True) 5787*da0073e9SAndroid Build Coastguard Worker 5788*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nil_dim_mean_keepdim, nil_dim_mean_cpu_keepdim) 5789*da0073e9SAndroid Build Coastguard Worker 5790*da0073e9SAndroid Build Coastguard Worker zero_dim_mean = torch.mean(x, dim=[0]) 5791*da0073e9SAndroid Build Coastguard Worker zero_dim_mean_cpu = torch.mean(cpu_x, dim=[0]) 5792*da0073e9SAndroid Build Coastguard Worker 5793*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_mean, zero_dim_mean_cpu) 5794*da0073e9SAndroid Build Coastguard Worker 5795*da0073e9SAndroid Build Coastguard Worker zero_dim_mean_keepdim = torch.mean(x, dim=[0], keepdim=True) 5796*da0073e9SAndroid Build Coastguard Worker zero_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0], keepdim=True) 5797*da0073e9SAndroid Build Coastguard Worker 5798*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_mean_keepdim, zero_dim_mean_cpu_keepdim) 5799*da0073e9SAndroid Build Coastguard Worker 5800*da0073e9SAndroid Build Coastguard Worker zero_one_dim_mean = torch.mean(x, dim=[0, 1]) 5801*da0073e9SAndroid Build Coastguard Worker zero_one_dim_mean_cpu = torch.mean(cpu_x, dim=[0, 1]) 5802*da0073e9SAndroid Build Coastguard Worker 5803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_one_dim_mean, zero_one_dim_mean_cpu) 5804*da0073e9SAndroid Build Coastguard Worker 5805*da0073e9SAndroid Build Coastguard Worker zero_one_dim_mean_keepdim = torch.mean(x, dim=[0, 1], keepdim=True) 5806*da0073e9SAndroid Build Coastguard Worker zero_one_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0, 1], keepdim=True) 5807*da0073e9SAndroid Build Coastguard Worker 5808*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_one_dim_mean_keepdim, zero_one_dim_mean_cpu_keepdim) 5809*da0073e9SAndroid Build Coastguard Worker 5810*da0073e9SAndroid Build Coastguard Worker two_three_dim_mean = torch.mean(x, dim=[2, 3]) 5811*da0073e9SAndroid Build Coastguard Worker two_three_dim_mean_cpu = torch.mean(cpu_x, dim=[2, 3]) 5812*da0073e9SAndroid Build Coastguard Worker 5813*da0073e9SAndroid Build Coastguard Worker self.assertEqual(two_three_dim_mean, two_three_dim_mean_cpu) 5814*da0073e9SAndroid Build Coastguard Worker 5815*da0073e9SAndroid Build Coastguard Worker two_three_keepdim_mean = torch.mean(x, dim=[2, 3], keepdim=True) 5816*da0073e9SAndroid Build Coastguard Worker two_three_dim_keepmean_cpu = torch.mean(cpu_x, dim=[2, 3], keepdim=True) 5817*da0073e9SAndroid Build Coastguard Worker 5818*da0073e9SAndroid Build Coastguard Worker self.assertEqual(two_three_keepdim_mean, two_three_dim_keepmean_cpu) 5819*da0073e9SAndroid Build Coastguard Worker 5820*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 5) 5821*da0073e9SAndroid Build Coastguard Worker 5822*da0073e9SAndroid Build Coastguard Worker # Test std 5823*da0073e9SAndroid Build Coastguard Worker def test_std(self): 5824*da0073e9SAndroid Build Coastguard Worker def helper(shape): 5825*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 5826*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5827*da0073e9SAndroid Build Coastguard Worker 5828*da0073e9SAndroid Build Coastguard Worker all_std = torch.std(x, unbiased=False) 5829*da0073e9SAndroid Build Coastguard Worker all_std_cpu = torch.std(cpu_x, unbiased=False) 5830*da0073e9SAndroid Build Coastguard Worker 5831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_std, all_std_cpu) 5832*da0073e9SAndroid Build Coastguard Worker 5833*da0073e9SAndroid Build Coastguard Worker nil_dim_std = torch.std(x, dim=[], unbiased=False) 5834*da0073e9SAndroid Build Coastguard Worker nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=False) 5835*da0073e9SAndroid Build Coastguard Worker 5836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nil_dim_std, nil_dim_std_cpu) 5837*da0073e9SAndroid Build Coastguard Worker 5838*da0073e9SAndroid Build Coastguard Worker nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=False) 5839*da0073e9SAndroid Build Coastguard Worker nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=False) 5840*da0073e9SAndroid Build Coastguard Worker 5841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim) 5842*da0073e9SAndroid Build Coastguard Worker 5843*da0073e9SAndroid Build Coastguard Worker zero_dim_std = torch.std(x, dim=[0], unbiased=False) 5844*da0073e9SAndroid Build Coastguard Worker zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=False) 5845*da0073e9SAndroid Build Coastguard Worker 5846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_std, zero_dim_std_cpu) 5847*da0073e9SAndroid Build Coastguard Worker 5848*da0073e9SAndroid Build Coastguard Worker zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=False) 5849*da0073e9SAndroid Build Coastguard Worker zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=False) 5850*da0073e9SAndroid Build Coastguard Worker 5851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim) 5852*da0073e9SAndroid Build Coastguard Worker 5853*da0073e9SAndroid Build Coastguard Worker zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=False) 5854*da0073e9SAndroid Build Coastguard Worker zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=False) 5855*da0073e9SAndroid Build Coastguard Worker 5856*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu) 5857*da0073e9SAndroid Build Coastguard Worker 5858*da0073e9SAndroid Build Coastguard Worker zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=False) 5859*da0073e9SAndroid Build Coastguard Worker zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=False) 5860*da0073e9SAndroid Build Coastguard Worker 5861*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim) 5862*da0073e9SAndroid Build Coastguard Worker 5863*da0073e9SAndroid Build Coastguard Worker two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=False) 5864*da0073e9SAndroid Build Coastguard Worker two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=False) 5865*da0073e9SAndroid Build Coastguard Worker 5866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(two_three_dim_std, two_three_dim_std_cpu) 5867*da0073e9SAndroid Build Coastguard Worker 5868*da0073e9SAndroid Build Coastguard Worker two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=False) 5869*da0073e9SAndroid Build Coastguard Worker two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=False) 5870*da0073e9SAndroid Build Coastguard Worker 5871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu) 5872*da0073e9SAndroid Build Coastguard Worker 5873*da0073e9SAndroid Build Coastguard Worker all_std = torch.std(x, unbiased=True) 5874*da0073e9SAndroid Build Coastguard Worker all_std_cpu = torch.std(cpu_x, unbiased=True) 5875*da0073e9SAndroid Build Coastguard Worker 5876*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_std, all_std_cpu) 5877*da0073e9SAndroid Build Coastguard Worker 5878*da0073e9SAndroid Build Coastguard Worker nil_dim_std = torch.std(x, dim=[], unbiased=True) 5879*da0073e9SAndroid Build Coastguard Worker nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=True) 5880*da0073e9SAndroid Build Coastguard Worker 5881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nil_dim_std, nil_dim_std_cpu) 5882*da0073e9SAndroid Build Coastguard Worker 5883*da0073e9SAndroid Build Coastguard Worker nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=True) 5884*da0073e9SAndroid Build Coastguard Worker nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=True) 5885*da0073e9SAndroid Build Coastguard Worker 5886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim) 5887*da0073e9SAndroid Build Coastguard Worker 5888*da0073e9SAndroid Build Coastguard Worker zero_dim_std = torch.std(x, dim=[0], unbiased=True) 5889*da0073e9SAndroid Build Coastguard Worker zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=True) 5890*da0073e9SAndroid Build Coastguard Worker 5891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_std, zero_dim_std_cpu) 5892*da0073e9SAndroid Build Coastguard Worker 5893*da0073e9SAndroid Build Coastguard Worker zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=True) 5894*da0073e9SAndroid Build Coastguard Worker zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=True) 5895*da0073e9SAndroid Build Coastguard Worker 5896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim) 5897*da0073e9SAndroid Build Coastguard Worker 5898*da0073e9SAndroid Build Coastguard Worker zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=True) 5899*da0073e9SAndroid Build Coastguard Worker zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=True) 5900*da0073e9SAndroid Build Coastguard Worker 5901*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu) 5902*da0073e9SAndroid Build Coastguard Worker 5903*da0073e9SAndroid Build Coastguard Worker zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=True) 5904*da0073e9SAndroid Build Coastguard Worker zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=True) 5905*da0073e9SAndroid Build Coastguard Worker 5906*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim) 5907*da0073e9SAndroid Build Coastguard Worker 5908*da0073e9SAndroid Build Coastguard Worker two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=True) 5909*da0073e9SAndroid Build Coastguard Worker two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=True) 5910*da0073e9SAndroid Build Coastguard Worker 5911*da0073e9SAndroid Build Coastguard Worker self.assertEqual(two_three_dim_std, two_three_dim_std_cpu) 5912*da0073e9SAndroid Build Coastguard Worker 5913*da0073e9SAndroid Build Coastguard Worker two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=True) 5914*da0073e9SAndroid Build Coastguard Worker two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=True) 5915*da0073e9SAndroid Build Coastguard Worker 5916*da0073e9SAndroid Build Coastguard Worker self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu) 5917*da0073e9SAndroid Build Coastguard Worker 5918*da0073e9SAndroid Build Coastguard Worker helper((4, 5, 6, 7)) 5919*da0073e9SAndroid Build Coastguard Worker # verify if a change in shape of input would cause problems with graph caching 5920*da0073e9SAndroid Build Coastguard Worker helper((9, 5, 6, 7)) 5921*da0073e9SAndroid Build Coastguard Worker 5922*da0073e9SAndroid Build Coastguard Worker # Test var 5923*da0073e9SAndroid Build Coastguard Worker def test_var_simple(self): 5924*da0073e9SAndroid Build Coastguard Worker def helper(): 5925*da0073e9SAndroid Build Coastguard Worker 5926*da0073e9SAndroid Build Coastguard Worker shape = [2, 3, 4, 5] 5927*da0073e9SAndroid Build Coastguard Worker 5928*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 5929*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 5930*da0073e9SAndroid Build Coastguard Worker 5931*da0073e9SAndroid Build Coastguard Worker for unbiased in [False, True]: 5932*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 5933*da0073e9SAndroid Build Coastguard Worker 5934*da0073e9SAndroid Build Coastguard Worker zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased) 5935*da0073e9SAndroid Build Coastguard Worker zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased) 5936*da0073e9SAndroid Build Coastguard Worker 5937*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_var, zero_dim_var_cpu) 5938*da0073e9SAndroid Build Coastguard Worker 5939*da0073e9SAndroid Build Coastguard Worker all_var = torch.var(x, unbiased=unbiased) 5940*da0073e9SAndroid Build Coastguard Worker all_var_cpu = torch.var(cpu_x, unbiased=unbiased) 5941*da0073e9SAndroid Build Coastguard Worker 5942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_var, all_var_cpu) 5943*da0073e9SAndroid Build Coastguard Worker 5944*da0073e9SAndroid Build Coastguard Worker nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased) 5945*da0073e9SAndroid Build Coastguard Worker nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased) 5946*da0073e9SAndroid Build Coastguard Worker 5947*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nil_dim_var, nil_dim_var_cpu) 5948*da0073e9SAndroid Build Coastguard Worker 5949*da0073e9SAndroid Build Coastguard Worker zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased) 5950*da0073e9SAndroid Build Coastguard Worker zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased) 5951*da0073e9SAndroid Build Coastguard Worker 5952*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_var, zero_dim_var_cpu) 5953*da0073e9SAndroid Build Coastguard Worker 5954*da0073e9SAndroid Build Coastguard Worker zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased) 5955*da0073e9SAndroid Build Coastguard Worker zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased) 5956*da0073e9SAndroid Build Coastguard Worker 5957*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu) 5958*da0073e9SAndroid Build Coastguard Worker 5959*da0073e9SAndroid Build Coastguard Worker two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased) 5960*da0073e9SAndroid Build Coastguard Worker two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased) 5961*da0073e9SAndroid Build Coastguard Worker 5962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(two_three_dim_var, two_three_dim_var_cpu) 5963*da0073e9SAndroid Build Coastguard Worker 5964*da0073e9SAndroid Build Coastguard Worker helper() 5965*da0073e9SAndroid Build Coastguard Worker 5966*da0073e9SAndroid Build Coastguard Worker # Test forward amax 5967*da0073e9SAndroid Build Coastguard Worker def test_amax(self): 5968*da0073e9SAndroid Build Coastguard Worker def helper(shape, dim, keepdim): 5969*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 5970*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 5971*da0073e9SAndroid Build Coastguard Worker 5972*da0073e9SAndroid Build Coastguard Worker result = torch.amax(x, dim=dim, keepdim=keepdim) 5973*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim) 5974*da0073e9SAndroid Build Coastguard Worker 5975*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(result_cpu.shape) 5976*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 5977*da0073e9SAndroid Build Coastguard Worker 5978*da0073e9SAndroid Build Coastguard Worker result_cpu.backward(gradient=cpu_grad) 5979*da0073e9SAndroid Build Coastguard Worker result.backward(gradient=grad) 5980*da0073e9SAndroid Build Coastguard Worker 5981*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_cpu) 5982*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 5983*da0073e9SAndroid Build Coastguard Worker 5984*da0073e9SAndroid Build Coastguard Worker for dim in ([], [0], [0, 1], [2, 3]): 5985*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 5986*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), dim, keepdim) 5987*da0073e9SAndroid Build Coastguard Worker 5988*da0073e9SAndroid Build Coastguard Worker # Test forward amin 5989*da0073e9SAndroid Build Coastguard Worker def test_amin(self): 5990*da0073e9SAndroid Build Coastguard Worker def helper(shape, dim, keepdim): 5991*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 5992*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 5993*da0073e9SAndroid Build Coastguard Worker 5994*da0073e9SAndroid Build Coastguard Worker result = torch.amin(x, dim=dim, keepdim=keepdim) 5995*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim) 5996*da0073e9SAndroid Build Coastguard Worker 5997*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(result_cpu.shape) 5998*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 5999*da0073e9SAndroid Build Coastguard Worker 6000*da0073e9SAndroid Build Coastguard Worker result_cpu.backward(gradient=cpu_grad) 6001*da0073e9SAndroid Build Coastguard Worker result.backward(gradient=grad) 6002*da0073e9SAndroid Build Coastguard Worker 6003*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_cpu) 6004*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 6005*da0073e9SAndroid Build Coastguard Worker 6006*da0073e9SAndroid Build Coastguard Worker for dim in ([], [0], [0, 1], [2, 3]): 6007*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 6008*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), dim, keepdim) 6009*da0073e9SAndroid Build Coastguard Worker 6010*da0073e9SAndroid Build Coastguard Worker # Test minimum and maximum 6011*da0073e9SAndroid Build Coastguard Worker def test_minimum_maximum(self): 6012*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w): 6013*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6014*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6015*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 6016*da0073e9SAndroid Build Coastguard Worker mps_y = cpu_y.detach().clone().to('mps') 6017*da0073e9SAndroid Build Coastguard Worker 6018*da0073e9SAndroid Build Coastguard Worker minimum_result_cpu = torch.minimum(cpu_x, cpu_y) 6019*da0073e9SAndroid Build Coastguard Worker minimum_result_mps = torch.minimum(mps_x, mps_y) 6020*da0073e9SAndroid Build Coastguard Worker self.assertEqual(minimum_result_cpu, minimum_result_mps) 6021*da0073e9SAndroid Build Coastguard Worker 6022*da0073e9SAndroid Build Coastguard Worker maximum_result_cpu = torch.maximum(cpu_x, cpu_y) 6023*da0073e9SAndroid Build Coastguard Worker maximum_result_mps = torch.maximum(mps_x, mps_y) 6024*da0073e9SAndroid Build Coastguard Worker self.assertEqual(maximum_result_cpu, maximum_result_mps) 6025*da0073e9SAndroid Build Coastguard Worker 6026*da0073e9SAndroid Build Coastguard Worker helper(1, 1, 4, 5) 6027*da0073e9SAndroid Build Coastguard Worker 6028*da0073e9SAndroid Build Coastguard Worker def test_clamp_fp16_fp32(self): 6029*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(10, device='cpu', dtype=torch.float, requires_grad=False) 6030*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6031*da0073e9SAndroid Build Coastguard Worker 6032*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 6033*da0073e9SAndroid Build Coastguard Worker 6034*da0073e9SAndroid Build Coastguard Worker clamp_min_vals_mps = torch.ones(10, device="mps").to(torch.float16) 6035*da0073e9SAndroid Build Coastguard Worker clamp_max_vals_mps = torch.ones(10, device="mps").to(torch.float16) * 10 6036*da0073e9SAndroid Build Coastguard Worker clamp_result_mps = torch.clamp(x, clamp_min_vals_mps, clamp_max_vals_mps) 6037*da0073e9SAndroid Build Coastguard Worker 6038*da0073e9SAndroid Build Coastguard Worker clamp_min_vals_cpu = torch.ones(10, device="cpu").to(torch.float16) 6039*da0073e9SAndroid Build Coastguard Worker clamp_max_vals_cpu = torch.ones(10, device="cpu").to(torch.float16) * 10 6040*da0073e9SAndroid Build Coastguard Worker clamp_result_cpu = torch.clamp(cpu_x, clamp_min_vals_cpu, clamp_max_vals_cpu) 6041*da0073e9SAndroid Build Coastguard Worker 6042*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_result_mps, clamp_result_cpu) 6043*da0073e9SAndroid Build Coastguard Worker 6044*da0073e9SAndroid Build Coastguard Worker def test_clamp_nan(self): 6045*da0073e9SAndroid Build Coastguard Worker t_mps = torch.tensor([torch.nan, 1, 2], device="mps") 6046*da0073e9SAndroid Build Coastguard Worker t_cpu = torch.tensor([torch.nan, 1, 2], device="cpu") 6047*da0073e9SAndroid Build Coastguard Worker 6048*da0073e9SAndroid Build Coastguard Worker clamp_min_max_mps = torch.clamp(t_mps, min=-100, max=100) 6049*da0073e9SAndroid Build Coastguard Worker clamp_min_max_cpu = torch.clamp(t_cpu, min=-100, max=100) 6050*da0073e9SAndroid Build Coastguard Worker 6051*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_min_max_mps, clamp_min_max_cpu) 6052*da0073e9SAndroid Build Coastguard Worker 6053*da0073e9SAndroid Build Coastguard Worker clamp_min_mps = torch.clamp(t_mps, min=-100) 6054*da0073e9SAndroid Build Coastguard Worker clamp_min_cpu = torch.clamp(t_cpu, min=-100) 6055*da0073e9SAndroid Build Coastguard Worker 6056*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_min_mps, clamp_min_cpu) 6057*da0073e9SAndroid Build Coastguard Worker 6058*da0073e9SAndroid Build Coastguard Worker clamp_max_mps = torch.clamp(t_mps, max=100) 6059*da0073e9SAndroid Build Coastguard Worker clamp_max_cpu = torch.clamp(t_cpu, max=100) 6060*da0073e9SAndroid Build Coastguard Worker 6061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_max_mps, clamp_max_cpu) 6062*da0073e9SAndroid Build Coastguard Worker 6063*da0073e9SAndroid Build Coastguard Worker # Test clamp_min 6064*da0073e9SAndroid Build Coastguard Worker def test_clamp_min(self): 6065*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w): 6066*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6067*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6068*da0073e9SAndroid Build Coastguard Worker 6069*da0073e9SAndroid Build Coastguard Worker cpu_min_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6070*da0073e9SAndroid Build Coastguard Worker min_t = cpu_min_t.detach().clone().to('mps') 6071*da0073e9SAndroid Build Coastguard Worker 6072*da0073e9SAndroid Build Coastguard Worker clamp_min_result = torch.clamp_min(x, min=5.0) 6073*da0073e9SAndroid Build Coastguard Worker clamp_min_result_cpu = torch.clamp_min(cpu_x, min=5.0) 6074*da0073e9SAndroid Build Coastguard Worker 6075*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_min_result, clamp_min_result_cpu) 6076*da0073e9SAndroid Build Coastguard Worker 6077*da0073e9SAndroid Build Coastguard Worker clamp_min_t_result = torch.clamp_min(x, min=min_t) 6078*da0073e9SAndroid Build Coastguard Worker clamp_min_t_result_cpu = torch.clamp_min(cpu_x, min=cpu_min_t) 6079*da0073e9SAndroid Build Coastguard Worker 6080*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_min_t_result, clamp_min_t_result_cpu) 6081*da0073e9SAndroid Build Coastguard Worker 6082*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 5) 6083*da0073e9SAndroid Build Coastguard Worker 6084*da0073e9SAndroid Build Coastguard Worker # Test clamp_max 6085*da0073e9SAndroid Build Coastguard Worker 6086*da0073e9SAndroid Build Coastguard Worker def test_clamp_max(self): 6087*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w): 6088*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6089*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6090*da0073e9SAndroid Build Coastguard Worker 6091*da0073e9SAndroid Build Coastguard Worker cpu_max_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6092*da0073e9SAndroid Build Coastguard Worker max_t = cpu_max_t.detach().clone().to('mps') 6093*da0073e9SAndroid Build Coastguard Worker 6094*da0073e9SAndroid Build Coastguard Worker clamp_max_result = torch.clamp_max(x, max=100.0) 6095*da0073e9SAndroid Build Coastguard Worker clamp_max_result_cpu = torch.clamp_max(cpu_x, max=100.0) 6096*da0073e9SAndroid Build Coastguard Worker 6097*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_max_result, clamp_max_result_cpu) 6098*da0073e9SAndroid Build Coastguard Worker 6099*da0073e9SAndroid Build Coastguard Worker clamp_max_t_result = torch.clamp_max(x, max=max_t) 6100*da0073e9SAndroid Build Coastguard Worker clamp_max_t_result_cpu = torch.clamp_max(cpu_x, max=cpu_max_t) 6101*da0073e9SAndroid Build Coastguard Worker 6102*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_max_t_result, clamp_max_t_result_cpu) 6103*da0073e9SAndroid Build Coastguard Worker 6104*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 5) 6105*da0073e9SAndroid Build Coastguard Worker 6106*da0073e9SAndroid Build Coastguard Worker # Test clamp 6107*da0073e9SAndroid Build Coastguard Worker def test_clamp(self): 6108*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w): 6109*da0073e9SAndroid Build Coastguard Worker import numpy as np 6110*da0073e9SAndroid Build Coastguard Worker upper_bound = 1000 6111*da0073e9SAndroid Build Coastguard Worker half_upper_bound = upper_bound / 2 6112*da0073e9SAndroid Build Coastguard Worker 6113*da0073e9SAndroid Build Coastguard Worker # x=[0..1000) 6114*da0073e9SAndroid Build Coastguard Worker x_arr = upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32) 6115*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(x_arr, device='cpu', dtype=torch.float, requires_grad=False) 6116*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6117*da0073e9SAndroid Build Coastguard Worker 6118*da0073e9SAndroid Build Coastguard Worker # x=[0..500) 6119*da0073e9SAndroid Build Coastguard Worker min_arr = half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32) 6120*da0073e9SAndroid Build Coastguard Worker cpu_min_t = torch.tensor(min_arr, device='cpu', dtype=torch.float, requires_grad=False) 6121*da0073e9SAndroid Build Coastguard Worker min_t = cpu_min_t.detach().clone().to('mps') 6122*da0073e9SAndroid Build Coastguard Worker 6123*da0073e9SAndroid Build Coastguard Worker # x=[500..1000), to ensure max's are greater than mins 6124*da0073e9SAndroid Build Coastguard Worker max_arr = (half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)) + half_upper_bound 6125*da0073e9SAndroid Build Coastguard Worker cpu_max_t = torch.tensor(max_arr, device='cpu', dtype=torch.float, requires_grad=False) 6126*da0073e9SAndroid Build Coastguard Worker max_t = cpu_max_t.detach().clone().to('mps') 6127*da0073e9SAndroid Build Coastguard Worker 6128*da0073e9SAndroid Build Coastguard Worker # [200..600]: just an arbitrary range between [0..1000] 6129*da0073e9SAndroid Build Coastguard Worker clamp_result = torch.clamp(x, min=200.0, max=600.0) 6130*da0073e9SAndroid Build Coastguard Worker clamp_result_cpu = torch.clamp(cpu_x, min=200.0, max=600.0) 6131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_result, clamp_result_cpu) 6132*da0073e9SAndroid Build Coastguard Worker 6133*da0073e9SAndroid Build Coastguard Worker # test optional scalar refs and cached graph keys by passing only max 6134*da0073e9SAndroid Build Coastguard Worker clamp_opt_result = torch.clamp(x, max=600.0) 6135*da0073e9SAndroid Build Coastguard Worker clamp_opt_result_cpu = torch.clamp(cpu_x, max=600.0) 6136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_opt_result, clamp_opt_result_cpu) 6137*da0073e9SAndroid Build Coastguard Worker 6138*da0073e9SAndroid Build Coastguard Worker clamp_t_result = torch.clamp(x, min=min_t, max=max_t) 6139*da0073e9SAndroid Build Coastguard Worker clamp_t_result_cpu = torch.clamp(cpu_x, min=cpu_min_t, max=cpu_max_t) 6140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_t_result, clamp_t_result_cpu) 6141*da0073e9SAndroid Build Coastguard Worker 6142*da0073e9SAndroid Build Coastguard Worker # test optional tensor refs and cached graph keys by passing only max 6143*da0073e9SAndroid Build Coastguard Worker clamp_topt_result = torch.clamp(x, max=max_t) 6144*da0073e9SAndroid Build Coastguard Worker clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t) 6145*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_topt_result, clamp_topt_result_cpu) 6146*da0073e9SAndroid Build Coastguard Worker 6147*da0073e9SAndroid Build Coastguard Worker # test strided x 6148*da0073e9SAndroid Build Coastguard Worker clamp_result = torch.clamp(x.movedim(0, -1), min=200.0, max=600.0) 6149*da0073e9SAndroid Build Coastguard Worker clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=200.0, max=600.0) 6150*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_result, clamp_result_cpu) 6151*da0073e9SAndroid Build Coastguard Worker 6152*da0073e9SAndroid Build Coastguard Worker # test strided x, min_t, max_t 6153*da0073e9SAndroid Build Coastguard Worker clamp_result = torch.clamp(x.movedim(0, -1), min=min_t.movedim(0, -1), max=max_t.movedim(0, -1)) 6154*da0073e9SAndroid Build Coastguard Worker clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=cpu_min_t.movedim(0, -1), max=cpu_max_t.movedim(0, -1)) 6155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_result, clamp_result_cpu) 6156*da0073e9SAndroid Build Coastguard Worker 6157*da0073e9SAndroid Build Coastguard Worker # test strided min_t, max_t 6158*da0073e9SAndroid Build Coastguard Worker clamp_result = torch.clamp( 6159*da0073e9SAndroid Build Coastguard Worker x.movedim(0, -1).clone(memory_format=torch.contiguous_format), 6160*da0073e9SAndroid Build Coastguard Worker min=min_t.movedim(0, -1), 6161*da0073e9SAndroid Build Coastguard Worker max=max_t.movedim(0, -1) 6162*da0073e9SAndroid Build Coastguard Worker ) 6163*da0073e9SAndroid Build Coastguard Worker clamp_result_cpu = torch.clamp( 6164*da0073e9SAndroid Build Coastguard Worker cpu_x.movedim(0, -1).clone(memory_format=torch.contiguous_format), 6165*da0073e9SAndroid Build Coastguard Worker min=cpu_min_t.movedim(0, -1), 6166*da0073e9SAndroid Build Coastguard Worker max=cpu_max_t.movedim(0, -1) 6167*da0073e9SAndroid Build Coastguard Worker ) 6168*da0073e9SAndroid Build Coastguard Worker self.assertEqual(clamp_result, clamp_result_cpu) 6169*da0073e9SAndroid Build Coastguard Worker 6170*da0073e9SAndroid Build Coastguard Worker # test inplace clamping 6171*da0073e9SAndroid Build Coastguard Worker x.clamp_(min=200.0, max=600.0) 6172*da0073e9SAndroid Build Coastguard Worker cpu_x.clamp_(min=200.0, max=600.0) 6173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x, x) 6174*da0073e9SAndroid Build Coastguard Worker 6175*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 5) 6176*da0073e9SAndroid Build Coastguard Worker 6177*da0073e9SAndroid Build Coastguard Worker def test_divmode(self): 6178*da0073e9SAndroid Build Coastguard Worker def helper(shape, rounding_mode): 6179*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]: 6180*da0073e9SAndroid Build Coastguard Worker if ((rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) or 6181*da0073e9SAndroid Build Coastguard Worker (rounding_mode is not None and "trunc" in rounding_mode and dtype == torch.float16)) is False: 6182*da0073e9SAndroid Build Coastguard Worker cpu_x = None 6183*da0073e9SAndroid Build Coastguard Worker cpu_y = None 6184*da0073e9SAndroid Build Coastguard Worker if (dtype in [torch.float32, torch.float16]): 6185*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) 6186*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) 6187*da0073e9SAndroid Build Coastguard Worker else: 6188*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) 6189*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) 6190*da0073e9SAndroid Build Coastguard Worker 6191*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 6192*da0073e9SAndroid Build Coastguard Worker # clamp to avoid division by 0 6193*da0073e9SAndroid Build Coastguard Worker mps_y = cpu_y.detach().clone().to('mps') 6194*da0073e9SAndroid Build Coastguard Worker 6195*da0073e9SAndroid Build Coastguard Worker if (rounding_mode == "floor_divide"): 6196*da0073e9SAndroid Build Coastguard Worker result_div_cpu = torch.floor_divide(cpu_x, cpu_y) 6197*da0073e9SAndroid Build Coastguard Worker result_div_mps = torch.floor_divide(mps_x, mps_y) 6198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_div_mps, result_div_cpu) 6199*da0073e9SAndroid Build Coastguard Worker else: 6200*da0073e9SAndroid Build Coastguard Worker result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode) 6201*da0073e9SAndroid Build Coastguard Worker result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode) 6202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_div_mps, result_div_cpu) 6203*da0073e9SAndroid Build Coastguard Worker 6204*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), None) 6205*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), "floor") 6206*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), "trunc") 6207*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), "floor_divide") 6208*da0073e9SAndroid Build Coastguard Worker 6209*da0073e9SAndroid Build Coastguard Worker def test_rounding(self): 6210*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6211*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6212*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 6213*da0073e9SAndroid Build Coastguard Worker 6214*da0073e9SAndroid Build Coastguard Worker result_floor_cpu = torch.floor(cpu_x) 6215*da0073e9SAndroid Build Coastguard Worker result_floor_mps = torch.floor(mps_x) 6216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_floor_mps, result_floor_cpu) 6217*da0073e9SAndroid Build Coastguard Worker 6218*da0073e9SAndroid Build Coastguard Worker result_ceil_cpu = torch.ceil(cpu_x) 6219*da0073e9SAndroid Build Coastguard Worker result_ceil_mps = torch.ceil(mps_x) 6220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_ceil_mps, result_ceil_cpu) 6221*da0073e9SAndroid Build Coastguard Worker 6222*da0073e9SAndroid Build Coastguard Worker result_trunc_cpu = torch.trunc(cpu_x) 6223*da0073e9SAndroid Build Coastguard Worker result_trunc_mps = torch.trunc(mps_x) 6224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_trunc_mps, result_trunc_cpu) 6225*da0073e9SAndroid Build Coastguard Worker 6226*da0073e9SAndroid Build Coastguard Worker result_round_cpu = torch.round(cpu_x) 6227*da0073e9SAndroid Build Coastguard Worker result_round_mps = torch.round(mps_x) 6228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_round_mps, result_round_cpu) 6229*da0073e9SAndroid Build Coastguard Worker 6230*da0073e9SAndroid Build Coastguard Worker helper((2, 6, 3, 5)) 6231*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 6232*da0073e9SAndroid Build Coastguard Worker 6233*da0073e9SAndroid Build Coastguard Worker def test_remainder(self): 6234*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.remainder( 6235*da0073e9SAndroid Build Coastguard Worker torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="cpu"), torch.tensor(2, device="cpu", dtype=torch.int32)) 6236*da0073e9SAndroid Build Coastguard Worker res_mps = torch.remainder( 6237*da0073e9SAndroid Build Coastguard Worker torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="mps"), torch.tensor(2, device="mps", dtype=torch.int32)) 6238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps) 6239*da0073e9SAndroid Build Coastguard Worker 6240*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.remainder( 6241*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="cpu"), -1.5) 6242*da0073e9SAndroid Build Coastguard Worker res_mps = torch.remainder( 6243*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5) 6244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps) 6245*da0073e9SAndroid Build Coastguard Worker 6246*da0073e9SAndroid Build Coastguard Worker def test_expand(self): 6247*da0073e9SAndroid Build Coastguard Worker def helper(n, c): 6248*da0073e9SAndroid Build Coastguard Worker values = [[1.0], [4.0], [7.0]] 6249*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu') 6250*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6251*da0073e9SAndroid Build Coastguard Worker 6252*da0073e9SAndroid Build Coastguard Worker strided_cpu = torch.as_strided(cpu_x, (3, 4), (1, 0)) 6253*da0073e9SAndroid Build Coastguard Worker strided_mps = torch.as_strided(x, (3, 4), (1, 0)) 6254*da0073e9SAndroid Build Coastguard Worker 6255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strided_mps, strided_cpu) 6256*da0073e9SAndroid Build Coastguard Worker 6257*da0073e9SAndroid Build Coastguard Worker helper(3, 1) 6258*da0073e9SAndroid Build Coastguard Worker 6259*da0073e9SAndroid Build Coastguard Worker def test_im2col(self): 6260*da0073e9SAndroid Build Coastguard Worker def helper(x): 6261*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.unfold(x, kernel_size=(10, 15), dilation=2, padding=5, stride=3) 6262*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.rand(1, 1, 200, 100) 6263*da0073e9SAndroid Build Coastguard Worker x = x_cpu.detach().clone().to('mps') 6264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(helper(x_cpu), helper(x)) 6265*da0073e9SAndroid Build Coastguard Worker 6266*da0073e9SAndroid Build Coastguard Worker def test_select(self): 6267*da0073e9SAndroid Build Coastguard Worker def helper(n, c): 6268*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(n, c, device='cpu', dtype=torch.float, requires_grad=True) 6269*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 6270*da0073e9SAndroid Build Coastguard Worker 6271*da0073e9SAndroid Build Coastguard Worker strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1)) 6272*da0073e9SAndroid Build Coastguard Worker strided_mps = torch.as_strided(x, (3, 1), (3, 1)) 6273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strided_mps, strided_cpu) 6274*da0073e9SAndroid Build Coastguard Worker 6275*da0073e9SAndroid Build Coastguard Worker strided_cpu = torch.as_strided(cpu_x, (1, 3), (3, 1)) 6276*da0073e9SAndroid Build Coastguard Worker strided_mps = torch.as_strided(x, (1, 3), (3, 1)) 6277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strided_mps, strided_cpu) 6278*da0073e9SAndroid Build Coastguard Worker 6279*da0073e9SAndroid Build Coastguard Worker strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1), storage_offset=1) 6280*da0073e9SAndroid Build Coastguard Worker strided_mps = torch.as_strided(x, (3, 1), (3, 1), storage_offset=1) 6281*da0073e9SAndroid Build Coastguard Worker 6282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strided_mps, strided_cpu) 6283*da0073e9SAndroid Build Coastguard Worker 6284*da0073e9SAndroid Build Coastguard Worker helper(3, 3) 6285*da0073e9SAndroid Build Coastguard Worker 6286*da0073e9SAndroid Build Coastguard Worker def test_sort(self): 6287*da0073e9SAndroid Build Coastguard Worker for SIZE in (4, 2049): 6288*da0073e9SAndroid Build Coastguard Worker device = 'mps' 6289*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4, SIZE, device=device) 6290*da0073e9SAndroid Build Coastguard Worker res1val, res1ind = torch.sort(x) 6291*da0073e9SAndroid Build Coastguard Worker 6292*da0073e9SAndroid Build Coastguard Worker res2val = torch.tensor((), device=device) 6293*da0073e9SAndroid Build Coastguard Worker res2ind = torch.tensor((), device=device, dtype=torch.long) 6294*da0073e9SAndroid Build Coastguard Worker torch.sort(x, out=(res2val, res2ind)) 6295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1val, res2val, atol=0, rtol=0) 6296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1ind, res2ind, atol=0, rtol=0) 6297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.argsort(x), res1ind) 6298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argsort(), res1ind) 6299*da0073e9SAndroid Build Coastguard Worker 6300*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6301*da0073e9SAndroid Build Coastguard Worker torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0], 6302*da0073e9SAndroid Build Coastguard Worker torch.tensor((10, 20, 30, 40, 50), device=device), 6303*da0073e9SAndroid Build Coastguard Worker atol=0, rtol=0 6304*da0073e9SAndroid Build Coastguard Worker ) 6305*da0073e9SAndroid Build Coastguard Worker 6306*da0073e9SAndroid Build Coastguard Worker def test_upsample_nearest2d(self): 6307*da0073e9SAndroid Build Coastguard Worker def helper(N, C, H, W, memory_format): 6308*da0073e9SAndroid Build Coastguard Worker inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float, 6309*da0073e9SAndroid Build Coastguard Worker requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format) 6310*da0073e9SAndroid Build Coastguard Worker inputCPU.retain_grad() 6311*da0073e9SAndroid Build Coastguard Worker inputMPS = inputCPU.detach().to('mps').requires_grad_() 6312*da0073e9SAndroid Build Coastguard Worker 6313*da0073e9SAndroid Build Coastguard Worker values = [1, 2, 5, 10, 40] 6314*da0073e9SAndroid Build Coastguard Worker 6315*da0073e9SAndroid Build Coastguard Worker for i in values: 6316*da0073e9SAndroid Build Coastguard Worker for j in values: 6317*da0073e9SAndroid Build Coastguard Worker upsample_nearest2d = nn.UpsamplingNearest2d(scale_factor=(i, j)) 6318*da0073e9SAndroid Build Coastguard Worker 6319*da0073e9SAndroid Build Coastguard Worker outputCPU = upsample_nearest2d(inputCPU) 6320*da0073e9SAndroid Build Coastguard Worker outputMPS = upsample_nearest2d(inputMPS) 6321*da0073e9SAndroid Build Coastguard Worker 6322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputCPU, outputMPS) 6323*da0073e9SAndroid Build Coastguard Worker upsample_nearest2d = nn.UpsamplingNearest2d((i * H, j * W)) 6324*da0073e9SAndroid Build Coastguard Worker 6325*da0073e9SAndroid Build Coastguard Worker outputCPU = upsample_nearest2d(inputCPU) 6326*da0073e9SAndroid Build Coastguard Worker outputMPS = upsample_nearest2d(inputMPS) 6327*da0073e9SAndroid Build Coastguard Worker 6328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputCPU, outputMPS) 6329*da0073e9SAndroid Build Coastguard Worker 6330*da0073e9SAndroid Build Coastguard Worker outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3)) 6331*da0073e9SAndroid Build Coastguard Worker outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3)) 6332*da0073e9SAndroid Build Coastguard Worker 6333*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputCPU.grad, inputMPS.grad) 6334*da0073e9SAndroid Build Coastguard Worker 6335*da0073e9SAndroid Build Coastguard Worker for memory_format in [torch.channels_last, torch.contiguous_format]: 6336*da0073e9SAndroid Build Coastguard Worker helper(1, 1, 4, 4, memory_format=memory_format) 6337*da0073e9SAndroid Build Coastguard Worker helper(7, 5, 3, 2, memory_format=memory_format) 6338*da0073e9SAndroid Build Coastguard Worker 6339*da0073e9SAndroid Build Coastguard Worker def test_upsample_bilinear2d(self): 6340*da0073e9SAndroid Build Coastguard Worker def helper(N, C, H, W): 6341*da0073e9SAndroid Build Coastguard Worker inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float, 6342*da0073e9SAndroid Build Coastguard Worker requires_grad=True).reshape(N, C, H, W) 6343*da0073e9SAndroid Build Coastguard Worker inputCPU.retain_grad() 6344*da0073e9SAndroid Build Coastguard Worker inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 6345*da0073e9SAndroid Build Coastguard Worker 6346*da0073e9SAndroid Build Coastguard Worker values = [1, 2, 5, 10, 40] 6347*da0073e9SAndroid Build Coastguard Worker 6348*da0073e9SAndroid Build Coastguard Worker for i in values: 6349*da0073e9SAndroid Build Coastguard Worker for j in values: 6350*da0073e9SAndroid Build Coastguard Worker upsample_bilinear2d = nn.UpsamplingBilinear2d(scale_factor=(i, j)) 6351*da0073e9SAndroid Build Coastguard Worker 6352*da0073e9SAndroid Build Coastguard Worker outputCPU = upsample_bilinear2d(inputCPU) 6353*da0073e9SAndroid Build Coastguard Worker outputMPS = upsample_bilinear2d(inputMPS) 6354*da0073e9SAndroid Build Coastguard Worker 6355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputCPU, outputMPS) 6356*da0073e9SAndroid Build Coastguard Worker 6357*da0073e9SAndroid Build Coastguard Worker upsample_bilinear2d = nn.UpsamplingBilinear2d((i * H, j * W)) 6358*da0073e9SAndroid Build Coastguard Worker 6359*da0073e9SAndroid Build Coastguard Worker outputCPU = upsample_bilinear2d(inputCPU) 6360*da0073e9SAndroid Build Coastguard Worker outputMPS = upsample_bilinear2d(inputMPS) 6361*da0073e9SAndroid Build Coastguard Worker 6362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputCPU, outputMPS) 6363*da0073e9SAndroid Build Coastguard Worker 6364*da0073e9SAndroid Build Coastguard Worker outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3)) 6365*da0073e9SAndroid Build Coastguard Worker outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3)) 6366*da0073e9SAndroid Build Coastguard Worker 6367*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputCPU.grad, inputMPS.grad) 6368*da0073e9SAndroid Build Coastguard Worker 6369*da0073e9SAndroid Build Coastguard Worker helper(1, 1, 4, 4) 6370*da0073e9SAndroid Build Coastguard Worker helper(7, 5, 3, 2) 6371*da0073e9SAndroid Build Coastguard Worker 6372*da0073e9SAndroid Build Coastguard Worker def test_interpolate(self): 6373*da0073e9SAndroid Build Coastguard Worker def helper(shape, output_size, scales, mode, align_corners=False): 6374*da0073e9SAndroid Build Coastguard Worker inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 6375*da0073e9SAndroid Build Coastguard Worker inputCPU.retain_grad() 6376*da0073e9SAndroid Build Coastguard Worker inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 6377*da0073e9SAndroid Build Coastguard Worker 6378*da0073e9SAndroid Build Coastguard Worker # align_corners is used for 2D interpolation only 6379*da0073e9SAndroid Build Coastguard Worker if (align_corners is True and len(shape) > 3 and mode == 'bilinear'): 6380*da0073e9SAndroid Build Coastguard Worker if scales is not None: 6381*da0073e9SAndroid Build Coastguard Worker outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode, align_corners=align_corners) 6382*da0073e9SAndroid Build Coastguard Worker outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode, align_corners=align_corners) 6383*da0073e9SAndroid Build Coastguard Worker else: 6384*da0073e9SAndroid Build Coastguard Worker outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode, align_corners=align_corners) 6385*da0073e9SAndroid Build Coastguard Worker outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode, align_corners=align_corners) 6386*da0073e9SAndroid Build Coastguard Worker elif scales is not None: 6387*da0073e9SAndroid Build Coastguard Worker outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode) 6388*da0073e9SAndroid Build Coastguard Worker outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode) 6389*da0073e9SAndroid Build Coastguard Worker else: 6390*da0073e9SAndroid Build Coastguard Worker outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode) 6391*da0073e9SAndroid Build Coastguard Worker outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode) 6392*da0073e9SAndroid Build Coastguard Worker 6393*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputCPU, outputMPS) 6394*da0073e9SAndroid Build Coastguard Worker 6395*da0073e9SAndroid Build Coastguard Worker # backward pass (chose 0.6 just to have the grad_output != 1) 6396*da0073e9SAndroid Build Coastguard Worker outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6)) 6397*da0073e9SAndroid Build Coastguard Worker outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6)) 6398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputCPU.grad, inputMPS.grad) 6399*da0073e9SAndroid Build Coastguard Worker 6400*da0073e9SAndroid Build Coastguard Worker # 1D interpolation 6401*da0073e9SAndroid Build Coastguard Worker for mode in ['nearest', 'nearest-exact']: 6402*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 4], [3], None, mode) # downsample with size 6403*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 4], [6], None, mode) # upsample with size 6404*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 4], None, [0.6], mode) # downsample with scale factor 6405*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 4], None, [1.7], mode) # upsample with scale factor 6406*da0073e9SAndroid Build Coastguard Worker # 2D interpolation 6407*da0073e9SAndroid Build Coastguard Worker for mode in ['nearest', 'nearest-exact', 'bilinear']: 6408*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 4, 5], [3, 4], None, mode) # downsample_nearest with size 6409*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 4, 5], [6, 7], None, mode) # upsample_nearest with size 6410*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 4, 5], None, [0.6, 0.7], mode) # downsample_nearest with scale factor 6411*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 4, 5], None, [1.4, 1.7], mode) # upsample_nearest with scale factor 6412*da0073e9SAndroid Build Coastguard Worker # align_corners=True 6413*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True) 6414*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True) 6415*da0073e9SAndroid Build Coastguard Worker 6416*da0073e9SAndroid Build Coastguard Worker # Test concat forward 6417*da0073e9SAndroid Build Coastguard Worker def test_cat1(self): 6418*da0073e9SAndroid Build Coastguard Worker def helper(shape_x, shape_y, shape_z): 6419*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False) 6420*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6421*da0073e9SAndroid Build Coastguard Worker 6422*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False) 6423*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps') 6424*da0073e9SAndroid Build Coastguard Worker 6425*da0073e9SAndroid Build Coastguard Worker cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False) 6426*da0073e9SAndroid Build Coastguard Worker z = cpu_z.detach().clone().to('mps') 6427*da0073e9SAndroid Build Coastguard Worker 6428*da0073e9SAndroid Build Coastguard Worker cat = torch.cat([x, y, z], dim=1) 6429*da0073e9SAndroid Build Coastguard Worker cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1) 6430*da0073e9SAndroid Build Coastguard Worker 6431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cat, cat_cpu) 6432*da0073e9SAndroid Build Coastguard Worker 6433*da0073e9SAndroid Build Coastguard Worker helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5]) 6434*da0073e9SAndroid Build Coastguard Worker helper([2, 2, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5]) 6435*da0073e9SAndroid Build Coastguard Worker helper([0, 2, 4, 5], [0, 3, 4, 5], [0, 5, 4, 5]) 6436*da0073e9SAndroid Build Coastguard Worker helper([2, 2, 6, 5], [0], [2, 5, 6, 5]) 6437*da0073e9SAndroid Build Coastguard Worker helper([0], [2, 3, 6, 5], [2, 5, 6, 5]) 6438*da0073e9SAndroid Build Coastguard Worker helper([2, 3, 4, 5], [2, 5, 4, 5], [0]) 6439*da0073e9SAndroid Build Coastguard Worker helper([2, 2, 6, 5], [2, 0, 6, 5], [2, 5, 6, 5]) 6440*da0073e9SAndroid Build Coastguard Worker helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5]) 6441*da0073e9SAndroid Build Coastguard Worker helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 0, 6, 5]) 6442*da0073e9SAndroid Build Coastguard Worker 6443*da0073e9SAndroid Build Coastguard Worker # Test stack forward 6444*da0073e9SAndroid Build Coastguard Worker def test_stack(self): 6445*da0073e9SAndroid Build Coastguard Worker # All shapes must be same 6446*da0073e9SAndroid Build Coastguard Worker def helper(shape, dtype=torch.float32): 6447*da0073e9SAndroid Build Coastguard Worker 6448*da0073e9SAndroid Build Coastguard Worker x, cpu_x = None, None 6449*da0073e9SAndroid Build Coastguard Worker y, cpu_y = None, None 6450*da0073e9SAndroid Build Coastguard Worker z, cpu_z = None, None 6451*da0073e9SAndroid Build Coastguard Worker 6452*da0073e9SAndroid Build Coastguard Worker if (dtype not in [torch.float32, torch.bool]): 6453*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False) 6454*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6455*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False) 6456*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps') 6457*da0073e9SAndroid Build Coastguard Worker cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False) 6458*da0073e9SAndroid Build Coastguard Worker z = cpu_z.detach().clone().to('mps') 6459*da0073e9SAndroid Build Coastguard Worker elif (dtype == torch.bool): 6460*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False) 6461*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6462*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False) 6463*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps') 6464*da0073e9SAndroid Build Coastguard Worker cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False) 6465*da0073e9SAndroid Build Coastguard Worker z = cpu_z.detach().clone().to('mps') 6466*da0073e9SAndroid Build Coastguard Worker else: 6467*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) 6468*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 6469*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) 6470*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps').requires_grad_() 6471*da0073e9SAndroid Build Coastguard Worker cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) 6472*da0073e9SAndroid Build Coastguard Worker z = cpu_z.detach().clone().to('mps').requires_grad_() 6473*da0073e9SAndroid Build Coastguard Worker 6474*da0073e9SAndroid Build Coastguard Worker stack = torch.stack([x, y, z], dim=1) 6475*da0073e9SAndroid Build Coastguard Worker stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1) 6476*da0073e9SAndroid Build Coastguard Worker 6477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stack, stack_cpu) 6478*da0073e9SAndroid Build Coastguard Worker 6479*da0073e9SAndroid Build Coastguard Worker helper([2, 8, 4, 5]) 6480*da0073e9SAndroid Build Coastguard Worker helper([2, 8, 4, 5], dtype=torch.float16) 6481*da0073e9SAndroid Build Coastguard Worker helper([2, 8, 4, 5], dtype=torch.int32) 6482*da0073e9SAndroid Build Coastguard Worker helper([2, 8, 4, 5], dtype=torch.int64) 6483*da0073e9SAndroid Build Coastguard Worker helper([2, 8, 4, 5], dtype=torch.bool) 6484*da0073e9SAndroid Build Coastguard Worker # Empty test - Currently failing! Empty tensor not handled! 6485*da0073e9SAndroid Build Coastguard Worker # helper([0, 2, 4, 5]) 6486*da0073e9SAndroid Build Coastguard Worker 6487*da0073e9SAndroid Build Coastguard Worker # Test abs 6488*da0073e9SAndroid Build Coastguard Worker def test_abs(self): 6489*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6490*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6491*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6492*da0073e9SAndroid Build Coastguard Worker 6493*da0073e9SAndroid Build Coastguard Worker abs_result = torch.abs(x) 6494*da0073e9SAndroid Build Coastguard Worker abs_result_cpu = torch.abs(cpu_x) 6495*da0073e9SAndroid Build Coastguard Worker 6496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(abs_result, abs_result_cpu) 6497*da0073e9SAndroid Build Coastguard Worker 6498*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 6499*da0073e9SAndroid Build Coastguard Worker 6500*da0073e9SAndroid Build Coastguard Worker def test_log(self): 6501*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6502*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6503*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6504*da0073e9SAndroid Build Coastguard Worker 6505*da0073e9SAndroid Build Coastguard Worker log_result = torch.log(x) 6506*da0073e9SAndroid Build Coastguard Worker log_result_cpu = torch.log(cpu_x) 6507*da0073e9SAndroid Build Coastguard Worker 6508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_result, log_result_cpu) 6509*da0073e9SAndroid Build Coastguard Worker 6510*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 6511*da0073e9SAndroid Build Coastguard Worker 6512*da0073e9SAndroid Build Coastguard Worker def test_log_ten(self): 6513*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6514*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6515*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6516*da0073e9SAndroid Build Coastguard Worker 6517*da0073e9SAndroid Build Coastguard Worker log_ten_result = torch.log10(x) 6518*da0073e9SAndroid Build Coastguard Worker log_ten_result_cpu = torch.log10(cpu_x) 6519*da0073e9SAndroid Build Coastguard Worker 6520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_ten_result, log_ten_result_cpu) 6521*da0073e9SAndroid Build Coastguard Worker 6522*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 6523*da0073e9SAndroid Build Coastguard Worker 6524*da0073e9SAndroid Build Coastguard Worker def test_log_two(self): 6525*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6526*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6527*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6528*da0073e9SAndroid Build Coastguard Worker 6529*da0073e9SAndroid Build Coastguard Worker log_two_result = torch.log2(x) 6530*da0073e9SAndroid Build Coastguard Worker log_two_result_cpu = torch.log2(cpu_x) 6531*da0073e9SAndroid Build Coastguard Worker 6532*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_two_result, log_two_result_cpu) 6533*da0073e9SAndroid Build Coastguard Worker 6534*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 6535*da0073e9SAndroid Build Coastguard Worker 6536*da0073e9SAndroid Build Coastguard Worker def test_log1p(self): 6537*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6538*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6539*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6540*da0073e9SAndroid Build Coastguard Worker 6541*da0073e9SAndroid Build Coastguard Worker log_result = torch.log1p(x) 6542*da0073e9SAndroid Build Coastguard Worker log_result_cpu = torch.log1p(cpu_x) 6543*da0073e9SAndroid Build Coastguard Worker 6544*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_result, log_result_cpu) 6545*da0073e9SAndroid Build Coastguard Worker 6546*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 6547*da0073e9SAndroid Build Coastguard Worker 6548*da0073e9SAndroid Build Coastguard Worker def test_logaddexp(self): 6549*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6550*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6551*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6552*da0073e9SAndroid Build Coastguard Worker 6553*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6554*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps') 6555*da0073e9SAndroid Build Coastguard Worker 6556*da0073e9SAndroid Build Coastguard Worker log_result = torch.logaddexp(x, y) 6557*da0073e9SAndroid Build Coastguard Worker log_result_cpu = torch.logaddexp(cpu_x, cpu_y) 6558*da0073e9SAndroid Build Coastguard Worker 6559*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_result, log_result_cpu) 6560*da0073e9SAndroid Build Coastguard Worker 6561*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 6562*da0073e9SAndroid Build Coastguard Worker 6563*da0073e9SAndroid Build Coastguard Worker def test_logaddexp2(self): 6564*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6565*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6566*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6567*da0073e9SAndroid Build Coastguard Worker 6568*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6569*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps') 6570*da0073e9SAndroid Build Coastguard Worker 6571*da0073e9SAndroid Build Coastguard Worker log_result = torch.logaddexp2(x, y) 6572*da0073e9SAndroid Build Coastguard Worker log_result_cpu = torch.logaddexp2(cpu_x, cpu_y) 6573*da0073e9SAndroid Build Coastguard Worker 6574*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_result, log_result_cpu) 6575*da0073e9SAndroid Build Coastguard Worker 6576*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 6577*da0073e9SAndroid Build Coastguard Worker 6578*da0073e9SAndroid Build Coastguard Worker def test_logsumexp(self): 6579*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6580*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6581*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6582*da0073e9SAndroid Build Coastguard Worker 6583*da0073e9SAndroid Build Coastguard Worker log_result = torch.logsumexp(x, -1) 6584*da0073e9SAndroid Build Coastguard Worker log_result_cpu = torch.logsumexp(cpu_x, -1) 6585*da0073e9SAndroid Build Coastguard Worker 6586*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log_result, log_result_cpu) 6587*da0073e9SAndroid Build Coastguard Worker 6588*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 6589*da0073e9SAndroid Build Coastguard Worker 6590*da0073e9SAndroid Build Coastguard Worker # Test concat forward 6591*da0073e9SAndroid Build Coastguard Worker def test_cat2(self): 6592*da0073e9SAndroid Build Coastguard Worker 6593*da0073e9SAndroid Build Coastguard Worker def helper1(shape_x, shape_y, shape_z, shape_w): 6594*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False) 6595*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6596*da0073e9SAndroid Build Coastguard Worker 6597*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False) 6598*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps') 6599*da0073e9SAndroid Build Coastguard Worker 6600*da0073e9SAndroid Build Coastguard Worker cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False) 6601*da0073e9SAndroid Build Coastguard Worker z = cpu_z.detach().clone().to('mps') 6602*da0073e9SAndroid Build Coastguard Worker 6603*da0073e9SAndroid Build Coastguard Worker cpu_w = torch.randn(shape_w, device='cpu', dtype=torch.float, requires_grad=False) 6604*da0073e9SAndroid Build Coastguard Worker w = cpu_w.detach().clone().to('mps') 6605*da0073e9SAndroid Build Coastguard Worker 6606*da0073e9SAndroid Build Coastguard Worker cat = torch.cat([x, y, z, w], dim=1) 6607*da0073e9SAndroid Build Coastguard Worker cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z, cpu_w], dim=1) 6608*da0073e9SAndroid Build Coastguard Worker 6609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cat, cat_cpu) 6610*da0073e9SAndroid Build Coastguard Worker 6611*da0073e9SAndroid Build Coastguard Worker def helper(shape_x, shape_y, shape_z): 6612*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False) 6613*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6614*da0073e9SAndroid Build Coastguard Worker 6615*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False) 6616*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps') 6617*da0073e9SAndroid Build Coastguard Worker 6618*da0073e9SAndroid Build Coastguard Worker cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False) 6619*da0073e9SAndroid Build Coastguard Worker z = cpu_z.detach().clone().to('mps') 6620*da0073e9SAndroid Build Coastguard Worker 6621*da0073e9SAndroid Build Coastguard Worker cat = torch.cat([x, y, z], dim=1) 6622*da0073e9SAndroid Build Coastguard Worker cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1) 6623*da0073e9SAndroid Build Coastguard Worker 6624*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cat, cat_cpu) 6625*da0073e9SAndroid Build Coastguard Worker 6626*da0073e9SAndroid Build Coastguard Worker helper([2, 8, 4, 5], [2, 10, 4, 5], [2, 6, 4, 5]) 6627*da0073e9SAndroid Build Coastguard Worker helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5]) 6628*da0073e9SAndroid Build Coastguard Worker # Empty test - Currently failing! Empty tensor not handled! 6629*da0073e9SAndroid Build Coastguard Worker # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5]) 6630*da0073e9SAndroid Build Coastguard Worker 6631*da0073e9SAndroid Build Coastguard Worker # Test isnan 6632*da0073e9SAndroid Build Coastguard Worker def test_isnan(self): 6633*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6634*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6635*da0073e9SAndroid Build Coastguard Worker nan_index = [random.randrange(0, shape[0])] 6636*da0073e9SAndroid Build Coastguard Worker # make a selected row inf 6637*da0073e9SAndroid Build Coastguard Worker cpu_x.index_put_(indices=[torch.tensor(nan_index)], values=torch.tensor(float('nan'))) 6638*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6639*da0073e9SAndroid Build Coastguard Worker 6640*da0073e9SAndroid Build Coastguard Worker isnan_result = torch.isnan(x) 6641*da0073e9SAndroid Build Coastguard Worker isnan_result_cpu = torch.isnan(cpu_x) 6642*da0073e9SAndroid Build Coastguard Worker 6643*da0073e9SAndroid Build Coastguard Worker self.assertEqual(isnan_result, isnan_result_cpu) 6644*da0073e9SAndroid Build Coastguard Worker 6645*da0073e9SAndroid Build Coastguard Worker helper((8, 2, 4, 5)) 6646*da0073e9SAndroid Build Coastguard Worker 6647*da0073e9SAndroid Build Coastguard Worker # Test reciprocal 6648*da0073e9SAndroid Build Coastguard Worker def test_reciprocal(self): 6649*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6650*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 6651*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 6652*da0073e9SAndroid Build Coastguard Worker 6653*da0073e9SAndroid Build Coastguard Worker reciprocal_result = torch.reciprocal(x) 6654*da0073e9SAndroid Build Coastguard Worker reciprocal_result_cpu = torch.reciprocal(cpu_x) 6655*da0073e9SAndroid Build Coastguard Worker 6656*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(reciprocal_result_cpu) 6657*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 6658*da0073e9SAndroid Build Coastguard Worker 6659*da0073e9SAndroid Build Coastguard Worker reciprocal_result.backward(gradient=grad) 6660*da0073e9SAndroid Build Coastguard Worker reciprocal_result_cpu.backward(gradient=cpu_grad) 6661*da0073e9SAndroid Build Coastguard Worker 6662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reciprocal_result, reciprocal_result_cpu) 6663*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 6664*da0073e9SAndroid Build Coastguard Worker 6665*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 6666*da0073e9SAndroid Build Coastguard Worker 6667*da0073e9SAndroid Build Coastguard Worker # Test sqrt 6668*da0073e9SAndroid Build Coastguard Worker def test_sqrt(self): 6669*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6670*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 6671*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 6672*da0073e9SAndroid Build Coastguard Worker 6673*da0073e9SAndroid Build Coastguard Worker sqrt_result = torch.sqrt(x) 6674*da0073e9SAndroid Build Coastguard Worker sqrt_result_cpu = torch.sqrt(cpu_x) 6675*da0073e9SAndroid Build Coastguard Worker 6676*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(sqrt_result_cpu) 6677*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 6678*da0073e9SAndroid Build Coastguard Worker 6679*da0073e9SAndroid Build Coastguard Worker sqrt_result.backward(gradient=grad) 6680*da0073e9SAndroid Build Coastguard Worker sqrt_result_cpu.backward(gradient=cpu_grad) 6681*da0073e9SAndroid Build Coastguard Worker 6682*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sqrt_result, sqrt_result_cpu) 6683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 6684*da0073e9SAndroid Build Coastguard Worker 6685*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 6686*da0073e9SAndroid Build Coastguard Worker 6687*da0073e9SAndroid Build Coastguard Worker # Test selu, elu, celu 6688*da0073e9SAndroid Build Coastguard Worker def test_elu(self): 6689*da0073e9SAndroid Build Coastguard Worker def helper(shape, alpha=1.0, memory_format=torch.contiguous_format): 6690*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 6691*da0073e9SAndroid Build Coastguard Worker cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_() 6692*da0073e9SAndroid Build Coastguard Worker 6693*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_(True) 6694*da0073e9SAndroid Build Coastguard Worker for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]: 6695*da0073e9SAndroid Build Coastguard Worker elu_result = activation_func(x) 6696*da0073e9SAndroid Build Coastguard Worker elu_result_cpu = activation_func(cpu_x) 6697*da0073e9SAndroid Build Coastguard Worker 6698*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(elu_result_cpu.shape) 6699*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 6700*da0073e9SAndroid Build Coastguard Worker 6701*da0073e9SAndroid Build Coastguard Worker elu_result.backward(gradient=grad) 6702*da0073e9SAndroid Build Coastguard Worker elu_result_cpu.backward(gradient=cpu_grad) 6703*da0073e9SAndroid Build Coastguard Worker 6704*da0073e9SAndroid Build Coastguard Worker self.assertEqual(elu_result, elu_result_cpu) 6705*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 6706*da0073e9SAndroid Build Coastguard Worker 6707*da0073e9SAndroid Build Coastguard Worker # Test empty shape too 6708*da0073e9SAndroid Build Coastguard Worker for memory_fromat in [torch.channels_last, torch.contiguous_format]: 6709*da0073e9SAndroid Build Coastguard Worker for shape in [(2, 8, 4, 5)]: 6710*da0073e9SAndroid Build Coastguard Worker for alpha in [0.000001, 1.0, 2.3, 0.34, 23]: 6711*da0073e9SAndroid Build Coastguard Worker helper(shape, alpha, memory_fromat) 6712*da0073e9SAndroid Build Coastguard Worker 6713*da0073e9SAndroid Build Coastguard Worker def test_elu_strided_output(self): 6714*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/124834 6715*da0073e9SAndroid Build Coastguard Worker elu_input = torch.randn(1, 1024, 500) 6716*da0073e9SAndroid Build Coastguard Worker alpha = float(1) 6717*da0073e9SAndroid Build Coastguard Worker inplace = False 6718*da0073e9SAndroid Build Coastguard Worker 6719*da0073e9SAndroid Build Coastguard Worker elu_input_noncontiguous = elu_input.transpose(1, 2) 6720*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6721*da0073e9SAndroid Build Coastguard Worker F.elu(elu_input_noncontiguous.to('cpu'), alpha, inplace), 6722*da0073e9SAndroid Build Coastguard Worker F.elu(elu_input_noncontiguous.to('mps'), alpha, inplace) 6723*da0073e9SAndroid Build Coastguard Worker ) 6724*da0073e9SAndroid Build Coastguard Worker 6725*da0073e9SAndroid Build Coastguard Worker # Test glu 6726*da0073e9SAndroid Build Coastguard Worker def test_glu(self): 6727*da0073e9SAndroid Build Coastguard Worker def helper(shape, dim=0): 6728*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 6729*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 6730*da0073e9SAndroid Build Coastguard Worker 6731*da0073e9SAndroid Build Coastguard Worker for activation_func in [torch.nn.GLU(dim=dim)]: 6732*da0073e9SAndroid Build Coastguard Worker glu_result = activation_func(x) 6733*da0073e9SAndroid Build Coastguard Worker glu_result_cpu = activation_func(cpu_x) 6734*da0073e9SAndroid Build Coastguard Worker 6735*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(glu_result_cpu.shape) 6736*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 6737*da0073e9SAndroid Build Coastguard Worker 6738*da0073e9SAndroid Build Coastguard Worker glu_result.backward(gradient=grad) 6739*da0073e9SAndroid Build Coastguard Worker glu_result_cpu.backward(gradient=cpu_grad) 6740*da0073e9SAndroid Build Coastguard Worker 6741*da0073e9SAndroid Build Coastguard Worker self.assertEqual(glu_result, glu_result_cpu) 6742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 6743*da0073e9SAndroid Build Coastguard Worker 6744*da0073e9SAndroid Build Coastguard Worker for shape in [[4], (2, 4), (2, 8, 4, 6)]: 6745*da0073e9SAndroid Build Coastguard Worker for dim in range(len(shape)): 6746*da0073e9SAndroid Build Coastguard Worker helper(shape, dim) 6747*da0073e9SAndroid Build Coastguard Worker 6748*da0073e9SAndroid Build Coastguard Worker # Test softplus 6749*da0073e9SAndroid Build Coastguard Worker def test_softplus(self): 6750*da0073e9SAndroid Build Coastguard Worker def helper(shape, beta, threshold, dtype): 6751*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) 6752*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 6753*da0073e9SAndroid Build Coastguard Worker 6754*da0073e9SAndroid Build Coastguard Worker softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x) 6755*da0073e9SAndroid Build Coastguard Worker softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x) 6756*da0073e9SAndroid Build Coastguard Worker 6757*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(softplus_result.shape) 6758*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 6759*da0073e9SAndroid Build Coastguard Worker 6760*da0073e9SAndroid Build Coastguard Worker softplus_result.backward(gradient=grad) 6761*da0073e9SAndroid Build Coastguard Worker softplus_result_cpu.backward(gradient=cpu_grad) 6762*da0073e9SAndroid Build Coastguard Worker 6763*da0073e9SAndroid Build Coastguard Worker self.assertEqual(softplus_result, softplus_result_cpu) 6764*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 6765*da0073e9SAndroid Build Coastguard Worker 6766*da0073e9SAndroid Build Coastguard Worker # Test empty shape too 6767*da0073e9SAndroid Build Coastguard Worker for shape, beta, threshold, dtype in product( 6768*da0073e9SAndroid Build Coastguard Worker [(), (2, 3), (10, 10), (2, 3, 4, 5)], 6769*da0073e9SAndroid Build Coastguard Worker [0.5, 1, 2, 3, 4], 6770*da0073e9SAndroid Build Coastguard Worker [0.5, 20, 30, 40, 50], 6771*da0073e9SAndroid Build Coastguard Worker [torch.float16, torch.float32] 6772*da0073e9SAndroid Build Coastguard Worker ): 6773*da0073e9SAndroid Build Coastguard Worker helper(shape, beta, threshold, dtype) 6774*da0073e9SAndroid Build Coastguard Worker 6775*da0073e9SAndroid Build Coastguard Worker # Test silu 6776*da0073e9SAndroid Build Coastguard Worker 6777*da0073e9SAndroid Build Coastguard Worker def test_silu(self): 6778*da0073e9SAndroid Build Coastguard Worker def helper(shape): 6779*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 6780*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 6781*da0073e9SAndroid Build Coastguard Worker 6782*da0073e9SAndroid Build Coastguard Worker silu_result = torch.nn.SiLU()(x) 6783*da0073e9SAndroid Build Coastguard Worker silu_result_cpu = torch.nn.SiLU()(cpu_x) 6784*da0073e9SAndroid Build Coastguard Worker 6785*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(silu_result_cpu.shape) 6786*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 6787*da0073e9SAndroid Build Coastguard Worker 6788*da0073e9SAndroid Build Coastguard Worker silu_result.backward(gradient=grad) 6789*da0073e9SAndroid Build Coastguard Worker silu_result_cpu.backward(gradient=cpu_grad) 6790*da0073e9SAndroid Build Coastguard Worker 6791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(silu_result, silu_result_cpu) 6792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 6793*da0073e9SAndroid Build Coastguard Worker 6794*da0073e9SAndroid Build Coastguard Worker # Test empty shape too 6795*da0073e9SAndroid Build Coastguard Worker for shape in [[], (2, 3), (2, 8, 4, 5)]: 6796*da0073e9SAndroid Build Coastguard Worker helper(shape) 6797*da0073e9SAndroid Build Coastguard Worker 6798*da0073e9SAndroid Build Coastguard Worker def test_cast_mps_to_cpu(self): 6799*da0073e9SAndroid Build Coastguard Worker def helper(src_dtype, dst_dtype): 6800*da0073e9SAndroid Build Coastguard Worker input = torch.rand((1, 3, 128, 128), dtype=src_dtype) 6801*da0073e9SAndroid Build Coastguard Worker input_cast_mps = input.to('mps') 6802*da0073e9SAndroid Build Coastguard Worker input_cast_cpu = input_cast_mps.to('cpu', dtype=dst_dtype) 6803*da0073e9SAndroid Build Coastguard Worker 6804*da0073e9SAndroid Build Coastguard Worker # needs to match the initial Tensor 6805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_cast_cpu, input.to(dtype=dst_dtype)) 6806*da0073e9SAndroid Build Coastguard Worker helper(torch.half, torch.float) 6807*da0073e9SAndroid Build Coastguard Worker helper(torch.float, torch.half) 6808*da0073e9SAndroid Build Coastguard Worker 6809*da0073e9SAndroid Build Coastguard Worker def test_cast_mps_to_mps(self): 6810*da0073e9SAndroid Build Coastguard Worker def helper(src_dtype, dst_dtype): 6811*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.rand((1, 3, 128, 128), dtype=src_dtype) 6812*da0073e9SAndroid Build Coastguard Worker input_mps = input_cpu.to('mps') 6813*da0073e9SAndroid Build Coastguard Worker output_mps = input_mps.to(dtype=dst_dtype) 6814*da0073e9SAndroid Build Coastguard Worker output_cpu = input_cpu.to(dtype=dst_dtype) 6815*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_mps.cpu(), output_cpu) 6816*da0073e9SAndroid Build Coastguard Worker helper(torch.half, torch.float) 6817*da0073e9SAndroid Build Coastguard Worker helper(torch.float, torch.half) 6818*da0073e9SAndroid Build Coastguard Worker helper(torch.half, torch.long) 6819*da0073e9SAndroid Build Coastguard Worker helper(torch.float, torch.int) 6820*da0073e9SAndroid Build Coastguard Worker 6821*da0073e9SAndroid Build Coastguard Worker def test_avg_pool2d_count_include_pad(self): 6822*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn((1, 3, 9, 9), device='cpu', dtype=torch.float, requires_grad=True) 6823*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 6824*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.AvgPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), ceil_mode=True, count_include_pad=True) 6825*da0073e9SAndroid Build Coastguard Worker ref_y = pool(cpu_x) 6826*da0073e9SAndroid Build Coastguard Worker y = pool(x) 6827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 6828*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(ref_y.shape) 6829*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 6830*da0073e9SAndroid Build Coastguard Worker ref_y.backward(gradient=cpu_grad) 6831*da0073e9SAndroid Build Coastguard Worker y.backward(gradient=grad) 6832*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 6833*da0073e9SAndroid Build Coastguard Worker 6834*da0073e9SAndroid Build Coastguard Worker # Test adaptive avg pool2d - when the input size is a multiple of output size 6835*da0073e9SAndroid Build Coastguard Worker # Not testing for channels last right now 6836*da0073e9SAndroid Build Coastguard Worker def test_adaptive_avg_pool2d_simple(self): 6837*da0073e9SAndroid Build Coastguard Worker def helper(input_shape, out_shape, channels_last): 6838*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True) 6839*da0073e9SAndroid Build Coastguard Worker if (channels_last): 6840*da0073e9SAndroid Build Coastguard Worker cpu_x = cpu_x.to(memory_format=torch.channels_last) 6841*da0073e9SAndroid Build Coastguard Worker cpu_x.retain_grad() 6842*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 6843*da0073e9SAndroid Build Coastguard Worker 6844*da0073e9SAndroid Build Coastguard Worker avg_result = torch.nn.AdaptiveAvgPool2d(out_shape)(x) 6845*da0073e9SAndroid Build Coastguard Worker avg_result_cpu = torch.nn.AdaptiveAvgPool2d(out_shape)(cpu_x) 6846*da0073e9SAndroid Build Coastguard Worker 6847*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(avg_result_cpu.shape) 6848*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 6849*da0073e9SAndroid Build Coastguard Worker 6850*da0073e9SAndroid Build Coastguard Worker avg_result.backward(gradient=grad) 6851*da0073e9SAndroid Build Coastguard Worker avg_result_cpu.backward(gradient=cpu_grad) 6852*da0073e9SAndroid Build Coastguard Worker 6853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(avg_result, avg_result_cpu) 6854*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 6855*da0073e9SAndroid Build Coastguard Worker 6856*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 4, 4), (2, 2), False) 6857*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 9, 9), (3, 3), False) 6858*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 9, 9), (9, 9), False) 6859*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 16, 16), (2, 2), False) 6860*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 16, 16), (2, 16), False) 6861*da0073e9SAndroid Build Coastguard Worker 6862*da0073e9SAndroid Build Coastguard Worker helper((2, 16, 16), (4, 4), False) 6863*da0073e9SAndroid Build Coastguard Worker 6864*da0073e9SAndroid Build Coastguard Worker # Output shape larger than input shape 6865*da0073e9SAndroid Build Coastguard Worker 6866*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 4, 4), (8, 8), False) 6867*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 2, 2), (4, 4), False) 6868*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 3, 3), (9, 9), False) 6869*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 2, 2), (16, 16), False) 6870*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 2, 16), (16, 16), False) 6871*da0073e9SAndroid Build Coastguard Worker 6872*da0073e9SAndroid Build Coastguard Worker helper((2, 4, 4), (16, 16), False) 6873*da0073e9SAndroid Build Coastguard Worker 6874*da0073e9SAndroid Build Coastguard Worker try: 6875*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 3, 3), (7, 7), False) 6876*da0073e9SAndroid Build Coastguard Worker except Exception as e: 6877*da0073e9SAndroid Build Coastguard Worker pass 6878*da0073e9SAndroid Build Coastguard Worker 6879*da0073e9SAndroid Build Coastguard Worker # Test max avg pool2d - when the input size is a multiple of output size 6880*da0073e9SAndroid Build Coastguard Worker # Not testing for channels last right now 6881*da0073e9SAndroid Build Coastguard Worker def test_adaptive_max_pool2d_simple(self): 6882*da0073e9SAndroid Build Coastguard Worker def helper(input_shape, out_shape, return_indices, dtype, channels_last=False): 6883*da0073e9SAndroid Build Coastguard Worker cpu_x = None 6884*da0073e9SAndroid Build Coastguard Worker if (dtype in [torch.float16, torch.float32]): 6885*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True) 6886*da0073e9SAndroid Build Coastguard Worker else: 6887*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True) 6888*da0073e9SAndroid Build Coastguard Worker if (channels_last): 6889*da0073e9SAndroid Build Coastguard Worker cpu_x = cpu_x.to(memory_format=torch.channels_last) 6890*da0073e9SAndroid Build Coastguard Worker cpu_x.retain_grad() 6891*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 6892*da0073e9SAndroid Build Coastguard Worker 6893*da0073e9SAndroid Build Coastguard Worker max_result, max_indices = None, None 6894*da0073e9SAndroid Build Coastguard Worker max_result_cpu, max_indices_cpu = None, None 6895*da0073e9SAndroid Build Coastguard Worker 6896*da0073e9SAndroid Build Coastguard Worker if (return_indices): 6897*da0073e9SAndroid Build Coastguard Worker max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x) 6898*da0073e9SAndroid Build Coastguard Worker max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x) 6899*da0073e9SAndroid Build Coastguard Worker else: 6900*da0073e9SAndroid Build Coastguard Worker max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x) 6901*da0073e9SAndroid Build Coastguard Worker max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x) 6902*da0073e9SAndroid Build Coastguard Worker 6903*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(max_result_cpu.shape) 6904*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 6905*da0073e9SAndroid Build Coastguard Worker 6906*da0073e9SAndroid Build Coastguard Worker max_result.backward(gradient=grad) 6907*da0073e9SAndroid Build Coastguard Worker max_result_cpu.backward(gradient=cpu_grad) 6908*da0073e9SAndroid Build Coastguard Worker 6909*da0073e9SAndroid Build Coastguard Worker self.assertEqual(max_result, max_result_cpu) 6910*da0073e9SAndroid Build Coastguard Worker if (return_indices): 6911*da0073e9SAndroid Build Coastguard Worker self.assertEqual(max_indices, max_indices_cpu) 6912*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 6913*da0073e9SAndroid Build Coastguard Worker 6914*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float32]: 6915*da0073e9SAndroid Build Coastguard Worker for return_indices in [False, True]: 6916*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 4, 4), (2, 2), return_indices, dtype) 6917*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 9, 9), (3, 3), return_indices, dtype) 6918*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 9, 9), (9, 9), return_indices, dtype) 6919*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 16, 16), (2, 2), return_indices, dtype) 6920*da0073e9SAndroid Build Coastguard Worker helper((2, 2, 16, 16), (2, 16), return_indices, dtype) 6921*da0073e9SAndroid Build Coastguard Worker helper((2, 16, 16), (4, 4), return_indices, dtype) 6922*da0073e9SAndroid Build Coastguard Worker 6923*da0073e9SAndroid Build Coastguard Worker def test_gelu_simple(self): 6924*da0073e9SAndroid Build Coastguard Worker def helper(shape, dtype=torch.float, contiguous=True): 6925*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtype) 6926*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6927*da0073e9SAndroid Build Coastguard Worker 6928*da0073e9SAndroid Build Coastguard Worker if not contiguous and (0 not in shape and len(shape) >= 2): 6929*da0073e9SAndroid Build Coastguard Worker # Tranposing will make the tensor non-contiguous 6930*da0073e9SAndroid Build Coastguard Worker cpu_x = cpu_x.transpose(0, 1) 6931*da0073e9SAndroid Build Coastguard Worker x = x.transpose(0, 1) 6932*da0073e9SAndroid Build Coastguard Worker assert not x.is_contiguous() 6933*da0073e9SAndroid Build Coastguard Worker 6934*da0073e9SAndroid Build Coastguard Worker cpu_x.requires_grad_() 6935*da0073e9SAndroid Build Coastguard Worker x.requires_grad_() 6936*da0073e9SAndroid Build Coastguard Worker 6937*da0073e9SAndroid Build Coastguard Worker gelu_result = torch.nn.GELU()(x) 6938*da0073e9SAndroid Build Coastguard Worker # GELU is not supported on CPU, so cast it to float 6939*da0073e9SAndroid Build Coastguard Worker gelu_result_cpu = torch.nn.GELU()(cpu_x.to(torch.float)) 6940*da0073e9SAndroid Build Coastguard Worker 6941*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(gelu_result_cpu) 6942*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 6943*da0073e9SAndroid Build Coastguard Worker 6944*da0073e9SAndroid Build Coastguard Worker gelu_result.backward(gradient=grad) 6945*da0073e9SAndroid Build Coastguard Worker gelu_result_cpu.backward(gradient=cpu_grad) 6946*da0073e9SAndroid Build Coastguard Worker 6947*da0073e9SAndroid Build Coastguard Worker atol = 1e-5 if dtype == torch.float else 1e-2 6948*da0073e9SAndroid Build Coastguard Worker rtol = 1e-3 if dtype == torch.float else 1e-2 6949*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol) 6950*da0073e9SAndroid Build Coastguard Worker 6951*da0073e9SAndroid Build Coastguard Worker assert x.grad is not None # Check that the grad is well-populated 6952*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol) 6953*da0073e9SAndroid Build Coastguard Worker 6954*da0073e9SAndroid Build Coastguard Worker # Test empty shape too 6955*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.half]: 6956*da0073e9SAndroid Build Coastguard Worker for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]: 6957*da0073e9SAndroid Build Coastguard Worker for contiguous in [True, False]: 6958*da0073e9SAndroid Build Coastguard Worker helper(shape, dtype, contiguous) 6959*da0073e9SAndroid Build Coastguard Worker # Test that gelu would raise an assert for integral types 6960*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: 6961*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps"))) 6962*da0073e9SAndroid Build Coastguard Worker 6963*da0073e9SAndroid Build Coastguard Worker def test_mish_simple(self): 6964*da0073e9SAndroid Build Coastguard Worker def helper(shape, dtype=torch.float, contiguous=True): 6965*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtype) 6966*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 6967*da0073e9SAndroid Build Coastguard Worker 6968*da0073e9SAndroid Build Coastguard Worker if not contiguous and (0 not in shape and len(shape) >= 2): 6969*da0073e9SAndroid Build Coastguard Worker # Tranposing will make the tensor non-contiguous 6970*da0073e9SAndroid Build Coastguard Worker cpu_x = cpu_x.transpose(0, 1) 6971*da0073e9SAndroid Build Coastguard Worker x = x.transpose(0, 1) 6972*da0073e9SAndroid Build Coastguard Worker assert not x.is_contiguous() 6973*da0073e9SAndroid Build Coastguard Worker 6974*da0073e9SAndroid Build Coastguard Worker cpu_x.requires_grad_() 6975*da0073e9SAndroid Build Coastguard Worker x.requires_grad_() 6976*da0073e9SAndroid Build Coastguard Worker 6977*da0073e9SAndroid Build Coastguard Worker mish_result = torch.nn.Mish()(x) 6978*da0073e9SAndroid Build Coastguard Worker mish_result_cpu = torch.nn.Mish()(cpu_x) 6979*da0073e9SAndroid Build Coastguard Worker 6980*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(mish_result_cpu) 6981*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 6982*da0073e9SAndroid Build Coastguard Worker 6983*da0073e9SAndroid Build Coastguard Worker mish_result.backward(gradient=grad) 6984*da0073e9SAndroid Build Coastguard Worker mish_result_cpu.backward(gradient=cpu_grad) 6985*da0073e9SAndroid Build Coastguard Worker 6986*da0073e9SAndroid Build Coastguard Worker atol = 1e-5 if dtype == torch.float else 1e-2 6987*da0073e9SAndroid Build Coastguard Worker rtol = 1e-3 if dtype == torch.float else 1e-2 6988*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol) 6989*da0073e9SAndroid Build Coastguard Worker 6990*da0073e9SAndroid Build Coastguard Worker assert x.grad is not None # Check that the grad is well-populated 6991*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol) 6992*da0073e9SAndroid Build Coastguard Worker 6993*da0073e9SAndroid Build Coastguard Worker # Test empty shape too 6994*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.half]: 6995*da0073e9SAndroid Build Coastguard Worker for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]: 6996*da0073e9SAndroid Build Coastguard Worker for contiguous in [True, False]: 6997*da0073e9SAndroid Build Coastguard Worker helper(shape, dtype, contiguous) 6998*da0073e9SAndroid Build Coastguard Worker 6999*da0073e9SAndroid Build Coastguard Worker def test_gelu(self): 7000*da0073e9SAndroid Build Coastguard Worker def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None): 7001*da0073e9SAndroid Build Coastguard Worker numpy_dtype = { 7002*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double 7003*da0073e9SAndroid Build Coastguard Worker }[dtype] 7004*da0073e9SAndroid Build Coastguard Worker devices = ['cpu'] 7005*da0073e9SAndroid Build Coastguard Worker devices += ['mps'] 7006*da0073e9SAndroid Build Coastguard Worker 7007*da0073e9SAndroid Build Coastguard Worker def _gelu_ref(X): 7008*da0073e9SAndroid Build Coastguard Worker return X * stats.norm.cdf(X) # noqa: F821 7009*da0073e9SAndroid Build Coastguard Worker 7010*da0073e9SAndroid Build Coastguard Worker for d in devices: 7011*da0073e9SAndroid Build Coastguard Worker X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2] 7012*da0073e9SAndroid Build Coastguard Worker res = X 7013*da0073e9SAndroid Build Coastguard Worker ref = (X.to(numpy_dtype).cpu().detach().numpy()) 7014*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False) 7015*da0073e9SAndroid Build Coastguard Worker 7016*da0073e9SAndroid Build Coastguard Worker for n in [1, 5, 10]: 7017*da0073e9SAndroid Build Coastguard Worker for m in [1, 5, 10]: 7018*da0073e9SAndroid Build Coastguard Worker _test_gelu(n, m, torch.float32, True) 7019*da0073e9SAndroid Build Coastguard Worker _test_gelu(n, m, torch.float32, False) 7020*da0073e9SAndroid Build Coastguard Worker 7021*da0073e9SAndroid Build Coastguard Worker # Test multi threaded 7022*da0073e9SAndroid Build Coastguard Worker num_threads = torch.get_num_threads() 7023*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(4) 7024*da0073e9SAndroid Build Coastguard Worker try: 7025*da0073e9SAndroid Build Coastguard Worker _test_gelu(32, 32, torch.float32, False) 7026*da0073e9SAndroid Build Coastguard Worker finally: 7027*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(num_threads) 7028*da0073e9SAndroid Build Coastguard Worker 7029*da0073e9SAndroid Build Coastguard Worker def test_gelu_tanh(self): 7030*da0073e9SAndroid Build Coastguard Worker def helper(shape): 7031*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 7032*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 7033*da0073e9SAndroid Build Coastguard Worker 7034*da0073e9SAndroid Build Coastguard Worker gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh') 7035*da0073e9SAndroid Build Coastguard Worker gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh') 7036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu) 7037*da0073e9SAndroid Build Coastguard Worker 7038*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 7039*da0073e9SAndroid Build Coastguard Worker 7040*da0073e9SAndroid Build Coastguard Worker # Test hardtanh 7041*da0073e9SAndroid Build Coastguard Worker def test_hardtanh(self): 7042*da0073e9SAndroid Build Coastguard Worker def helper(shape, min_val, max_val, inplace=False): 7043*da0073e9SAndroid Build Coastguard Worker cpu_x = None 7044*da0073e9SAndroid Build Coastguard Worker x = None 7045*da0073e9SAndroid Build Coastguard Worker 7046*da0073e9SAndroid Build Coastguard Worker if (not inplace): 7047*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7048*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7049*da0073e9SAndroid Build Coastguard Worker else: 7050*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 7051*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 7052*da0073e9SAndroid Build Coastguard Worker 7053*da0073e9SAndroid Build Coastguard Worker hardtanh_result = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(x) 7054*da0073e9SAndroid Build Coastguard Worker hardtanh_result_cpu = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(cpu_x) 7055*da0073e9SAndroid Build Coastguard Worker 7056*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hardtanh_result, hardtanh_result_cpu) 7057*da0073e9SAndroid Build Coastguard Worker 7058*da0073e9SAndroid Build Coastguard Worker if (not inplace): 7059*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(hardtanh_result_cpu.shape) 7060*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7061*da0073e9SAndroid Build Coastguard Worker hardtanh_result.backward(gradient=grad) 7062*da0073e9SAndroid Build Coastguard Worker hardtanh_result_cpu.backward(gradient=cpu_grad) 7063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 7064*da0073e9SAndroid Build Coastguard Worker 7065*da0073e9SAndroid Build Coastguard Worker # Test empty shape too 7066*da0073e9SAndroid Build Coastguard Worker for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]: 7067*da0073e9SAndroid Build Coastguard Worker for min_val, max_val in zip([-1, -2, 3], [1, -1, 4]): 7068*da0073e9SAndroid Build Coastguard Worker helper(shape, min_val, max_val) 7069*da0073e9SAndroid Build Coastguard Worker helper(shape, min_val, max_val, inplace=True) 7070*da0073e9SAndroid Build Coastguard Worker 7071*da0073e9SAndroid Build Coastguard Worker def test_hardswish(self): 7072*da0073e9SAndroid Build Coastguard Worker def helper(shape, inplace=False, requires_grad=True): 7073*da0073e9SAndroid Build Coastguard Worker m = nn.Hardswish(inplace=inplace) 7074*da0073e9SAndroid Build Coastguard Worker 7075*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=requires_grad) 7076*da0073e9SAndroid Build Coastguard Worker input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad) 7077*da0073e9SAndroid Build Coastguard Worker 7078*da0073e9SAndroid Build Coastguard Worker if inplace and requires_grad: # check that both raise runtime error 7079*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: m(input_cpu)) 7080*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: m(input_mps)) 7081*da0073e9SAndroid Build Coastguard Worker return 7082*da0073e9SAndroid Build Coastguard Worker 7083*da0073e9SAndroid Build Coastguard Worker output_cpu = m(input_cpu) 7084*da0073e9SAndroid Build Coastguard Worker output_mps = m(input_mps) 7085*da0073e9SAndroid Build Coastguard Worker 7086*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(output_cpu) 7087*da0073e9SAndroid Build Coastguard Worker mps_grad = cpu_grad.to('mps') 7088*da0073e9SAndroid Build Coastguard Worker 7089*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps) 7090*da0073e9SAndroid Build Coastguard Worker 7091*da0073e9SAndroid Build Coastguard Worker if requires_grad: 7092*da0073e9SAndroid Build Coastguard Worker output_cpu.backward(gradient=cpu_grad) 7093*da0073e9SAndroid Build Coastguard Worker output_mps.backward(gradient=mps_grad) 7094*da0073e9SAndroid Build Coastguard Worker 7095*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_cpu.grad, input_mps.grad) 7096*da0073e9SAndroid Build Coastguard Worker 7097*da0073e9SAndroid Build Coastguard Worker for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]: 7098*da0073e9SAndroid Build Coastguard Worker helper(shape, inplace=False, requires_grad=False) 7099*da0073e9SAndroid Build Coastguard Worker helper(shape, inplace=True, requires_grad=False) 7100*da0073e9SAndroid Build Coastguard Worker helper(shape, inplace=False, requires_grad=True) 7101*da0073e9SAndroid Build Coastguard Worker helper(shape, inplace=True, requires_grad=True) 7102*da0073e9SAndroid Build Coastguard Worker 7103*da0073e9SAndroid Build Coastguard Worker def test_transpose_2D(self): 7104*da0073e9SAndroid Build Coastguard Worker values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] 7105*da0073e9SAndroid Build Coastguard Worker values1 = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] 7106*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu') 7107*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor(values, device='mps') 7108*da0073e9SAndroid Build Coastguard Worker mps_x1 = torch.tensor(values1, device='mps') 7109*da0073e9SAndroid Build Coastguard Worker 7110*da0073e9SAndroid Build Coastguard Worker cpu_transpose = torch.transpose(cpu_x, 0, 1) 7111*da0073e9SAndroid Build Coastguard Worker mps_transpose = torch.transpose(mps_x, 0, 1) 7112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_transpose, mps_transpose.to('cpu')) 7113*da0073e9SAndroid Build Coastguard Worker 7114*da0073e9SAndroid Build Coastguard Worker def test_transpose_3D(self): 7115*da0073e9SAndroid Build Coastguard Worker values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] 7116*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu') 7117*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor(values, device='mps') 7118*da0073e9SAndroid Build Coastguard Worker 7119*da0073e9SAndroid Build Coastguard Worker cpu_transpose1 = torch.transpose(cpu_x, 0, 1) 7120*da0073e9SAndroid Build Coastguard Worker mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu') 7121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_transpose1, mps_transpose1) 7122*da0073e9SAndroid Build Coastguard Worker 7123*da0073e9SAndroid Build Coastguard Worker cpu_transpose2 = torch.transpose(cpu_x, 0, 2) 7124*da0073e9SAndroid Build Coastguard Worker mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu') 7125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_transpose2, mps_transpose2) 7126*da0073e9SAndroid Build Coastguard Worker 7127*da0073e9SAndroid Build Coastguard Worker cpu_transpose3 = torch.transpose(cpu_x, 1, 2) 7128*da0073e9SAndroid Build Coastguard Worker mps_transpose3 = torch.transpose(mps_x, 1, 2).to('cpu') 7129*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_transpose3, mps_transpose3) 7130*da0073e9SAndroid Build Coastguard Worker 7131*da0073e9SAndroid Build Coastguard Worker 7132*da0073e9SAndroid Build Coastguard Worker def test_transpose_4D(self): 7133*da0073e9SAndroid Build Coastguard Worker values = [[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]], 7134*da0073e9SAndroid Build Coastguard Worker [[[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]]] 7135*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(values, device='cpu') 7136*da0073e9SAndroid Build Coastguard Worker mps_x = torch.tensor(values, device='mps') 7137*da0073e9SAndroid Build Coastguard Worker 7138*da0073e9SAndroid Build Coastguard Worker cpu_transpose1 = torch.transpose(cpu_x, 0, 1) 7139*da0073e9SAndroid Build Coastguard Worker mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu') 7140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_transpose1, mps_transpose1) 7141*da0073e9SAndroid Build Coastguard Worker 7142*da0073e9SAndroid Build Coastguard Worker cpu_transpose2 = torch.transpose(cpu_x, 0, 2) 7143*da0073e9SAndroid Build Coastguard Worker mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu') 7144*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_transpose2, mps_transpose2) 7145*da0073e9SAndroid Build Coastguard Worker 7146*da0073e9SAndroid Build Coastguard Worker cpu_transpose3 = torch.transpose(cpu_x, 0, 3) 7147*da0073e9SAndroid Build Coastguard Worker mps_transpose3 = torch.transpose(mps_x, 0, 3).to('cpu') 7148*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_transpose3, mps_transpose3) 7149*da0073e9SAndroid Build Coastguard Worker 7150*da0073e9SAndroid Build Coastguard Worker cpu_transpose4 = torch.transpose(cpu_x, 3, 1) 7151*da0073e9SAndroid Build Coastguard Worker mps_transpose4 = torch.transpose(mps_x, 3, 1).to('cpu') 7152*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_transpose4, mps_transpose4) 7153*da0073e9SAndroid Build Coastguard Worker 7154*da0073e9SAndroid Build Coastguard Worker cpu_transpose5 = torch.transpose(cpu_x, 3, 2) 7155*da0073e9SAndroid Build Coastguard Worker mps_transpose5 = torch.transpose(mps_x, 3, 2).to('cpu') 7156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_transpose5, mps_transpose5) 7157*da0073e9SAndroid Build Coastguard Worker 7158*da0073e9SAndroid Build Coastguard Worker cpu_transpose6 = torch.transpose(cpu_x, 1, 2) 7159*da0073e9SAndroid Build Coastguard Worker mps_transpose6 = torch.transpose(mps_x, 1, 2).to('cpu') 7160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_transpose6, mps_transpose6) 7161*da0073e9SAndroid Build Coastguard Worker 7162*da0073e9SAndroid Build Coastguard Worker # Test sign 7163*da0073e9SAndroid Build Coastguard Worker def test_sign(self): 7164*da0073e9SAndroid Build Coastguard Worker def helper(shape): 7165*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7166*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7167*da0073e9SAndroid Build Coastguard Worker 7168*da0073e9SAndroid Build Coastguard Worker sign_result = torch.sign(x) 7169*da0073e9SAndroid Build Coastguard Worker sign_result_cpu = torch.sign(cpu_x) 7170*da0073e9SAndroid Build Coastguard Worker 7171*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(sign_result_cpu) 7172*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7173*da0073e9SAndroid Build Coastguard Worker 7174*da0073e9SAndroid Build Coastguard Worker sign_result.backward(gradient=grad) 7175*da0073e9SAndroid Build Coastguard Worker sign_result_cpu.backward(gradient=cpu_grad) 7176*da0073e9SAndroid Build Coastguard Worker 7177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sign_result, sign_result_cpu) 7178*da0073e9SAndroid Build Coastguard Worker 7179*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 7180*da0073e9SAndroid Build Coastguard Worker 7181*da0073e9SAndroid Build Coastguard Worker def test_signbit(self): 7182*da0073e9SAndroid Build Coastguard Worker def helper(shape, dtype): 7183*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu').to(dtype) 7184*da0073e9SAndroid Build Coastguard Worker x = cpu_x.clone().to('mps') 7185*da0073e9SAndroid Build Coastguard Worker 7186*da0073e9SAndroid Build Coastguard Worker signbit_result = torch.signbit(x) 7187*da0073e9SAndroid Build Coastguard Worker signbit_result_cpu = torch.signbit(cpu_x) 7188*da0073e9SAndroid Build Coastguard Worker 7189*da0073e9SAndroid Build Coastguard Worker self.assertEqual(signbit_result, signbit_result_cpu) 7190*da0073e9SAndroid Build Coastguard Worker 7191*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), torch.int) 7192*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), torch.float) 7193*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), torch.int64) 7194*da0073e9SAndroid Build Coastguard Worker 7195*da0073e9SAndroid Build Coastguard Worker # Test neg 7196*da0073e9SAndroid Build Coastguard Worker def test_neg(self): 7197*da0073e9SAndroid Build Coastguard Worker def helper(shape): 7198*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7199*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7200*da0073e9SAndroid Build Coastguard Worker 7201*da0073e9SAndroid Build Coastguard Worker neg_result = torch.neg(x) 7202*da0073e9SAndroid Build Coastguard Worker neg_result_cpu = torch.neg(cpu_x) 7203*da0073e9SAndroid Build Coastguard Worker 7204*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.ones_like(neg_result_cpu) 7205*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7206*da0073e9SAndroid Build Coastguard Worker 7207*da0073e9SAndroid Build Coastguard Worker neg_result.backward(gradient=grad) 7208*da0073e9SAndroid Build Coastguard Worker neg_result_cpu.backward(gradient=cpu_grad) 7209*da0073e9SAndroid Build Coastguard Worker 7210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(neg_result, neg_result_cpu) 7211*da0073e9SAndroid Build Coastguard Worker 7212*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 7213*da0073e9SAndroid Build Coastguard Worker 7214*da0073e9SAndroid Build Coastguard Worker def test_neg_strided_input(self): 7215*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/98074#issuecomment-1496088337 7216*da0073e9SAndroid Build Coastguard Worker x = torch.arange(18.0, device='mps').reshape(2, 3, 3) 7217*da0073e9SAndroid Build Coastguard Worker y = x.permute(1, 0, 2)[..., 1] 7218*da0073e9SAndroid Build Coastguard Worker z = y + y.neg() 7219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.abs().max().item(), 0.0) 7220*da0073e9SAndroid Build Coastguard Worker 7221*da0073e9SAndroid Build Coastguard Worker # Test index add 7222*da0073e9SAndroid Build Coastguard Worker def test_index_add(self): 7223*da0073e9SAndroid Build Coastguard Worker def helper(shape, dim, index, source_shape, alpha, x_dtype=torch.float32, idx_dtype=torch.int32): 7224*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=x_dtype, requires_grad=False) 7225*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 7226*da0073e9SAndroid Build Coastguard Worker 7227*da0073e9SAndroid Build Coastguard Worker cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype) 7228*da0073e9SAndroid Build Coastguard Worker idx = cpu_idx.detach().clone().to('mps') 7229*da0073e9SAndroid Build Coastguard Worker 7230*da0073e9SAndroid Build Coastguard Worker cpu_source = torch.randn(source_shape, device='cpu', dtype=x_dtype, requires_grad=False) 7231*da0073e9SAndroid Build Coastguard Worker source = cpu_source.detach().clone().to('mps') 7232*da0073e9SAndroid Build Coastguard Worker 7233*da0073e9SAndroid Build Coastguard Worker idx_result = torch.index_add(x, dim=dim, index=idx, source=source, alpha=alpha) 7234*da0073e9SAndroid Build Coastguard Worker idx_result_cpu = torch.index_add(cpu_x, dim=dim, index=cpu_idx, source=cpu_source, alpha=alpha) 7235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_result, idx_result_cpu) 7236*da0073e9SAndroid Build Coastguard Worker 7237*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 0, [0, 1, 0], (3, 8, 4, 5), 5) 7238*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, [7], (1, 8, 4, 5), 6.0) 7239*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 1, [0, 3, 7], (2, 3, 4, 5), 5) 7240*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 2, [3, 0], (2, 8, 2, 5), 3.0) 7241*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 3, [2, 3, 0], (2, 8, 4, 3), 4) 7242*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 3), -1, [1, 2], (2, 3, 2), 6.0) 7243*da0073e9SAndroid Build Coastguard Worker # test result dim=1 7244*da0073e9SAndroid Build Coastguard Worker helper((2,), 0, [1], (1,), 6.0) 7245*da0073e9SAndroid Build Coastguard Worker helper(2, 0, 1, 1, 6) 7246*da0073e9SAndroid Build Coastguard Worker # test float16 7247*da0073e9SAndroid Build Coastguard Worker helper((2,), 0, [1], (1,), 6.0, x_dtype=torch.float16) 7248*da0073e9SAndroid Build Coastguard Worker 7249*da0073e9SAndroid Build Coastguard Worker def test_index_64bit(self): 7250*da0073e9SAndroid Build Coastguard Worker """ Test that index operations work for 4Gb+ tensors """ 7251*da0073e9SAndroid Build Coastguard Worker if product_version < 14.0: 7252*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("Sonoma is needed for large tensors, see https://github.com/pytorch/pytorch/issues/84039") 7253*da0073e9SAndroid Build Coastguard Worker # Cleanup memory 7254*da0073e9SAndroid Build Coastguard Worker gc.collect() 7255*da0073e9SAndroid Build Coastguard Worker torch.mps.empty_cache() 7256*da0073e9SAndroid Build Coastguard Worker # Check that index operations work for 4+GB tensors 7257*da0073e9SAndroid Build Coastguard Worker x = torch.rand(16000, 67120, device="mps") 7258*da0073e9SAndroid Build Coastguard Worker self.assertGreater(x.element_size() * x.numel(), 2**32) 7259*da0073e9SAndroid Build Coastguard Worker idx = torch.arange(0, 2, device="mps") 7260*da0073e9SAndroid Build Coastguard Worker x_sampled = x[:, idx] 7261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:, 0], x_sampled[:, 0]) 7262*da0073e9SAndroid Build Coastguard Worker # Reclaim memory after running the tests 7263*da0073e9SAndroid Build Coastguard Worker del x 7264*da0073e9SAndroid Build Coastguard Worker gc.collect() 7265*da0073e9SAndroid Build Coastguard Worker torch.mps.empty_cache() 7266*da0073e9SAndroid Build Coastguard Worker 7267*da0073e9SAndroid Build Coastguard Worker def test_mm_large(self): 7268*da0073e9SAndroid Build Coastguard Worker """ Test that MM works for matrices with index larger than 32K """ 7269*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, 1, device="mps") 7270*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 32769, device="mps") 7271*da0073e9SAndroid Build Coastguard Worker # This used to crash with: 7272*da0073e9SAndroid Build Coastguard Worker # error: subRange.start (24576) is not less than length of dimension[0] (16384) 7273*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095 7274*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0) 7275*da0073e9SAndroid Build Coastguard Worker 7276*da0073e9SAndroid Build Coastguard Worker def compare_mm(m, n, k, dtype=torch.float): 7277*da0073e9SAndroid Build Coastguard Worker x = torch.rand(m, n, device="mps", dtype=dtype) 7278*da0073e9SAndroid Build Coastguard Worker y = torch.rand(n, k, device="mps", dtype=dtype) 7279*da0073e9SAndroid Build Coastguard Worker z = torch.mm(x, y).cpu() 7280*da0073e9SAndroid Build Coastguard Worker z_cpu = torch.mm(x.cpu(), y.cpu()) 7281*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, z_cpu) 7282*da0073e9SAndroid Build Coastguard Worker 7283*da0073e9SAndroid Build Coastguard Worker # Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal 7284*da0073e9SAndroid Build Coastguard Worker compare_mm(1024, 1, 32769) 7285*da0073e9SAndroid Build Coastguard Worker # one more time, but with dimensions inverted 7286*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984 7287*da0073e9SAndroid Build Coastguard Worker compare_mm(32769, 1, 1025) 7288*da0073e9SAndroid Build Coastguard Worker 7289*da0073e9SAndroid Build Coastguard Worker if product_version >= 14.0: 7290*da0073e9SAndroid Build Coastguard Worker # Test bfloat16 mm 7291*da0073e9SAndroid Build Coastguard Worker compare_mm(1024, 1, 32769, torch.bfloat16) 7292*da0073e9SAndroid Build Coastguard Worker 7293*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test") 7294*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(product_version < 14.0, "Can't allocate 4Gb tensor on MacOS 13") 7295*da0073e9SAndroid Build Coastguard Worker def test_copy_large(self): 7296*da0073e9SAndroid Build Coastguard Worker """ Test that copy of 4Gb+ tensors works """ 7297*da0073e9SAndroid Build Coastguard Worker x = torch.ones((2**30 + 11,), dtype=torch.float32) 7298*da0073e9SAndroid Build Coastguard Worker y = x.to(device="mps") 7299*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(y == torch.tensor(1.0, device="mps"))) 7300*da0073e9SAndroid Build Coastguard Worker del y 7301*da0073e9SAndroid Build Coastguard Worker del x 7302*da0073e9SAndroid Build Coastguard Worker 7303*da0073e9SAndroid Build Coastguard Worker # Test flip 7304*da0073e9SAndroid Build Coastguard Worker def test_flip(self): 7305*da0073e9SAndroid Build Coastguard Worker def helper(shape, dims): 7306*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 7307*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 7308*da0073e9SAndroid Build Coastguard Worker 7309*da0073e9SAndroid Build Coastguard Worker flip_result = torch.flip(x, dims=dims) 7310*da0073e9SAndroid Build Coastguard Worker flip_result_cpu = torch.flip(cpu_x, dims=dims) 7311*da0073e9SAndroid Build Coastguard Worker 7312*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flip_result, flip_result_cpu) 7313*da0073e9SAndroid Build Coastguard Worker 7314*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), [0]) 7315*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), [0, 1]) 7316*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), (0, 1, 2, 3)) 7317*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 3), (-1,)) 7318*da0073e9SAndroid Build Coastguard Worker # empty dims 7319*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), []) 7320*da0073e9SAndroid Build Coastguard Worker # input.numel() == 1 7321*da0073e9SAndroid Build Coastguard Worker helper((1,), (0,)) 7322*da0073e9SAndroid Build Coastguard Worker # input.numel() == 0 7323*da0073e9SAndroid Build Coastguard Worker helper((0,), (0,)) 7324*da0073e9SAndroid Build Coastguard Worker # none of dims that needs to be flipped 7325*da0073e9SAndroid Build Coastguard Worker helper((1, 3), [0]) 7326*da0073e9SAndroid Build Coastguard Worker 7327*da0073e9SAndroid Build Coastguard Worker # Test index select 7328*da0073e9SAndroid Build Coastguard Worker def test_index_select(self): 7329*da0073e9SAndroid Build Coastguard Worker def helper(shape, dim, index, idx_dtype=torch.int32): 7330*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 7331*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 7332*da0073e9SAndroid Build Coastguard Worker 7333*da0073e9SAndroid Build Coastguard Worker cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype) 7334*da0073e9SAndroid Build Coastguard Worker idx = cpu_idx.detach().clone().to('mps') 7335*da0073e9SAndroid Build Coastguard Worker 7336*da0073e9SAndroid Build Coastguard Worker idx_result = torch.index_select(x, dim=dim, index=idx) 7337*da0073e9SAndroid Build Coastguard Worker idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx) 7338*da0073e9SAndroid Build Coastguard Worker 7339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_result, idx_result_cpu) 7340*da0073e9SAndroid Build Coastguard Worker 7341*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 0, [1]) 7342*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, [0, 3, 2, 7, 6]) 7343*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 1, [0, 3, 2, 7, 6]) 7344*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 2, [3, 0, 1]) 7345*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 3, [2, 3, 0]) 7346*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 3), -1, [1, 2]) 7347*da0073e9SAndroid Build Coastguard Worker helper((), 0, [0]) 7348*da0073e9SAndroid Build Coastguard Worker helper((5), 0, []) 7349*da0073e9SAndroid Build Coastguard Worker 7350*da0073e9SAndroid Build Coastguard Worker def test_index_select_scalar(self): 7351*da0073e9SAndroid Build Coastguard Worker def helper(value, dim, index, idx_dtype=torch.int32): 7352*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(value, device='cpu', dtype=torch.float, requires_grad=False) 7353*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 7354*da0073e9SAndroid Build Coastguard Worker 7355*da0073e9SAndroid Build Coastguard Worker cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype) 7356*da0073e9SAndroid Build Coastguard Worker idx = cpu_idx.detach().clone().to('mps') 7357*da0073e9SAndroid Build Coastguard Worker 7358*da0073e9SAndroid Build Coastguard Worker idx_result = torch.index_select(x, dim=dim, index=idx) 7359*da0073e9SAndroid Build Coastguard Worker idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx) 7360*da0073e9SAndroid Build Coastguard Worker 7361*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx_result, idx_result_cpu) 7362*da0073e9SAndroid Build Coastguard Worker 7363*da0073e9SAndroid Build Coastguard Worker helper(22, 0, [0]) 7364*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"): 7365*da0073e9SAndroid Build Coastguard Worker helper(22, 0, []) 7366*da0073e9SAndroid Build Coastguard Worker 7367*da0073e9SAndroid Build Coastguard Worker def test_embedding_dense_backward(self): 7368*da0073e9SAndroid Build Coastguard Worker def helper(n, d, m, idx): 7369*da0073e9SAndroid Build Coastguard Worker embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps') 7370*da0073e9SAndroid Build Coastguard Worker emedding_weight = embeddingMPS.weight.detach().cpu() 7371*da0073e9SAndroid Build Coastguard Worker W_MPS = torch.randn((m, d), requires_grad=True, device='mps') 7372*da0073e9SAndroid Build Coastguard Worker idx_MPS = torch.tensor(idx, device='mps') 7373*da0073e9SAndroid Build Coastguard Worker a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable 7374*da0073e9SAndroid Build Coastguard Worker a_MPS.retain_grad() 7375*da0073e9SAndroid Build Coastguard Worker b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t() # modifies weight in-place 7376*da0073e9SAndroid Build Coastguard Worker b_MPS.retain_grad() 7377*da0073e9SAndroid Build Coastguard Worker out_MPS = (a_MPS.unsqueeze(0) + b_MPS) 7378*da0073e9SAndroid Build Coastguard Worker loss_MPS = out_MPS.sigmoid().prod() 7379*da0073e9SAndroid Build Coastguard Worker loss_MPS.backward() 7380*da0073e9SAndroid Build Coastguard Worker 7381*da0073e9SAndroid Build Coastguard Worker embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=emedding_weight) 7382*da0073e9SAndroid Build Coastguard Worker W_CPU = W_MPS.to('cpu') 7383*da0073e9SAndroid Build Coastguard Worker idx_CPU = torch.tensor(idx) 7384*da0073e9SAndroid Build Coastguard Worker a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable 7385*da0073e9SAndroid Build Coastguard Worker a_CPU.retain_grad() 7386*da0073e9SAndroid Build Coastguard Worker b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t() # modifies weight in-place 7387*da0073e9SAndroid Build Coastguard Worker b_CPU.retain_grad() 7388*da0073e9SAndroid Build Coastguard Worker out_CPU = (a_CPU.unsqueeze(0) + b_CPU) 7389*da0073e9SAndroid Build Coastguard Worker loss_CPU = out_CPU.sigmoid().prod() 7390*da0073e9SAndroid Build Coastguard Worker loss_CPU.backward() 7391*da0073e9SAndroid Build Coastguard Worker 7392*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_CPU.grad, b_MPS.grad) 7393*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_CPU.grad, a_MPS.grad) 7394*da0073e9SAndroid Build Coastguard Worker 7395*da0073e9SAndroid Build Coastguard Worker helper(3, 5, 7, [0, 1, 2]) 7396*da0073e9SAndroid Build Coastguard Worker helper(3, 6, 7, [0, 1, 2]) # verify if changes in shape would cause cached graph lookup problems 7397*da0073e9SAndroid Build Coastguard Worker helper(3, 5, 7, 2) # test scalar index 7398*da0073e9SAndroid Build Coastguard Worker 7399*da0073e9SAndroid Build Coastguard Worker # Test pytorch gather 7400*da0073e9SAndroid Build Coastguard Worker def test_gather(self): 7401*da0073e9SAndroid Build Coastguard Worker def helper(shape, dim, idx_shape, idx_dtype=torch.int64): 7402*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7403*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7404*da0073e9SAndroid Build Coastguard Worker 7405*da0073e9SAndroid Build Coastguard Worker # Indices should be taken from range of axis along which gathering is done 7406*da0073e9SAndroid Build Coastguard Worker idx_np = np.random.randint(0, shape[dim], idx_shape) 7407*da0073e9SAndroid Build Coastguard Worker 7408*da0073e9SAndroid Build Coastguard Worker cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype) 7409*da0073e9SAndroid Build Coastguard Worker idx = cpu_idx.detach().clone().to('mps') 7410*da0073e9SAndroid Build Coastguard Worker 7411*da0073e9SAndroid Build Coastguard Worker gather_result = torch.gather(x, dim=dim, index=idx) 7412*da0073e9SAndroid Build Coastguard Worker gather_result_cpu = torch.gather(cpu_x, dim=dim, index=cpu_idx) 7413*da0073e9SAndroid Build Coastguard Worker 7414*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(idx_shape, device='cpu', dtype=torch.float) 7415*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7416*da0073e9SAndroid Build Coastguard Worker gather_result.backward(gradient=grad) 7417*da0073e9SAndroid Build Coastguard Worker gather_result_cpu.backward(gradient=cpu_grad) 7418*da0073e9SAndroid Build Coastguard Worker 7419*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gather_result, gather_result_cpu) 7420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x.grad, x.grad) 7421*da0073e9SAndroid Build Coastguard Worker 7422*da0073e9SAndroid Build Coastguard Worker helper((6, 3, 3), 0, (3, 3, 3)) 7423*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 3, 3), 0, (10, 3, 3, 3)) 7424*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 0, (10, 8, 4, 5)) 7425*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 0, (10, 6, 3, 2)) 7426*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, (6, 8, 4, 5)) 7427*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, (6, 7, 2, 3)) 7428*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 1, (2, 5, 3, 4)) 7429*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 2, (1, 8, 10, 3)) 7430*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 3, (2, 5, 3, 12)) 7431*da0073e9SAndroid Build Coastguard Worker 7432*da0073e9SAndroid Build Coastguard Worker # Test pytorch gather 7433*da0073e9SAndroid Build Coastguard Worker def test_gather_scalar(self): 7434*da0073e9SAndroid Build Coastguard Worker idx_dtype = torch.int64 7435*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True) 7436*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7437*da0073e9SAndroid Build Coastguard Worker 7438*da0073e9SAndroid Build Coastguard Worker idx_np = [0] 7439*da0073e9SAndroid Build Coastguard Worker 7440*da0073e9SAndroid Build Coastguard Worker cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype) 7441*da0073e9SAndroid Build Coastguard Worker idx = cpu_idx.detach().clone().to('mps') 7442*da0073e9SAndroid Build Coastguard Worker 7443*da0073e9SAndroid Build Coastguard Worker gather_result = torch.gather(x, dim=0, index=idx) 7444*da0073e9SAndroid Build Coastguard Worker gather_result_cpu = torch.gather(cpu_x, dim=0, index=cpu_idx) 7445*da0073e9SAndroid Build Coastguard Worker 7446*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn([1], device='cpu', dtype=torch.float) 7447*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7448*da0073e9SAndroid Build Coastguard Worker gather_result.backward(gradient=grad) 7449*da0073e9SAndroid Build Coastguard Worker gather_result_cpu.backward(gradient=cpu_grad) 7450*da0073e9SAndroid Build Coastguard Worker 7451*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gather_result, gather_result_cpu) 7452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x.grad, x.grad) 7453*da0073e9SAndroid Build Coastguard Worker 7454*da0073e9SAndroid Build Coastguard Worker # Test pytorch scatter_add and scatter 7455*da0073e9SAndroid Build Coastguard Worker def test_scatter_add(self): 7456*da0073e9SAndroid Build Coastguard Worker def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True): 7457*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7458*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7459*da0073e9SAndroid Build Coastguard Worker 7460*da0073e9SAndroid Build Coastguard Worker cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True) 7461*da0073e9SAndroid Build Coastguard Worker src = cpu_src.detach().clone().to('mps').requires_grad_() 7462*da0073e9SAndroid Build Coastguard Worker 7463*da0073e9SAndroid Build Coastguard Worker # Indices should be taken from range of axis along which gathering is done 7464*da0073e9SAndroid Build Coastguard Worker idx_np = None 7465*da0073e9SAndroid Build Coastguard Worker if (do_add): 7466*da0073e9SAndroid Build Coastguard Worker idx_np = np.random.randint(0, shape[dim], idx_shape) 7467*da0073e9SAndroid Build Coastguard Worker else: 7468*da0073e9SAndroid Build Coastguard Worker idx_np = np.array([[0, 1, 2], 7469*da0073e9SAndroid Build Coastguard Worker [1, 2, 3], 7470*da0073e9SAndroid Build Coastguard Worker [2, 3, 4], 7471*da0073e9SAndroid Build Coastguard Worker [3, 4, 5], 7472*da0073e9SAndroid Build Coastguard Worker [4, 5, 6]]) 7473*da0073e9SAndroid Build Coastguard Worker 7474*da0073e9SAndroid Build Coastguard Worker cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype) 7475*da0073e9SAndroid Build Coastguard Worker idx = cpu_idx.detach().clone().to('mps') 7476*da0073e9SAndroid Build Coastguard Worker 7477*da0073e9SAndroid Build Coastguard Worker scatter_result = None 7478*da0073e9SAndroid Build Coastguard Worker scatter_result_cpu = None 7479*da0073e9SAndroid Build Coastguard Worker 7480*da0073e9SAndroid Build Coastguard Worker if (do_add): 7481*da0073e9SAndroid Build Coastguard Worker scatter_result = torch.scatter_add(x, dim=dim, index=idx, src=src) 7482*da0073e9SAndroid Build Coastguard Worker scatter_result_cpu = torch.scatter_add(cpu_x, dim=dim, index=cpu_idx, src=cpu_src) 7483*da0073e9SAndroid Build Coastguard Worker else: 7484*da0073e9SAndroid Build Coastguard Worker scatter_result = torch.scatter(x, dim=dim, index=idx, src=src) 7485*da0073e9SAndroid Build Coastguard Worker scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src) 7486*da0073e9SAndroid Build Coastguard Worker 7487*da0073e9SAndroid Build Coastguard Worker cpu_grad = None 7488*da0073e9SAndroid Build Coastguard Worker grad = None 7489*da0073e9SAndroid Build Coastguard Worker 7490*da0073e9SAndroid Build Coastguard Worker if (idx_shape == src_shape): 7491*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float) 7492*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7493*da0073e9SAndroid Build Coastguard Worker scatter_result.backward(gradient=grad) 7494*da0073e9SAndroid Build Coastguard Worker scatter_result_cpu.backward(gradient=cpu_grad) 7495*da0073e9SAndroid Build Coastguard Worker 7496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scatter_result, scatter_result_cpu) 7497*da0073e9SAndroid Build Coastguard Worker if (idx_shape == src_shape): 7498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x.grad, x.grad) 7499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_src.grad, src.grad) 7500*da0073e9SAndroid Build Coastguard Worker 7501*da0073e9SAndroid Build Coastguard Worker helper((2, 3), 0, (5, 3), (5, 3)) 7502*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5)) 7503*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5)) 7504*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2)) 7505*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2)) 7506*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5)) 7507*da0073e9SAndroid Build Coastguard Worker 7508*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5)) 7509*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2)) 7510*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3)) 7511*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3)) 7512*da0073e9SAndroid Build Coastguard Worker 7513*da0073e9SAndroid Build Coastguard Worker helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8)) 7514*da0073e9SAndroid Build Coastguard Worker helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6)) 7515*da0073e9SAndroid Build Coastguard Worker helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6)) 7516*da0073e9SAndroid Build Coastguard Worker 7517*da0073e9SAndroid Build Coastguard Worker # Test scatter src 7518*da0073e9SAndroid Build Coastguard Worker helper((8, 3), 0, (5, 3), (5, 3), do_add=False) 7519*da0073e9SAndroid Build Coastguard Worker helper((10, 3), 0, (5, 3), (5, 8), do_add=False) 7520*da0073e9SAndroid Build Coastguard Worker 7521*da0073e9SAndroid Build Coastguard Worker # Test pytorch scatter_add and scatter for scalar input 7522*da0073e9SAndroid Build Coastguard Worker def test_scatter_add_scalar(self): 7523*da0073e9SAndroid Build Coastguard Worker def helper(idx_dtype=torch.int64, do_add=True): 7524*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True) 7525*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7526*da0073e9SAndroid Build Coastguard Worker 7527*da0073e9SAndroid Build Coastguard Worker cpu_src = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True) 7528*da0073e9SAndroid Build Coastguard Worker src = cpu_src.detach().clone().to('mps').requires_grad_() 7529*da0073e9SAndroid Build Coastguard Worker 7530*da0073e9SAndroid Build Coastguard Worker # Indices should be taken from range of axis along which gathering is done 7531*da0073e9SAndroid Build Coastguard Worker idx_np = [0] 7532*da0073e9SAndroid Build Coastguard Worker 7533*da0073e9SAndroid Build Coastguard Worker cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype) 7534*da0073e9SAndroid Build Coastguard Worker idx = cpu_idx.detach().clone().to('mps') 7535*da0073e9SAndroid Build Coastguard Worker 7536*da0073e9SAndroid Build Coastguard Worker scatter_result = None 7537*da0073e9SAndroid Build Coastguard Worker scatter_result_cpu = None 7538*da0073e9SAndroid Build Coastguard Worker 7539*da0073e9SAndroid Build Coastguard Worker if (do_add): 7540*da0073e9SAndroid Build Coastguard Worker scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src) 7541*da0073e9SAndroid Build Coastguard Worker scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src) 7542*da0073e9SAndroid Build Coastguard Worker else: 7543*da0073e9SAndroid Build Coastguard Worker scatter_result = torch.scatter(x, dim=0, index=idx, src=src) 7544*da0073e9SAndroid Build Coastguard Worker scatter_result_cpu = torch.scatter(cpu_x, dim=0, index=cpu_idx, src=cpu_src) 7545*da0073e9SAndroid Build Coastguard Worker 7546*da0073e9SAndroid Build Coastguard Worker cpu_grad = None 7547*da0073e9SAndroid Build Coastguard Worker grad = None 7548*da0073e9SAndroid Build Coastguard Worker 7549*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.tensor(1.2, device='cpu', dtype=torch.float) 7550*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7551*da0073e9SAndroid Build Coastguard Worker scatter_result.backward(gradient=grad) 7552*da0073e9SAndroid Build Coastguard Worker scatter_result_cpu.backward(gradient=cpu_grad) 7553*da0073e9SAndroid Build Coastguard Worker 7554*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scatter_result, scatter_result_cpu) 7555*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x.grad, x.grad) 7556*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_src.grad, src.grad) 7557*da0073e9SAndroid Build Coastguard Worker 7558*da0073e9SAndroid Build Coastguard Worker helper() 7559*da0073e9SAndroid Build Coastguard Worker helper(do_add=False) 7560*da0073e9SAndroid Build Coastguard Worker 7561*da0073e9SAndroid Build Coastguard Worker # Test pytorch scatter_reduce 7562*da0073e9SAndroid Build Coastguard Worker def test_scatter_reduce(self): 7563*da0073e9SAndroid Build Coastguard Worker def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"): 7564*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7565*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7566*da0073e9SAndroid Build Coastguard Worker 7567*da0073e9SAndroid Build Coastguard Worker cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True) 7568*da0073e9SAndroid Build Coastguard Worker src = cpu_src.detach().clone().to('mps').requires_grad_() 7569*da0073e9SAndroid Build Coastguard Worker 7570*da0073e9SAndroid Build Coastguard Worker # Indices should be taken from range of axis along which gathering is done 7571*da0073e9SAndroid Build Coastguard Worker idx_np = np.random.randint(0, shape[dim], idx_shape) 7572*da0073e9SAndroid Build Coastguard Worker 7573*da0073e9SAndroid Build Coastguard Worker cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype) 7574*da0073e9SAndroid Build Coastguard Worker idx = cpu_idx.detach().clone().to('mps') 7575*da0073e9SAndroid Build Coastguard Worker 7576*da0073e9SAndroid Build Coastguard Worker scatter_result = torch.scatter(x, dim=dim, index=idx, src=src, reduce=reduce_str) 7577*da0073e9SAndroid Build Coastguard Worker scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src, reduce=reduce_str) 7578*da0073e9SAndroid Build Coastguard Worker 7579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scatter_result, scatter_result_cpu) 7580*da0073e9SAndroid Build Coastguard Worker 7581*da0073e9SAndroid Build Coastguard Worker # for reduce in ["sum", "prod", "amax", "amin"]: 7582*da0073e9SAndroid Build Coastguard Worker for reduce_type in ["add", "multiply"]: 7583*da0073e9SAndroid Build Coastguard Worker helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type) 7584*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type) 7585*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type) 7586*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type) 7587*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type) 7588*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type) 7589*da0073e9SAndroid Build Coastguard Worker 7590*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type) 7591*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type) 7592*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type) 7593*da0073e9SAndroid Build Coastguard Worker helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type) 7594*da0073e9SAndroid Build Coastguard Worker 7595*da0073e9SAndroid Build Coastguard Worker helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type) 7596*da0073e9SAndroid Build Coastguard Worker helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type) 7597*da0073e9SAndroid Build Coastguard Worker helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type) 7598*da0073e9SAndroid Build Coastguard Worker 7599*da0073e9SAndroid Build Coastguard Worker def test_is_nonzero(self): 7600*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps'))) 7601*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_nonzero(torch.tensor([1.5]).to('mps'))) 7602*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_nonzero(torch.tensor([False]).to('mps'))) 7603*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_nonzero(torch.tensor([3]).to('mps'))) 7604*da0073e9SAndroid Build Coastguard Worker 7605*da0073e9SAndroid Build Coastguard Worker # Test triu 7606*da0073e9SAndroid Build Coastguard Worker def test_triu(self): 7607*da0073e9SAndroid Build Coastguard Worker def helper(shape, diag=0): 7608*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7609*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7610*da0073e9SAndroid Build Coastguard Worker 7611*da0073e9SAndroid Build Coastguard Worker triu_result = torch.triu(x, diag) 7612*da0073e9SAndroid Build Coastguard Worker triu_result_cpu = torch.triu(cpu_x, diag) 7613*da0073e9SAndroid Build Coastguard Worker 7614*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(triu_result_cpu.shape) 7615*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7616*da0073e9SAndroid Build Coastguard Worker 7617*da0073e9SAndroid Build Coastguard Worker triu_result.backward(gradient=grad) 7618*da0073e9SAndroid Build Coastguard Worker triu_result_cpu.backward(gradient=cpu_grad) 7619*da0073e9SAndroid Build Coastguard Worker 7620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(triu_result, triu_result_cpu) 7621*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 7622*da0073e9SAndroid Build Coastguard Worker 7623*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 7624*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=1) 7625*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=2) 7626*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=3) 7627*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=-1) 7628*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=-2) 7629*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=-3) 7630*da0073e9SAndroid Build Coastguard Worker 7631*da0073e9SAndroid Build Coastguard Worker # Test inverse 7632*da0073e9SAndroid Build Coastguard Worker def test_inverse(self): 7633*da0073e9SAndroid Build Coastguard Worker def helper(n): 7634*da0073e9SAndroid Build Coastguard Worker cpu_input = torch.randn(n, n, device='cpu') 7635*da0073e9SAndroid Build Coastguard Worker mps_input = cpu_input.to('mps') 7636*da0073e9SAndroid Build Coastguard Worker 7637*da0073e9SAndroid Build Coastguard Worker cpu_result = torch.linalg.inv(cpu_input) 7638*da0073e9SAndroid Build Coastguard Worker mps_result = torch.linalg.inv(mps_input) 7639*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_result, mps_result) 7640*da0073e9SAndroid Build Coastguard Worker 7641*da0073e9SAndroid Build Coastguard Worker helper(2) 7642*da0073e9SAndroid Build Coastguard Worker helper(6) 7643*da0073e9SAndroid Build Coastguard Worker helper(3) 7644*da0073e9SAndroid Build Coastguard Worker helper(8) 7645*da0073e9SAndroid Build Coastguard Worker 7646*da0073e9SAndroid Build Coastguard Worker # Test tril 7647*da0073e9SAndroid Build Coastguard Worker def test_tril(self): 7648*da0073e9SAndroid Build Coastguard Worker def helper(shape, diag=0): 7649*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7650*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7651*da0073e9SAndroid Build Coastguard Worker 7652*da0073e9SAndroid Build Coastguard Worker tril_result = torch.tril(x, diag) 7653*da0073e9SAndroid Build Coastguard Worker tril_result_cpu = torch.tril(cpu_x, diag) 7654*da0073e9SAndroid Build Coastguard Worker 7655*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(tril_result_cpu.shape) 7656*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7657*da0073e9SAndroid Build Coastguard Worker 7658*da0073e9SAndroid Build Coastguard Worker tril_result.backward(gradient=grad) 7659*da0073e9SAndroid Build Coastguard Worker tril_result_cpu.backward(gradient=cpu_grad) 7660*da0073e9SAndroid Build Coastguard Worker 7661*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tril_result, tril_result_cpu) 7662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 7663*da0073e9SAndroid Build Coastguard Worker 7664*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5)) 7665*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=1) 7666*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=2) 7667*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=3) 7668*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=-1) 7669*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=-2) 7670*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), diag=-3) 7671*da0073e9SAndroid Build Coastguard Worker 7672*da0073e9SAndroid Build Coastguard Worker # test eye 7673*da0073e9SAndroid Build Coastguard Worker def test_eye(self): 7674*da0073e9SAndroid Build Coastguard Worker def helper(n, m, dtype): 7675*da0073e9SAndroid Build Coastguard Worker cpu_result = None 7676*da0073e9SAndroid Build Coastguard Worker result = None 7677*da0073e9SAndroid Build Coastguard Worker 7678*da0073e9SAndroid Build Coastguard Worker if (n == m): 7679*da0073e9SAndroid Build Coastguard Worker cpu_result = torch.eye(n, dtype=dtype, device='cpu') 7680*da0073e9SAndroid Build Coastguard Worker result = torch.eye(n, dtype=dtype, device='mps') 7681*da0073e9SAndroid Build Coastguard Worker else: 7682*da0073e9SAndroid Build Coastguard Worker cpu_result = torch.eye(n, m, device='cpu') 7683*da0073e9SAndroid Build Coastguard Worker result = torch.eye(n, m, device='mps') 7684*da0073e9SAndroid Build Coastguard Worker 7685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, cpu_result) 7686*da0073e9SAndroid Build Coastguard Worker 7687*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bool, torch.float16, torch.float32, torch.uint8, torch.int16, torch.int32, torch.int64]: 7688*da0073e9SAndroid Build Coastguard Worker helper(2, 2, dtype) 7689*da0073e9SAndroid Build Coastguard Worker helper(2, 3, dtype) 7690*da0073e9SAndroid Build Coastguard Worker helper(0, 2, dtype) 7691*da0073e9SAndroid Build Coastguard Worker helper(0, 0, dtype) 7692*da0073e9SAndroid Build Coastguard Worker helper(3, 8, dtype) 7693*da0073e9SAndroid Build Coastguard Worker helper(8, 3, dtype) 7694*da0073e9SAndroid Build Coastguard Worker 7695*da0073e9SAndroid Build Coastguard Worker # Test diag 7696*da0073e9SAndroid Build Coastguard Worker def test_diag(self): 7697*da0073e9SAndroid Build Coastguard Worker def helper(shape, diag=0): 7698*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7699*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7700*da0073e9SAndroid Build Coastguard Worker 7701*da0073e9SAndroid Build Coastguard Worker diag_result = torch.diag(x, diag) 7702*da0073e9SAndroid Build Coastguard Worker diag_result_cpu = torch.diag(cpu_x, diag) 7703*da0073e9SAndroid Build Coastguard Worker 7704*da0073e9SAndroid Build Coastguard Worker # cpu_grad = torch.randn(diag_result_cpu.shape) 7705*da0073e9SAndroid Build Coastguard Worker # grad = cpu_grad.to('mps') 7706*da0073e9SAndroid Build Coastguard Worker 7707*da0073e9SAndroid Build Coastguard Worker # diag_result.backward(gradient=grad) 7708*da0073e9SAndroid Build Coastguard Worker # diag_result_cpu.backward(gradient=cpu_grad) 7709*da0073e9SAndroid Build Coastguard Worker 7710*da0073e9SAndroid Build Coastguard Worker self.assertEqual(diag_result, diag_result_cpu) 7711*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(x.grad, cpu_x.grad) 7712*da0073e9SAndroid Build Coastguard Worker 7713*da0073e9SAndroid Build Coastguard Worker for shape in [(5, 5), (5, 6), (6, 5), (5,), (6,)]: 7714*da0073e9SAndroid Build Coastguard Worker for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]: 7715*da0073e9SAndroid Build Coastguard Worker helper(shape, diag=diag) 7716*da0073e9SAndroid Build Coastguard Worker 7717*da0073e9SAndroid Build Coastguard Worker # Test linspace 7718*da0073e9SAndroid Build Coastguard Worker def test_linspace(self): 7719*da0073e9SAndroid Build Coastguard Worker def helper(start, end, steps, dtype=torch.float32): 7720*da0073e9SAndroid Build Coastguard Worker cpu_result = torch.tensor(np.linspace(start, end, steps), dtype=dtype) 7721*da0073e9SAndroid Build Coastguard Worker result = torch.linspace(start, end, steps, dtype=dtype, device='mps') 7722*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_result, result) 7723*da0073e9SAndroid Build Coastguard Worker 7724*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float32, torch.int32, torch.uint8, torch.int64]: 7725*da0073e9SAndroid Build Coastguard Worker helper(2, 5, 10, dtype) 7726*da0073e9SAndroid Build Coastguard Worker helper(2, 2, 10, dtype) 7727*da0073e9SAndroid Build Coastguard Worker helper(5, 2, 10, dtype) 7728*da0073e9SAndroid Build Coastguard Worker helper(2, 2, 0, dtype) 7729*da0073e9SAndroid Build Coastguard Worker 7730*da0073e9SAndroid Build Coastguard Worker # Test argange 7731*da0073e9SAndroid Build Coastguard Worker def test_arange(self): 7732*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.arange(10), torch.arange(10, device='mps')) 7733*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.arange(7, 1, -1), torch.arange(7, 1, -1, device='mps')) 7734*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps')) 7735*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps')) 7736*da0073e9SAndroid Build Coastguard Worker 7737*da0073e9SAndroid Build Coastguard Worker def test_arange_empty(self): 7738*da0073e9SAndroid Build Coastguard Worker out_mps = torch.tensor([], device="mps") 7739*da0073e9SAndroid Build Coastguard Worker out_cpu = torch.tensor([], device="cpu") 7740*da0073e9SAndroid Build Coastguard Worker 7741*da0073e9SAndroid Build Coastguard Worker y_mps = torch.arange(0, 0, 1, out=out_mps) 7742*da0073e9SAndroid Build Coastguard Worker y_cpu = torch.arange(0, 0, 1, out=out_cpu) 7743*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_mps, y_cpu) 7744*da0073e9SAndroid Build Coastguard Worker 7745*da0073e9SAndroid Build Coastguard Worker # Test rgange 7746*da0073e9SAndroid Build Coastguard Worker def test_range(self): 7747*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps')) 7748*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps')) 7749*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps')) 7750*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps')) 7751*da0073e9SAndroid Build Coastguard Worker 7752*da0073e9SAndroid Build Coastguard Worker # Test softmax 7753*da0073e9SAndroid Build Coastguard Worker def test_softmax(self): 7754*da0073e9SAndroid Build Coastguard Worker def helper(shape, dim, channels_last=False): 7755*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7756*da0073e9SAndroid Build Coastguard Worker if (channels_last): 7757*da0073e9SAndroid Build Coastguard Worker cpu_x = cpu_x.to(memory_format=torch.channels_last) 7758*da0073e9SAndroid Build Coastguard Worker cpu_x.retain_grad() 7759*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7760*da0073e9SAndroid Build Coastguard Worker 7761*da0073e9SAndroid Build Coastguard Worker softmax_result = torch.nn.functional.softmax(x, dim=dim) 7762*da0073e9SAndroid Build Coastguard Worker softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim) 7763*da0073e9SAndroid Build Coastguard Worker 7764*da0073e9SAndroid Build Coastguard Worker # Currently NOT testing backward for channels last backward 7765*da0073e9SAndroid Build Coastguard Worker cpu_grad = None 7766*da0073e9SAndroid Build Coastguard Worker grad = None 7767*da0073e9SAndroid Build Coastguard Worker 7768*da0073e9SAndroid Build Coastguard Worker if (not channels_last): 7769*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float) 7770*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7771*da0073e9SAndroid Build Coastguard Worker 7772*da0073e9SAndroid Build Coastguard Worker softmax_result.backward(gradient=grad) 7773*da0073e9SAndroid Build Coastguard Worker softmax_result_cpu.backward(gradient=cpu_grad) 7774*da0073e9SAndroid Build Coastguard Worker 7775*da0073e9SAndroid Build Coastguard Worker self.assertEqual(softmax_result, softmax_result_cpu) 7776*da0073e9SAndroid Build Coastguard Worker if (not channels_last): 7777*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 7778*da0073e9SAndroid Build Coastguard Worker 7779*da0073e9SAndroid Build Coastguard Worker def helper2(dim): 7780*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(1.23, device='cpu', dtype=torch.float, requires_grad=True) 7781*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7782*da0073e9SAndroid Build Coastguard Worker 7783*da0073e9SAndroid Build Coastguard Worker softmax_result = torch.nn.functional.softmax(x, dim=dim) 7784*da0073e9SAndroid Build Coastguard Worker softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim) 7785*da0073e9SAndroid Build Coastguard Worker 7786*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.tensor(2.34, device='cpu', dtype=torch.float) 7787*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7788*da0073e9SAndroid Build Coastguard Worker 7789*da0073e9SAndroid Build Coastguard Worker softmax_result.backward(gradient=grad) 7790*da0073e9SAndroid Build Coastguard Worker softmax_result_cpu.backward(gradient=cpu_grad) 7791*da0073e9SAndroid Build Coastguard Worker 7792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(softmax_result, softmax_result_cpu) 7793*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 7794*da0073e9SAndroid Build Coastguard Worker 7795*da0073e9SAndroid Build Coastguard Worker helper2(0) 7796*da0073e9SAndroid Build Coastguard Worker 7797*da0073e9SAndroid Build Coastguard Worker for channels_last in [False]: 7798*da0073e9SAndroid Build Coastguard Worker for shape in [(2, 4, 8, 5), (3, 4, 6, 7, 2)]: 7799*da0073e9SAndroid Build Coastguard Worker if (len(shape) != 4 and channels_last): 7800*da0073e9SAndroid Build Coastguard Worker continue 7801*da0073e9SAndroid Build Coastguard Worker for dim in [0, 1, 2, 3, -1, -2, -3]: 7802*da0073e9SAndroid Build Coastguard Worker helper(shape, dim, channels_last) 7803*da0073e9SAndroid Build Coastguard Worker 7804*da0073e9SAndroid Build Coastguard Worker def test_nan_to_num(self): 7805*da0073e9SAndroid Build Coastguard Worker inputCPU = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) 7806*da0073e9SAndroid Build Coastguard Worker inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 7807*da0073e9SAndroid Build Coastguard Worker outputCPU = torch.nan_to_num(inputCPU, nan=2.0, posinf=1.0, neginf=-1.0) 7808*da0073e9SAndroid Build Coastguard Worker outputMPS = torch.nan_to_num(inputMPS, nan=2.0, posinf=1.0, neginf=-1.0) 7809*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputMPS, outputCPU) 7810*da0073e9SAndroid Build Coastguard Worker 7811*da0073e9SAndroid Build Coastguard Worker # Test where 7812*da0073e9SAndroid Build Coastguard Worker def test_where(self): 7813*da0073e9SAndroid Build Coastguard Worker def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float): 7814*da0073e9SAndroid Build Coastguard Worker 7815*da0073e9SAndroid Build Coastguard Worker cpu_cond = torch.randint(2, shape, device='cpu', dtype=cond_dtype, requires_grad=False) 7816*da0073e9SAndroid Build Coastguard Worker cond = cpu_cond.detach().clone().to('mps') 7817*da0073e9SAndroid Build Coastguard Worker 7818*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(x_shape, device='cpu', dtype=x_dtype, requires_grad=True) 7819*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps').requires_grad_() 7820*da0073e9SAndroid Build Coastguard Worker 7821*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(y_shape, device='cpu', dtype=x_dtype, requires_grad=True) 7822*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps').requires_grad_() 7823*da0073e9SAndroid Build Coastguard Worker 7824*da0073e9SAndroid Build Coastguard Worker cpu_out = torch.where(cpu_cond, cpu_x, cpu_y) 7825*da0073e9SAndroid Build Coastguard Worker out = torch.where(cond, x, y) 7826*da0073e9SAndroid Build Coastguard Worker 7827*da0073e9SAndroid Build Coastguard Worker cpu_grad = torch.randn(cpu_out.shape) 7828*da0073e9SAndroid Build Coastguard Worker grad = cpu_grad.to('mps') 7829*da0073e9SAndroid Build Coastguard Worker 7830*da0073e9SAndroid Build Coastguard Worker cpu_out.backward(gradient=cpu_grad) 7831*da0073e9SAndroid Build Coastguard Worker out.backward(gradient=grad) 7832*da0073e9SAndroid Build Coastguard Worker 7833*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, cpu_out) 7834*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, cpu_x.grad) 7835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, cpu_y.grad) 7836*da0073e9SAndroid Build Coastguard Worker 7837*da0073e9SAndroid Build Coastguard Worker for shape in ([(0, 3), [], (2, 3), (9,)]): 7838*da0073e9SAndroid Build Coastguard Worker helper(shape, shape, shape) 7839*da0073e9SAndroid Build Coastguard Worker 7840*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 1), (2, 3, 4), (2, 1, 4)) 7841*da0073e9SAndroid Build Coastguard Worker helper((2, 1, 1), (2, 3, 4), (1, 3, 4)) 7842*da0073e9SAndroid Build Coastguard Worker helper((1, 1, 1), (1, 1, 4), (2, 3, 1)) 7843*da0073e9SAndroid Build Coastguard Worker helper([], (1, 1, 4), (2, 3, 1)) 7844*da0073e9SAndroid Build Coastguard Worker helper([], (2, 3, 4), []) 7845*da0073e9SAndroid Build Coastguard Worker helper((5, 2, 3), (2, 3), (2, 3)) 7846*da0073e9SAndroid Build Coastguard Worker helper((2, 3), (5, 2, 3), (2, 3)) 7847*da0073e9SAndroid Build Coastguard Worker helper((2, 3), (2, 3), (5, 2, 3)) 7848*da0073e9SAndroid Build Coastguard Worker helper((2, 3), (5, 2, 3), (6, 5, 2, 3)) 7849*da0073e9SAndroid Build Coastguard Worker # Test that output is correctly resizes 7850*da0073e9SAndroid Build Coastguard Worker # TODO: Remove me when out OpInfo testing is enabled on MPS 7851*da0073e9SAndroid Build Coastguard Worker output = torch.tensor(0.0, device="mps") 7852*da0073e9SAndroid Build Coastguard Worker cond = torch.randint(2, (3, 3), dtype=torch.bool, device="mps") 7853*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(3, 3, device="mps") 7854*da0073e9SAndroid Build Coastguard Worker other = torch.rand(3, 3, device="mps") 7855*da0073e9SAndroid Build Coastguard Worker out = torch.where(cond, inp, other, out=output) 7856*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(out), id(output)) 7857*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.shape, (3, 3)) 7858*da0073e9SAndroid Build Coastguard Worker 7859*da0073e9SAndroid Build Coastguard Worker # Test normal 7860*da0073e9SAndroid Build Coastguard Worker def test_normal(self): 7861*da0073e9SAndroid Build Coastguard Worker def helper(shape, mean=0.0, std=1.0): 7862*da0073e9SAndroid Build Coastguard Worker mps_out = torch.normal(mean, std, shape, device='mps') 7863*da0073e9SAndroid Build Coastguard Worker 7864*da0073e9SAndroid Build Coastguard Worker mean_array = np.ones(shape) 7865*da0073e9SAndroid Build Coastguard Worker mean_array *= mean 7866*da0073e9SAndroid Build Coastguard Worker cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False) 7867*da0073e9SAndroid Build Coastguard Worker mean_tensor = cpu_mean_tensor.detach().clone().to('mps') 7868*da0073e9SAndroid Build Coastguard Worker 7869*da0073e9SAndroid Build Coastguard Worker std_array = np.ones(shape) 7870*da0073e9SAndroid Build Coastguard Worker std_array *= std 7871*da0073e9SAndroid Build Coastguard Worker cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False) 7872*da0073e9SAndroid Build Coastguard Worker std_tensor = cpu_std_tensor.detach().clone().to('mps') 7873*da0073e9SAndroid Build Coastguard Worker 7874*da0073e9SAndroid Build Coastguard Worker # test out 7875*da0073e9SAndroid Build Coastguard Worker mps_out = torch.zeros(shape, device='mps') 7876*da0073e9SAndroid Build Coastguard Worker torch.normal(mean_tensor, std, out=mps_out) 7877*da0073e9SAndroid Build Coastguard Worker 7878*da0073e9SAndroid Build Coastguard Worker mps_out = torch.zeros(shape, device='mps') 7879*da0073e9SAndroid Build Coastguard Worker torch.normal(mean, std_tensor, out=mps_out) 7880*da0073e9SAndroid Build Coastguard Worker 7881*da0073e9SAndroid Build Coastguard Worker mps_out = torch.zeros(shape, device='mps') 7882*da0073e9SAndroid Build Coastguard Worker torch.normal(mean_tensor, std_tensor, out=mps_out) 7883*da0073e9SAndroid Build Coastguard Worker 7884*da0073e9SAndroid Build Coastguard Worker # test without out 7885*da0073e9SAndroid Build Coastguard Worker mps_out = torch.normal(mean_tensor, std) 7886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_out.size(), mean_tensor.size()) 7887*da0073e9SAndroid Build Coastguard Worker 7888*da0073e9SAndroid Build Coastguard Worker mps_out = torch.normal(mean, std_tensor) 7889*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_out.size(), std_tensor.size()) 7890*da0073e9SAndroid Build Coastguard Worker 7891*da0073e9SAndroid Build Coastguard Worker inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size()) 7892*da0073e9SAndroid Build Coastguard Worker mps_out = torch.normal(mean_tensor, std_tensor) 7893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_out.size(), inferred_shape) 7894*da0073e9SAndroid Build Coastguard Worker 7895*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4, 5, 6)) 7896*da0073e9SAndroid Build Coastguard Worker helper((100, 100), 2.5, 1.2) 7897*da0073e9SAndroid Build Coastguard Worker 7898*da0073e9SAndroid Build Coastguard Worker def test_bernoulli(self): 7899*da0073e9SAndroid Build Coastguard Worker shape = (10, 10) 7900*da0073e9SAndroid Build Coastguard Worker all_ones = torch.ones(shape, device='mps') 7901*da0073e9SAndroid Build Coastguard Worker all_zeros = torch.zeros(shape, device='mps') 7902*da0073e9SAndroid Build Coastguard Worker 7903*da0073e9SAndroid Build Coastguard Worker prob_tensor = all_ones * 0.5 7904*da0073e9SAndroid Build Coastguard Worker # probability of drawing "1" is 0.5 7905*da0073e9SAndroid Build Coastguard Worker mps_out = torch.bernoulli(prob_tensor) 7906*da0073e9SAndroid Build Coastguard Worker # We can't check reliably the mean and std. 7907*da0073e9SAndroid Build Coastguard Worker # Just make sure we don't return constant values 7908*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(mps_out.to('cpu').mean(), 0.) 7909*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.) 7910*da0073e9SAndroid Build Coastguard Worker 7911*da0073e9SAndroid Build Coastguard Worker # probability of drawing "1" is 0 7912*da0073e9SAndroid Build Coastguard Worker mps_out = torch.bernoulli(all_zeros) 7913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_out, all_zeros) 7914*da0073e9SAndroid Build Coastguard Worker 7915*da0073e9SAndroid Build Coastguard Worker # probability of drawing "1" is 1 7916*da0073e9SAndroid Build Coastguard Worker mps_out = torch.bernoulli(all_ones) 7917*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_out, all_ones) 7918*da0073e9SAndroid Build Coastguard Worker 7919*da0073e9SAndroid Build Coastguard Worker # Check it works for different dtypes 7920*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float16, torch.int8, torch.int16, torch.int32, torch.int64]: 7921*da0073e9SAndroid Build Coastguard Worker mps_out = torch.zeros(shape, device='mps', dtype=dtype).bernoulli(0.5) 7922*da0073e9SAndroid Build Coastguard Worker # Check that output is not all zeros or ones 7923*da0073e9SAndroid Build Coastguard Worker if product_version > 13.0: 7924*da0073e9SAndroid Build Coastguard Worker uniq = mps_out.unique() 7925*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype)) 7926*da0073e9SAndroid Build Coastguard Worker else: 7927*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_out.min().item(), 0.) 7928*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_out.max().item(), 1.) 7929*da0073e9SAndroid Build Coastguard Worker 7930*da0073e9SAndroid Build Coastguard Worker def test_mps_generator(self): 7931*da0073e9SAndroid Build Coastguard Worker # explicit manual seeding by creating an MPS Generator 7932*da0073e9SAndroid Build Coastguard Worker g_mps = torch.Generator(device='mps') 7933*da0073e9SAndroid Build Coastguard Worker g_mps.manual_seed(999) 7934*da0073e9SAndroid Build Coastguard Worker mps_x = torch.randn(5, device='mps', generator=g_mps) 7935*da0073e9SAndroid Build Coastguard Worker g_mps.manual_seed(999) 7936*da0073e9SAndroid Build Coastguard Worker # generate random numbers with offset `0` 7937*da0073e9SAndroid Build Coastguard Worker mps_y = torch.randn(5, device='mps', generator=g_mps) 7938*da0073e9SAndroid Build Coastguard Worker # seed values were the same, so the random tensor contents should match 7939*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_x, mps_y) 7940*da0073e9SAndroid Build Coastguard Worker # save generator's state (offset = 1) to restore it later 7941*da0073e9SAndroid Build Coastguard Worker g_state = g_mps.get_state() 7942*da0073e9SAndroid Build Coastguard Worker 7943*da0073e9SAndroid Build Coastguard Worker # generate random numbers with offset `1` 7944*da0073e9SAndroid Build Coastguard Worker mps_x = torch.randn(5, device='mps', generator=g_mps) 7945*da0073e9SAndroid Build Coastguard Worker # in this case, the random results must differ from the last generated random results 7946*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(mps_x, mps_y) 7947*da0073e9SAndroid Build Coastguard Worker 7948*da0073e9SAndroid Build Coastguard Worker # mps_x was produced by g_state, we use it as our reference mps_y. 7949*da0073e9SAndroid Build Coastguard Worker mps_y = mps_x 7950*da0073e9SAndroid Build Coastguard Worker 7951*da0073e9SAndroid Build Coastguard Worker # restore the previously saved state, and the results should match again 7952*da0073e9SAndroid Build Coastguard Worker g_mps.set_state(g_state) 7953*da0073e9SAndroid Build Coastguard Worker mps_x = torch.randn(5, device='mps', generator=g_mps) 7954*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_x, mps_y) 7955*da0073e9SAndroid Build Coastguard Worker 7956*da0073e9SAndroid Build Coastguard Worker @serialTest() 7957*da0073e9SAndroid Build Coastguard Worker def test_default_mps_generator(self): 7958*da0073e9SAndroid Build Coastguard Worker # manual seeding on the "default" MPS generator using 7959*da0073e9SAndroid Build Coastguard Worker # the global torch.manual_seed() 7960*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(230) 7961*da0073e9SAndroid Build Coastguard Worker mps_x = torch.randn(5, device='mps') 7962*da0073e9SAndroid Build Coastguard Worker # manual seeding using torch.mps.manual_seed() 7963*da0073e9SAndroid Build Coastguard Worker # which should set the "default" MPS generator 7964*da0073e9SAndroid Build Coastguard Worker # like the global torch.manual_seed() 7965*da0073e9SAndroid Build Coastguard Worker torch.mps.manual_seed(230) 7966*da0073e9SAndroid Build Coastguard Worker # generate random numbers with offset `0` 7967*da0073e9SAndroid Build Coastguard Worker mps_y = torch.randn(5, device='mps') 7968*da0073e9SAndroid Build Coastguard Worker # seed values were the same, so the random tensor contents should match 7969*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_x, mps_y) 7970*da0073e9SAndroid Build Coastguard Worker 7971*da0073e9SAndroid Build Coastguard Worker # save the default generator's state (offset = 1) to restore it later 7972*da0073e9SAndroid Build Coastguard Worker g_state = torch.mps.get_rng_state() 7973*da0073e9SAndroid Build Coastguard Worker 7974*da0073e9SAndroid Build Coastguard Worker # generate random numbers with offset `1` 7975*da0073e9SAndroid Build Coastguard Worker mps_x = torch.randn(5, device='mps') 7976*da0073e9SAndroid Build Coastguard Worker # in this case, the random results must differ from the last generated random results 7977*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(mps_x, mps_y) 7978*da0073e9SAndroid Build Coastguard Worker # since we called randn twice after seeding, the offset should be 2 7979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mps._get_default_mps_generator().get_offset(), 2) 7980*da0073e9SAndroid Build Coastguard Worker 7981*da0073e9SAndroid Build Coastguard Worker # mps_x was produced by g_state, we use it as our reference mps_y. 7982*da0073e9SAndroid Build Coastguard Worker mps_y = mps_x 7983*da0073e9SAndroid Build Coastguard Worker 7984*da0073e9SAndroid Build Coastguard Worker # restore the previously saved state to the "default" MPS generator, and the results should match again 7985*da0073e9SAndroid Build Coastguard Worker torch.mps.set_rng_state(g_state) 7986*da0073e9SAndroid Build Coastguard Worker mps_x = torch.randn(5, device='mps') 7987*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_x, mps_y) 7988*da0073e9SAndroid Build Coastguard Worker 7989*da0073e9SAndroid Build Coastguard Worker def test_device_synchronize(self): 7990*da0073e9SAndroid Build Coastguard Worker # just running some ops each followed by a synchronize to wait for 7991*da0073e9SAndroid Build Coastguard Worker # MPS stream to finish running each of them 7992*da0073e9SAndroid Build Coastguard Worker net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ 7993*da0073e9SAndroid Build Coastguard Worker .to(device='mps', dtype=torch.float) 7994*da0073e9SAndroid Build Coastguard Worker 7995*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) 7996*da0073e9SAndroid Build Coastguard Worker torch.mps.synchronize() 7997*da0073e9SAndroid Build Coastguard Worker x = net1(x) 7998*da0073e9SAndroid Build Coastguard Worker torch.mps.synchronize() 7999*da0073e9SAndroid Build Coastguard Worker x.backward(torch.randn_like(x)) 8000*da0073e9SAndroid Build Coastguard Worker torch.mps.synchronize() 8001*da0073e9SAndroid Build Coastguard Worker 8002*da0073e9SAndroid Build Coastguard Worker @serialTest() 8003*da0073e9SAndroid Build Coastguard Worker def test_mps_allocator_module(self): 8004*da0073e9SAndroid Build Coastguard Worker # first garbage collect and empty the cached blocks 8005*da0073e9SAndroid Build Coastguard Worker gc.collect() 8006*da0073e9SAndroid Build Coastguard Worker torch.mps.empty_cache() 8007*da0073e9SAndroid Build Coastguard Worker # measure memory allocations from MPSAllocator 8008*da0073e9SAndroid Build Coastguard Worker current_alloc_before = torch.mps.current_allocated_memory() 8009*da0073e9SAndroid Build Coastguard Worker # after garbage collection and emptying the cache the 8010*da0073e9SAndroid Build Coastguard Worker # current_allocated_memory must be zero 8011*da0073e9SAndroid Build Coastguard Worker self.assertEqual(current_alloc_before, 0) 8012*da0073e9SAndroid Build Coastguard Worker # measure total memory allocations from Metal driver 8013*da0073e9SAndroid Build Coastguard Worker driver_alloc_before = torch.mps.driver_allocated_memory() 8014*da0073e9SAndroid Build Coastguard Worker # allocate a new 8 MB tensor to force allocation of a new Metal Heap 8015*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1024 * 1024 * 8, device="mps") 8016*da0073e9SAndroid Build Coastguard Worker # get memory allocations after allocating tensor x 8017*da0073e9SAndroid Build Coastguard Worker current_alloc_after = torch.mps.current_allocated_memory() 8018*da0073e9SAndroid Build Coastguard Worker driver_alloc_after = torch.mps.driver_allocated_memory() 8019*da0073e9SAndroid Build Coastguard Worker # current and driver memory allocations must have 8020*da0073e9SAndroid Build Coastguard Worker # grown at this point 8021*da0073e9SAndroid Build Coastguard Worker self.assertGreater(current_alloc_after, current_alloc_before) 8022*da0073e9SAndroid Build Coastguard Worker self.assertGreater(driver_alloc_after, driver_alloc_before) 8023*da0073e9SAndroid Build Coastguard Worker 8024*da0073e9SAndroid Build Coastguard Worker def test_mps_allocator_stats(self): 8025*da0073e9SAndroid Build Coastguard Worker max_memory = torch.mps.recommended_max_memory() 8026*da0073e9SAndroid Build Coastguard Worker print(f"Recommended Max Memory : {max_memory/ 1024 ** 3} GB") 8027*da0073e9SAndroid Build Coastguard Worker self.assertGreater(max_memory, 0) 8028*da0073e9SAndroid Build Coastguard Worker 8029*da0073e9SAndroid Build Coastguard Worker # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool, 8030*da0073e9SAndroid Build Coastguard Worker # press record, then run this python test, and press stop. Next expand 8031*da0073e9SAndroid Build Coastguard Worker # the os_signposts->PyTorchMPS and check if events or intervals are logged 8032*da0073e9SAndroid Build Coastguard Worker # like this example: 8033*da0073e9SAndroid Build Coastguard Worker # "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)" 8034*da0073e9SAndroid Build Coastguard Worker def test_mps_profiler_module(self): 8035*da0073e9SAndroid Build Coastguard Worker with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p: 8036*da0073e9SAndroid Build Coastguard Worker # just running some ops to capture the OS Signposts traces for profiling 8037*da0073e9SAndroid Build Coastguard Worker net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ 8038*da0073e9SAndroid Build Coastguard Worker .to(device='mps', dtype=torch.float) 8039*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) 8040*da0073e9SAndroid Build Coastguard Worker x = net1(x) 8041*da0073e9SAndroid Build Coastguard Worker 8042*da0073e9SAndroid Build Coastguard Worker torch.mps.profiler.start(mode="interval", wait_until_completed=True) 8043*da0073e9SAndroid Build Coastguard Worker # just running some ops to capture the OS Signposts traces for profiling 8044*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) 8045*da0073e9SAndroid Build Coastguard Worker x = net1(x) 8046*da0073e9SAndroid Build Coastguard Worker torch.mps.profiler.stop() 8047*da0073e9SAndroid Build Coastguard Worker 8048*da0073e9SAndroid Build Coastguard Worker def test_mps_event_module(self): 8049*da0073e9SAndroid Build Coastguard Worker startEvent = torch.mps.Event(enable_timing=True) 8050*da0073e9SAndroid Build Coastguard Worker startEvent.record() 8051*da0073e9SAndroid Build Coastguard Worker net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ 8052*da0073e9SAndroid Build Coastguard Worker .to(device='mps', dtype=torch.float) 8053*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) 8054*da0073e9SAndroid Build Coastguard Worker x = net1(x) 8055*da0073e9SAndroid Build Coastguard Worker endEvent = torch.mps.Event(enable_timing=True) 8056*da0073e9SAndroid Build Coastguard Worker endEvent.record() 8057*da0073e9SAndroid Build Coastguard Worker elapsedTime = startEvent.elapsed_time(endEvent) 8058*da0073e9SAndroid Build Coastguard Worker self.assertGreater(elapsedTime, 0.0) 8059*da0073e9SAndroid Build Coastguard Worker 8060*da0073e9SAndroid Build Coastguard Worker def test_jit_save_load(self): 8061*da0073e9SAndroid Build Coastguard Worker m = torch.nn.Module() 8062*da0073e9SAndroid Build Coastguard Worker m.x = torch.rand(3, 3, device='mps') 8063*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 8064*da0073e9SAndroid Build Coastguard Worker torch.jit.save(torch.jit.script(m), buffer) 8065*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 8066*da0073e9SAndroid Build Coastguard Worker n = torch.jit.load(buffer) 8067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n.x, m.x) 8068*da0073e9SAndroid Build Coastguard Worker 8069*da0073e9SAndroid Build Coastguard Worker # Test random_, random_.to and random_.from 8070*da0073e9SAndroid Build Coastguard Worker def test_random(self): 8071*da0073e9SAndroid Build Coastguard Worker def helper(shape, low, high, dtype=torch.int32): 8072*da0073e9SAndroid Build Coastguard Worker 8073*da0073e9SAndroid Build Coastguard Worker mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps') 8074*da0073e9SAndroid Build Coastguard Worker 8075*da0073e9SAndroid Build Coastguard Worker # We can't check reliably the mean and std. 8076*da0073e9SAndroid Build Coastguard Worker # Just make sure we don't return constant values 8077*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(mps_out.float().mean().item(), 0.) 8078*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(mps_out.float().std().item(), 0.) 8079*da0073e9SAndroid Build Coastguard Worker 8080*da0073e9SAndroid Build Coastguard Worker helper([100, 100], 0, 10) 8081*da0073e9SAndroid Build Coastguard Worker helper([100, 100], 23, 89) 8082*da0073e9SAndroid Build Coastguard Worker helper([100, 100], 23, 89, dtype=torch.float32) 8083*da0073e9SAndroid Build Coastguard Worker helper([100, 100], 23, 89, dtype=torch.int64) 8084*da0073e9SAndroid Build Coastguard Worker helper([100, 100], 0, 2, dtype=torch.bool) 8085*da0073e9SAndroid Build Coastguard Worker 8086*da0073e9SAndroid Build Coastguard Worker # Test random_ 8087*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bool, torch.int8, torch.uint8, torch.int32, torch.float16, torch.float32]: 8088*da0073e9SAndroid Build Coastguard Worker x = torch.empty(10, 10, dtype=dtype, device='mps') 8089*da0073e9SAndroid Build Coastguard Worker x.random_() 8090*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(x.max().item(), 0) 8091*da0073e9SAndroid Build Coastguard Worker 8092*da0073e9SAndroid Build Coastguard Worker # Test exponential 8093*da0073e9SAndroid Build Coastguard Worker def test_exponential(self): 8094*da0073e9SAndroid Build Coastguard Worker def helper(shape, lamda, dtype=torch.float32): 8095*da0073e9SAndroid Build Coastguard Worker 8096*da0073e9SAndroid Build Coastguard Worker mps_out = torch.zeros(shape, device='mps', dtype=dtype) 8097*da0073e9SAndroid Build Coastguard Worker mps_out.exponential_(lamda) 8098*da0073e9SAndroid Build Coastguard Worker 8099*da0073e9SAndroid Build Coastguard Worker print(mps_out.to('cpu').float().mean(), 1 / lamda) 8100*da0073e9SAndroid Build Coastguard Worker print(mps_out.to('cpu').float().std() ** 2, 1 / (lamda**2)) 8101*da0073e9SAndroid Build Coastguard Worker 8102*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float32, torch.float16]: 8103*da0073e9SAndroid Build Coastguard Worker helper([100, 100], 2, dtype) 8104*da0073e9SAndroid Build Coastguard Worker helper([100, 100], 1, dtype) 8105*da0073e9SAndroid Build Coastguard Worker helper([100, 100], 3, dtype) 8106*da0073e9SAndroid Build Coastguard Worker helper([100, 100], 0.5, dtype) 8107*da0073e9SAndroid Build Coastguard Worker 8108*da0073e9SAndroid Build Coastguard Worker def test_exponential_1(self): 8109*da0073e9SAndroid Build Coastguard Worker rate = torch.randn(5, 5).abs().requires_grad_() 8110*da0073e9SAndroid Build Coastguard Worker rate_1d = torch.randn(1).abs().requires_grad_() 8111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(rate).sample().size(), (5, 5)) 8112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5)) 8113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1)) 8114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(rate_1d).sample().size(), (1,)) 8115*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,)) 8116*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,)) 8117*da0073e9SAndroid Build Coastguard Worker 8118*da0073e9SAndroid Build Coastguard Worker # Test add 8119*da0073e9SAndroid Build Coastguard Worker def test_add_sub(self): 8120*da0073e9SAndroid Build Coastguard Worker def helper(shape, alpha, op_name, inplace): 8121*da0073e9SAndroid Build Coastguard Worker if op_name == "add": 8122*da0073e9SAndroid Build Coastguard Worker op = torch.Tensor.add_ if inplace else torch.add 8123*da0073e9SAndroid Build Coastguard Worker elif op_name == "sub": 8124*da0073e9SAndroid Build Coastguard Worker op = torch.Tensor.sub_ if inplace else torch.sub 8125*da0073e9SAndroid Build Coastguard Worker 8126*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float16, torch.float32]: 8127*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) 8128*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 8129*da0073e9SAndroid Build Coastguard Worker 8130*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) 8131*da0073e9SAndroid Build Coastguard Worker mps_y = cpu_y.detach().clone().to('mps') 8132*da0073e9SAndroid Build Coastguard Worker 8133*da0073e9SAndroid Build Coastguard Worker cpu_out = op(cpu_x, cpu_y, alpha=alpha) 8134*da0073e9SAndroid Build Coastguard Worker mps_out = op(mps_x, mps_y, alpha=alpha) 8135*da0073e9SAndroid Build Coastguard Worker # fp16 isn't accurate when alpha is passed 8136*da0073e9SAndroid Build Coastguard Worker # TODO: remove or fix 'tol' when we fix problems with fp16 8137*da0073e9SAndroid Build Coastguard Worker tol = 2e-3 if dtype is torch.float16 else None 8138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_out, cpu_out, rtol=tol, atol=tol) 8139*da0073e9SAndroid Build Coastguard Worker if not (cpu_y.shape != () and inplace): # in-place output cannot be broadcasted. 8140*da0073e9SAndroid Build Coastguard Worker # create a scalar tensor 8141*da0073e9SAndroid Build Coastguard Worker cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False) 8142*da0073e9SAndroid Build Coastguard Worker mps_s = cpu_s.detach().clone().to('mps') 8143*da0073e9SAndroid Build Coastguard Worker # primary tensor is scalar 8144*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(cpu_s, cpu_y), op(mps_s, mps_y)) 8145*da0073e9SAndroid Build Coastguard Worker # create a scalar tensor 8146*da0073e9SAndroid Build Coastguard Worker cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False) 8147*da0073e9SAndroid Build Coastguard Worker mps_s = cpu_s.detach().clone().to('mps') 8148*da0073e9SAndroid Build Coastguard Worker # secondary tensor is scalar 8149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(cpu_x, cpu_s), op(mps_x, mps_s), rtol=tol, atol=tol) 8150*da0073e9SAndroid Build Coastguard Worker 8151*da0073e9SAndroid Build Coastguard Worker 8152*da0073e9SAndroid Build Coastguard Worker for op_name, inplace in product(["add", "sub"], [True, False]): 8153*da0073e9SAndroid Build Coastguard Worker helper((), 0.0, op_name, inplace) 8154*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 0.0, op_name, inplace) 8155*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 0.1, op_name, inplace) 8156*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), 1.0, op_name, inplace) 8157*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 3, 5), 0.1, op_name, inplace) 8158*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 3, 5), 0.2, op_name, inplace) 8159*da0073e9SAndroid Build Coastguard Worker 8160*da0073e9SAndroid Build Coastguard Worker # Test add 8161*da0073e9SAndroid Build Coastguard Worker def test_add_scalars(self): 8162*da0073e9SAndroid Build Coastguard Worker def helper(alpha): 8163*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float16, torch.float32]: 8164*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False) 8165*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 8166*da0073e9SAndroid Build Coastguard Worker 8167*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.tensor(3.4, device='cpu', dtype=dtype, requires_grad=False) 8168*da0073e9SAndroid Build Coastguard Worker y = cpu_y.detach().clone().to('mps') 8169*da0073e9SAndroid Build Coastguard Worker 8170*da0073e9SAndroid Build Coastguard Worker cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha) 8171*da0073e9SAndroid Build Coastguard Worker out = torch.add(x, y, alpha=alpha) 8172*da0073e9SAndroid Build Coastguard Worker # fp16 isn't accurate when alpha is passed 8173*da0073e9SAndroid Build Coastguard Worker tol = 1e-3 if dtype is torch.float16 else None 8174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, cpu_out, rtol=tol, atol=tol) 8175*da0073e9SAndroid Build Coastguard Worker 8176*da0073e9SAndroid Build Coastguard Worker helper(1.0) 8177*da0073e9SAndroid Build Coastguard Worker helper(0.0) 8178*da0073e9SAndroid Build Coastguard Worker helper(0.1) 8179*da0073e9SAndroid Build Coastguard Worker helper(0.2) 8180*da0073e9SAndroid Build Coastguard Worker 8181*da0073e9SAndroid Build Coastguard Worker # Test int32 tensor + int64 scalar add 8182*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/79835#issuecomment-1164984534 8183*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4, dtype=torch.int32, device='mps') 8184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps')) 8185*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps'))) 8186*da0073e9SAndroid Build Coastguard Worker 8187*da0073e9SAndroid Build Coastguard Worker def test_types_binary_op(self): 8188*da0073e9SAndroid Build Coastguard Worker # Float * Bool 8189*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([True, False, True, False, True], device="cpu") 8190*da0073e9SAndroid Build Coastguard Worker mps_x = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([True, False, True, False, True], device="mps") 8191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x, mps_x) 8192*da0073e9SAndroid Build Coastguard Worker # Float * Int64 8193*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([1, 0, 1, 0, 1], device="cpu") 8194*da0073e9SAndroid Build Coastguard Worker mps_y = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([1, 0, 1, 0, 1], device="mps") 8195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_y, mps_y) 8196*da0073e9SAndroid Build Coastguard Worker 8197*da0073e9SAndroid Build Coastguard Worker def test_unary_ops(self): 8198*da0073e9SAndroid Build Coastguard Worker def helper(shape, op): 8199*da0073e9SAndroid Build Coastguard Worker for dtypef in [torch.float32]: 8200*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False) 8201*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 8202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(cpu_x), op(mps_x)) 8203*da0073e9SAndroid Build Coastguard Worker 8204*da0073e9SAndroid Build Coastguard Worker for dtypei in [torch.int32, torch.int16]: 8205*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(0, 1000, shape, device='cpu', dtype=dtypei, requires_grad=False) 8206*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.to('mps') 8207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(cpu_x), op(mps_x), rtol=1e-4, atol=1e-4) 8208*da0073e9SAndroid Build Coastguard Worker # test slice 8209*da0073e9SAndroid Build Coastguard Worker for dtypef in [torch.float32]: 8210*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False) 8211*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 8212*da0073e9SAndroid Build Coastguard Worker cpu_slice = cpu_x[:, ::2, :, :] 8213*da0073e9SAndroid Build Coastguard Worker mps_slice = mps_x[:, ::2, :, :] 8214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(cpu_slice), op(mps_slice)) 8215*da0073e9SAndroid Build Coastguard Worker # test view 8216*da0073e9SAndroid Build Coastguard Worker for dtypef in [torch.float32]: 8217*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False) 8218*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 8219*da0073e9SAndroid Build Coastguard Worker # create view of tensor by reducing the 3rd and 4th dimension 8220*da0073e9SAndroid Build Coastguard Worker combined_dim = shape[-1] * shape[-2] 8221*da0073e9SAndroid Build Coastguard Worker reshaped_dims = list(shape[:-2]) + [combined_dim] 8222*da0073e9SAndroid Build Coastguard Worker cpu_view = cpu_x.view(*reshaped_dims) 8223*da0073e9SAndroid Build Coastguard Worker mps_view = mps_x.view(*reshaped_dims) 8224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(cpu_view), op(mps_view)) 8225*da0073e9SAndroid Build Coastguard Worker 8226*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 4, 5), torch.exp) 8227*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 3, 5), torch.exp2) 8228*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 3, 5), torch.expm1) 8229*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 3, 5), torch.log) 8230*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 3, 5), torch.cos) 8231*da0073e9SAndroid Build Coastguard Worker helper((2, 8, 3, 5), torch.erfinv) 8232*da0073e9SAndroid Build Coastguard Worker 8233*da0073e9SAndroid Build Coastguard Worker 8234*da0073e9SAndroid Build Coastguard Worker def test_non_dense_in_storage_unary_ops(self): 8235*da0073e9SAndroid Build Coastguard Worker def helper(op): 8236*da0073e9SAndroid Build Coastguard Worker for dtypef in [torch.float32]: 8237*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(100, device='cpu', dtype=dtypef, requires_grad=False) 8238*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 8239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(cpu_x[::2]), op(mps_x[::2])) 8240*da0073e9SAndroid Build Coastguard Worker 8241*da0073e9SAndroid Build Coastguard Worker for dtypei in [torch.int32, torch.int16, torch.int8]: 8242*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randint(127, device='cpu', size=(100,), dtype=dtypei, requires_grad=False) 8243*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.to('mps') 8244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]), rtol=1e-4, atol=1e-4) 8245*da0073e9SAndroid Build Coastguard Worker 8246*da0073e9SAndroid Build Coastguard Worker helper(torch.exp) 8247*da0073e9SAndroid Build Coastguard Worker helper(torch.exp2) 8248*da0073e9SAndroid Build Coastguard Worker helper(torch.expm1) 8249*da0073e9SAndroid Build Coastguard Worker helper(torch.log) 8250*da0073e9SAndroid Build Coastguard Worker helper(torch.cos) 8251*da0073e9SAndroid Build Coastguard Worker 8252*da0073e9SAndroid Build Coastguard Worker def test_unary_ops_storage_offset_strided(self): 8253*da0073e9SAndroid Build Coastguard Worker def helper(shape, op, inplace, dtype=torch.float32): 8254*da0073e9SAndroid Build Coastguard Worker # test in-place with storage_offset 8255*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtype) 8256*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 8257*da0073e9SAndroid Build Coastguard Worker y = op(mps_x[1]) 8258*da0073e9SAndroid Build Coastguard Worker cpu_y = op(cpu_x[1]) 8259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, cpu_y) 8260*da0073e9SAndroid Build Coastguard Worker 8261*da0073e9SAndroid Build Coastguard Worker 8262*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/100764 8263*da0073e9SAndroid Build Coastguard Worker if not inplace: 8264*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=dtype) 8265*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 8266*da0073e9SAndroid Build Coastguard Worker cpu_y = torch.empty(shape, device='cpu', dtype=dtype).t() 8267*da0073e9SAndroid Build Coastguard Worker mps_y = cpu_y.detach().clone().to('mps') 8268*da0073e9SAndroid Build Coastguard Worker op(cpu_x, out=cpu_y) 8269*da0073e9SAndroid Build Coastguard Worker op(mps_x, out=mps_y) 8270*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_y, cpu_y) 8271*da0073e9SAndroid Build Coastguard Worker 8272*da0073e9SAndroid Build Coastguard Worker 8273*da0073e9SAndroid Build Coastguard Worker helper((5, 5), torch.exp, False) 8274*da0073e9SAndroid Build Coastguard Worker helper((5, 5), torch.cos, False) 8275*da0073e9SAndroid Build Coastguard Worker helper((5, 5), torch.neg, False) 8276*da0073e9SAndroid Build Coastguard Worker helper((5, 5), torch.tanh, False) 8277*da0073e9SAndroid Build Coastguard Worker helper((5, 5), torch.tanh_, True) 8278*da0073e9SAndroid Build Coastguard Worker 8279*da0073e9SAndroid Build Coastguard Worker def test_atan2(self): 8280*da0073e9SAndroid Build Coastguard Worker def helper(shape): 8281*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(shape) 8282*da0073e9SAndroid Build Coastguard Worker input_mps = input_cpu.detach().clone().to("mps") 8283*da0073e9SAndroid Build Coastguard Worker 8284*da0073e9SAndroid Build Coastguard Worker other_cpu = torch.randn(shape) 8285*da0073e9SAndroid Build Coastguard Worker other_mps = other_cpu.detach().clone().to("mps") 8286*da0073e9SAndroid Build Coastguard Worker 8287*da0073e9SAndroid Build Coastguard Worker atan2_cpu = torch.atan2(input_cpu, other_cpu) 8288*da0073e9SAndroid Build Coastguard Worker atan2_mps = torch.atan2(input_mps, other_mps) 8289*da0073e9SAndroid Build Coastguard Worker 8290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(atan2_cpu, atan2_mps.to("cpu")) 8291*da0073e9SAndroid Build Coastguard Worker 8292*da0073e9SAndroid Build Coastguard Worker helper(4) 8293*da0073e9SAndroid Build Coastguard Worker helper(10000) 8294*da0073e9SAndroid Build Coastguard Worker helper((10000, 40)) 8295*da0073e9SAndroid Build Coastguard Worker 8296*da0073e9SAndroid Build Coastguard Worker def test_multinomial(self): 8297*da0073e9SAndroid Build Coastguard Worker # Test with num_dist = 1 8298*da0073e9SAndroid Build Coastguard Worker def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True): 8299*da0073e9SAndroid Build Coastguard Worker cpu_prob_tensor = torch.tensor(probs, device='cpu', dtype=torch.float, requires_grad=False) 8300*da0073e9SAndroid Build Coastguard Worker prob_tensor = cpu_prob_tensor.detach().clone().to('mps') 8301*da0073e9SAndroid Build Coastguard Worker 8302*da0073e9SAndroid Build Coastguard Worker mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement) 8303*da0073e9SAndroid Build Coastguard Worker if (not replacement): 8304*da0073e9SAndroid Build Coastguard Worker print(mps_out.to('cpu')) 8305*da0073e9SAndroid Build Coastguard Worker else: 8306*da0073e9SAndroid Build Coastguard Worker # Compare "real" with theoretical values 8307*da0073e9SAndroid Build Coastguard Worker print(mps_out.to('cpu').float().mean(), compare_mean) 8308*da0073e9SAndroid Build Coastguard Worker print(mps_out.to('cpu').float().std() ** 2, compare_var) 8309*da0073e9SAndroid Build Coastguard Worker 8310*da0073e9SAndroid Build Coastguard Worker # TODO: Add tests for data types 8311*da0073e9SAndroid Build Coastguard Worker helper(np.array([[0., 0., 0., 0.5, 0.5]]), (3 + 4) / 2, (12.5 - 3.5 ** 2), 100000) 8312*da0073e9SAndroid Build Coastguard Worker helper(np.array([[.2, .2, .2, .2, .2]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000) 8313*da0073e9SAndroid Build Coastguard Worker helper(np.array([[1, 1, 1, 1, 1]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000) 8314*da0073e9SAndroid Build Coastguard Worker helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000) 8315*da0073e9SAndroid Build Coastguard Worker helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False) 8316*da0073e9SAndroid Build Coastguard Worker 8317*da0073e9SAndroid Build Coastguard Worker def test_cumsum_dim_check(self): 8318*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), device="mps") 8319*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.cumsum(1), x.cumsum(-1)) 8320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.cumsum(0), x.cumsum(-2)) 8321*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: x.cumsum(2)) 8322*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: x.cumsum(-3)) 8323*da0073e9SAndroid Build Coastguard Worker 8324*da0073e9SAndroid Build Coastguard Worker def test_cumprod_dim_check(self): 8325*da0073e9SAndroid Build Coastguard Worker x = torch.rand((3, 3), device="mps") 8326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.cumprod(1), x.cumprod(-1)) 8327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.cumprod(0), x.cumprod(-2)) 8328*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: x.cumprod(2)) 8329*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: x.cumprod(-3)) 8330*da0073e9SAndroid Build Coastguard Worker 8331*da0073e9SAndroid Build Coastguard Workerclass TestLogical(TestCaseMPS): 8332*da0073e9SAndroid Build Coastguard Worker def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): 8333*da0073e9SAndroid Build Coastguard Worker return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad) 8334*da0073e9SAndroid Build Coastguard Worker 8335*da0073e9SAndroid Build Coastguard Worker def test_logical_not(self): 8336*da0073e9SAndroid Build Coastguard Worker def helper(x): 8337*da0073e9SAndroid Build Coastguard Worker cpu_x = x 8338*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 8339*da0073e9SAndroid Build Coastguard Worker 8340*da0073e9SAndroid Build Coastguard Worker result = torch.logical_not(x) 8341*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.logical_not(cpu_x) 8342*da0073e9SAndroid Build Coastguard Worker 8343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_cpu) 8344*da0073e9SAndroid Build Coastguard Worker 8345*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor([1, 1, 0, 0])) 8346*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True)) 8347*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor([True, True, False, False])) 8348*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor(1)) 8349*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor(0)) 8350*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor(True)) 8351*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor(False)) 8352*da0073e9SAndroid Build Coastguard Worker 8353*da0073e9SAndroid Build Coastguard Worker def test_logical_and(self): 8354*da0073e9SAndroid Build Coastguard Worker def helper(x, other): 8355*da0073e9SAndroid Build Coastguard Worker cpu_x = x 8356*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 8357*da0073e9SAndroid Build Coastguard Worker 8358*da0073e9SAndroid Build Coastguard Worker cpu_other = other 8359*da0073e9SAndroid Build Coastguard Worker other = cpu_other.detach().clone().to('mps') 8360*da0073e9SAndroid Build Coastguard Worker 8361*da0073e9SAndroid Build Coastguard Worker result = torch.logical_and(x, other) 8362*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.logical_and(cpu_x, cpu_other) 8363*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_cpu) 8364*da0073e9SAndroid Build Coastguard Worker 8365*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1])) 8366*da0073e9SAndroid Build Coastguard Worker helper( 8367*da0073e9SAndroid Build Coastguard Worker self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True), 8368*da0073e9SAndroid Build Coastguard Worker self._wrap_tensor([1, 0, 0, 1], dtype=torch.float) 8369*da0073e9SAndroid Build Coastguard Worker ) 8370*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True])) 8371*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1)) 8372*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0)) 8373*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True)) 8374*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False)) 8375*da0073e9SAndroid Build Coastguard Worker 8376*da0073e9SAndroid Build Coastguard Worker def test_logical_or(self): 8377*da0073e9SAndroid Build Coastguard Worker def helper(x, other): 8378*da0073e9SAndroid Build Coastguard Worker cpu_x = x 8379*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 8380*da0073e9SAndroid Build Coastguard Worker 8381*da0073e9SAndroid Build Coastguard Worker cpu_other = other 8382*da0073e9SAndroid Build Coastguard Worker other = cpu_other.detach().clone().to('mps') 8383*da0073e9SAndroid Build Coastguard Worker 8384*da0073e9SAndroid Build Coastguard Worker result = torch.logical_or(x, other) 8385*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.logical_or(cpu_x, cpu_other) 8386*da0073e9SAndroid Build Coastguard Worker 8387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_cpu) 8388*da0073e9SAndroid Build Coastguard Worker 8389*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1])) 8390*da0073e9SAndroid Build Coastguard Worker helper( 8391*da0073e9SAndroid Build Coastguard Worker self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True), 8392*da0073e9SAndroid Build Coastguard Worker self._wrap_tensor([1, 0, 0, 1], dtype=torch.float) 8393*da0073e9SAndroid Build Coastguard Worker ) 8394*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True])) 8395*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1)) 8396*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0)) 8397*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True)) 8398*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False)) 8399*da0073e9SAndroid Build Coastguard Worker 8400*da0073e9SAndroid Build Coastguard Worker def test_logical_xor(self): 8401*da0073e9SAndroid Build Coastguard Worker def helper(x, other): 8402*da0073e9SAndroid Build Coastguard Worker cpu_x = x 8403*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 8404*da0073e9SAndroid Build Coastguard Worker 8405*da0073e9SAndroid Build Coastguard Worker cpu_other = other 8406*da0073e9SAndroid Build Coastguard Worker other = cpu_other.detach().clone().to('mps') 8407*da0073e9SAndroid Build Coastguard Worker 8408*da0073e9SAndroid Build Coastguard Worker result = torch.logical_xor(x, other) 8409*da0073e9SAndroid Build Coastguard Worker result_cpu = torch.logical_xor(cpu_x, cpu_other) 8410*da0073e9SAndroid Build Coastguard Worker 8411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_cpu) 8412*da0073e9SAndroid Build Coastguard Worker 8413*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1])) 8414*da0073e9SAndroid Build Coastguard Worker helper( 8415*da0073e9SAndroid Build Coastguard Worker self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True), 8416*da0073e9SAndroid Build Coastguard Worker self._wrap_tensor([1, 0, 0, 1], dtype=torch.float) 8417*da0073e9SAndroid Build Coastguard Worker ) 8418*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True])) 8419*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1)) 8420*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0)) 8421*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True)) 8422*da0073e9SAndroid Build Coastguard Worker helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False)) 8423*da0073e9SAndroid Build Coastguard Worker 8424*da0073e9SAndroid Build Coastguard Worker def test_min_max(self): 8425*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 8426*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 8427*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float32 or dtype == torch.float16: 8428*da0073e9SAndroid Build Coastguard Worker x = torch.randn((30, 15), device='mps', dtype=dtype) 8429*da0073e9SAndroid Build Coastguard Worker else: 8430*da0073e9SAndroid Build Coastguard Worker x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype) 8431*da0073e9SAndroid Build Coastguard Worker x_cpu = x.to("cpu") 8432*da0073e9SAndroid Build Coastguard Worker 8433*da0073e9SAndroid Build Coastguard Worker y = x.max() 8434*da0073e9SAndroid Build Coastguard Worker y_cpu = x_cpu.max() 8435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_cpu) 8436*da0073e9SAndroid Build Coastguard Worker 8437*da0073e9SAndroid Build Coastguard Worker z = x.min() 8438*da0073e9SAndroid Build Coastguard Worker z_cpu = x_cpu.min() 8439*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, z_cpu) 8440*da0073e9SAndroid Build Coastguard Worker 8441*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]] 8442*da0073e9SAndroid Build Coastguard Worker 8443*da0073e9SAndroid Build Coastguard Worker def test_min_max_nan_propagation(self): 8444*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 8445*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu") 8446*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 8447*da0073e9SAndroid Build Coastguard Worker 8448*da0073e9SAndroid Build Coastguard Worker cpu_max = torch.max(cpu_x) 8449*da0073e9SAndroid Build Coastguard Worker mps_max = torch.max(mps_x).to('cpu') 8450*da0073e9SAndroid Build Coastguard Worker 8451*da0073e9SAndroid Build Coastguard Worker cpu_amax = torch.amax(cpu_x) 8452*da0073e9SAndroid Build Coastguard Worker mps_amax = torch.amax(mps_x).to('cpu') 8453*da0073e9SAndroid Build Coastguard Worker 8454*da0073e9SAndroid Build Coastguard Worker cpu_min = torch.min(cpu_x) 8455*da0073e9SAndroid Build Coastguard Worker mps_min = torch.min(mps_x).to('cpu') 8456*da0073e9SAndroid Build Coastguard Worker 8457*da0073e9SAndroid Build Coastguard Worker cpu_amin = torch.amin(cpu_x) 8458*da0073e9SAndroid Build Coastguard Worker mps_amin = torch.amin(mps_x).to('cpu') 8459*da0073e9SAndroid Build Coastguard Worker 8460*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_max, mps_max) 8461*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_amax, mps_amax) 8462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_min, mps_min) 8463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_amin, mps_amin) 8464*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in [torch.float32, torch.float16, torch.bfloat16]] 8465*da0073e9SAndroid Build Coastguard Worker 8466*da0073e9SAndroid Build Coastguard Worker def test_isin(self): 8467*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 8468*da0073e9SAndroid Build Coastguard Worker shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]), 8469*da0073e9SAndroid Build Coastguard Worker ([5], [10]), ([0], [5]), ([5], [0])] 8470*da0073e9SAndroid Build Coastguard Worker for shape_tuple in shapes: 8471*da0073e9SAndroid Build Coastguard Worker for inverted in [True, False]: 8472*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 8473*da0073e9SAndroid Build Coastguard Worker # Half is not supported for CPU isin. Compute reference in FP32 8474*da0073e9SAndroid Build Coastguard Worker A = torch.randn(size=shape_tuple[0], device='cpu', dtype=torch.float32) 8475*da0073e9SAndroid Build Coastguard Worker B = torch.randn(size=shape_tuple[1], device='cpu', dtype=torch.float32) 8476*da0073e9SAndroid Build Coastguard Worker else: 8477*da0073e9SAndroid Build Coastguard Worker A = torch.randint(0, 100, size=shape_tuple[0], device='cpu', dtype=dtype) 8478*da0073e9SAndroid Build Coastguard Worker B = torch.randint(0, 100, size=shape_tuple[1], device='cpu', dtype=dtype) 8479*da0073e9SAndroid Build Coastguard Worker 8480*da0073e9SAndroid Build Coastguard Worker A_mps = A.clone().detach().to('mps') 8481*da0073e9SAndroid Build Coastguard Worker B_mps = B.clone().detach().to('mps') 8482*da0073e9SAndroid Build Coastguard Worker 8483*da0073e9SAndroid Build Coastguard Worker cpu_ref = torch.isin(A, B, invert=inverted) 8484*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.float16, torch.bfloat16]: 8485*da0073e9SAndroid Build Coastguard Worker cpu_ref.type(dtype) 8486*da0073e9SAndroid Build Coastguard Worker 8487*da0073e9SAndroid Build Coastguard Worker mps_out = torch.isin(A_mps, B_mps, invert=inverted) 8488*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_out, cpu_ref) 8489*da0073e9SAndroid Build Coastguard Worker 8490*da0073e9SAndroid Build Coastguard Worker dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8] 8491*da0073e9SAndroid Build Coastguard Worker if product_version < 14.0: 8492*da0073e9SAndroid Build Coastguard Worker # Int types expected to fail on MacOS < 14.0 8493*da0073e9SAndroid Build Coastguard Worker dtypes = [torch.float32, torch.float16, torch.bfloat16] 8494*da0073e9SAndroid Build Coastguard Worker 8495*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in dtypes] 8496*da0073e9SAndroid Build Coastguard Worker 8497*da0073e9SAndroid Build Coastguard Worker def test_isin_asserts(self): 8498*da0073e9SAndroid Build Coastguard Worker A = torch.randn(size=[1, 4], device='mps', dtype=torch.float32) 8499*da0073e9SAndroid Build Coastguard Worker B = torch.randn(size=[1, 4], device='mps', dtype=torch.float16) 8500*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected elements.dtype()*'): 8501*da0073e9SAndroid Build Coastguard Worker out = torch.isin(A, B) 8502*da0073e9SAndroid Build Coastguard Worker 8503*da0073e9SAndroid Build Coastguard Worker 8504*da0073e9SAndroid Build Coastguard Worker C = torch.randn(size=[1, 4], device='mps', dtype=torch.float32) 8505*da0073e9SAndroid Build Coastguard Worker D = torch.randn(size=[1, 4], device='cpu', dtype=torch.float32) 8506*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected elements.is_mps()*'): 8507*da0073e9SAndroid Build Coastguard Worker out = torch.isin(C, D) 8508*da0073e9SAndroid Build Coastguard Worker 8509*da0073e9SAndroid Build Coastguard Workerclass TestSmoothL1Loss(TestCaseMPS): 8510*da0073e9SAndroid Build Coastguard Worker 8511*da0073e9SAndroid Build Coastguard Worker def _smooth_l1_loss_helper(self, reduction="mean", requires_grad=False): 8512*da0073e9SAndroid Build Coastguard Worker # CPU 8513*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(4, 7, requires_grad=requires_grad) 8514*da0073e9SAndroid Build Coastguard Worker target_cpu = torch.randn(4, 7) 8515*da0073e9SAndroid Build Coastguard Worker 8516*da0073e9SAndroid Build Coastguard Worker # MPS 8517*da0073e9SAndroid Build Coastguard Worker input_mps = input_cpu.detach().clone().to('mps').requires_grad_() 8518*da0073e9SAndroid Build Coastguard Worker target_mps = target_cpu.detach().clone().to('mps') 8519*da0073e9SAndroid Build Coastguard Worker 8520*da0073e9SAndroid Build Coastguard Worker smooth_l1_loss_cpu = F.smooth_l1_loss(input_cpu, target_cpu, beta=1.0, reduction=reduction) 8521*da0073e9SAndroid Build Coastguard Worker smooth_l1_loss_mps = F.smooth_l1_loss(input_mps, target_mps, beta=1.0, reduction=reduction) 8522*da0073e9SAndroid Build Coastguard Worker 8523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(smooth_l1_loss_cpu, smooth_l1_loss_mps) 8524*da0073e9SAndroid Build Coastguard Worker 8525*da0073e9SAndroid Build Coastguard Worker if requires_grad: 8526*da0073e9SAndroid Build Coastguard Worker smooth_l1_loss_cpu.backward() 8527*da0073e9SAndroid Build Coastguard Worker smooth_l1_loss_mps.backward() 8528*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_cpu.grad, input_mps.grad.to("cpu")) 8529*da0073e9SAndroid Build Coastguard Worker 8530*da0073e9SAndroid Build Coastguard Worker return smooth_l1_loss_cpu, smooth_l1_loss_mps 8531*da0073e9SAndroid Build Coastguard Worker 8532*da0073e9SAndroid Build Coastguard Worker def test_smooth_l1_loss_reduction_none(self): 8533*da0073e9SAndroid Build Coastguard Worker self._smooth_l1_loss_helper(reduction="none") 8534*da0073e9SAndroid Build Coastguard Worker 8535*da0073e9SAndroid Build Coastguard Worker def test_smooth_l1_loss_reduction_mean(self): 8536*da0073e9SAndroid Build Coastguard Worker self._smooth_l1_loss_helper(reduction="mean") 8537*da0073e9SAndroid Build Coastguard Worker 8538*da0073e9SAndroid Build Coastguard Worker def test_smooth_l1_loss_reduction_sum(self): 8539*da0073e9SAndroid Build Coastguard Worker self._smooth_l1_loss_helper(reduction="sum") 8540*da0073e9SAndroid Build Coastguard Worker 8541*da0073e9SAndroid Build Coastguard Worker def test_smooth_l1_loss_reduction_mean_backward(self): 8542*da0073e9SAndroid Build Coastguard Worker self._smooth_l1_loss_helper(reduction="mean", requires_grad=True) 8543*da0073e9SAndroid Build Coastguard Worker 8544*da0073e9SAndroid Build Coastguard Worker def test_smooth_l1_loss_reduction_mean_sum_backward(self): 8545*da0073e9SAndroid Build Coastguard Worker self._smooth_l1_loss_helper(reduction="sum", requires_grad=True) 8546*da0073e9SAndroid Build Coastguard Worker 8547*da0073e9SAndroid Build Coastguard Workerclass TestNLLLoss(TestCaseMPS): 8548*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_mismatched_batch(self, device='mps'): 8549*da0073e9SAndroid Build Coastguard Worker x = torch.randn((10, 3), requires_grad=True, device=device) 8550*da0073e9SAndroid Build Coastguard Worker # t should have size (10,) 8551*da0073e9SAndroid Build Coastguard Worker t = torch.zeros((3,), dtype=torch.int64, device=device) 8552*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'): 8553*da0073e9SAndroid Build Coastguard Worker F.nll_loss(x, t) 8554*da0073e9SAndroid Build Coastguard Worker 8555*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_out_of_bounds_ignore_index(self): 8556*da0073e9SAndroid Build Coastguard Worker 8557*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_out_of_bounds_ignore_index_helper(device): 8558*da0073e9SAndroid Build Coastguard Worker output = [] 8559*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [ 8560*da0073e9SAndroid Build Coastguard Worker 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device) 8561*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device) 8562*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor([0, 1, 1, 0, -100, 2], dtype=torch.int64, device=device) 8563*da0073e9SAndroid Build Coastguard Worker for reduction in ['mean', 'none']: 8564*da0073e9SAndroid Build Coastguard Worker # out of bound ignore_index 8565*da0073e9SAndroid Build Coastguard Worker output.append(F.nll_loss(x, t1, ignore_index=255, reduction=reduction)) 8566*da0073e9SAndroid Build Coastguard Worker # default ignore_index 8567*da0073e9SAndroid Build Coastguard Worker output.append(F.nll_loss(x, t2, reduction=reduction)) 8568*da0073e9SAndroid Build Coastguard Worker return output 8569*da0073e9SAndroid Build Coastguard Worker 8570*da0073e9SAndroid Build Coastguard Worker output_cpu = test_nll_loss_out_of_bounds_ignore_index_helper(device='cpu') 8571*da0073e9SAndroid Build Coastguard Worker output_mps = test_nll_loss_out_of_bounds_ignore_index_helper(device='mps') 8572*da0073e9SAndroid Build Coastguard Worker 8573*da0073e9SAndroid Build Coastguard Worker for cpu, mps in zip(output_cpu, output_mps): 8574*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu, mps) 8575*da0073e9SAndroid Build Coastguard Worker 8576*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_invalid_target_dim(self): 8577*da0073e9SAndroid Build Coastguard Worker 8578*da0073e9SAndroid Build Coastguard Worker def _test_nll_loss_invalid_target_dim(device): 8579*da0073e9SAndroid Build Coastguard Worker output = [] 8580*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [ 8581*da0073e9SAndroid Build Coastguard Worker 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device) 8582*da0073e9SAndroid Build Coastguard Worker t = torch.zeros((6, 2), dtype=torch.int64, device=device) 8583*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"): 8584*da0073e9SAndroid Build Coastguard Worker F.nll_loss(x, t) 8585*da0073e9SAndroid Build Coastguard Worker 8586*da0073e9SAndroid Build Coastguard Worker _test_nll_loss_invalid_target_dim(device='cpu') 8587*da0073e9SAndroid Build Coastguard Worker _test_nll_loss_invalid_target_dim(device='mps') 8588*da0073e9SAndroid Build Coastguard Worker 8589*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_invalid_weights(self): 8590*da0073e9SAndroid Build Coastguard Worker 8591*da0073e9SAndroid Build Coastguard Worker def _test_nll_loss_invalid_weights(device): 8592*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [ 8593*da0073e9SAndroid Build Coastguard Worker 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device) 8594*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([0, 1, 2, 1, 1, 2], dtype=torch.int64, device=device) 8595*da0073e9SAndroid Build Coastguard Worker invalid_weights = [ 8596*da0073e9SAndroid Build Coastguard Worker torch.zeros(4, device=device), 8597*da0073e9SAndroid Build Coastguard Worker torch.zeros((1, 3), device=device), 8598*da0073e9SAndroid Build Coastguard Worker ] 8599*da0073e9SAndroid Build Coastguard Worker msg = "weight tensor should be defined either for all 3 classes or no classes" 8600*da0073e9SAndroid Build Coastguard Worker for weight in invalid_weights: 8601*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 8602*da0073e9SAndroid Build Coastguard Worker F.nll_loss(x, t, weight=weight) 8603*da0073e9SAndroid Build Coastguard Worker 8604*da0073e9SAndroid Build Coastguard Worker _test_nll_loss_invalid_weights(device='cpu') 8605*da0073e9SAndroid Build Coastguard Worker _test_nll_loss_invalid_weights(device='mps') 8606*da0073e9SAndroid Build Coastguard Worker 8607*da0073e9SAndroid Build Coastguard Worker def _nll_loss_helper(self, input_size, reduction, expected): 8608*da0073e9SAndroid Build Coastguard Worker 8609*da0073e9SAndroid Build Coastguard Worker # CPU 8610*da0073e9SAndroid Build Coastguard Worker input = torch.rand(input_size, requires_grad=True, device='cpu') 8611*da0073e9SAndroid Build Coastguard Worker num_channels = input_size[1] 8612*da0073e9SAndroid Build Coastguard Worker target_size = (input_size[0], ) + tuple(input_size[2:]) 8613*da0073e9SAndroid Build Coastguard Worker target = torch.randint(num_channels, target_size, device='cpu') 8614*da0073e9SAndroid Build Coastguard Worker weights = torch.randn(num_channels) 8615*da0073e9SAndroid Build Coastguard Worker 8616*da0073e9SAndroid Build Coastguard Worker # MPS 8617*da0073e9SAndroid Build Coastguard Worker input_mps = input.detach().clone().to('mps').requires_grad_() 8618*da0073e9SAndroid Build Coastguard Worker target_mps = target.detach().clone().to('mps') 8619*da0073e9SAndroid Build Coastguard Worker weights_mps = weights.to("mps") 8620*da0073e9SAndroid Build Coastguard Worker 8621*da0073e9SAndroid Build Coastguard Worker output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction) 8622*da0073e9SAndroid Build Coastguard Worker output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction) 8623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps.to('cpu')) 8624*da0073e9SAndroid Build Coastguard Worker 8625*da0073e9SAndroid Build Coastguard Worker output_cpu.sum().backward() 8626*da0073e9SAndroid Build Coastguard Worker output_mps.sum().backward() 8627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, input_mps.grad.to('cpu')) 8628*da0073e9SAndroid Build Coastguard Worker 8629*da0073e9SAndroid Build Coastguard Worker def _nll_loss_1d_helper(self, input_size, reduction): 8630*da0073e9SAndroid Build Coastguard Worker 8631*da0073e9SAndroid Build Coastguard Worker # CPU 8632*da0073e9SAndroid Build Coastguard Worker input = torch.rand(input_size, requires_grad=True, device='cpu') 8633*da0073e9SAndroid Build Coastguard Worker num_channels = input_size[0] 8634*da0073e9SAndroid Build Coastguard Worker target = torch.randint(num_channels, [], device='cpu') 8635*da0073e9SAndroid Build Coastguard Worker 8636*da0073e9SAndroid Build Coastguard Worker # MPS 8637*da0073e9SAndroid Build Coastguard Worker input_mps = input.detach().clone().to('mps').requires_grad_() 8638*da0073e9SAndroid Build Coastguard Worker target_mps = target.detach().clone().to('mps') 8639*da0073e9SAndroid Build Coastguard Worker 8640*da0073e9SAndroid Build Coastguard Worker output_cpu = F.nll_loss(input, target, reduction=reduction) 8641*da0073e9SAndroid Build Coastguard Worker output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction) 8642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps.to('cpu')) 8643*da0073e9SAndroid Build Coastguard Worker 8644*da0073e9SAndroid Build Coastguard Worker output_cpu.sum().backward() 8645*da0073e9SAndroid Build Coastguard Worker output_mps.sum().backward() 8646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, input_mps.grad.to('cpu')) 8647*da0073e9SAndroid Build Coastguard Worker 8648*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_1d(self, device='cpu'): 8649*da0073e9SAndroid Build Coastguard Worker self._nll_loss_1d_helper([10], "none") 8650*da0073e9SAndroid Build Coastguard Worker self._nll_loss_1d_helper([10], "mean") 8651*da0073e9SAndroid Build Coastguard Worker self._nll_loss_1d_helper([10], "sum") 8652*da0073e9SAndroid Build Coastguard Worker 8653*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'): 8654*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device)) 8655*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device)) 8656*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 1, 7], "none", torch.empty([2, 1, 7], device=device)) 8657*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 1], "none", torch.empty([2, 5, 1], device=device)) 8658*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 7, 1], "none", torch.empty([2, 5, 7, 1], device=device)) 8659*da0073e9SAndroid Build Coastguard Worker 8660*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_empty_tensor_reduction_mean(self, device='cpu'): 8661*da0073e9SAndroid Build Coastguard Worker nan = torch.tensor(float('nan'), device=device) 8662*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([1, 3], "mean", nan) 8663*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([1, 3, 5, 7], "mean", nan) 8664*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 1, 7], "mean", nan) 8665*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 1], "mean", nan) 8666*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 7, 1], "mean", nan) 8667*da0073e9SAndroid Build Coastguard Worker 8668*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_empty_tensor_reduction_sum(self, device='cpu'): 8669*da0073e9SAndroid Build Coastguard Worker zero = torch.tensor(0, device=device) 8670*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([1, 3], "sum", zero) 8671*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([1, 3, 5, 7], "sum", zero) 8672*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 1, 7], "sum", zero) 8673*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 1], "sum", zero) 8674*da0073e9SAndroid Build Coastguard Worker self._nll_loss_helper([2, 3, 5, 7, 1], "sum", zero) 8675*da0073e9SAndroid Build Coastguard Worker 8676*da0073e9SAndroid Build Coastguard Worker def test_nll_loss_byte_target_matches_long(self, device='cpu'): 8677*da0073e9SAndroid Build Coastguard Worker N, C = 10, 4 8678*da0073e9SAndroid Build Coastguard Worker input = torch.randn(N, C, device=device, requires_grad=True) 8679*da0073e9SAndroid Build Coastguard Worker target = torch.empty(N, dtype=torch.long, device=device).random_(0, C) 8680*da0073e9SAndroid Build Coastguard Worker 8681*da0073e9SAndroid Build Coastguard Worker def compute_result_and_gradient(reduction, target_dtype): 8682*da0073e9SAndroid Build Coastguard Worker result, grad = {}, {} 8683*da0073e9SAndroid Build Coastguard Worker for dev in ['cpu', 'mps']: 8684*da0073e9SAndroid Build Coastguard Worker input_dev = input.to(dev) 8685*da0073e9SAndroid Build Coastguard Worker input_ = input_dev.detach() 8686*da0073e9SAndroid Build Coastguard Worker input_.requires_grad_() 8687*da0073e9SAndroid Build Coastguard Worker 8688*da0073e9SAndroid Build Coastguard Worker target_dev = target.to(dev) 8689*da0073e9SAndroid Build Coastguard Worker 8690*da0073e9SAndroid Build Coastguard Worker prob = F.log_softmax(input_, dim=-1) 8691*da0073e9SAndroid Build Coastguard Worker loss = nn.NLLLoss(reduction=reduction) 8692*da0073e9SAndroid Build Coastguard Worker result[dev] = loss(prob, target_dev.to(target_dtype)) 8693*da0073e9SAndroid Build Coastguard Worker result[dev].sum().backward() 8694*da0073e9SAndroid Build Coastguard Worker grad[dev] = input_.grad 8695*da0073e9SAndroid Build Coastguard Worker 8696*da0073e9SAndroid Build Coastguard Worker return result, grad 8697*da0073e9SAndroid Build Coastguard Worker 8698*da0073e9SAndroid Build Coastguard Worker for reduction in ["none", "mean", "sum"]: 8699*da0073e9SAndroid Build Coastguard Worker result_long, grad_long = compute_result_and_gradient(reduction, torch.long) 8700*da0073e9SAndroid Build Coastguard Worker result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8) 8701*da0073e9SAndroid Build Coastguard Worker 8702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu']) 8703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu']) 8704*da0073e9SAndroid Build Coastguard Worker 8705*da0073e9SAndroid Build Coastguard Workerclass TestTopK(TestCase): 8706*da0073e9SAndroid Build Coastguard Worker def _test_topk(self, shape, largest): 8707*da0073e9SAndroid Build Coastguard Worker cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 8708*da0073e9SAndroid Build Coastguard Worker x = cpu_x.detach().clone().to('mps') 8709*da0073e9SAndroid Build Coastguard Worker if isinstance(shape, tuple): 8710*da0073e9SAndroid Build Coastguard Worker for curr_dim, dim_size in enumerate(shape): 8711*da0073e9SAndroid Build Coastguard Worker for k in range(1, dim_size + 1): 8712*da0073e9SAndroid Build Coastguard Worker topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest) 8713*da0073e9SAndroid Build Coastguard Worker topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest) 8714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(topk_values, topk_values_cpu) 8715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(topk_indices, topk_indices_cpu) 8716*da0073e9SAndroid Build Coastguard Worker else: 8717*da0073e9SAndroid Build Coastguard Worker for k in range(1, shape): 8718*da0073e9SAndroid Build Coastguard Worker topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest) 8719*da0073e9SAndroid Build Coastguard Worker topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest) 8720*da0073e9SAndroid Build Coastguard Worker self.assertEqual(topk_values, topk_values_cpu) 8721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(topk_indices, topk_indices_cpu) 8722*da0073e9SAndroid Build Coastguard Worker 8723*da0073e9SAndroid Build Coastguard Worker def test_topk(self): 8724*da0073e9SAndroid Build Coastguard Worker largest_vals = [True, False] 8725*da0073e9SAndroid Build Coastguard Worker shapes = [ 8726*da0073e9SAndroid Build Coastguard Worker # Zero Element Tensors 8727*da0073e9SAndroid Build Coastguard Worker 0, 8728*da0073e9SAndroid Build Coastguard Worker (1, 0), 8729*da0073e9SAndroid Build Coastguard Worker (0, 1), 8730*da0073e9SAndroid Build Coastguard Worker (1, 0, 1), 8731*da0073e9SAndroid Build Coastguard Worker # Multiple Element Tensors 8732*da0073e9SAndroid Build Coastguard Worker 1, 8733*da0073e9SAndroid Build Coastguard Worker 2, 8734*da0073e9SAndroid Build Coastguard Worker (5, 1), 8735*da0073e9SAndroid Build Coastguard Worker (1, 5), 8736*da0073e9SAndroid Build Coastguard Worker (5, 9, 7, 4), 8737*da0073e9SAndroid Build Coastguard Worker ] 8738*da0073e9SAndroid Build Coastguard Worker 8739*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 8740*da0073e9SAndroid Build Coastguard Worker for largest_val in largest_vals: 8741*da0073e9SAndroid Build Coastguard Worker with self.subTest(shape=shape, largest_val=largest_val): 8742*da0073e9SAndroid Build Coastguard Worker self._test_topk(shape, largest_val) 8743*da0073e9SAndroid Build Coastguard Worker 8744*da0073e9SAndroid Build Coastguard Workerclass TestNNMPS(NNTestCase): 8745*da0073e9SAndroid Build Coastguard Worker 8746*da0073e9SAndroid Build Coastguard Worker def _create_basic_net(self): 8747*da0073e9SAndroid Build Coastguard Worker class Layer(nn.Module): 8748*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8749*da0073e9SAndroid Build Coastguard Worker super().__init__() 8750*da0073e9SAndroid Build Coastguard Worker self.layer_dummy_param = Parameter(torch.empty(3, 5)) 8751*da0073e9SAndroid Build Coastguard Worker self.layer_dummy_buf = Buffer(torch.zeros(1, 3, 3, 7)) 8752*da0073e9SAndroid Build Coastguard Worker 8753*da0073e9SAndroid Build Coastguard Worker class Net(nn.Module): 8754*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8755*da0073e9SAndroid Build Coastguard Worker super().__init__() 8756*da0073e9SAndroid Build Coastguard Worker self.l1 = Layer() 8757*da0073e9SAndroid Build Coastguard Worker self.dummy_param = Parameter(torch.empty(3, 5)) 8758*da0073e9SAndroid Build Coastguard Worker self.dummy_buf = Buffer(torch.zeros(7, 3, 3, 1)) 8759*da0073e9SAndroid Build Coastguard Worker 8760*da0073e9SAndroid Build Coastguard Worker l = Layer() 8761*da0073e9SAndroid Build Coastguard Worker n = Net() 8762*da0073e9SAndroid Build Coastguard Worker s = nn.Sequential(n, n) 8763*da0073e9SAndroid Build Coastguard Worker 8764*da0073e9SAndroid Build Coastguard Worker return l, n, s 8765*da0073e9SAndroid Build Coastguard Worker 8766*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_(self): 8767*da0073e9SAndroid Build Coastguard Worker m = self._create_basic_net()[-1] 8768*da0073e9SAndroid Build Coastguard Worker assert len(list(m.buffers())) > 0, 'invalid test' 8769*da0073e9SAndroid Build Coastguard Worker assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test' 8770*da0073e9SAndroid Build Coastguard Worker assert len(list(m.parameters())) > 0, 'invalid test' 8771*da0073e9SAndroid Build Coastguard Worker assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test' 8772*da0073e9SAndroid Build Coastguard Worker for requires_grad in (False, True): 8773*da0073e9SAndroid Build Coastguard Worker self.assertIs(m.requires_grad_(requires_grad), m) 8774*da0073e9SAndroid Build Coastguard Worker for p in m.parameters(): 8775*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p.requires_grad, requires_grad) 8776*da0073e9SAndroid Build Coastguard Worker for b in m.buffers(): 8777*da0073e9SAndroid Build Coastguard Worker self.assertFalse(b.requires_grad) 8778*da0073e9SAndroid Build Coastguard Worker 8779*da0073e9SAndroid Build Coastguard Worker def test_module_backcompat(self): 8780*da0073e9SAndroid Build Coastguard Worker from torch.serialization import SourceChangeWarning 8781*da0073e9SAndroid Build Coastguard Worker path = download_file('https://download.pytorch.org/test_data/linear.pt') 8782*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 8783*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter('ignore', SourceChangeWarning) 8784*da0073e9SAndroid Build Coastguard Worker m = torch.load(path) 8785*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, dtype=torch.float) 8786*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input).size(), (2, 5)) 8787*da0073e9SAndroid Build Coastguard Worker 8788*da0073e9SAndroid Build Coastguard Worker def test_conv_backcompat(self): 8789*da0073e9SAndroid Build Coastguard Worker from torch.serialization import SourceChangeWarning 8790*da0073e9SAndroid Build Coastguard Worker # This file was generated by running on PyTorch 1.0.1 on Python 2: 8791*da0073e9SAndroid Build Coastguard Worker # 8792*da0073e9SAndroid Build Coastguard Worker # import torch 8793*da0073e9SAndroid Build Coastguard Worker # from torch import nn 8794*da0073e9SAndroid Build Coastguard Worker # m = nn.Conv2d(1, 1, 1) 8795*da0073e9SAndroid Build Coastguard Worker # torch.save(m, 'legacy_conv2d.pt') 8796*da0073e9SAndroid Build Coastguard Worker # 8797*da0073e9SAndroid Build Coastguard Worker # NB: This Pickle also contains some Unicode data! 8798*da0073e9SAndroid Build Coastguard Worker path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt') 8799*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 8800*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter('ignore', SourceChangeWarning) 8801*da0073e9SAndroid Build Coastguard Worker m = torch.load(path, encoding='utf-8') 8802*da0073e9SAndroid Build Coastguard Worker input = torch.randn((1, 1, 1, 1), dtype=torch.float) 8803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input).size(), (1, 1, 1, 1)) 8804*da0073e9SAndroid Build Coastguard Worker 8805*da0073e9SAndroid Build Coastguard Worker def test_conv_expand(self): 8806*da0073e9SAndroid Build Coastguard Worker device = 'mps' 8807*da0073e9SAndroid Build Coastguard Worker input_ = torch.rand(2, 3, 16, 16, device=device) 8808*da0073e9SAndroid Build Coastguard Worker kernel = torch.rand(1, 1, 3, 11, device=device) 8809*da0073e9SAndroid Build Coastguard Worker tmp_kernel = kernel.expand(-1, 3, -1, -1) 8810*da0073e9SAndroid Build Coastguard Worker output = F.conv2d(input_, tmp_kernel, groups=1, padding=0, stride=1) 8811*da0073e9SAndroid Build Coastguard Worker 8812*da0073e9SAndroid Build Coastguard Worker # The test should not crash 8813*da0073e9SAndroid Build Coastguard Worker def test_permute(self): 8814*da0073e9SAndroid Build Coastguard Worker M_cpu = torch.randn(5, 5) 8815*da0073e9SAndroid Build Coastguard Worker M_mps = M_cpu.to('mps') 8816*da0073e9SAndroid Build Coastguard Worker 8817*da0073e9SAndroid Build Coastguard Worker output_cpu = M_cpu.permute(1, 0) 8818*da0073e9SAndroid Build Coastguard Worker output_mps = M_mps.permute(1, 0) 8819*da0073e9SAndroid Build Coastguard Worker 8820*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps) 8821*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu.size(), output_mps.size()) 8822*da0073e9SAndroid Build Coastguard Worker 8823*da0073e9SAndroid Build Coastguard Worker # Printing of non_contiguous should not crash 8824*da0073e9SAndroid Build Coastguard Worker def test_print_non_contiguous(self): 8825*da0073e9SAndroid Build Coastguard Worker print(torch.ones(100, 100, device='mps').nonzero()) 8826*da0073e9SAndroid Build Coastguard Worker print(torch.ones(100, 100, device='mps').nonzero().contiguous()) 8827*da0073e9SAndroid Build Coastguard Worker 8828*da0073e9SAndroid Build Coastguard Worker def test_zero_grad(self): 8829*da0073e9SAndroid Build Coastguard Worker i = torch.randn(2, 5, requires_grad=True) 8830*da0073e9SAndroid Build Coastguard Worker module = nn.Linear(5, 5) 8831*da0073e9SAndroid Build Coastguard Worker for p in module.parameters(): 8832*da0073e9SAndroid Build Coastguard Worker p.requires_grad = False 8833*da0073e9SAndroid Build Coastguard Worker module.zero_grad() 8834*da0073e9SAndroid Build Coastguard Worker 8835*da0073e9SAndroid Build Coastguard Worker module.weight.requires_grad = True 8836*da0073e9SAndroid Build Coastguard Worker module.zero_grad() 8837*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.weight.grad) # uninitialized grad 8838*da0073e9SAndroid Build Coastguard Worker 8839*da0073e9SAndroid Build Coastguard Worker module(i).sum().backward() 8840*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(module.weight.grad) 8841*da0073e9SAndroid Build Coastguard Worker self.assertGreater(module.weight.grad.data.abs().sum(), 0) 8842*da0073e9SAndroid Build Coastguard Worker module.zero_grad() 8843*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.weight.grad) 8844*da0073e9SAndroid Build Coastguard Worker 8845*da0073e9SAndroid Build Coastguard Worker module.bias.requires_grad = True 8846*da0073e9SAndroid Build Coastguard Worker module.zero_grad() 8847*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.weight.grad) 8848*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.bias.grad) 8849*da0073e9SAndroid Build Coastguard Worker module(i).sum().backward() 8850*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(module.weight.grad) 8851*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(module.bias.grad) 8852*da0073e9SAndroid Build Coastguard Worker self.assertGreater(module.weight.grad.data.abs().sum(), 0) 8853*da0073e9SAndroid Build Coastguard Worker self.assertGreater(module.bias.grad.data.abs().sum(), 0) 8854*da0073e9SAndroid Build Coastguard Worker 8855*da0073e9SAndroid Build Coastguard Worker # Force set to zeros. 8856*da0073e9SAndroid Build Coastguard Worker module.zero_grad(set_to_none=False) 8857*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_()) 8858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_()) 8859*da0073e9SAndroid Build Coastguard Worker 8860*da0073e9SAndroid Build Coastguard Worker module.zero_grad() 8861*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.weight.grad) 8862*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(module.bias.grad) 8863*da0073e9SAndroid Build Coastguard Worker 8864*da0073e9SAndroid Build Coastguard Worker 8865*da0073e9SAndroid Build Coastguard Worker def test_no_grad(self): 8866*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bfloat16, torch.float, torch.double]: 8867*da0073e9SAndroid Build Coastguard Worker module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype) 8868*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 2, 10, 10).to(dtype) 8869*da0073e9SAndroid Build Coastguard Worker x = input 8870*da0073e9SAndroid Build Coastguard Worker y = input.clone() 8871*da0073e9SAndroid Build Coastguard Worker 8872*da0073e9SAndroid Build Coastguard Worker output = module(x) 8873*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.requires_grad) 8874*da0073e9SAndroid Build Coastguard Worker output.backward(torch.ones(1, 5, 10, 10)) 8875*da0073e9SAndroid Build Coastguard Worker 8876*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 8877*da0073e9SAndroid Build Coastguard Worker output2 = module(y) 8878*da0073e9SAndroid Build Coastguard Worker self.assertFalse(output2.requires_grad) 8879*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10))) 8880*da0073e9SAndroid Build Coastguard Worker 8881*da0073e9SAndroid Build Coastguard Worker def test_invalid_conv1d(self): 8882*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bfloat16, torch.float, torch.double]: 8883*da0073e9SAndroid Build Coastguard Worker module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype) 8884*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 3, 4).to(dtype) 8885*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 8886*da0073e9SAndroid Build Coastguard Worker r'Calculated padded input size per channel: \(4\). ' + 8887*da0073e9SAndroid Build Coastguard Worker r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'): 8888*da0073e9SAndroid Build Coastguard Worker module(input) 8889*da0073e9SAndroid Build Coastguard Worker 8890*da0073e9SAndroid Build Coastguard Worker # Negative stride check 8891*da0073e9SAndroid Build Coastguard Worker module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype) 8892*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 3, 4).to(dtype) 8893*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): 8894*da0073e9SAndroid Build Coastguard Worker module(input) 8895*da0073e9SAndroid Build Coastguard Worker 8896*da0073e9SAndroid Build Coastguard Worker def test_conv2d_discontiguous_weight(self): 8897*da0073e9SAndroid Build Coastguard Worker # Test for https://github.com/pytorch/pytorch/issues/55781 8898*da0073e9SAndroid Build Coastguard Worker x = torch.ones(64, 16, 16, 16) 8899*da0073e9SAndroid Build Coastguard Worker weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2)[:, :, :, ::2] 8900*da0073e9SAndroid Build Coastguard Worker self.assertFalse(weight.is_contiguous()) 8901*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.conv2d(x, weight, None) 8902*da0073e9SAndroid Build Coastguard Worker if torch.backends.mkldnn.is_available(): 8903*da0073e9SAndroid Build Coastguard Worker # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used 8904*da0073e9SAndroid Build Coastguard Worker with torch.backends.mkldnn.flags(enabled=False): 8905*da0073e9SAndroid Build Coastguard Worker y_ = torch.nn.functional.conv2d(x, weight, None) 8906*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_) 8907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.sum(), 4186112.) 8908*da0073e9SAndroid Build Coastguard Worker 8909*da0073e9SAndroid Build Coastguard Worker def test_invalid_conv2d(self): 8910*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bfloat16, torch.float, torch.double]: 8911*da0073e9SAndroid Build Coastguard Worker module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype) 8912*da0073e9SAndroid Build Coastguard Worker input = torch.empty(1, 1, 4, 4).to(dtype) 8913*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: module(input)) 8914*da0073e9SAndroid Build Coastguard Worker 8915*da0073e9SAndroid Build Coastguard Worker module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True) 8916*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 3, 1, 1) 8917*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 8918*da0073e9SAndroid Build Coastguard Worker r'Calculated padded input size per channel: \(1 x 1\). ' + 8919*da0073e9SAndroid Build Coastguard Worker r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'): 8920*da0073e9SAndroid Build Coastguard Worker module(input) 8921*da0073e9SAndroid Build Coastguard Worker 8922*da0073e9SAndroid Build Coastguard Worker # Negative stride check 8923*da0073e9SAndroid Build Coastguard Worker module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype) 8924*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 3, 4, 4).to(dtype) 8925*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): 8926*da0073e9SAndroid Build Coastguard Worker module(input) 8927*da0073e9SAndroid Build Coastguard Worker 8928*da0073e9SAndroid Build Coastguard Worker # Zero stride check 8929*da0073e9SAndroid Build Coastguard Worker module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype) 8930*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 3, 4, 4).to(dtype) 8931*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): 8932*da0073e9SAndroid Build Coastguard Worker module(input) 8933*da0073e9SAndroid Build Coastguard Worker 8934*da0073e9SAndroid Build Coastguard Worker # Input and weights on different devices 8935*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 8936*da0073e9SAndroid Build Coastguard Worker 'must be on the same device', 8937*da0073e9SAndroid Build Coastguard Worker lambda: torch.conv2d(torch.rand(1, 3, 32, 32), torch.rand(1, 3, 3, 3, device='mps'))) 8938*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 8939*da0073e9SAndroid Build Coastguard Worker 'Input type \\(MPSFloatType\\) and weight type \\(torch\\.FloatTensor\\) should be the same', 8940*da0073e9SAndroid Build Coastguard Worker lambda: torch.conv2d(torch.rand(1, 3, 32, 32, device='mps'), torch.rand(1, 3, 3, 3))) 8941*da0073e9SAndroid Build Coastguard Worker 8942*da0073e9SAndroid Build Coastguard Worker 8943*da0073e9SAndroid Build Coastguard Worker def test_conv2d_valid_padding(self, device='mps'): 8944*da0073e9SAndroid Build Coastguard Worker # Test F.conv2d padding='valid' is the same as no padding 8945*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 1, 10, device=device).to(torch.float) 8946*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 1, 1, 4, device=device).to(torch.float) 8947*da0073e9SAndroid Build Coastguard Worker 8948*da0073e9SAndroid Build Coastguard Worker expect = F.conv2d(x, y) 8949*da0073e9SAndroid Build Coastguard Worker actual = F.conv2d(x, y, padding='valid') 8950*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect.to('cpu'), actual.to('cpu')) 8951*da0073e9SAndroid Build Coastguard Worker 8952*da0073e9SAndroid Build Coastguard Worker def test_conv2d_backward_collision(self): 8953*da0073e9SAndroid Build Coastguard Worker # Test for https://github.com/pytorch/pytorch/issues/112998 8954*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 10, 10, device="mps", requires_grad=True) 8955*da0073e9SAndroid Build Coastguard Worker m1 = nn.Conv2d(1, 1, 3, stride=2, padding=1).to("mps") 8956*da0073e9SAndroid Build Coastguard Worker m2 = nn.Conv2d(1, 1, 4, stride=2, padding=1).to("mps") 8957*da0073e9SAndroid Build Coastguard Worker y1, y2 = m1(x), m2(x) 8958*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1.shape, y2.shape) 8959*da0073e9SAndroid Build Coastguard Worker y1.sum().backward() 8960*da0073e9SAndroid Build Coastguard Worker # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion 8961*da0073e9SAndroid Build Coastguard Worker y2.sum().backward() 8962*da0073e9SAndroid Build Coastguard Worker 8963*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12") 8964*da0073e9SAndroid Build Coastguard Worker def test_conv3d_backward_collision(self): 8965*da0073e9SAndroid Build Coastguard Worker # Conv3D is only available from MacOS 13.2 onwards 8966*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 1, 10, 10, 20, device="mps", requires_grad=True) 8967*da0073e9SAndroid Build Coastguard Worker m1 = nn.Conv3d(1, 1, 3, stride=2, padding=1).to("mps") 8968*da0073e9SAndroid Build Coastguard Worker m2 = nn.Conv3d(1, 1, 4, stride=2, padding=1).to("mps") 8969*da0073e9SAndroid Build Coastguard Worker y1, y2 = m1(x), m2(x) 8970*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1.shape, y2.shape) 8971*da0073e9SAndroid Build Coastguard Worker y1.sum().backward() 8972*da0073e9SAndroid Build Coastguard Worker # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion 8973*da0073e9SAndroid Build Coastguard Worker y2.sum().backward() 8974*da0073e9SAndroid Build Coastguard Worker 8975*da0073e9SAndroid Build Coastguard Worker def test_gemm_permute_transpose(self): 8976*da0073e9SAndroid Build Coastguard Worker batch_size = 32 8977*da0073e9SAndroid Build Coastguard Worker n = 20 8978*da0073e9SAndroid Build Coastguard Worker hidden = 768 8979*da0073e9SAndroid Build Coastguard Worker num_attention_heads = 12 8980*da0073e9SAndroid Build Coastguard Worker attention_head_size = hidden // num_attention_heads 8981*da0073e9SAndroid Build Coastguard Worker 8982*da0073e9SAndroid Build Coastguard Worker def transpose_for_scores(x: torch.Tensor) -> torch.Tensor: 8983*da0073e9SAndroid Build Coastguard Worker new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) 8984*da0073e9SAndroid Build Coastguard Worker x = x.view(new_x_shape) 8985*da0073e9SAndroid Build Coastguard Worker return x.permute(0, 2, 1, 3) 8986*da0073e9SAndroid Build Coastguard Worker 8987*da0073e9SAndroid Build Coastguard Worker def attention2(key, *, workaround=False, device): 8988*da0073e9SAndroid Build Coastguard Worker key = transpose_for_scores(key) 8989*da0073e9SAndroid Build Coastguard Worker res = key.transpose(-1, -2) 8990*da0073e9SAndroid Build Coastguard Worker return res 8991*da0073e9SAndroid Build Coastguard Worker 8992*da0073e9SAndroid Build Coastguard Worker A = torch.randn(batch_size, n, hidden) 8993*da0073e9SAndroid Build Coastguard Worker A_mps = A.detach().clone().to("mps") 8994*da0073e9SAndroid Build Coastguard Worker 8995*da0073e9SAndroid Build Coastguard Worker r1 = attention2(A, device="cpu") 8996*da0073e9SAndroid Build Coastguard Worker r2 = attention2(A_mps, device="mps") 8997*da0073e9SAndroid Build Coastguard Worker 8998*da0073e9SAndroid Build Coastguard Worker r2_cpu = r2.to("cpu") 8999*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, r2_cpu) 9000*da0073e9SAndroid Build Coastguard Worker 9001*da0073e9SAndroid Build Coastguard Worker def test_group_norm_backward(self, device='mps'): 9002*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/88331 for more detail 9003*da0073e9SAndroid Build Coastguard Worker shape = [1, 4, 16, 16] 9004*da0073e9SAndroid Build Coastguard Worker x = torch.full(shape, 7.0, device=device) 9005*da0073e9SAndroid Build Coastguard Worker 9006*da0073e9SAndroid Build Coastguard Worker target = torch.ones((1, 3, 128, 128), device=device) 9007*da0073e9SAndroid Build Coastguard Worker 9008*da0073e9SAndroid Build Coastguard Worker conv_in = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device) 9009*da0073e9SAndroid Build Coastguard Worker conv_out = nn.Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device) 9010*da0073e9SAndroid Build Coastguard Worker norm = nn.GroupNorm(32, 128, eps=1e-6, affine=True, device=device) 9011*da0073e9SAndroid Build Coastguard Worker 9012*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 9013*da0073e9SAndroid Build Coastguard Worker x = x.detach().requires_grad_() 9014*da0073e9SAndroid Build Coastguard Worker out = 5.5 * x 9015*da0073e9SAndroid Build Coastguard Worker out = conv_in(out) 9016*da0073e9SAndroid Build Coastguard Worker out = out + norm(out) 9017*da0073e9SAndroid Build Coastguard Worker out = out + norm(out) 9018*da0073e9SAndroid Build Coastguard Worker out = out + norm(out) 9019*da0073e9SAndroid Build Coastguard Worker out = F.interpolate(out, scale_factor=8.0, mode="nearest") 9020*da0073e9SAndroid Build Coastguard Worker out = norm(out) 9021*da0073e9SAndroid Build Coastguard Worker out = conv_out(out) 9022*da0073e9SAndroid Build Coastguard Worker 9023*da0073e9SAndroid Build Coastguard Worker loss = (out - target).norm(dim=-1).sum() 9024*da0073e9SAndroid Build Coastguard Worker grad = -torch.autograd.grad(loss, x)[0] 9025*da0073e9SAndroid Build Coastguard Worker self.assertFalse(grad.detach().isnan().any().item(), 'NaN gradients returned by autograd') 9026*da0073e9SAndroid Build Coastguard Worker 9027*da0073e9SAndroid Build Coastguard Worker 9028*da0073e9SAndroid Build Coastguard Worker # def test_conv2d_same_padding(self, device='mps'): 9029*da0073e9SAndroid Build Coastguard Worker # x = torch.rand(1, 1, 10, 11, device=device) 9030*da0073e9SAndroid Build Coastguard Worker # y = torch.rand(1, 1, 4, 5, device=device) 9031*da0073e9SAndroid Build Coastguard Worker # expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :] 9032*da0073e9SAndroid Build Coastguard Worker # actual = F.conv2d(x, y, padding='same') 9033*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(expect.to('cpu'), actual.to('cpu')) 9034*da0073e9SAndroid Build Coastguard Worker 9035*da0073e9SAndroid Build Coastguard Worker # # With dilation 9036*da0073e9SAndroid Build Coastguard Worker # y = torch.rand(1, 1, 3, 4, device=device) 9037*da0073e9SAndroid Build Coastguard Worker # expect = F.conv2d(x, y, padding=(2, 3), dilation=2) 9038*da0073e9SAndroid Build Coastguard Worker # actual = F.conv2d(x, y, padding='same', dilation=2) 9039*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(expect, actual) 9040*da0073e9SAndroid Build Coastguard Worker 9041*da0073e9SAndroid Build Coastguard Worker # # Dilation with asymmetric padding 9042*da0073e9SAndroid Build Coastguard Worker # y = torch.rand(1, 1, 4, 4, device=device) 9043*da0073e9SAndroid Build Coastguard Worker # expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:] 9044*da0073e9SAndroid Build Coastguard Worker # actual = F.conv2d(x, y, padding='same', dilation=3) 9045*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(expect, actual) 9046*da0073e9SAndroid Build Coastguard Worker 9047*da0073e9SAndroid Build Coastguard Worker 9048*da0073e9SAndroid Build Coastguard Workerclass TestPad(TestCaseMPS): 9049*da0073e9SAndroid Build Coastguard Worker def test_constant_pad(self): 9050*da0073e9SAndroid Build Coastguard Worker m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5) 9051*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(1, 16, 16, 16) 9052*da0073e9SAndroid Build Coastguard Worker input_mps = input_cpu.detach().clone().to("mps") 9053*da0073e9SAndroid Build Coastguard Worker r_cpu = m(input_cpu) 9054*da0073e9SAndroid Build Coastguard Worker r_mps = m(input_mps) 9055*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r_cpu, r_mps.to("cpu")) 9056*da0073e9SAndroid Build Coastguard Worker 9057*da0073e9SAndroid Build Coastguard Worker # Arbitrary input dimensions 9058*da0073e9SAndroid Build Coastguard Worker pad = (1, 1, 0, 0, 0, 0) 9059*da0073e9SAndroid Build Coastguard Worker value = 3.5 9060*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3)) 9061*da0073e9SAndroid Build Coastguard Worker input_mps = input_cpu.detach().clone().to("mps") 9062*da0073e9SAndroid Build Coastguard Worker r_cpu = F.pad(input_cpu, pad=pad, value=value) 9063*da0073e9SAndroid Build Coastguard Worker r_mps = F.pad(input_mps, pad=pad, value=value) 9064*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r_cpu, r_mps.to("cpu")) 9065*da0073e9SAndroid Build Coastguard Worker 9066*da0073e9SAndroid Build Coastguard Worker def test_circular_pad(self): 9067*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/80856 9068*da0073e9SAndroid Build Coastguard Worker k_cpu = torch.ones(3, 3, 9, 9) 9069*da0073e9SAndroid Build Coastguard Worker k_mps = k_cpu.detach().clone().to("mps") 9070*da0073e9SAndroid Build Coastguard Worker 9071*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.rand(1, 3, 32, 32) 9072*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.detach().clone().to("mps") 9073*da0073e9SAndroid Build Coastguard Worker 9074*da0073e9SAndroid Build Coastguard Worker x_pad_cpu = F.pad(x_cpu, (2, 2, 2, 2), mode='circular') 9075*da0073e9SAndroid Build Coastguard Worker x_pad_mps = F.pad(x_mps, (2, 2, 2, 2), mode='circular') 9076*da0073e9SAndroid Build Coastguard Worker 9077*da0073e9SAndroid Build Coastguard Worker y_cpu = F.conv2d(x_pad_cpu, k_cpu) 9078*da0073e9SAndroid Build Coastguard Worker y_mps = F.conv2d(x_pad_mps, k_mps) 9079*da0073e9SAndroid Build Coastguard Worker 9080*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_cpu, y_mps.cpu()) 9081*da0073e9SAndroid Build Coastguard Worker 9082*da0073e9SAndroid Build Coastguard Worker def test_constant_pad_4d_warning(self): 9083*da0073e9SAndroid Build Coastguard Worker inputCPU = torch.rand((1, 2, 2, 2, 1, 1)) 9084*da0073e9SAndroid Build Coastguard Worker inputMPS = inputCPU.detach().clone().to('mps') 9085*da0073e9SAndroid Build Coastguard Worker outputCPU = F.pad(inputCPU, [0, 0, 0, 0, 0, 0, 1, 0]) 9086*da0073e9SAndroid Build Coastguard Worker outputMPS = F.pad(inputMPS, [0, 0, 0, 0, 0, 0, 1, 0]) 9087*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputCPU, outputMPS) 9088*da0073e9SAndroid Build Coastguard Worker 9089*da0073e9SAndroid Build Coastguard Worker def test_pad(self): 9090*da0073e9SAndroid Build Coastguard Worker def helper(shape, padding, op, value=0): 9091*da0073e9SAndroid Build Coastguard Worker inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 9092*da0073e9SAndroid Build Coastguard Worker inputCPU.retain_grad() 9093*da0073e9SAndroid Build Coastguard Worker inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 9094*da0073e9SAndroid Build Coastguard Worker 9095*da0073e9SAndroid Build Coastguard Worker if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]): 9096*da0073e9SAndroid Build Coastguard Worker padCriteria = op(padding, value) 9097*da0073e9SAndroid Build Coastguard Worker else: 9098*da0073e9SAndroid Build Coastguard Worker padCriteria = op(padding) 9099*da0073e9SAndroid Build Coastguard Worker outputCPU = padCriteria(inputCPU) 9100*da0073e9SAndroid Build Coastguard Worker outputMPS = padCriteria(inputMPS) 9101*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputCPU, outputMPS) 9102*da0073e9SAndroid Build Coastguard Worker 9103*da0073e9SAndroid Build Coastguard Worker # backward pass (chose 0.6 just to have the grad_output != 1) 9104*da0073e9SAndroid Build Coastguard Worker outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6)) 9105*da0073e9SAndroid Build Coastguard Worker outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6)) 9106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputCPU.grad, inputMPS.grad) 9107*da0073e9SAndroid Build Coastguard Worker 9108*da0073e9SAndroid Build Coastguard Worker # 1D Padding 9109*da0073e9SAndroid Build Coastguard Worker helper((2, 4, 3), 2, nn.ReflectionPad1d) 9110*da0073e9SAndroid Build Coastguard Worker # verify if a change in shape of input would cause problems with graph caching 9111*da0073e9SAndroid Build Coastguard Worker helper((2, 4, 4), (1, 3), nn.ReflectionPad1d) 9112*da0073e9SAndroid Build Coastguard Worker # Replication 1D 9113*da0073e9SAndroid Build Coastguard Worker helper((2, 1, 6), 3, nn.ReplicationPad1d) 9114*da0073e9SAndroid Build Coastguard Worker # Constant Pad 1D 9115*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 4), 2, nn.ConstantPad1d) 9116*da0073e9SAndroid Build Coastguard Worker # Constant Pad 1D with single dimension input 9117*da0073e9SAndroid Build Coastguard Worker helper((16), (1, 2), nn.ConstantPad1d) 9118*da0073e9SAndroid Build Coastguard Worker 9119*da0073e9SAndroid Build Coastguard Worker # 2D Padding 9120*da0073e9SAndroid Build Coastguard Worker helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d) 9121*da0073e9SAndroid Build Coastguard Worker # verify if a change in shape of input would cause problems with graph caching 9122*da0073e9SAndroid Build Coastguard Worker helper((2, 4, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d) 9123*da0073e9SAndroid Build Coastguard Worker # this should make the padding (2, 2, 2, 2) 9124*da0073e9SAndroid Build Coastguard Worker helper((2, 1, 6, 8), 2, nn.ReplicationPad2d) 9125*da0073e9SAndroid Build Coastguard Worker # verify if a change in shape of padding would cause problems with graph caching 9126*da0073e9SAndroid Build Coastguard Worker helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d) 9127*da0073e9SAndroid Build Coastguard Worker # Constant Pad 2D 9128*da0073e9SAndroid Build Coastguard Worker helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d) 9129*da0073e9SAndroid Build Coastguard Worker # input size < pad size 9130*da0073e9SAndroid Build Coastguard Worker helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d) 9131*da0073e9SAndroid Build Coastguard Worker # pad dims < input dims 9132*da0073e9SAndroid Build Coastguard Worker helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d) 9133*da0073e9SAndroid Build Coastguard Worker # pad dims == input dims 9134*da0073e9SAndroid Build Coastguard Worker helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d) 9135*da0073e9SAndroid Build Coastguard Worker # input.numel() == 0 but output.numel() > 0 9136*da0073e9SAndroid Build Coastguard Worker helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d) 9137*da0073e9SAndroid Build Coastguard Worker # pad dims < input dims - 2 9138*da0073e9SAndroid Build Coastguard Worker helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d) 9139*da0073e9SAndroid Build Coastguard Worker 9140*da0073e9SAndroid Build Coastguard Worker # 3D Padding 9141*da0073e9SAndroid Build Coastguard Worker helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d) 9142*da0073e9SAndroid Build Coastguard Worker # verify if a change in shape of padding would cause problems with graph caching 9143*da0073e9SAndroid Build Coastguard Worker helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d) 9144*da0073e9SAndroid Build Coastguard Worker # case where input_d == pad_front/back for ReplicationPad3d 9145*da0073e9SAndroid Build Coastguard Worker helper((3, 4, 5, 6, 7), (1, 2, 3, 4, 5, 6), nn.ReplicationPad3d) 9146*da0073e9SAndroid Build Coastguard Worker # Constant Pad 3D 9147*da0073e9SAndroid Build Coastguard Worker helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d) 9148*da0073e9SAndroid Build Coastguard Worker # input size < pad size 9149*da0073e9SAndroid Build Coastguard Worker helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d) 9150*da0073e9SAndroid Build Coastguard Worker # check the workaround for the right padding bug in Monterey 9151*da0073e9SAndroid Build Coastguard Worker helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d) 9152*da0073e9SAndroid Build Coastguard Worker 9153*da0073e9SAndroid Build Coastguard Worker def test_constant_pad_nd_preserves_memory_format(self): 9154*da0073e9SAndroid Build Coastguard Worker nchw_tensor = torch.rand((1, 2, 5, 3)) 9155*da0073e9SAndroid Build Coastguard Worker nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5) 9156*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format)) 9157*da0073e9SAndroid Build Coastguard Worker 9158*da0073e9SAndroid Build Coastguard Worker nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last) 9159*da0073e9SAndroid Build Coastguard Worker nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5) 9160*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last)) 9161*da0073e9SAndroid Build Coastguard Worker 9162*da0073e9SAndroid Build Coastguard Worker 9163*da0073e9SAndroid Build Coastguard Workerclass TestLinalgMPS(TestCaseMPS): 9164*da0073e9SAndroid Build Coastguard Worker def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False): 9165*da0073e9SAndroid Build Coastguard Worker dtype = t.dtype 9166*da0073e9SAndroid Build Coastguard Worker numpy_dtype = dtype 9167*da0073e9SAndroid Build Coastguard Worker alpha = 1.2 if alpha is None else alpha 9168*da0073e9SAndroid Build Coastguard Worker beta = 0.8 if beta is None else beta 9169*da0073e9SAndroid Build Coastguard Worker res1 = f(t, m, v, alpha=alpha, beta=beta) 9170*da0073e9SAndroid Build Coastguard Worker res2 = torch.full_like(res1, math.nan) 9171*da0073e9SAndroid Build Coastguard Worker if transpose_out: 9172*da0073e9SAndroid Build Coastguard Worker res2 = res2.t().clone(memory_format=torch.contiguous_format).t() 9173*da0073e9SAndroid Build Coastguard Worker f(t, m, v, alpha=alpha, beta=beta, out=res2) 9174*da0073e9SAndroid Build Coastguard Worker res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()) 9175*da0073e9SAndroid Build Coastguard Worker if beta != 0: 9176*da0073e9SAndroid Build Coastguard Worker res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy() 9177*da0073e9SAndroid Build Coastguard Worker res3 = torch.from_numpy(res3).to(dtype) 9178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 9179*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res3) 9180*da0073e9SAndroid Build Coastguard Worker 9181*da0073e9SAndroid Build Coastguard Worker def test_addmm(self, device="mps", dtype=torch.float32): 9182*da0073e9SAndroid Build Coastguard Worker M = torch.randn(10, 25, device=device).to(dtype) 9183*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 50, device=device).to(dtype) 9184*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(50, 25, device=device).to(dtype) 9185*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(torch.addmm, M, m1, m2) 9186*da0073e9SAndroid Build Coastguard Worker 9187*da0073e9SAndroid Build Coastguard Worker # Test beta=0, M=nan 9188*da0073e9SAndroid Build Coastguard Worker M = torch.full((10, 25), math.nan, device=device).to(dtype) 9189*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 50, device=device).to(dtype) 9190*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(50, 25, device=device).to(dtype) 9191*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0) 9192*da0073e9SAndroid Build Coastguard Worker 9193*da0073e9SAndroid Build Coastguard Worker # Test transpose 9194*da0073e9SAndroid Build Coastguard Worker for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): 9195*da0073e9SAndroid Build Coastguard Worker def maybe_transpose(cond, m): 9196*da0073e9SAndroid Build Coastguard Worker if not cond: 9197*da0073e9SAndroid Build Coastguard Worker return m 9198*da0073e9SAndroid Build Coastguard Worker return m.t().clone(memory_format=torch.contiguous_format).t() 9199*da0073e9SAndroid Build Coastguard Worker 9200*da0073e9SAndroid Build Coastguard Worker M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) 9201*da0073e9SAndroid Build Coastguard Worker m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) 9202*da0073e9SAndroid Build Coastguard Worker m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) 9203*da0073e9SAndroid Build Coastguard Worker self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4) 9204*da0073e9SAndroid Build Coastguard Worker 9205*da0073e9SAndroid Build Coastguard Worker def _test_addr(self, f, t, m, v, alpha=None, beta=None): 9206*da0073e9SAndroid Build Coastguard Worker dtype = t.dtype 9207*da0073e9SAndroid Build Coastguard Worker numpy_dtype = dtype 9208*da0073e9SAndroid Build Coastguard Worker alpha = 1.2 if alpha is None else alpha 9209*da0073e9SAndroid Build Coastguard Worker beta = 0.8 if beta is None else beta 9210*da0073e9SAndroid Build Coastguard Worker res1 = f(t, m, v, alpha=alpha, beta=beta) 9211*da0073e9SAndroid Build Coastguard Worker res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy()) 9212*da0073e9SAndroid Build Coastguard Worker if beta != 0: 9213*da0073e9SAndroid Build Coastguard Worker res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy() 9214*da0073e9SAndroid Build Coastguard Worker res2 = torch.from_numpy(res2).to(dtype) 9215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 9216*da0073e9SAndroid Build Coastguard Worker 9217*da0073e9SAndroid Build Coastguard Worker def test_addr(self, device="mps", dtype=torch.float32): 9218*da0073e9SAndroid Build Coastguard Worker M = torch.randn(10, 25, device=device).to(dtype) 9219*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, device=device).to(dtype) 9220*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(25, device=device).to(dtype) 9221*da0073e9SAndroid Build Coastguard Worker self._test_addr(torch.addr, M, m1, m2) 9222*da0073e9SAndroid Build Coastguard Worker 9223*da0073e9SAndroid Build Coastguard Worker # Test beta=0, M=nan 9224*da0073e9SAndroid Build Coastguard Worker M = torch.full((10, 25), math.nan, device=device).to(dtype) 9225*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, device=device).to(dtype) 9226*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(25, device=device).to(dtype) 9227*da0073e9SAndroid Build Coastguard Worker self._test_addr(torch.addr, M, m1, m2, beta=0) 9228*da0073e9SAndroid Build Coastguard Worker 9229*da0073e9SAndroid Build Coastguard Worker def test_matrix_rank(self, device="mps", dtype=torch.float32): 9230*da0073e9SAndroid Build Coastguard Worker matrix_rank = torch.linalg.matrix_rank 9231*da0073e9SAndroid Build Coastguard Worker 9232*da0073e9SAndroid Build Coastguard Worker def run_test(shape0, shape1, batch): 9233*da0073e9SAndroid Build Coastguard Worker a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device) 9234*da0073e9SAndroid Build Coastguard Worker rank_a = matrix_rank(a) 9235*da0073e9SAndroid Build Coastguard Worker 9236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_a, matrix_rank(a.mH)) 9237*da0073e9SAndroid Build Coastguard Worker aaH = torch.matmul(a, a.mH) 9238*da0073e9SAndroid Build Coastguard Worker rank_aaH = matrix_rank(aaH) 9239*da0073e9SAndroid Build Coastguard Worker rank_aaH_hermitian = matrix_rank(aaH, hermitian=True) 9240*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_aaH, rank_aaH_hermitian) 9241*da0073e9SAndroid Build Coastguard Worker aHa = torch.matmul(a.mH, a) 9242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True)) 9243*da0073e9SAndroid Build Coastguard Worker 9244*da0073e9SAndroid Build Coastguard Worker # check against NumPy 9245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy())) 9246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01)) 9247*da0073e9SAndroid Build Coastguard Worker 9248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy())) 9249*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01)) 9250*da0073e9SAndroid Build Coastguard Worker 9251*da0073e9SAndroid Build Coastguard Worker # hermitian flag for NumPy was added in 1.14.0 9252*da0073e9SAndroid Build Coastguard Worker if np.lib.NumpyVersion(np.__version__) >= '1.14.0': 9253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(rank_aaH_hermitian, 9254*da0073e9SAndroid Build Coastguard Worker np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True)) 9255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(matrix_rank(aaH, 0.01, True), 9256*da0073e9SAndroid Build Coastguard Worker np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True)) 9257*da0073e9SAndroid Build Coastguard Worker 9258*da0073e9SAndroid Build Coastguard Worker # check out= variant 9259*da0073e9SAndroid Build Coastguard Worker out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device) 9260*da0073e9SAndroid Build Coastguard Worker ans = matrix_rank(a, out=out) 9261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 9262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, rank_a) 9263*da0073e9SAndroid Build Coastguard Worker 9264*da0073e9SAndroid Build Coastguard Worker shapes = (3, 13) 9265*da0073e9SAndroid Build Coastguard Worker batches = ((), (0, ), (4, ), (3, 5, )) 9266*da0073e9SAndroid Build Coastguard Worker for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches): 9267*da0073e9SAndroid Build Coastguard Worker # escape only when NotImplementedError of downstream function is raised 9268*da0073e9SAndroid Build Coastguard Worker # TODO: remove this once the required function is implemented 9269*da0073e9SAndroid Build Coastguard Worker try: 9270*da0073e9SAndroid Build Coastguard Worker run_test(shape0, shape1, batch) 9271*da0073e9SAndroid Build Coastguard Worker except NotImplementedError as e: 9272*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9273*da0073e9SAndroid Build Coastguard Worker NotImplementedError, 9274*da0073e9SAndroid Build Coastguard Worker "The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."): 9275*da0073e9SAndroid Build Coastguard Worker raise e 9276*da0073e9SAndroid Build Coastguard Worker 9277*da0073e9SAndroid Build Coastguard Worker def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4): 9278*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import random_hermitian_pd_matrix 9279*da0073e9SAndroid Build Coastguard Worker 9280*da0073e9SAndroid Build Coastguard Worker def run_test_main(A, hermitian): 9281*da0073e9SAndroid Build Coastguard Worker # Testing against definition for pseudo-inverses 9282*da0073e9SAndroid Build Coastguard Worker A_pinv = torch.linalg.pinv(A, hermitian=hermitian) 9283*da0073e9SAndroid Build Coastguard Worker np_A = A.cpu().numpy() 9284*da0073e9SAndroid Build Coastguard Worker np_A_pinv = A_pinv.cpu().numpy() 9285*da0073e9SAndroid Build Coastguard Worker if A.numel() > 0: 9286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision) 9287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision) 9288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision) 9289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision) 9290*da0073e9SAndroid Build Coastguard Worker else: 9291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2])) 9292*da0073e9SAndroid Build Coastguard Worker 9293*da0073e9SAndroid Build Coastguard Worker # Check out= variant 9294*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(A_pinv) 9295*da0073e9SAndroid Build Coastguard Worker ans = torch.linalg.pinv(A, hermitian=hermitian, out=out) 9296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, out) 9297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ans, A_pinv) 9298*da0073e9SAndroid Build Coastguard Worker 9299*da0073e9SAndroid Build Coastguard Worker def run_test_numpy(A, hermitian): 9300*da0073e9SAndroid Build Coastguard Worker # Check against NumPy output 9301*da0073e9SAndroid Build Coastguard Worker # Test float rcond, and specific value for each matrix 9302*da0073e9SAndroid Build Coastguard Worker rconds = [float(torch.rand(1)), ] 9303*da0073e9SAndroid Build Coastguard Worker # Test different types of rcond tensor 9304*da0073e9SAndroid Build Coastguard Worker for rcond_type in MPS_DTYPES: 9305*da0073e9SAndroid Build Coastguard Worker rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type)) 9306*da0073e9SAndroid Build Coastguard Worker # Test broadcasting of rcond 9307*da0073e9SAndroid Build Coastguard Worker if A.ndim > 2: 9308*da0073e9SAndroid Build Coastguard Worker rconds.append(torch.rand(A.shape[-3], device=device)) 9309*da0073e9SAndroid Build Coastguard Worker for rcond in rconds: 9310*da0073e9SAndroid Build Coastguard Worker actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian) 9311*da0073e9SAndroid Build Coastguard Worker torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian) 9312*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision) 9313*da0073e9SAndroid Build Coastguard Worker numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy() 9314*da0073e9SAndroid Build Coastguard Worker expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian) 9315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=precision, rtol=precision) 9316*da0073e9SAndroid Build Coastguard Worker 9317*da0073e9SAndroid Build Coastguard Worker for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices 9318*da0073e9SAndroid Build Coastguard Worker (3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices 9319*da0073e9SAndroid Build Coastguard Worker (2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices 9320*da0073e9SAndroid Build Coastguard Worker (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices 9321*da0073e9SAndroid Build Coastguard Worker A = torch.randn(*sizes, dtype=dtype, device=device) 9322*da0073e9SAndroid Build Coastguard Worker hermitian = False 9323*da0073e9SAndroid Build Coastguard Worker run_test_main(A, hermitian) 9324*da0073e9SAndroid Build Coastguard Worker run_test_numpy(A, hermitian) 9325*da0073e9SAndroid Build Coastguard Worker 9326*da0073e9SAndroid Build Coastguard Worker # Check hermitian = True 9327*da0073e9SAndroid Build Coastguard Worker for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices 9328*da0073e9SAndroid Build Coastguard Worker (0, 0), (3, 0, 0), ]: # zero numel square matrices 9329*da0073e9SAndroid Build Coastguard Worker A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device) 9330*da0073e9SAndroid Build Coastguard Worker hermitian = True 9331*da0073e9SAndroid Build Coastguard Worker # escape only when NotImplementedError of downstream function is raised 9332*da0073e9SAndroid Build Coastguard Worker # TODO: remove this once the required function is implemented 9333*da0073e9SAndroid Build Coastguard Worker try: 9334*da0073e9SAndroid Build Coastguard Worker run_test_main(A, hermitian) 9335*da0073e9SAndroid Build Coastguard Worker except NotImplementedError as e: 9336*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9337*da0073e9SAndroid Build Coastguard Worker NotImplementedError, 9338*da0073e9SAndroid Build Coastguard Worker "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."): 9339*da0073e9SAndroid Build Coastguard Worker raise e 9340*da0073e9SAndroid Build Coastguard Worker try: 9341*da0073e9SAndroid Build Coastguard Worker run_test_numpy(A, hermitian) 9342*da0073e9SAndroid Build Coastguard Worker except NotImplementedError as e: 9343*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9344*da0073e9SAndroid Build Coastguard Worker NotImplementedError, 9345*da0073e9SAndroid Build Coastguard Worker "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."): 9346*da0073e9SAndroid Build Coastguard Worker raise e 9347*da0073e9SAndroid Build Coastguard Worker 9348*da0073e9SAndroid Build Coastguard Worker @parametrize("m", [1, 32, 64]) 9349*da0073e9SAndroid Build Coastguard Worker @parametrize("n", [48, 64]) 9350*da0073e9SAndroid Build Coastguard Worker @parametrize("q_group", [32, 64, 128, 256]) 9351*da0073e9SAndroid Build Coastguard Worker @parametrize("num_groups", [1, 2]) 9352*da0073e9SAndroid Build Coastguard Worker def test__int4_mm(self, m, n, q_group, num_groups): 9353*da0073e9SAndroid Build Coastguard Worker k = q_group * num_groups 9354*da0073e9SAndroid Build Coastguard Worker inner_k_tiles = 2 9355*da0073e9SAndroid Build Coastguard Worker 9356*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 9357*da0073e9SAndroid Build Coastguard Worker a_f32 = torch.rand((m, k), device="mps") 9358*da0073e9SAndroid Build Coastguard Worker b_f32 = torch.rand((k, n), device="mps") 9359*da0073e9SAndroid Build Coastguard Worker 9360*da0073e9SAndroid Build Coastguard Worker def convert_weight_to_int4pack(b): 9361*da0073e9SAndroid Build Coastguard Worker b_int32, b_scales_and_zeros = _group_quantize_tensor( 9362*da0073e9SAndroid Build Coastguard Worker b.to("cpu"), n_bit=4, q_group_size=q_group 9363*da0073e9SAndroid Build Coastguard Worker ) 9364*da0073e9SAndroid Build Coastguard Worker b_int32 = b_int32.to("mps") 9365*da0073e9SAndroid Build Coastguard Worker b_scales_and_zeros = b_scales_and_zeros.to("mps") 9366*da0073e9SAndroid Build Coastguard Worker b_int4pack = torch._convert_weight_to_int4pack( 9367*da0073e9SAndroid Build Coastguard Worker b_int32, inner_k_tiles 9368*da0073e9SAndroid Build Coastguard Worker ) 9369*da0073e9SAndroid Build Coastguard Worker 9370*da0073e9SAndroid Build Coastguard Worker return b_int4pack, b_scales_and_zeros 9371*da0073e9SAndroid Build Coastguard Worker 9372*da0073e9SAndroid Build Coastguard Worker def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): 9373*da0073e9SAndroid Build Coastguard Worker return torch._weight_int4pack_mm( 9374*da0073e9SAndroid Build Coastguard Worker a, b_int4pack, q_group, b_scales_and_zeros 9375*da0073e9SAndroid Build Coastguard Worker ) 9376*da0073e9SAndroid Build Coastguard Worker 9377*da0073e9SAndroid Build Coastguard Worker b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b_f32) 9378*da0073e9SAndroid Build Coastguard Worker 9379*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []): 9380*da0073e9SAndroid Build Coastguard Worker a = a_f32.to(dtype=dtype) 9381*da0073e9SAndroid Build Coastguard Worker b = b_f32.to(dtype=dtype) 9382*da0073e9SAndroid Build Coastguard Worker b_scales_and_zeros = b_scales_and_zeros_f32.to(dtype=dtype) 9383*da0073e9SAndroid Build Coastguard Worker ref = torch.mm(a, b) 9384*da0073e9SAndroid Build Coastguard Worker res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros) 9385*da0073e9SAndroid Build Coastguard Worker 9386*da0073e9SAndroid Build Coastguard Worker mean_err = ((res - ref).abs() / ref).mean() 9387*da0073e9SAndroid Build Coastguard Worker self.assertLess(mean_err, 0.05) 9388*da0073e9SAndroid Build Coastguard Worker 9389*da0073e9SAndroid Build Coastguard Worker @parametrize("m", [1, 32, 64]) 9390*da0073e9SAndroid Build Coastguard Worker @parametrize("k", [32, 64]) 9391*da0073e9SAndroid Build Coastguard Worker @parametrize("n", [32, 64]) 9392*da0073e9SAndroid Build Coastguard Worker def test__int8_mm(self, m, k, n): 9393*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 9394*da0073e9SAndroid Build Coastguard Worker a_f32 = torch.rand((m, k), device="mps") 9395*da0073e9SAndroid Build Coastguard Worker b_f32 = torch.rand((n, k), device="mps") 9396*da0073e9SAndroid Build Coastguard Worker 9397*da0073e9SAndroid Build Coastguard Worker def convert_weight_to_int8pack(b): 9398*da0073e9SAndroid Build Coastguard Worker b_int8pack, b_scales, _ = _dynamically_quantize_per_channel( 9399*da0073e9SAndroid Build Coastguard Worker b, -128, 127, torch.int8 9400*da0073e9SAndroid Build Coastguard Worker ) 9401*da0073e9SAndroid Build Coastguard Worker return b_int8pack, b_scales 9402*da0073e9SAndroid Build Coastguard Worker 9403*da0073e9SAndroid Build Coastguard Worker def weight_int8pack_mm(a, b_int8pack, b_scales): 9404*da0073e9SAndroid Build Coastguard Worker return torch._weight_int8pack_mm(a, b_int8pack, b_scales) 9405*da0073e9SAndroid Build Coastguard Worker 9406*da0073e9SAndroid Build Coastguard Worker b_int8pack, b_scales_f32 = convert_weight_to_int8pack(b_f32) 9407*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []): 9408*da0073e9SAndroid Build Coastguard Worker a = a_f32.to(dtype=dtype) 9409*da0073e9SAndroid Build Coastguard Worker b = b_f32.to(dtype=dtype) 9410*da0073e9SAndroid Build Coastguard Worker b_scales = b_scales_f32.to(dtype=dtype) 9411*da0073e9SAndroid Build Coastguard Worker res = weight_int8pack_mm(a, b_int8pack, b_scales) 9412*da0073e9SAndroid Build Coastguard Worker ref = torch.mm(a, b.transpose(0, 1)) 9413*da0073e9SAndroid Build Coastguard Worker 9414*da0073e9SAndroid Build Coastguard Worker mean_err = ((res - ref).abs() / ref).mean() 9415*da0073e9SAndroid Build Coastguard Worker self.assertLess(mean_err, 0.05) 9416*da0073e9SAndroid Build Coastguard Worker 9417*da0073e9SAndroid Build Coastguard Worker 9418*da0073e9SAndroid Build Coastguard Workerclass TestSDPA(TestCaseMPS): 9419*da0073e9SAndroid Build Coastguard Worker def _compare_tensors(self, y, ref): 9420*da0073e9SAndroid Build Coastguard Worker denom = torch.maximum(ref.abs(), torch.tensor([1e-6], device=ref.device, dtype=ref.dtype)) 9421*da0073e9SAndroid Build Coastguard Worker err = ((y - ref).abs() / denom).mean().item() 9422*da0073e9SAndroid Build Coastguard Worker self.assertLess(err, 0.01) 9423*da0073e9SAndroid Build Coastguard Worker 9424*da0073e9SAndroid Build Coastguard Worker def _test_sdpa_no_mask( 9425*da0073e9SAndroid Build Coastguard Worker self, 9426*da0073e9SAndroid Build Coastguard Worker is_causal: bool, 9427*da0073e9SAndroid Build Coastguard Worker dtype: torch.dtype, 9428*da0073e9SAndroid Build Coastguard Worker L: int = 1, 9429*da0073e9SAndroid Build Coastguard Worker S: int = 72, 9430*da0073e9SAndroid Build Coastguard Worker NH: int = 32, 9431*da0073e9SAndroid Build Coastguard Worker HS: int = 128, 9432*da0073e9SAndroid Build Coastguard Worker requires_grad: bool = False 9433*da0073e9SAndroid Build Coastguard Worker ): 9434*da0073e9SAndroid Build Coastguard Worker 9435*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1729) 9436*da0073e9SAndroid Build Coastguard Worker with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): 9437*da0073e9SAndroid Build Coastguard Worker q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad) 9438*da0073e9SAndroid Build Coastguard Worker k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") 9439*da0073e9SAndroid Build Coastguard Worker v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") 9440*da0073e9SAndroid Build Coastguard Worker q_cpu = q.cpu().detach().cpu().requires_grad_(requires_grad) 9441*da0073e9SAndroid Build Coastguard Worker k_cpu = k.cpu() 9442*da0073e9SAndroid Build Coastguard Worker v_cpu = v.cpu() 9443*da0073e9SAndroid Build Coastguard Worker 9444*da0073e9SAndroid Build Coastguard Worker y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal) 9445*da0073e9SAndroid Build Coastguard Worker y_ref = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, dropout_p=0.0, is_causal=is_causal) 9446*da0073e9SAndroid Build Coastguard Worker 9447*da0073e9SAndroid Build Coastguard Worker self._compare_tensors(y.cpu(), y_ref) 9448*da0073e9SAndroid Build Coastguard Worker 9449*da0073e9SAndroid Build Coastguard Worker if requires_grad and torch.is_grad_enabled(): 9450*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9451*da0073e9SAndroid Build Coastguard Worker y_ref.sum().backward() 9452*da0073e9SAndroid Build Coastguard Worker 9453*da0073e9SAndroid Build Coastguard Worker self._compare_tensors(q.grad.cpu(), q_cpu.grad) 9454*da0073e9SAndroid Build Coastguard Worker 9455*da0073e9SAndroid Build Coastguard Worker def test_sdpa_no_mask_no_causal_fp32(self): 9456*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_no_mask(False, torch.float32) 9457*da0073e9SAndroid Build Coastguard Worker 9458*da0073e9SAndroid Build Coastguard Worker def test_sdpa_no_mask_no_causal_fp16(self): 9459*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_no_mask(False, torch.float16) 9460*da0073e9SAndroid Build Coastguard Worker 9461*da0073e9SAndroid Build Coastguard Worker def test_sdpa_no_mask_causal_fp32(self): 9462*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_no_mask(True, torch.float32) 9463*da0073e9SAndroid Build Coastguard Worker 9464*da0073e9SAndroid Build Coastguard Worker def test_sdpa_no_mask_causal_fp16(self): 9465*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_no_mask(True, torch.float16) 9466*da0073e9SAndroid Build Coastguard Worker 9467*da0073e9SAndroid Build Coastguard Worker def test_sdpa_no_mask_causal_fp16_L7(self): 9468*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_no_mask(True, torch.float16, 7) 9469*da0073e9SAndroid Build Coastguard Worker 9470*da0073e9SAndroid Build Coastguard Worker def test_sdpa_no_mask_causal_fp16_L7_S17(self): 9471*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_no_mask(True, torch.float16, 7, 17) 9472*da0073e9SAndroid Build Coastguard Worker 9473*da0073e9SAndroid Build Coastguard Worker def test_sdpa_no_mask_causal_fp16_L7_S17_NH23_HS121(self): 9474*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121) 9475*da0073e9SAndroid Build Coastguard Worker 9476*da0073e9SAndroid Build Coastguard Worker def test_sdpa_no_mask_no_causal_fp32_grad(self): 9477*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_no_mask(False, torch.float32, requires_grad=True) 9478*da0073e9SAndroid Build Coastguard Worker 9479*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 9480*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_no_mask(False, torch.float32, requires_grad=True) 9481*da0073e9SAndroid Build Coastguard Worker 9482*da0073e9SAndroid Build Coastguard Worker def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128): 9483*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1729) 9484*da0073e9SAndroid Build Coastguard Worker causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps')) 9485*da0073e9SAndroid Build Coastguard Worker with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): 9486*da0073e9SAndroid Build Coastguard Worker i = 42 9487*da0073e9SAndroid Build Coastguard Worker 9488*da0073e9SAndroid Build Coastguard Worker q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps") 9489*da0073e9SAndroid Build Coastguard Worker k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") 9490*da0073e9SAndroid Build Coastguard Worker v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") 9491*da0073e9SAndroid Build Coastguard Worker 9492*da0073e9SAndroid Build Coastguard Worker input_pos = torch.tensor([i], dtype=torch.int32, device='mps') 9493*da0073e9SAndroid Build Coastguard Worker mask = causal_mask[None, None, input_pos] 9494*da0073e9SAndroid Build Coastguard Worker 9495*da0073e9SAndroid Build Coastguard Worker y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) 9496*da0073e9SAndroid Build Coastguard Worker y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False) 9497*da0073e9SAndroid Build Coastguard Worker 9498*da0073e9SAndroid Build Coastguard Worker self._compare_tensors(y.cpu(), y_ref) 9499*da0073e9SAndroid Build Coastguard Worker 9500*da0073e9SAndroid Build Coastguard Worker def test_sdpa_mask_fp32(self): 9501*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_mask(torch.float32) 9502*da0073e9SAndroid Build Coastguard Worker 9503*da0073e9SAndroid Build Coastguard Worker def test_sdpa_mask_fp16(self): 9504*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_mask(torch.float16) 9505*da0073e9SAndroid Build Coastguard Worker 9506*da0073e9SAndroid Build Coastguard Worker def test_sdpa_mask_fp16_L6(self): 9507*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_mask(torch.float16, 6) 9508*da0073e9SAndroid Build Coastguard Worker 9509*da0073e9SAndroid Build Coastguard Worker def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self): 9510*da0073e9SAndroid Build Coastguard Worker self._test_sdpa_mask(torch.float16, 7, 17, 23, 121) 9511*da0073e9SAndroid Build Coastguard Worker 9512*da0073e9SAndroid Build Coastguard Worker 9513*da0073e9SAndroid Build Coastguard Workerclass TestGatherScatter(TestCaseMPS): 9514*da0073e9SAndroid Build Coastguard Worker def test_slicing_with_step(self): 9515*da0073e9SAndroid Build Coastguard Worker # Slicing with step 9516*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/78886 9517*da0073e9SAndroid Build Coastguard Worker x_mps = torch.zeros(10, dtype=torch.float32, device="mps") 9518*da0073e9SAndroid Build Coastguard Worker x_mps[::2] = 1.0 9519*da0073e9SAndroid Build Coastguard Worker 9520*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.zeros(10, dtype=torch.float32, device="cpu") 9521*da0073e9SAndroid Build Coastguard Worker x_cpu[::2] = 1.0 9522*da0073e9SAndroid Build Coastguard Worker 9523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_mps) 9524*da0073e9SAndroid Build Coastguard Worker 9525*da0073e9SAndroid Build Coastguard Worker def test_cast_gather_scatter(self): 9526*da0073e9SAndroid Build Coastguard Worker for _ in range(0, 50): 9527*da0073e9SAndroid Build Coastguard Worker input = np.random.randint(0, 255, size=(5, 5, 4), dtype=np.uint8) 9528*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 9529*da0073e9SAndroid Build Coastguard Worker s = torch.tensor(input, dtype=torch.uint8, device="mps").unsqueeze(0) 9530*da0073e9SAndroid Build Coastguard Worker s_cpu = torch.tensor(input, dtype=torch.uint8, device="cpu").unsqueeze(0) 9531*da0073e9SAndroid Build Coastguard Worker s = s.long() 9532*da0073e9SAndroid Build Coastguard Worker s_cpu = s_cpu.long() 9533*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.cpu(), s_cpu) 9534*da0073e9SAndroid Build Coastguard Worker 9535*da0073e9SAndroid Build Coastguard Worker s = s.float() 9536*da0073e9SAndroid Build Coastguard Worker s_cpu = s_cpu.float() 9537*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.cpu(), s_cpu) 9538*da0073e9SAndroid Build Coastguard Worker 9539*da0073e9SAndroid Build Coastguard Worker s /= 255 9540*da0073e9SAndroid Build Coastguard Worker s_cpu /= 255 9541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.cpu(), s_cpu) 9542*da0073e9SAndroid Build Coastguard Worker 9543*da0073e9SAndroid Build Coastguard Worker def test_slicing_replace_column(self): 9544*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/78074 9545*da0073e9SAndroid Build Coastguard Worker def _helper(tensor_data): 9546*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.tensor(tensor_data) 9547*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.to('mps') 9548*da0073e9SAndroid Build Coastguard Worker 9549*da0073e9SAndroid Build Coastguard Worker x_cpu[:, 0] = 7 9550*da0073e9SAndroid Build Coastguard Worker x_mps[:, 0] = 7 9551*da0073e9SAndroid Build Coastguard Worker 9552*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_mps) 9553*da0073e9SAndroid Build Coastguard Worker 9554*da0073e9SAndroid Build Coastguard Worker _helper([[1, 2, 3], [4, 5, 6]]) 9555*da0073e9SAndroid Build Coastguard Worker _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 9556*da0073e9SAndroid Build Coastguard Worker _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) 9557*da0073e9SAndroid Build Coastguard Worker 9558*da0073e9SAndroid Build Coastguard Worker def test_inplace_scatter(self): 9559*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/79672 9560*da0073e9SAndroid Build Coastguard Worker a_mps = torch.ones((2, 2),).to(torch.device("mps")) 9561*da0073e9SAndroid Build Coastguard Worker b_mps = torch.ones((2, 2),).to(torch.device("mps")) 9562*da0073e9SAndroid Build Coastguard Worker 9563*da0073e9SAndroid Build Coastguard Worker a_cpu = torch.ones((2, 2),).to(torch.device("cpu")) 9564*da0073e9SAndroid Build Coastguard Worker b_cpu = torch.ones((2, 2),).to(torch.device("cpu")) 9565*da0073e9SAndroid Build Coastguard Worker 9566*da0073e9SAndroid Build Coastguard Worker a_mps[:, 0] += b_mps[:, 0] 9567*da0073e9SAndroid Build Coastguard Worker a_cpu[:, 0] += b_cpu[:, 0] 9568*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_cpu, a_mps) 9569*da0073e9SAndroid Build Coastguard Worker 9570*da0073e9SAndroid Build Coastguard Worker a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0] 9571*da0073e9SAndroid Build Coastguard Worker a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0] 9572*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_cpu, a_mps) 9573*da0073e9SAndroid Build Coastguard Worker 9574*da0073e9SAndroid Build Coastguard Worker# These tests were taken from test/test_view_ops.py 9575*da0073e9SAndroid Build Coastguard Worker# They are subset of those tests as currently only this subset is working. 9576*da0073e9SAndroid Build Coastguard Worker# This whole `class` will be removed when we add generic device testing. There 9577*da0073e9SAndroid Build Coastguard Worker# are no additional tests added apart from what is part of test_view_ops.py 9578*da0073e9SAndroid Build Coastguard Workerclass TestViewOpsMPS(TestCaseMPS): 9579*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 9580*da0073e9SAndroid Build Coastguard Worker 9581*da0073e9SAndroid Build Coastguard Worker def test_permute_slicing(self): 9582*da0073e9SAndroid Build Coastguard Worker # test the fix for crash reported in 9583*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/94190 9584*da0073e9SAndroid Build Coastguard Worker cpu_x = (torch.randn([3, 2, 2]).float()) 9585*da0073e9SAndroid Build Coastguard Worker mps_x = cpu_x.detach().clone().to('mps') 9586*da0073e9SAndroid Build Coastguard Worker cpu_out = cpu_x.permute((2, 0, 1)) * 2.0 9587*da0073e9SAndroid Build Coastguard Worker mps_out = mps_x.permute((2, 0, 1)) * 2.0 9588*da0073e9SAndroid Build Coastguard Worker # this print caused a crash prior to fix PR#94259 9589*da0073e9SAndroid Build Coastguard Worker print(torch.zeros_like(mps_out)) 9590*da0073e9SAndroid Build Coastguard Worker # test the fix for fill_scalar_mps() mentioned in issue #94190 9591*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros_like(cpu_out), torch.zeros_like(mps_out)) 9592*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_x[:, 1, :].fill_(1), mps_x[:, 1, :].fill_(1)) 9593*da0073e9SAndroid Build Coastguard Worker 9594*da0073e9SAndroid Build Coastguard Worker def is_view_of(self, base, other): 9595*da0073e9SAndroid Build Coastguard Worker if (not other._is_view() or 9596*da0073e9SAndroid Build Coastguard Worker other is base or 9597*da0073e9SAndroid Build Coastguard Worker other._base is not base or 9598*da0073e9SAndroid Build Coastguard Worker base.device != other.device): 9599*da0073e9SAndroid Build Coastguard Worker return False 9600*da0073e9SAndroid Build Coastguard Worker # Note: only validates storage on native device types 9601*da0073e9SAndroid Build Coastguard Worker # because some accelerators, like XLA, do not expose storage 9602*da0073e9SAndroid Build Coastguard Worker if base.device.type == 'mps': 9603*da0073e9SAndroid Build Coastguard Worker if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr(): 9604*da0073e9SAndroid Build Coastguard Worker return False 9605*da0073e9SAndroid Build Coastguard Worker 9606*da0073e9SAndroid Build Coastguard Worker return True 9607*da0073e9SAndroid Build Coastguard Worker 9608*da0073e9SAndroid Build Coastguard Worker # Returns true if v1 and v2 are views of the same base 9609*da0073e9SAndroid Build Coastguard Worker def is_view_of_same_base(self, v1, v2): 9610*da0073e9SAndroid Build Coastguard Worker if (not v1._is_view() or v1 is v2): 9611*da0073e9SAndroid Build Coastguard Worker return False 9612*da0073e9SAndroid Build Coastguard Worker return self.is_view_of(v1._base, v2) 9613*da0073e9SAndroid Build Coastguard Worker 9614*da0073e9SAndroid Build Coastguard Worker # Performs transpose if contiguous=True, else returns the input tensor as is 9615*da0073e9SAndroid Build Coastguard Worker def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1): 9616*da0073e9SAndroid Build Coastguard Worker if contiguous: 9617*da0073e9SAndroid Build Coastguard Worker return x 9618*da0073e9SAndroid Build Coastguard Worker else: 9619*da0073e9SAndroid Build Coastguard Worker return x.transpose(dim0, dim1) 9620*da0073e9SAndroid Build Coastguard Worker 9621*da0073e9SAndroid Build Coastguard Worker def test_diagonal_view(self, device="mps"): 9622*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 9623*da0073e9SAndroid Build Coastguard Worker v = torch.diagonal(t) 9624*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9625*da0073e9SAndroid Build Coastguard Worker 9626*da0073e9SAndroid Build Coastguard Worker v[0] = 0 9627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 0], v[0]) 9628*da0073e9SAndroid Build Coastguard Worker 9629*da0073e9SAndroid Build Coastguard Worker t = torch.ones((3, 3, 3), device="mps") 9630*da0073e9SAndroid Build Coastguard Worker v = torch.diagonal(t, offset=1, dim1=1, dim2=2) 9631*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9632*da0073e9SAndroid Build Coastguard Worker 9633*da0073e9SAndroid Build Coastguard Worker v[0, 0] = 0 9634*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 0, 1], v[0, 0]) 9635*da0073e9SAndroid Build Coastguard Worker 9636*da0073e9SAndroid Build Coastguard Worker def test_select_view(self, device="mps") -> None: 9637*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 9638*da0073e9SAndroid Build Coastguard Worker v = t.select(0, 2) 9639*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9640*da0073e9SAndroid Build Coastguard Worker 9641*da0073e9SAndroid Build Coastguard Worker v[0] = 0 9642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[2, 0], v[0]) 9643*da0073e9SAndroid Build Coastguard Worker 9644*da0073e9SAndroid Build Coastguard Worker def test_unbind_view(self, device="mps") -> None: 9645*da0073e9SAndroid Build Coastguard Worker t = torch.zeros((5, 5), device=device) 9646*da0073e9SAndroid Build Coastguard Worker tup = torch.unbind(t) 9647*da0073e9SAndroid Build Coastguard Worker 9648*da0073e9SAndroid Build Coastguard Worker for idx, v in enumerate(tup): 9649*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9650*da0073e9SAndroid Build Coastguard Worker 9651*da0073e9SAndroid Build Coastguard Worker v[0] = idx + 1 9652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[idx, 0], v[0]) 9653*da0073e9SAndroid Build Coastguard Worker 9654*da0073e9SAndroid Build Coastguard Worker def test_expand_view(self, device="mps") -> None: 9655*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 1), device=device) 9656*da0073e9SAndroid Build Coastguard Worker v = t.expand(5, 5) 9657*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9658*da0073e9SAndroid Build Coastguard Worker 9659*da0073e9SAndroid Build Coastguard Worker v[2, 2] = 0 9660*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[2, 0], v[2, 2]) 9661*da0073e9SAndroid Build Coastguard Worker 9662*da0073e9SAndroid Build Coastguard Worker def test_expand_as_view(self, device="mps"): 9663*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 1), device=device) 9664*da0073e9SAndroid Build Coastguard Worker e = torch.empty((5, 5), device=device) 9665*da0073e9SAndroid Build Coastguard Worker v = t.expand_as(e) 9666*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9667*da0073e9SAndroid Build Coastguard Worker 9668*da0073e9SAndroid Build Coastguard Worker v[2, 2] = 0 9669*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[2, 0], v[2, 2]) 9670*da0073e9SAndroid Build Coastguard Worker 9671*da0073e9SAndroid Build Coastguard Worker def test_narrow_view(self, device="mps"): 9672*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 9673*da0073e9SAndroid Build Coastguard Worker v = torch.narrow(t, 1, 2, 2) 9674*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9675*da0073e9SAndroid Build Coastguard Worker 9676*da0073e9SAndroid Build Coastguard Worker v[0, 0] = 0 9677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 2], v[0, 0]) 9678*da0073e9SAndroid Build Coastguard Worker 9679*da0073e9SAndroid Build Coastguard Worker def test_permute_view(self, device="mps") -> None: 9680*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 9681*da0073e9SAndroid Build Coastguard Worker v = t.permute(1, 0) 9682*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9683*da0073e9SAndroid Build Coastguard Worker 9684*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 9685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 9686*da0073e9SAndroid Build Coastguard Worker 9687*da0073e9SAndroid Build Coastguard Worker def test_transpose_view(self, device="mps"): 9688*da0073e9SAndroid Build Coastguard Worker for fn in (torch.swapdims, torch.swapaxes, torch.transpose): 9689*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 9690*da0073e9SAndroid Build Coastguard Worker v = fn(t, 0, 1) 9691*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9692*da0073e9SAndroid Build Coastguard Worker 9693*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 9694*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 9695*da0073e9SAndroid Build Coastguard Worker 9696*da0073e9SAndroid Build Coastguard Worker def test_transpose_inplace_view(self, device="mps"): 9697*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9698*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 9699*da0073e9SAndroid Build Coastguard Worker v = v.swapdims_(0, 1) 9700*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9701*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 9702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 9703*da0073e9SAndroid Build Coastguard Worker 9704*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9705*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 9706*da0073e9SAndroid Build Coastguard Worker v = v.swapaxes_(0, 1) 9707*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9708*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 9709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 9710*da0073e9SAndroid Build Coastguard Worker 9711*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9712*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 9713*da0073e9SAndroid Build Coastguard Worker v = v.transpose_(0, 1) 9714*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9715*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 9716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 9717*da0073e9SAndroid Build Coastguard Worker 9718*da0073e9SAndroid Build Coastguard Worker def test_t_view(self, device="mps"): 9719*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 9720*da0073e9SAndroid Build Coastguard Worker v = t.t() 9721*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9722*da0073e9SAndroid Build Coastguard Worker 9723*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 9724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 9725*da0073e9SAndroid Build Coastguard Worker 9726*da0073e9SAndroid Build Coastguard Worker def test_inplace_view_add(self): 9727*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/96153 9728*da0073e9SAndroid Build Coastguard Worker t_mps = torch.ones((2, 6,), device='mps')[1].reshape(2, 3) 9729*da0073e9SAndroid Build Coastguard Worker t_cpu = torch.ones((2, 6,), device='cpu')[1].reshape(2, 3) 9730*da0073e9SAndroid Build Coastguard Worker t_mps = t_mps + 1 9731*da0073e9SAndroid Build Coastguard Worker t_cpu = t_cpu + 1 9732*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_mps, t_cpu) 9733*da0073e9SAndroid Build Coastguard Worker 9734*da0073e9SAndroid Build Coastguard Worker def test_t_inplace_view(self, device="mps"): 9735*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9736*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 9737*da0073e9SAndroid Build Coastguard Worker v = v.t_() 9738*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9739*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 9740*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 9741*da0073e9SAndroid Build Coastguard Worker 9742*da0073e9SAndroid Build Coastguard Worker def test_T_view(self, device="mps"): 9743*da0073e9SAndroid Build Coastguard Worker for op in ("T", "H", "mT", "mH"): 9744*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 9745*da0073e9SAndroid Build Coastguard Worker v = getattr(t, op) 9746*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9747*da0073e9SAndroid Build Coastguard Worker 9748*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 9749*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 9750*da0073e9SAndroid Build Coastguard Worker 9751*da0073e9SAndroid Build Coastguard Worker def test_unfold_view(self, device="mps"): 9752*da0073e9SAndroid Build Coastguard Worker t = torch.ones(10, device=device) 9753*da0073e9SAndroid Build Coastguard Worker v = t.unfold(0, 3, 2) 9754*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9755*da0073e9SAndroid Build Coastguard Worker 9756*da0073e9SAndroid Build Coastguard Worker v[1, 0] = 0 9757*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[2], v[1, 0]) 9758*da0073e9SAndroid Build Coastguard Worker 9759*da0073e9SAndroid Build Coastguard Worker def test_squeeze_view(self, device="mps"): 9760*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 1, 5, device=device) 9761*da0073e9SAndroid Build Coastguard Worker v = torch.squeeze(t) 9762*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9763*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 9764*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, v._base) 9765*da0073e9SAndroid Build Coastguard Worker 9766*da0073e9SAndroid Build Coastguard Worker def test_squeeze_inplace_view(self, device="mps"): 9767*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9768*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 9769*da0073e9SAndroid Build Coastguard Worker v = v.squeeze_() 9770*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9771*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 9772*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, v._base) 9773*da0073e9SAndroid Build Coastguard Worker 9774*da0073e9SAndroid Build Coastguard Worker def test_unsqueeze_view(self, device="mps"): 9775*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9776*da0073e9SAndroid Build Coastguard Worker v = torch.unsqueeze(t, 1) 9777*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9778*da0073e9SAndroid Build Coastguard Worker 9779*da0073e9SAndroid Build Coastguard Worker v[0, 0, 1] = 0 9780*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 1], v[0, 0, 1]) 9781*da0073e9SAndroid Build Coastguard Worker 9782*da0073e9SAndroid Build Coastguard Worker def test_unsqueeze_inplace_view(self, device="mps"): 9783*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9784*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 9785*da0073e9SAndroid Build Coastguard Worker v = v.unsqueeze_(1) 9786*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9787*da0073e9SAndroid Build Coastguard Worker v[0, 0, 1] = 0 9788*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 1], v[0, 0, 1]) 9789*da0073e9SAndroid Build Coastguard Worker 9790*da0073e9SAndroid Build Coastguard Worker def test_as_strided_view(self, device="mps"): 9791*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9792*da0073e9SAndroid Build Coastguard Worker v = torch.as_strided(t, (25,), (1,)) 9793*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9794*da0073e9SAndroid Build Coastguard Worker 9795*da0073e9SAndroid Build Coastguard Worker v[6] = 0 9796*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 9797*da0073e9SAndroid Build Coastguard Worker 9798*da0073e9SAndroid Build Coastguard Worker def test_as_strided_inplace_view(self, device="mps"): 9799*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9800*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 9801*da0073e9SAndroid Build Coastguard Worker v = v.as_strided_((25,), (1,)) 9802*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9803*da0073e9SAndroid Build Coastguard Worker v[6] = 0 9804*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 9805*da0073e9SAndroid Build Coastguard Worker 9806*da0073e9SAndroid Build Coastguard Worker def test_view_view(self, device="mps"): 9807*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9808*da0073e9SAndroid Build Coastguard Worker v = t.view(25) 9809*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9810*da0073e9SAndroid Build Coastguard Worker 9811*da0073e9SAndroid Build Coastguard Worker v[6] = 0 9812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 9813*da0073e9SAndroid Build Coastguard Worker 9814*da0073e9SAndroid Build Coastguard Worker def test_view_as_view(self, device="mps"): 9815*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9816*da0073e9SAndroid Build Coastguard Worker e = torch.empty((25,)) 9817*da0073e9SAndroid Build Coastguard Worker v = t.view_as(e) 9818*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9819*da0073e9SAndroid Build Coastguard Worker 9820*da0073e9SAndroid Build Coastguard Worker v[6] = 0 9821*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 9822*da0073e9SAndroid Build Coastguard Worker 9823*da0073e9SAndroid Build Coastguard Worker def test_contiguous_self(self, device="mps"): 9824*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9825*da0073e9SAndroid Build Coastguard Worker s = t.contiguous() 9826*da0073e9SAndroid Build Coastguard Worker self.assertIs(s, t) 9827*da0073e9SAndroid Build Coastguard Worker 9828*da0073e9SAndroid Build Coastguard Worker def test_contiguous_nonview(self, device="mps"): 9829*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9830*da0073e9SAndroid Build Coastguard Worker nv = t.t().contiguous() 9831*da0073e9SAndroid Build Coastguard Worker self.assertFalse(self.is_view_of(t, nv)) 9832*da0073e9SAndroid Build Coastguard Worker 9833*da0073e9SAndroid Build Coastguard Worker nv[0, 0] = 0 9834*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t[0, 0], nv[0, 0]) 9835*da0073e9SAndroid Build Coastguard Worker 9836*da0073e9SAndroid Build Coastguard Worker def test_reshape_view(self, device="mps"): 9837*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9838*da0073e9SAndroid Build Coastguard Worker v = torch.reshape(t, (25,)) 9839*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9840*da0073e9SAndroid Build Coastguard Worker 9841*da0073e9SAndroid Build Coastguard Worker v[6] = 0 9842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 9843*da0073e9SAndroid Build Coastguard Worker 9844*da0073e9SAndroid Build Coastguard Worker def test_reshape_as_view(self, device="mps"): 9845*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9846*da0073e9SAndroid Build Coastguard Worker e = torch.empty((25,), device=device) 9847*da0073e9SAndroid Build Coastguard Worker v = t.reshape_as(e) 9848*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9849*da0073e9SAndroid Build Coastguard Worker 9850*da0073e9SAndroid Build Coastguard Worker v[6] = 0 9851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 9852*da0073e9SAndroid Build Coastguard Worker 9853*da0073e9SAndroid Build Coastguard Worker def test_reshape_nonview(self, device="mps"): 9854*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9855*da0073e9SAndroid Build Coastguard Worker nv = torch.reshape(t.t(), (25,)) 9856*da0073e9SAndroid Build Coastguard Worker self.assertFalse(self.is_view_of(t, nv)) 9857*da0073e9SAndroid Build Coastguard Worker 9858*da0073e9SAndroid Build Coastguard Worker nv[6] = 0 9859*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t[1, 1], nv[6]) 9860*da0073e9SAndroid Build Coastguard Worker 9861*da0073e9SAndroid Build Coastguard Worker def test_flatten_view(self, device="mps"): 9862*da0073e9SAndroid Build Coastguard Worker def test_writes_propagate(t, v): 9863*da0073e9SAndroid Build Coastguard Worker idx_t = (0,) * t.ndim 9864*da0073e9SAndroid Build Coastguard Worker idx_v = (0,) * v.ndim 9865*da0073e9SAndroid Build Coastguard Worker v[idx_v] = 0 9866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[idx_t], v[idx_v]) 9867*da0073e9SAndroid Build Coastguard Worker 9868*da0073e9SAndroid Build Coastguard Worker t = torch.ones(1, 2, 3, 4, device=device) 9869*da0073e9SAndroid Build Coastguard Worker v = t.flatten() 9870*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9871*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v) 9872*da0073e9SAndroid Build Coastguard Worker 9873*da0073e9SAndroid Build Coastguard Worker # zero-dimensional tensor 9874*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(1, device=device) 9875*da0073e9SAndroid Build Coastguard Worker v = t.flatten() 9876*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v) 9877*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9878*da0073e9SAndroid Build Coastguard Worker 9879*da0073e9SAndroid Build Coastguard Worker t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3) 9880*da0073e9SAndroid Build Coastguard Worker v = t.flatten(0, 1) 9881*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v) 9882*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of_same_base(t, v)) 9883*da0073e9SAndroid Build Coastguard Worker 9884*da0073e9SAndroid Build Coastguard Worker # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups: 9885*da0073e9SAndroid Build Coastguard Worker t = torch.ones(720, device=device) \ 9886*da0073e9SAndroid Build Coastguard Worker .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0)) 9887*da0073e9SAndroid Build Coastguard Worker # [--1--|---2---|-3-] [--1--|----2---|-3-] 9888*da0073e9SAndroid Build Coastguard Worker v1 = t.flatten(0, 1) 9889*da0073e9SAndroid Build Coastguard Worker v2 = v1.flatten(1, 3) 9890*da0073e9SAndroid Build Coastguard Worker v3 = v2.flatten(2, 2) 9891*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v1) 9892*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of_same_base(t, v1)) 9893*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v2) 9894*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of_same_base(t, v2)) 9895*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v3) 9896*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of_same_base(t, v3)) 9897*da0073e9SAndroid Build Coastguard Worker 9898*da0073e9SAndroid Build Coastguard Worker def test_flatten_nonview(self, device="mps"): 9899*da0073e9SAndroid Build Coastguard Worker def assert_is_nonview(t, nv): 9900*da0073e9SAndroid Build Coastguard Worker idx_t = (0,) * t.ndim 9901*da0073e9SAndroid Build Coastguard Worker idx_nv = (0,) * nv.ndim 9902*da0073e9SAndroid Build Coastguard Worker self.assertFalse(nv._is_view()) 9903*da0073e9SAndroid Build Coastguard Worker nv[idx_nv] = 0 9904*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t[idx_t], nv[idx_nv]) 9905*da0073e9SAndroid Build Coastguard Worker t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3) 9906*da0073e9SAndroid Build Coastguard Worker nv = t.flatten(1, 3) 9907*da0073e9SAndroid Build Coastguard Worker assert_is_nonview(t, nv) 9908*da0073e9SAndroid Build Coastguard Worker 9909*da0073e9SAndroid Build Coastguard Worker t = torch.ones(2, 2, device=device).T 9910*da0073e9SAndroid Build Coastguard Worker nv = t.flatten() 9911*da0073e9SAndroid Build Coastguard Worker assert_is_nonview(t, nv) 9912*da0073e9SAndroid Build Coastguard Worker 9913*da0073e9SAndroid Build Coastguard Worker # flatten returns the original object if start_dim=end_dim 9914*da0073e9SAndroid Build Coastguard Worker t = t = torch.ones(2, 2, device=device) 9915*da0073e9SAndroid Build Coastguard Worker nv = t.flatten(1, 1) 9916*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, nv) 9917*da0073e9SAndroid Build Coastguard Worker 9918*da0073e9SAndroid Build Coastguard Worker def test_basic_indexing_slice_view(self, device="mps"): 9919*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9920*da0073e9SAndroid Build Coastguard Worker v = t[:2, :3] 9921*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9922*da0073e9SAndroid Build Coastguard Worker 9923*da0073e9SAndroid Build Coastguard Worker v[0, 0] = 0 9924*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 0], v[0, 0]) 9925*da0073e9SAndroid Build Coastguard Worker 9926*da0073e9SAndroid Build Coastguard Worker def test_basic_indexing_ellipses_view(self, device="mps"): 9927*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9928*da0073e9SAndroid Build Coastguard Worker v = t[..., :2] 9929*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9930*da0073e9SAndroid Build Coastguard Worker 9931*da0073e9SAndroid Build Coastguard Worker v[0, 0] = 0 9932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 0], v[0, 0]) 9933*da0073e9SAndroid Build Coastguard Worker 9934*da0073e9SAndroid Build Coastguard Worker def test_basic_indexing_newaxis_view(self, device="mps"): 9935*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 9936*da0073e9SAndroid Build Coastguard Worker v = t[None, :2, 3] 9937*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9938*da0073e9SAndroid Build Coastguard Worker 9939*da0073e9SAndroid Build Coastguard Worker v[0, 0] = 0 9940*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 3], v[0, 0]) 9941*da0073e9SAndroid Build Coastguard Worker 9942*da0073e9SAndroid Build Coastguard Worker def test_chunk_view(self, device="mps"): 9943*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(3, 3, device=device) 9944*da0073e9SAndroid Build Coastguard Worker l = torch.chunk(t, 3) 9945*da0073e9SAndroid Build Coastguard Worker 9946*da0073e9SAndroid Build Coastguard Worker for idx, v in enumerate(l): 9947*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9948*da0073e9SAndroid Build Coastguard Worker 9949*da0073e9SAndroid Build Coastguard Worker v[0, 0] = idx + 1 9950*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[idx, 0], v[0, 0]) 9951*da0073e9SAndroid Build Coastguard Worker 9952*da0073e9SAndroid Build Coastguard Worker def test_split_view(self, device="mps"): 9953*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(3, 3, device=device) 9954*da0073e9SAndroid Build Coastguard Worker l = torch.split(t, [1, 1, 1]) 9955*da0073e9SAndroid Build Coastguard Worker 9956*da0073e9SAndroid Build Coastguard Worker for idx, v in enumerate(l): 9957*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 9958*da0073e9SAndroid Build Coastguard Worker 9959*da0073e9SAndroid Build Coastguard Worker v[0, 0] = idx + 1 9960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[idx, 0], v[0, 0]) 9961*da0073e9SAndroid Build Coastguard Worker 9962*da0073e9SAndroid Build Coastguard Worker def test_movedim_view(self, device="mps"): 9963*da0073e9SAndroid Build Coastguard Worker def run_test(device, op): 9964*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(3, 3, device=device) 9965*da0073e9SAndroid Build Coastguard Worker out = op(t) 9966*da0073e9SAndroid Build Coastguard Worker 9967*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, out)) 9968*da0073e9SAndroid Build Coastguard Worker 9969*da0073e9SAndroid Build Coastguard Worker # Randomly change values in output 9970*da0073e9SAndroid Build Coastguard Worker # and verify that original is changed 9971*da0073e9SAndroid Build Coastguard Worker # as well. 9972*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 9973*da0073e9SAndroid Build Coastguard Worker idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2) 9974*da0073e9SAndroid Build Coastguard Worker out[idx_1, idx_2] = random.random() 9975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2]) 9976*da0073e9SAndroid Build Coastguard Worker 9977*da0073e9SAndroid Build Coastguard Worker for fn in [torch.movedim, torch.moveaxis]: 9978*da0073e9SAndroid Build Coastguard Worker op = partial(fn, source=(0, 1), destination=(1, 0)) 9979*da0073e9SAndroid Build Coastguard Worker run_test(device, op) 9980*da0073e9SAndroid Build Coastguard Worker 9981*da0073e9SAndroid Build Coastguard Worker op = partial(fn, source=0, destination=1) 9982*da0073e9SAndroid Build Coastguard Worker run_test(device, op) 9983*da0073e9SAndroid Build Coastguard Worker 9984*da0073e9SAndroid Build Coastguard Worker # Testing that the generated view_copy kernel and its derivative are implemented correctly 9985*da0073e9SAndroid Build Coastguard Worker def test_view_copy(self, device="mps"): 9986*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, device=device, requires_grad=True) 9987*da0073e9SAndroid Build Coastguard Worker a_ref = a.clone().detach().requires_grad_() 9988*da0073e9SAndroid Build Coastguard Worker a_view = a_ref.view(2, 2) 9989*da0073e9SAndroid Build Coastguard Worker a_view_copy = torch.view_copy(a, (2, 2)) 9990*da0073e9SAndroid Build Coastguard Worker 9991*da0073e9SAndroid Build Coastguard Worker # view_copy ops don't preserve view relationship 9992*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(a_ref, a_view)) 9993*da0073e9SAndroid Build Coastguard Worker self.assertFalse(self.is_view_of(a, a_view_copy)) 9994*da0073e9SAndroid Build Coastguard Worker 9995*da0073e9SAndroid Build Coastguard Worker a_view_copy.sum().backward() 9996*da0073e9SAndroid Build Coastguard Worker a_view.sum().backward() 9997*da0073e9SAndroid Build Coastguard Worker 9998*da0073e9SAndroid Build Coastguard Worker # forward and backward give the same shape + result 9999*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_view_copy, a_view) 10000*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, a_ref.grad) 10001*da0073e9SAndroid Build Coastguard Worker 10002*da0073e9SAndroid Build Coastguard Worker def test_view_copy_out(self, device="mps"): 10003*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, device=device) 10004*da0073e9SAndroid Build Coastguard Worker out = torch.empty(2, device=device) 10005*da0073e9SAndroid Build Coastguard Worker 10006*da0073e9SAndroid Build Coastguard Worker torch.diagonal_copy(a, out=out) 10007*da0073e9SAndroid Build Coastguard Worker expected = torch.diagonal_copy(a) 10008*da0073e9SAndroid Build Coastguard Worker 10009*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 10010*da0073e9SAndroid Build Coastguard Worker 10011*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, device=device) 10012*da0073e9SAndroid Build Coastguard Worker out1 = torch.empty(2, device=device) 10013*da0073e9SAndroid Build Coastguard Worker out2 = torch.empty(2, device=device) 10014*da0073e9SAndroid Build Coastguard Worker 10015*da0073e9SAndroid Build Coastguard Worker torch.split_copy(a, 2, out=(out1, out2)) 10016*da0073e9SAndroid Build Coastguard Worker expected1, expected2 = torch.split_copy(a, 2) 10017*da0073e9SAndroid Build Coastguard Worker 10018*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected1, out1) 10019*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected2, out2) 10020*da0073e9SAndroid Build Coastguard Worker 10021*da0073e9SAndroid Build Coastguard Worker def test_detached_view_copy(self, device="mps"): 10022*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/86052 10023*da0073e9SAndroid Build Coastguard Worker x = torch.arange(2) 10024*da0073e9SAndroid Build Coastguard Worker # .detach() makes y not a view, but contig tensor 10025*da0073e9SAndroid Build Coastguard Worker # with non-zero offset 10026*da0073e9SAndroid Build Coastguard Worker y = x[1].detach() 10027*da0073e9SAndroid Build Coastguard Worker z = y.to(device) 10028*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, z.cpu()) 10029*da0073e9SAndroid Build Coastguard Worker 10030*da0073e9SAndroid Build Coastguard Worker def test_empty_reshape(self, device="mps"): 10031*da0073e9SAndroid Build Coastguard Worker x = torch.randn(0, 6, device=device) 10032*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape) 10033*da0073e9SAndroid Build Coastguard Worker # should be viewable -- i.e. data_ptr is the same. 10034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr()) 10035*da0073e9SAndroid Build Coastguard Worker 10036*da0073e9SAndroid Build Coastguard Worker # match NumPy semantics -- don't infer the size of dimension with a degree of freedom 10037*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) 10038*da0073e9SAndroid Build Coastguard Worker 10039*da0073e9SAndroid Build Coastguard Worker def test_expand(self, device="mps"): 10040*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(1, 8, 1, device=device) 10041*da0073e9SAndroid Build Coastguard Worker tensor2 = torch.rand(5, device=device) 10042*da0073e9SAndroid Build Coastguard Worker template = torch.rand(4, 8, 5, device=device) 10043*da0073e9SAndroid Build Coastguard Worker target = template.size() 10044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.expand_as(template).size(), target) 10045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.expand(4, 8, 5).size(), target) 10046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.expand(target).size(), target) 10047*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor2.expand_as(template).size(), target) 10048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor2.expand(4, 8, 5).size(), target) 10049*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor2.expand(target).size(), target) 10050*da0073e9SAndroid Build Coastguard Worker 10051*da0073e9SAndroid Build Coastguard Worker # test double expand 10052*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1)) 10053*da0073e9SAndroid Build Coastguard Worker 10054*da0073e9SAndroid Build Coastguard Worker # test non-contiguous 10055*da0073e9SAndroid Build Coastguard Worker noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0] 10056*da0073e9SAndroid Build Coastguard Worker self.assertFalse(noncontig.is_contiguous()) 10057*da0073e9SAndroid Build Coastguard Worker self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1)) 10058*da0073e9SAndroid Build Coastguard Worker 10059*da0073e9SAndroid Build Coastguard Worker # make sure it's compatible with unsqueeze 10060*da0073e9SAndroid Build Coastguard Worker expanded = tensor2.expand(1, 1, 5) 10061*da0073e9SAndroid Build Coastguard Worker unsqueezed = tensor2.unsqueeze(0).unsqueeze(1) 10062*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded, unsqueezed) 10063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded.stride(), unsqueezed.stride()) 10064*da0073e9SAndroid Build Coastguard Worker 10065*da0073e9SAndroid Build Coastguard Worker # test -1 as target size 10066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5)) 10067*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1)) 10068*da0073e9SAndroid Build Coastguard Worker 10069*da0073e9SAndroid Build Coastguard Worker # test expanding empty to empty 10070*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device)) 10071*da0073e9SAndroid Build Coastguard Worker 10072*da0073e9SAndroid Build Coastguard Worker def test_view_empty(self, device="mps"): 10073*da0073e9SAndroid Build Coastguard Worker x = torch.randn(0, 6, device=device) 10074*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape) 10075*da0073e9SAndroid Build Coastguard Worker 10076*da0073e9SAndroid Build Coastguard Worker def test_reshape(self, device="mps"): 10077*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, device=device) 10078*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) 10079*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) 10080*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) 10081*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) 10082*da0073e9SAndroid Build Coastguard Worker 10083*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, 4, device=device)[:, 0, :] 10084*da0073e9SAndroid Build Coastguard Worker # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape 10085*da0073e9SAndroid Build Coastguard Worker if device != "meta": 10086*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) 10087*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) 10088*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) 10089*da0073e9SAndroid Build Coastguard Worker 10090*da0073e9SAndroid Build Coastguard Worker s = torch.randn((), device=device) 10091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) 10092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.reshape(-1).shape, (1,)) 10093*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: s.reshape(2)) 10094*da0073e9SAndroid Build Coastguard Worker 10095*da0073e9SAndroid Build Coastguard Worker empty = torch.tensor([], device=device) 10096*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty, empty.reshape(-1)) 10097*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty, empty.reshape([0])) 10098*da0073e9SAndroid Build Coastguard Worker # TODO: fix these once we have multi-dimensional empty tensors 10099*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.reshape([0, 1]).shape, (0, 1)) 10100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.reshape([1, -1]).shape, (1, 0)) 10101*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: empty.reshape(1)) 10102*da0073e9SAndroid Build Coastguard Worker 10103*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, device=device) 10104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr()) 10105*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr()) 10106*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device))) 10107*da0073e9SAndroid Build Coastguard Worker 10108*da0073e9SAndroid Build Coastguard Worker def test_narrow(self, device="mps"): 10109*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 10110*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]])) 10111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]])) 10112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]])) 10113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]])) 10114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]])) 10115*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) 10116*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]])) 10117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]])) 10118*da0073e9SAndroid Build Coastguard Worker 10119*da0073e9SAndroid Build Coastguard Worker def test_narrow_tensor(self, device="mps"): 10120*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 10121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]])) 10122*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 10123*da0073e9SAndroid Build Coastguard Worker x.narrow(0, torch.tensor(0.), 1) 10124*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 10125*da0073e9SAndroid Build Coastguard Worker x.narrow(0, torch.tensor([0]), 1) 10126*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 10127*da0073e9SAndroid Build Coastguard Worker x.narrow(0, torch.tensor([0, 1]), 1) 10128*da0073e9SAndroid Build Coastguard Worker 10129*da0073e9SAndroid Build Coastguard Worker def test_t(self, device="mps"): 10130*da0073e9SAndroid Build Coastguard Worker # Test 0D tensors 10131*da0073e9SAndroid Build Coastguard Worker x = torch.randn(()) 10132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.t()) 10133*da0073e9SAndroid Build Coastguard Worker x = x.to_sparse() 10134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.t()) 10135*da0073e9SAndroid Build Coastguard Worker 10136*da0073e9SAndroid Build Coastguard Worker # Test 1D tensors 10137*da0073e9SAndroid Build Coastguard Worker x = torch.arange(4) 10138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.t()) 10139*da0073e9SAndroid Build Coastguard Worker x = x.to_sparse() 10140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.t()) 10141*da0073e9SAndroid Build Coastguard Worker 10142*da0073e9SAndroid Build Coastguard Worker # Test 2D tensors 10143*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 10144*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.t(), x.transpose(0, 1)) 10145*da0073e9SAndroid Build Coastguard Worker x = x.to_sparse() 10146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.t(), x.transpose(0, 1)) 10147*da0073e9SAndroid Build Coastguard Worker 10148*da0073e9SAndroid Build Coastguard Worker # Test 3D tensor 10149*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2, 2)) 10150*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'): 10151*da0073e9SAndroid Build Coastguard Worker x.t() 10152*da0073e9SAndroid Build Coastguard Worker x = x.to_sparse() 10153*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'): 10154*da0073e9SAndroid Build Coastguard Worker x.t() 10155*da0073e9SAndroid Build Coastguard Worker 10156*da0073e9SAndroid Build Coastguard Worker def test_split(self, device="mps"): 10157*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(7, 4) 10158*da0073e9SAndroid Build Coastguard Worker split_size = 3 10159*da0073e9SAndroid Build Coastguard Worker dim = 0 10160*da0073e9SAndroid Build Coastguard Worker target_sizes = ([3, 4], [3, 4], [1, 4]) 10161*da0073e9SAndroid Build Coastguard Worker splits = tensor.split(split_size, dim) 10162*da0073e9SAndroid Build Coastguard Worker start = 0 10163*da0073e9SAndroid Build Coastguard Worker for target_size, split in zip(target_sizes, splits): 10164*da0073e9SAndroid Build Coastguard Worker self.assertEqual(split.size(), target_size) 10165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) 10166*da0073e9SAndroid Build Coastguard Worker start = start + target_size[dim] 10167*da0073e9SAndroid Build Coastguard Worker 10168*da0073e9SAndroid Build Coastguard Worker # Variable sections split 10169*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(20, 10) 10170*da0073e9SAndroid Build Coastguard Worker dim = 0 10171*da0073e9SAndroid Build Coastguard Worker split_sizes = [5, 5, 10] 10172*da0073e9SAndroid Build Coastguard Worker target_sizes = ([[5, 10], [5, 10], [10, 10]]) 10173*da0073e9SAndroid Build Coastguard Worker splits = tensor.split(split_sizes, dim) 10174*da0073e9SAndroid Build Coastguard Worker start = 0 10175*da0073e9SAndroid Build Coastguard Worker for target_size, split in zip(target_sizes, splits): 10176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(split.size(), target_size) 10177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) 10178*da0073e9SAndroid Build Coastguard Worker start = start + target_size[dim] 10179*da0073e9SAndroid Build Coastguard Worker 10180*da0073e9SAndroid Build Coastguard Worker split_sizes = [2, 2, 6] 10181*da0073e9SAndroid Build Coastguard Worker target_sizes = ([20, 2], [20, 2], [20, 6]) 10182*da0073e9SAndroid Build Coastguard Worker dim = 1 10183*da0073e9SAndroid Build Coastguard Worker splits = tensor.split(split_sizes, dim) 10184*da0073e9SAndroid Build Coastguard Worker start = 0 10185*da0073e9SAndroid Build Coastguard Worker for target_size, split in zip(target_sizes, splits): 10186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(split.size(), target_size) 10187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) 10188*da0073e9SAndroid Build Coastguard Worker start = start + target_size[dim] 10189*da0073e9SAndroid Build Coastguard Worker 10190*da0073e9SAndroid Build Coastguard Worker def test_chunk(self, device="mps"): 10191*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(4, 7) 10192*da0073e9SAndroid Build Coastguard Worker num_chunks = 3 10193*da0073e9SAndroid Build Coastguard Worker dim = 1 10194*da0073e9SAndroid Build Coastguard Worker target_sizes = ([4, 3], [4, 3], [4, 1]) 10195*da0073e9SAndroid Build Coastguard Worker splits = tensor.chunk(num_chunks, dim) 10196*da0073e9SAndroid Build Coastguard Worker start = 0 10197*da0073e9SAndroid Build Coastguard Worker for target_size, split in zip(target_sizes, splits): 10198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(split.size(), target_size) 10199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 10200*da0073e9SAndroid Build Coastguard Worker atol=0, rtol=0) 10201*da0073e9SAndroid Build Coastguard Worker start = start + target_size[dim] 10202*da0073e9SAndroid Build Coastguard Worker 10203*da0073e9SAndroid Build Coastguard Worker # Invalid chunk sizes 10204*da0073e9SAndroid Build Coastguard Worker error_regex = 'chunk expects.*greater than 0' 10205*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_regex): 10206*da0073e9SAndroid Build Coastguard Worker tensor.chunk(0) 10207*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_regex): 10208*da0073e9SAndroid Build Coastguard Worker tensor.chunk(-2) 10209*da0073e9SAndroid Build Coastguard Worker 10210*da0073e9SAndroid Build Coastguard Worker def test_unsqueeze(self, device="mps") -> None: 10211*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4) 10212*da0073e9SAndroid Build Coastguard Worker y = x.unsqueeze(1) 10213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.view(2, 1, 3, 4)) 10214*da0073e9SAndroid Build Coastguard Worker y = x.clone().unsqueeze_(2) 10215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.view(2, 3, 1, 4)) 10216*da0073e9SAndroid Build Coastguard Worker 10217*da0073e9SAndroid Build Coastguard Worker x = x[:, 1] 10218*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_contiguous()) 10219*da0073e9SAndroid Build Coastguard Worker y = x.unsqueeze(1) 10220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.contiguous().view(2, 1, 4)) 10221*da0073e9SAndroid Build Coastguard Worker y = x.clone().unsqueeze_(2) 10222*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.contiguous().view(2, 4, 1)) 10223*da0073e9SAndroid Build Coastguard Worker 10224*da0073e9SAndroid Build Coastguard Worker # unit test for special case transposed copy (see ATen/native/Copy.cpp for details) 10225*da0073e9SAndroid Build Coastguard Worker def test_big_transpose(self, device="mps"): 10226*da0073e9SAndroid Build Coastguard Worker t = torch.rand(456, 789, device=device) 10227*da0073e9SAndroid Build Coastguard Worker t1 = t.t().contiguous() 10228*da0073e9SAndroid Build Coastguard Worker t2 = torch.from_numpy(t.cpu().numpy().transpose()) 10229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, t2) 10230*da0073e9SAndroid Build Coastguard Worker 10231*da0073e9SAndroid Build Coastguard Worker def test_T(self, device="mps"): 10232*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, 4, device=device) 10233*da0073e9SAndroid Build Coastguard Worker t1 = a.T 10234*da0073e9SAndroid Build Coastguard Worker t2 = a.permute(2, 1, 0) 10235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2, t1) 10236*da0073e9SAndroid Build Coastguard Worker b = torch.randn(10, device=device) 10237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, b.T) 10238*da0073e9SAndroid Build Coastguard Worker 10239*da0073e9SAndroid Build Coastguard Worker def test_transposes(self, device="mps", dtype=torch.float32): 10240*da0073e9SAndroid Build Coastguard Worker for op in ("T", "H", "mT", "mH", "adjoint"): 10241*da0073e9SAndroid Build Coastguard Worker shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),) 10242*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 10243*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, device=device, dtype=dtype) 10244*da0073e9SAndroid Build Coastguard Worker t1 = getattr(a, op) 10245*da0073e9SAndroid Build Coastguard Worker if op == "adjoint": 10246*da0073e9SAndroid Build Coastguard Worker t1 = t1() 10247*da0073e9SAndroid Build Coastguard Worker t2 = a 10248*da0073e9SAndroid Build Coastguard Worker if a.ndim != 0: 10249*da0073e9SAndroid Build Coastguard Worker t2 = t2.transpose(-2, -1) 10250*da0073e9SAndroid Build Coastguard Worker if op[-1] == "H" or op == "adjoint": 10251*da0073e9SAndroid Build Coastguard Worker t2 = t2.conj() 10252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2, t1) 10253*da0073e9SAndroid Build Coastguard Worker 10254*da0073e9SAndroid Build Coastguard Worker def test_transposes_errors(self, device="mps", dtype=torch.float32): 10255*da0073e9SAndroid Build Coastguard Worker for op in ("H", "mT", "mH", "adjoint"): 10256*da0073e9SAndroid Build Coastguard Worker shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),) 10257*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 10258*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, device=device, dtype=dtype) 10259*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "only supported on matrices"): 10260*da0073e9SAndroid Build Coastguard Worker t1 = getattr(a, op) 10261*da0073e9SAndroid Build Coastguard Worker if op == "adjoint": 10262*da0073e9SAndroid Build Coastguard Worker t1 = t1() 10263*da0073e9SAndroid Build Coastguard Worker 10264*da0073e9SAndroid Build Coastguard Worker def test_python_types(self, device="mps"): 10265*da0073e9SAndroid Build Coastguard Worker a1 = torch.randn((1, 2), device=device, dtype=torch.float32) 10266*da0073e9SAndroid Build Coastguard Worker a2 = torch.randn((1, 2), device=device, dtype=torch.float32) 10267*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.dtype, a2.dtype) 10268*da0073e9SAndroid Build Coastguard Worker 10269*da0073e9SAndroid Build Coastguard Worker b1 = torch.arange(10, 20, dtype=torch.int64, device=device) 10270*da0073e9SAndroid Build Coastguard Worker b2 = torch.arange(10, 20, dtype=int, device=device) 10271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b1.dtype, b2.dtype) 10272*da0073e9SAndroid Build Coastguard Worker 10273*da0073e9SAndroid Build Coastguard Worker c1 = torch.tensor([True, False], dtype=torch.bool, device=device) 10274*da0073e9SAndroid Build Coastguard Worker c2 = torch.tensor([True, False], dtype=bool, device=device) 10275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c1.dtype, c2.dtype) 10276*da0073e9SAndroid Build Coastguard Worker 10277*da0073e9SAndroid Build Coastguard Worker # TODO: is resize best put in test_view_ops? 10278*da0073e9SAndroid Build Coastguard Worker def test_resize_as_preserves_strides(self, device="mps"): 10279*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2, 3).t() 10280*da0073e9SAndroid Build Coastguard Worker old_strides = x.stride() 10281*da0073e9SAndroid Build Coastguard Worker x.resize_as_(x) 10282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.stride(), old_strides) 10283*da0073e9SAndroid Build Coastguard Worker 10284*da0073e9SAndroid Build Coastguard Worker def test_memory_format_resize_as(self, device="mps"): 10285*da0073e9SAndroid Build Coastguard Worker def test_helper(shape, memory_format, device="mps"): 10286*da0073e9SAndroid Build Coastguard Worker xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format) 10287*da0073e9SAndroid Build Coastguard Worker flat = torch.randn(xc.numel(), device=device) 10288*da0073e9SAndroid Build Coastguard Worker flat.resize_as_(xc, memory_format=torch.preserve_format) 10289*da0073e9SAndroid Build Coastguard Worker self.assertTrue(flat.is_contiguous(memory_format=memory_format)) 10290*da0073e9SAndroid Build Coastguard Worker 10291*da0073e9SAndroid Build Coastguard Worker test_helper((10, 3, 32, 32), torch.channels_last, device="mps") 10292*da0073e9SAndroid Build Coastguard Worker test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps") 10293*da0073e9SAndroid Build Coastguard Worker 10294*da0073e9SAndroid Build Coastguard Worker def test_memory_format_resize_(self, device="mps"): 10295*da0073e9SAndroid Build Coastguard Worker def test_helper(shape, numel, memory_format, device="mps"): 10296*da0073e9SAndroid Build Coastguard Worker flat = torch.randn(numel, device=device) 10297*da0073e9SAndroid Build Coastguard Worker flat.resize_(shape, memory_format=memory_format) 10298*da0073e9SAndroid Build Coastguard Worker self.assertTrue(flat.is_contiguous(memory_format=memory_format)) 10299*da0073e9SAndroid Build Coastguard Worker 10300*da0073e9SAndroid Build Coastguard Worker test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps") 10301*da0073e9SAndroid Build Coastguard Worker test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps") 10302*da0073e9SAndroid Build Coastguard Worker 10303*da0073e9SAndroid Build Coastguard Worker # TODO: OpInfo this 10304*da0073e9SAndroid Build Coastguard Worker def _test_atleast(self, device, torch_fn): 10305*da0073e9SAndroid Build Coastguard Worker # 0-dim 10306*da0073e9SAndroid Build Coastguard Worker s = torch.tensor(0.5, dtype=torch.double, requires_grad=True) 10307*da0073e9SAndroid Build Coastguard Worker 10308*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: torch_fn(x), s) 10309*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: torch_fn(x), s) 10310*da0073e9SAndroid Build Coastguard Worker 10311*da0073e9SAndroid Build Coastguard Worker # 1-dim 10312*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4, dtype=torch.double, requires_grad=True) 10313*da0073e9SAndroid Build Coastguard Worker 10314*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: torch_fn(x), a) 10315*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: torch_fn(x), a) 10316*da0073e9SAndroid Build Coastguard Worker 10317*da0073e9SAndroid Build Coastguard Worker # 2,3,4-dim 10318*da0073e9SAndroid Build Coastguard Worker b = torch.rand(4, 3, dtype=torch.double, requires_grad=True) 10319*da0073e9SAndroid Build Coastguard Worker c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True) 10320*da0073e9SAndroid Build Coastguard Worker d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True) 10321*da0073e9SAndroid Build Coastguard Worker 10322*da0073e9SAndroid Build Coastguard Worker input_tuple = (s, a, b, c, d) 10323*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) 10324*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) 10325*da0073e9SAndroid Build Coastguard Worker 10326*da0073e9SAndroid Build Coastguard Worker def test_atleast_gradient(self, device="mps"): 10327*da0073e9SAndroid Build Coastguard Worker self._test_atleast(device, torch.atleast_1d) 10328*da0073e9SAndroid Build Coastguard Worker self._test_atleast(device, torch.atleast_2d) 10329*da0073e9SAndroid Build Coastguard Worker self._test_atleast(device, torch.atleast_3d) 10330*da0073e9SAndroid Build Coastguard Worker 10331*da0073e9SAndroid Build Coastguard Worker def test_view(self, device="mps"): 10332*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(15, device=device) 10333*da0073e9SAndroid Build Coastguard Worker template = torch.rand(3, 5, device=device) 10334*da0073e9SAndroid Build Coastguard Worker empty = torch.empty(0, device=device) 10335*da0073e9SAndroid Build Coastguard Worker target = template.size() 10336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view_as(template).size(), target) 10337*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(3, 5).size(), target) 10338*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target) 10339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(-1, 5).size(), target) 10340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(3, -1).size(), target) 10341*da0073e9SAndroid Build Coastguard Worker tensor_view = tensor.view(5, 3) 10342*da0073e9SAndroid Build Coastguard Worker tensor_view.fill_(random.uniform(0, 1)) 10343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view_as(empty), empty) 10344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view(0), empty) 10345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1])) 10346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty) 10347*da0073e9SAndroid Build Coastguard Worker 10348*da0073e9SAndroid Build Coastguard Worker # test size inference with empty tensors 10349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view(-1).size(), torch.Size([0])) 10350*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0])) 10351*da0073e9SAndroid Build Coastguard Worker 10352*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): 10353*da0073e9SAndroid Build Coastguard Worker empty.view(-1, 0) 10354*da0073e9SAndroid Build Coastguard Worker 10355*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): 10356*da0073e9SAndroid Build Coastguard Worker empty.view(3, 0, -1, 0) 10357*da0073e9SAndroid Build Coastguard Worker 10358*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor.view(15, 0)) 10359*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) 10360*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) 10361*da0073e9SAndroid Build Coastguard Worker 10362*da0073e9SAndroid Build Coastguard Worker def test_contiguous(self, device="mps"): 10363*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 16, 5, 5, device=device) 10364*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_contiguous()) 10365*da0073e9SAndroid Build Coastguard Worker stride = list(x.stride()) 10366*da0073e9SAndroid Build Coastguard Worker stride[0] = 20 10367*da0073e9SAndroid Build Coastguard Worker # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 10368*da0073e9SAndroid Build Coastguard Worker x.set_(x.storage(), 0, x.size(), stride) 10369*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_contiguous()) 10370*da0073e9SAndroid Build Coastguard Worker 10371*da0073e9SAndroid Build Coastguard Worker def test_resize_mps_dtypes(self, device="mps"): 10372*da0073e9SAndroid Build Coastguard Worker shape = (2, 2) 10373*da0073e9SAndroid Build Coastguard Worker for dt in MPS_DTYPES: 10374*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 10375*da0073e9SAndroid Build Coastguard Worker x.resize_(shape) 10376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, x.shape) 10377*da0073e9SAndroid Build Coastguard Worker 10378*da0073e9SAndroid Build Coastguard Worker def test_resize_as_mps_dtypes(self, device="mps"): 10379*da0073e9SAndroid Build Coastguard Worker for dt in MPS_DTYPES: 10380*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 10381*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) 10382*da0073e9SAndroid Build Coastguard Worker x.resize_as_(y) 10383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.shape, x.shape) 10384*da0073e9SAndroid Build Coastguard Worker 10385*da0073e9SAndroid Build Coastguard Worker def test_resize_overflow(self, device="mps"): 10386*da0073e9SAndroid Build Coastguard Worker x = torch.empty((), dtype=torch.float64) 10387*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'): 10388*da0073e9SAndroid Build Coastguard Worker x.resize_([2, 4, 2**29, 2**29]) 10389*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'overflow'): 10390*da0073e9SAndroid Build Coastguard Worker x.resize_([8, 8, 2**29, 2**29]) 10391*da0073e9SAndroid Build Coastguard Worker 10392*da0073e9SAndroid Build Coastguard Worker def test_view_all_dtypes_and_devices(self, device="mps"): 10393*da0073e9SAndroid Build Coastguard Worker for dt in (torch.float, torch.bool): 10394*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 10395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.view(6).shape, [6]) 10396*da0073e9SAndroid Build Coastguard Worker 10397*da0073e9SAndroid Build Coastguard Workerclass TestConvolutionMPS(TestCaseMPS): 10398*da0073e9SAndroid Build Coastguard Worker def test_conv1d_all_strides_paddings(self): 10399*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/82921 10400*da0073e9SAndroid Build Coastguard Worker def helper(stride, padding): 10401*da0073e9SAndroid Build Coastguard Worker y_cpu = torch.randn(1, 57, 40) 10402*da0073e9SAndroid Build Coastguard Worker conv_cpu = nn.Conv1d(57, 20, stride=stride, padding=padding, kernel_size=3, bias=False) 10403*da0073e9SAndroid Build Coastguard Worker conv_gpu = copy.deepcopy(conv_cpu).to(device='mps') 10404*da0073e9SAndroid Build Coastguard Worker x_cpu = conv_cpu(y_cpu) 10405*da0073e9SAndroid Build Coastguard Worker 10406*da0073e9SAndroid Build Coastguard Worker y_gpu = y_cpu.to(device='mps') 10407*da0073e9SAndroid Build Coastguard Worker x_gpu = conv_gpu(y_gpu) 10408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_gpu.cpu()) 10409*da0073e9SAndroid Build Coastguard Worker for stride in range(1, 4): 10410*da0073e9SAndroid Build Coastguard Worker for padding in range(1, 4): 10411*da0073e9SAndroid Build Coastguard Worker helper(stride, padding) 10412*da0073e9SAndroid Build Coastguard Worker 10413*da0073e9SAndroid Build Coastguard Worker 10414*da0073e9SAndroid Build Coastguard Worker def test_conv1d_channels_last(self): 10415*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/81557 10416*da0073e9SAndroid Build Coastguard Worker model_cpu = torch.nn.Conv1d(1, 128, 3) 10417*da0073e9SAndroid Build Coastguard Worker a_cpu = torch.arange((128 * 176), dtype=torch.float32) 10418*da0073e9SAndroid Build Coastguard Worker a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1) 10419*da0073e9SAndroid Build Coastguard Worker out_cpu = model_cpu(a_cpu) 10420*da0073e9SAndroid Build Coastguard Worker 10421*da0073e9SAndroid Build Coastguard Worker a_mps = a_cpu.detach().clone().to("mps") 10422*da0073e9SAndroid Build Coastguard Worker model_mps = model_cpu.to("mps") 10423*da0073e9SAndroid Build Coastguard Worker out_mps = model_mps(a_mps) 10424*da0073e9SAndroid Build Coastguard Worker 10425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04) 10426*da0073e9SAndroid Build Coastguard Worker 10427*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose_1d_all_strides(self): 10428*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/82711 10429*da0073e9SAndroid Build Coastguard Worker def helper(stride): 10430*da0073e9SAndroid Build Coastguard Worker y_cpu = torch.ones(1, 1, 2) 10431*da0073e9SAndroid Build Coastguard Worker deconv_cpu = nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=1, stride=stride, bias=False, padding=1) 10432*da0073e9SAndroid Build Coastguard Worker deconv_cpu.weight.data = torch.ones(1, 1, 2) 10433*da0073e9SAndroid Build Coastguard Worker deconv_gpu = copy.deepcopy(deconv_cpu).to(device='mps') 10434*da0073e9SAndroid Build Coastguard Worker x_cpu = deconv_cpu(y_cpu) 10435*da0073e9SAndroid Build Coastguard Worker 10436*da0073e9SAndroid Build Coastguard Worker y_gpu = y_cpu.to(device='mps') 10437*da0073e9SAndroid Build Coastguard Worker x_gpu = deconv_gpu(y_gpu) 10438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_gpu.cpu()) 10439*da0073e9SAndroid Build Coastguard Worker [helper(stride) for stride in [1, 2, 3]] 10440*da0073e9SAndroid Build Coastguard Worker 10441*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose_1d_nn_functional(self): 10442*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/82563 10443*da0073e9SAndroid Build Coastguard Worker tin = torch.rand((1, 512, 1245), dtype=torch.float32) 10444*da0073e9SAndroid Build Coastguard Worker tparams = torch.rand((512, 256, 16), dtype=torch.float32) 10445*da0073e9SAndroid Build Coastguard Worker tbias = torch.rand((256), dtype=torch.float32) 10446*da0073e9SAndroid Build Coastguard Worker 10447*da0073e9SAndroid Build Coastguard Worker device = 'cpu' 10448*da0073e9SAndroid Build Coastguard Worker tcpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4) 10449*da0073e9SAndroid Build Coastguard Worker 10450*da0073e9SAndroid Build Coastguard Worker device = 'mps' 10451*da0073e9SAndroid Build Coastguard Worker tgpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4) 10452*da0073e9SAndroid Build Coastguard Worker 10453*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04) 10454*da0073e9SAndroid Build Coastguard Worker 10455*da0073e9SAndroid Build Coastguard Worker def test_conv_backward_1d_channels_last(self): 10456*da0073e9SAndroid Build Coastguard Worker def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1): 10457*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/84511 10458*da0073e9SAndroid Build Coastguard Worker conv_cpu = torch.nn.Conv1d( 10459*da0073e9SAndroid Build Coastguard Worker in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_() 10460*da0073e9SAndroid Build Coastguard Worker conv_mps = torch.nn.Conv1d( 10461*da0073e9SAndroid Build Coastguard Worker in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps") 10462*da0073e9SAndroid Build Coastguard Worker conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True) 10463*da0073e9SAndroid Build Coastguard Worker conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True) 10464*da0073e9SAndroid Build Coastguard Worker 10465*da0073e9SAndroid Build Coastguard Worker 10466*da0073e9SAndroid Build Coastguard Worker data = torch.rand(shape, dtype=torch.float32) 10467*da0073e9SAndroid Build Coastguard Worker x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True) 10468*da0073e9SAndroid Build Coastguard Worker x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True) 10469*da0073e9SAndroid Build Coastguard Worker res_cpu = conv_cpu(x_cpu) 10470*da0073e9SAndroid Build Coastguard Worker res_mps = conv_mps(x_mps) 10471*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps) 10472*da0073e9SAndroid Build Coastguard Worker res_cpu = res_cpu.sum().backward() 10473*da0073e9SAndroid Build Coastguard Worker res_mps = res_mps.sum().backward() 10474*da0073e9SAndroid Build Coastguard Worker 10475*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04) 10476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu.grad, x_mps.grad) 10477*da0073e9SAndroid Build Coastguard Worker 10478*da0073e9SAndroid Build Coastguard Worker helper(shape=(1, 176, 1)) 10479*da0073e9SAndroid Build Coastguard Worker helper(shape=(2, 12, 1)) 10480*da0073e9SAndroid Build Coastguard Worker helper(shape=(3, 176, 1)) 10481*da0073e9SAndroid Build Coastguard Worker helper(shape=(4, 376, 1)) 10482*da0073e9SAndroid Build Coastguard Worker helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1) 10483*da0073e9SAndroid Build Coastguard Worker helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3) 10484*da0073e9SAndroid Build Coastguard Worker 10485*da0073e9SAndroid Build Coastguard Worker def test_conv1d_contiguous(self): 10486*da0073e9SAndroid Build Coastguard Worker model_cpu = torch.nn.Conv1d(1, 128, 3) 10487*da0073e9SAndroid Build Coastguard Worker a_cpu = torch.ones(128, 1, 176) 10488*da0073e9SAndroid Build Coastguard Worker out_cpu = model_cpu(a_cpu) 10489*da0073e9SAndroid Build Coastguard Worker 10490*da0073e9SAndroid Build Coastguard Worker a_mps = a_cpu.detach().clone().to("mps") 10491*da0073e9SAndroid Build Coastguard Worker model_mps = model_cpu.to("mps") 10492*da0073e9SAndroid Build Coastguard Worker out_mps = model_mps(a_mps) 10493*da0073e9SAndroid Build Coastguard Worker 10494*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu.shape, out_mps.shape) 10495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_mps.cpu()) 10496*da0073e9SAndroid Build Coastguard Worker 10497*da0073e9SAndroid Build Coastguard Worker def test_conv2d_all_strides_paddings(self): 10498*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/83180 10499*da0073e9SAndroid Build Coastguard Worker def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data): 10500*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_() 10501*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_() 10502*da0073e9SAndroid Build Coastguard Worker 10503*da0073e9SAndroid Build Coastguard Worker if permute_data: 10504*da0073e9SAndroid Build Coastguard Worker x_cpu.permute(0, 2, 3, 1) 10505*da0073e9SAndroid Build Coastguard Worker x_mps.permute(0, 2, 3, 1) 10506*da0073e9SAndroid Build Coastguard Worker 10507*da0073e9SAndroid Build Coastguard Worker for strideX in range(1, 4): 10508*da0073e9SAndroid Build Coastguard Worker for strideY in range(1, 4): 10509*da0073e9SAndroid Build Coastguard Worker conv_cpu = torch.nn.Conv2d( 10510*da0073e9SAndroid Build Coastguard Worker in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_() 10511*da0073e9SAndroid Build Coastguard Worker conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_() 10512*da0073e9SAndroid Build Coastguard Worker 10513*da0073e9SAndroid Build Coastguard Worker conv_mps = torch.nn.Conv2d( 10514*da0073e9SAndroid Build Coastguard Worker in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps") 10515*da0073e9SAndroid Build Coastguard Worker conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_() 10516*da0073e9SAndroid Build Coastguard Worker conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_() 10517*da0073e9SAndroid Build Coastguard Worker 10518*da0073e9SAndroid Build Coastguard Worker res_cpu = conv_cpu(x_cpu) 10519*da0073e9SAndroid Build Coastguard Worker res_mps = conv_mps(x_mps) 10520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05) 10521*da0073e9SAndroid Build Coastguard Worker res_cpu = res_cpu.sum().backward() 10522*da0073e9SAndroid Build Coastguard Worker res_mps = res_mps.sum().backward() 10523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04) 10524*da0073e9SAndroid Build Coastguard Worker 10525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04) 10526*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad) 10527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu.grad, x_mps.grad) 10528*da0073e9SAndroid Build Coastguard Worker 10529*da0073e9SAndroid Build Coastguard Worker for mem_format_input in [torch.contiguous_format, torch.channels_last]: 10530*da0073e9SAndroid Build Coastguard Worker for mem_format_weight in [torch.contiguous_format, torch.channels_last]: 10531*da0073e9SAndroid Build Coastguard Worker for permute_data in [True, False]: 10532*da0073e9SAndroid Build Coastguard Worker helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data) 10533*da0073e9SAndroid Build Coastguard Worker helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data) 10534*da0073e9SAndroid Build Coastguard Worker helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data) 10535*da0073e9SAndroid Build Coastguard Worker 10536*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose_2d_strided(self): 10537*da0073e9SAndroid Build Coastguard Worker def helper(m_cpu, memory_format): 10538*da0073e9SAndroid Build Coastguard Worker m_mps = copy.deepcopy(m_cpu).requires_grad_() 10539*da0073e9SAndroid Build Coastguard Worker m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_() 10540*da0073e9SAndroid Build Coastguard Worker m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_() 10541*da0073e9SAndroid Build Coastguard Worker 10542*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_() 10543*da0073e9SAndroid Build Coastguard Worker input_mps = input_cpu.detach().clone().to("mps") 10544*da0073e9SAndroid Build Coastguard Worker 10545*da0073e9SAndroid Build Coastguard Worker output_cpu = m_cpu(input_cpu) 10546*da0073e9SAndroid Build Coastguard Worker output_mps = m_mps(input_mps) 10547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps) 10548*da0073e9SAndroid Build Coastguard Worker 10549*da0073e9SAndroid Build Coastguard Worker for mem_format_input in [torch.contiguous_format, torch.channels_last]: 10550*da0073e9SAndroid Build Coastguard Worker # With square kernels and equal stride 10551*da0073e9SAndroid Build Coastguard Worker helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input) 10552*da0073e9SAndroid Build Coastguard Worker 10553*da0073e9SAndroid Build Coastguard Worker # non-square kernels and unequal stride and with padding 10554*da0073e9SAndroid Build Coastguard Worker helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input) 10555*da0073e9SAndroid Build Coastguard Worker 10556*da0073e9SAndroid Build Coastguard Worker def test_conv_transpose_2d_specified_output(self): 10557*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(1, 16, 12, 12) 10558*da0073e9SAndroid Build Coastguard Worker input_mps = input_cpu.detach().clone().to("mps") 10559*da0073e9SAndroid Build Coastguard Worker 10560*da0073e9SAndroid Build Coastguard Worker downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1) 10561*da0073e9SAndroid Build Coastguard Worker downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps") 10562*da0073e9SAndroid Build Coastguard Worker downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_() 10563*da0073e9SAndroid Build Coastguard Worker downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_() 10564*da0073e9SAndroid Build Coastguard Worker 10565*da0073e9SAndroid Build Coastguard Worker upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1) 10566*da0073e9SAndroid Build Coastguard Worker upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps") 10567*da0073e9SAndroid Build Coastguard Worker upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_() 10568*da0073e9SAndroid Build Coastguard Worker upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_() 10569*da0073e9SAndroid Build Coastguard Worker 10570*da0073e9SAndroid Build Coastguard Worker h_cpu = downsample_cpu(input_cpu) 10571*da0073e9SAndroid Build Coastguard Worker h_mps = downsample_mps(input_mps) 10572*da0073e9SAndroid Build Coastguard Worker self.assertEqual(h_cpu, h_mps) 10573*da0073e9SAndroid Build Coastguard Worker 10574*da0073e9SAndroid Build Coastguard Worker size_cpu = h_cpu.size() 10575*da0073e9SAndroid Build Coastguard Worker size_mps = h_mps.size() 10576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(size_cpu, size_mps) 10577*da0073e9SAndroid Build Coastguard Worker 10578*da0073e9SAndroid Build Coastguard Worker output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size()) 10579*da0073e9SAndroid Build Coastguard Worker output_mps = upsample_mps(h_mps, output_size=input_mps.size()) 10580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu, output_mps) 10581*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_cpu.size(), output_mps.size()) 10582*da0073e9SAndroid Build Coastguard Worker 10583*da0073e9SAndroid Build Coastguard Worker def test_conv2d_single_stride(self): 10584*da0073e9SAndroid Build Coastguard Worker y_cpu = torch.randn(2, 2, 3, 6) 10585*da0073e9SAndroid Build Coastguard Worker y_gpu = y_cpu.to(device='mps') 10586*da0073e9SAndroid Build Coastguard Worker for stride in range(1, 4): 10587*da0073e9SAndroid Build Coastguard Worker conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=stride) 10588*da0073e9SAndroid Build Coastguard Worker conv_gpu = copy.deepcopy(conv_cpu).to(device='mps') 10589*da0073e9SAndroid Build Coastguard Worker x_cpu = conv_cpu(y_cpu) 10590*da0073e9SAndroid Build Coastguard Worker x_gpu = conv_gpu(y_gpu) 10591*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05) 10592*da0073e9SAndroid Build Coastguard Worker 10593*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12") 10594*da0073e9SAndroid Build Coastguard Worker def test_conv3d_single_stride(self): 10595*da0073e9SAndroid Build Coastguard Worker # Conv3d is only available from MacOS 13.2 onwards 10596*da0073e9SAndroid Build Coastguard Worker y_cpu = torch.randn(2, 2, 3, 6) 10597*da0073e9SAndroid Build Coastguard Worker y_gpu = y_cpu.to(device='mps') 10598*da0073e9SAndroid Build Coastguard Worker for stride in range(1, 4): 10599*da0073e9SAndroid Build Coastguard Worker conv_cpu = torch.nn.Conv3d(in_channels=2, out_channels=2, kernel_size=2, stride=stride) 10600*da0073e9SAndroid Build Coastguard Worker conv_gpu = copy.deepcopy(conv_cpu).to(device='mps') 10601*da0073e9SAndroid Build Coastguard Worker x_cpu = conv_cpu(y_cpu) 10602*da0073e9SAndroid Build Coastguard Worker x_gpu = conv_gpu(y_gpu) 10603*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05) 10604*da0073e9SAndroid Build Coastguard Worker 10605*da0073e9SAndroid Build Coastguard Worker def test_grid_sample(self): 10606*da0073e9SAndroid Build Coastguard Worker def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad): 10607*da0073e9SAndroid Build Coastguard Worker def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners): 10608*da0073e9SAndroid Build Coastguard Worker for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]: 10609*da0073e9SAndroid Build Coastguard Worker # grid_dim_contig_order specifies the dimension order that can 10610*da0073e9SAndroid Build Coastguard Worker # make grid to be contiguous. 10611*da0073e9SAndroid Build Coastguard Worker # i.e., grid.permute(grid_dim_contig_order) is contiguous. 10612*da0073e9SAndroid Build Coastguard Worker # e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be 10613*da0073e9SAndroid Build Coastguard Worker # initialized with contiguous tensor of shape [N, 2, H, W] 10614*da0073e9SAndroid Build Coastguard Worker # and permuted to [N, H, W, 2] afterwards. 10615*da0073e9SAndroid Build Coastguard Worker grid_shape = [N, H, W, 2] 10616*da0073e9SAndroid Build Coastguard Worker grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order] 10617*da0073e9SAndroid Build Coastguard Worker grid_fwd_permute = [None, None, None, None] 10618*da0073e9SAndroid Build Coastguard Worker for i, d in enumerate(grid_dim_contig_order): 10619*da0073e9SAndroid Build Coastguard Worker grid_fwd_permute[d] = i 10620*da0073e9SAndroid Build Coastguard Worker 10621*da0073e9SAndroid Build Coastguard Worker def get_grid(device='cpu', data=None): 10622*da0073e9SAndroid Build Coastguard Worker if data is not None: 10623*da0073e9SAndroid Build Coastguard Worker assert list(data.shape) == grid_shape 10624*da0073e9SAndroid Build Coastguard Worker data = data.permute(grid_dim_contig_order).to(device) 10625*da0073e9SAndroid Build Coastguard Worker else: 10626*da0073e9SAndroid Build Coastguard Worker data = torch.randn(grid_init_shape, device=device) 10627*da0073e9SAndroid Build Coastguard Worker grid = data.permute(grid_fwd_permute) 10628*da0073e9SAndroid Build Coastguard Worker assert grid.permute(grid_dim_contig_order).is_contiguous() 10629*da0073e9SAndroid Build Coastguard Worker return grid 10630*da0073e9SAndroid Build Coastguard Worker 10631*da0073e9SAndroid Build Coastguard Worker input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad) 10632*da0073e9SAndroid Build Coastguard Worker grid_cpu = get_grid().requires_grad_() 10633*da0073e9SAndroid Build Coastguard Worker out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode, 10634*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 10635*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu.size(), torch.Size([N, C, H, W])) 10636*da0073e9SAndroid Build Coastguard Worker 10637*da0073e9SAndroid Build Coastguard Worker gradients = torch.randn_like(out_cpu) 10638*da0073e9SAndroid Build Coastguard Worker out_cpu.backward(gradients) 10639*da0073e9SAndroid Build Coastguard Worker 10640*da0073e9SAndroid Build Coastguard Worker 10641*da0073e9SAndroid Build Coastguard Worker # Compare against unvectorized CPU fallback 10642*da0073e9SAndroid Build Coastguard Worker 10643*da0073e9SAndroid Build Coastguard Worker # NOTE [ grid_sample CPU fallback ] 10644*da0073e9SAndroid Build Coastguard Worker # grid_sample uses AVX for 2d images, but that requires 32-bit indexing for 10645*da0073e9SAndroid Build Coastguard Worker # 32-bit floats. So we also have a fallback that is used only for float tensors 10646*da0073e9SAndroid Build Coastguard Worker # requiring 64-bit indexing. That requires too much memory to run on CI, so we 10647*da0073e9SAndroid Build Coastguard Worker # also export the fallback and test it here to ensure feature parity with 10648*da0073e9SAndroid Build Coastguard Worker # the vectorized version. 10649*da0073e9SAndroid Build Coastguard Worker input_fallback = input_cpu.float().detach_().requires_grad_() 10650*da0073e9SAndroid Build Coastguard Worker grid_fallback = grid_cpu.float().detach_().requires_grad_() 10651*da0073e9SAndroid Build Coastguard Worker out_fallback = torch._grid_sampler_2d_cpu_fallback( 10652*da0073e9SAndroid Build Coastguard Worker input_fallback, grid_fallback, 10653*da0073e9SAndroid Build Coastguard Worker F.GRID_SAMPLE_INTERPOLATION_MODES[mode], 10654*da0073e9SAndroid Build Coastguard Worker F.GRID_SAMPLE_PADDING_MODES[padding_mode], 10655*da0073e9SAndroid Build Coastguard Worker align_corners) 10656*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5) 10657*da0073e9SAndroid Build Coastguard Worker 10658*da0073e9SAndroid Build Coastguard Worker out_fallback.backward(gradients.float()) 10659*da0073e9SAndroid Build Coastguard Worker if input_requires_grad: 10660*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5) 10661*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5) 10662*da0073e9SAndroid Build Coastguard Worker 10663*da0073e9SAndroid Build Coastguard Worker input_mps = input_cpu.detach().transpose(0, 1).to("mps").transpose(0, 1).requires_grad_(input_requires_grad) 10664*da0073e9SAndroid Build Coastguard Worker grid_mps = get_grid('mps', grid_cpu.detach()).requires_grad_() 10665*da0073e9SAndroid Build Coastguard Worker out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners) 10666*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_mps) 10667*da0073e9SAndroid Build Coastguard Worker out_mps.backward(gradients.to("mps")) 10668*da0073e9SAndroid Build Coastguard Worker if input_requires_grad: 10669*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_cpu.grad, input_mps.grad) 10670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grid_cpu.grad, grid_mps.grad, atol=5e-5, rtol=0) 10671*da0073e9SAndroid Build Coastguard Worker 10672*da0073e9SAndroid Build Coastguard Worker # check that zero-dimensional input strides don't error out 10673*da0073e9SAndroid Build Coastguard Worker base_input = torch.randn(N, C, 1, IW) 10674*da0073e9SAndroid Build Coastguard Worker input_cpu = base_input.expand_as(input_mps).requires_grad_(input_requires_grad) 10675*da0073e9SAndroid Build Coastguard Worker out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode, 10676*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 10677*da0073e9SAndroid Build Coastguard Worker 10678*da0073e9SAndroid Build Coastguard Worker input_mps = base_input.to("mps").expand_as(input_mps).requires_grad_(input_requires_grad) 10679*da0073e9SAndroid Build Coastguard Worker out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners) 10680*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_mps) 10681*da0073e9SAndroid Build Coastguard Worker 10682*da0073e9SAndroid Build Coastguard Worker # test same size output 10683*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners) 10684*da0073e9SAndroid Build Coastguard Worker 10685*da0073e9SAndroid Build Coastguard Worker # test larger output 10686*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 8) 10687*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 8) 10688*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 8) 10689*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 8) 10690*da0073e9SAndroid Build Coastguard Worker H = random.randint(IH + 1, 12) 10691*da0073e9SAndroid Build Coastguard Worker W = random.randint(IW + 1, 12) 10692*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) 10693*da0073e9SAndroid Build Coastguard Worker 10694*da0073e9SAndroid Build Coastguard Worker # test smaller output 10695*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 8) 10696*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 8) 10697*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 8) 10698*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 8) 10699*da0073e9SAndroid Build Coastguard Worker H = random.randint(2, IH) 10700*da0073e9SAndroid Build Coastguard Worker W = random.randint(2, IW) 10701*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) 10702*da0073e9SAndroid Build Coastguard Worker 10703*da0073e9SAndroid Build Coastguard Worker # test 1x1 inpput 10704*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 8) 10705*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 8) 10706*da0073e9SAndroid Build Coastguard Worker IH = 1 10707*da0073e9SAndroid Build Coastguard Worker IW = 1 10708*da0073e9SAndroid Build Coastguard Worker H = random.randint(2, 5) 10709*da0073e9SAndroid Build Coastguard Worker W = random.randint(2, 5) 10710*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) 10711*da0073e9SAndroid Build Coastguard Worker 10712*da0073e9SAndroid Build Coastguard Worker # testing empty grid 10713*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 8) 10714*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 8) 10715*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 8) 10716*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 8) 10717*da0073e9SAndroid Build Coastguard Worker W = random.randint(3, IW + 2) 10718*da0073e9SAndroid Build Coastguard Worker test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners) 10719*da0073e9SAndroid Build Coastguard Worker 10720*da0073e9SAndroid Build Coastguard Worker # testing empty channel 10721*da0073e9SAndroid Build Coastguard Worker N = random.randint(2, 8) 10722*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 8) 10723*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 8) 10724*da0073e9SAndroid Build Coastguard Worker H = random.randint(3, IH + 2) 10725*da0073e9SAndroid Build Coastguard Worker W = random.randint(3, IW + 2) 10726*da0073e9SAndroid Build Coastguard Worker test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners) 10727*da0073e9SAndroid Build Coastguard Worker 10728*da0073e9SAndroid Build Coastguard Worker # testing empty batch 10729*da0073e9SAndroid Build Coastguard Worker C = random.randint(2, 8) 10730*da0073e9SAndroid Build Coastguard Worker IH = random.randint(2, 8) 10731*da0073e9SAndroid Build Coastguard Worker IW = random.randint(2, 8) 10732*da0073e9SAndroid Build Coastguard Worker H = random.randint(3, IH + 2) 10733*da0073e9SAndroid Build Coastguard Worker W = random.randint(3, IW + 2) 10734*da0073e9SAndroid Build Coastguard Worker test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners) 10735*da0073e9SAndroid Build Coastguard Worker 10736*da0073e9SAndroid Build Coastguard Worker for mode in ('bilinear', 'nearest'): 10737*da0073e9SAndroid Build Coastguard Worker for padding_mode in ('zeros', 'reflection'): 10738*da0073e9SAndroid Build Coastguard Worker for align_corners in (True, False): 10739*da0073e9SAndroid Build Coastguard Worker # test known input 10740*da0073e9SAndroid Build Coastguard Worker input = torch.arange(1., 11, device="mps").view(1, 1, 2, 5) 10741*da0073e9SAndroid Build Coastguard Worker grid = torch.tensor( 10742*da0073e9SAndroid Build Coastguard Worker [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]], 10743*da0073e9SAndroid Build Coastguard Worker [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]], device="mps").view(1, 2, 5, 2) 10744*da0073e9SAndroid Build Coastguard Worker if mode == 'bilinear': 10745*da0073e9SAndroid Build Coastguard Worker if padding_mode == 'zeros': 10746*da0073e9SAndroid Build Coastguard Worker if align_corners: 10747*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10748*da0073e9SAndroid Build Coastguard Worker [[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000], 10749*da0073e9SAndroid Build Coastguard Worker [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]], device="mps").view(1, 1, 2, 5) 10750*da0073e9SAndroid Build Coastguard Worker else: 10751*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10752*da0073e9SAndroid Build Coastguard Worker [[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250], 10753*da0073e9SAndroid Build Coastguard Worker [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]], device="mps").view(1, 1, 2, 5) 10754*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'border': 10755*da0073e9SAndroid Build Coastguard Worker if align_corners: 10756*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10757*da0073e9SAndroid Build Coastguard Worker [[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000], 10758*da0073e9SAndroid Build Coastguard Worker [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]], device="mps").view(1, 1, 2, 5) 10759*da0073e9SAndroid Build Coastguard Worker else: 10760*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10761*da0073e9SAndroid Build Coastguard Worker [[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500], 10762*da0073e9SAndroid Build Coastguard Worker [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]], device="mps").view(1, 1, 2, 5) 10763*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'reflection': 10764*da0073e9SAndroid Build Coastguard Worker if align_corners: 10765*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10766*da0073e9SAndroid Build Coastguard Worker [[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000], 10767*da0073e9SAndroid Build Coastguard Worker [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]], device="mps").view(1, 1, 2, 5) 10768*da0073e9SAndroid Build Coastguard Worker else: 10769*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10770*da0073e9SAndroid Build Coastguard Worker [[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500], 10771*da0073e9SAndroid Build Coastguard Worker [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]], device="mps").view(1, 1, 2, 5) 10772*da0073e9SAndroid Build Coastguard Worker else: 10773*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'") 10774*da0073e9SAndroid Build Coastguard Worker elif mode == 'nearest': 10775*da0073e9SAndroid Build Coastguard Worker if padding_mode == 'zeros': 10776*da0073e9SAndroid Build Coastguard Worker if align_corners: 10777*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10778*da0073e9SAndroid Build Coastguard Worker [[0., 8., 5., 7., 9.], 10779*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5) 10780*da0073e9SAndroid Build Coastguard Worker else: 10781*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10782*da0073e9SAndroid Build Coastguard Worker [[0., 8., 5., 7., 0.], 10783*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5) 10784*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'border': 10785*da0073e9SAndroid Build Coastguard Worker if align_corners: 10786*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10787*da0073e9SAndroid Build Coastguard Worker [[1., 8., 5., 7., 9.], 10788*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5) 10789*da0073e9SAndroid Build Coastguard Worker else: 10790*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10791*da0073e9SAndroid Build Coastguard Worker [[1., 8., 5., 7., 9.], 10792*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5) 10793*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'reflection': 10794*da0073e9SAndroid Build Coastguard Worker if align_corners: 10795*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10796*da0073e9SAndroid Build Coastguard Worker [[1., 8., 5., 7., 9.], 10797*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5) 10798*da0073e9SAndroid Build Coastguard Worker else: 10799*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10800*da0073e9SAndroid Build Coastguard Worker [[1., 8., 5., 7., 9.], 10801*da0073e9SAndroid Build Coastguard Worker [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5) 10802*da0073e9SAndroid Build Coastguard Worker else: 10803*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'") 10804*da0073e9SAndroid Build Coastguard Worker elif mode == 'bicubic': 10805*da0073e9SAndroid Build Coastguard Worker if padding_mode == 'zeros': 10806*da0073e9SAndroid Build Coastguard Worker if align_corners: 10807*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10808*da0073e9SAndroid Build Coastguard Worker [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000], 10809*da0073e9SAndroid Build Coastguard Worker [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]], device="mps").view(1, 1, 2, 5) 10810*da0073e9SAndroid Build Coastguard Worker else: 10811*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10812*da0073e9SAndroid Build Coastguard Worker [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264], 10813*da0073e9SAndroid Build Coastguard Worker [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]], device="mps").view(1, 1, 2, 5) 10814*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'border': 10815*da0073e9SAndroid Build Coastguard Worker if align_corners: 10816*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10817*da0073e9SAndroid Build Coastguard Worker [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000], 10818*da0073e9SAndroid Build Coastguard Worker [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]], device="mps").view(1, 1, 2, 5) 10819*da0073e9SAndroid Build Coastguard Worker else: 10820*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10821*da0073e9SAndroid Build Coastguard Worker [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781], 10822*da0073e9SAndroid Build Coastguard Worker [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]], device="mps").view(1, 1, 2, 5) 10823*da0073e9SAndroid Build Coastguard Worker elif padding_mode == 'reflection': 10824*da0073e9SAndroid Build Coastguard Worker if align_corners: 10825*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10826*da0073e9SAndroid Build Coastguard Worker [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000], 10827*da0073e9SAndroid Build Coastguard Worker [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]], device="mps").view(1, 1, 2, 5) 10828*da0073e9SAndroid Build Coastguard Worker else: 10829*da0073e9SAndroid Build Coastguard Worker groundtruth = torch.tensor( 10830*da0073e9SAndroid Build Coastguard Worker [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531], 10831*da0073e9SAndroid Build Coastguard Worker [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]], device="mps").view(1, 1, 2, 5) 10832*da0073e9SAndroid Build Coastguard Worker else: 10833*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'") 10834*da0073e9SAndroid Build Coastguard Worker 10835*da0073e9SAndroid Build Coastguard Worker else: 10836*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"missing groundtruth test for interpolation mode '{mode}'") 10837*da0073e9SAndroid Build Coastguard Worker output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, 10838*da0073e9SAndroid Build Coastguard Worker align_corners=align_corners) 10839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, groundtruth, atol=1e-5, rtol=0, 10840*da0073e9SAndroid Build Coastguard Worker msg=f"groundtruth comparison failed for mode={mode}, " 10841*da0073e9SAndroid Build Coastguard Worker f"padding_mode={padding_mode}") 10842*da0073e9SAndroid Build Coastguard Worker 10843*da0073e9SAndroid Build Coastguard Workerclass TestAdvancedIndexing(TestCaseMPS): 10844*da0073e9SAndroid Build Coastguard Worker supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8] 10845*da0073e9SAndroid Build Coastguard Worker supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8] 10846*da0073e9SAndroid Build Coastguard Worker 10847*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(product_version < 14.0, "Skipped on macOS < 14") 10848*da0073e9SAndroid Build Coastguard Worker def test_nonzero_no_warning(self): 10849*da0073e9SAndroid Build Coastguard Worker device = "mps" 10850*da0073e9SAndroid Build Coastguard Worker t = torch.randn((2, 2), device=device) 10851*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 10852*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 10853*da0073e9SAndroid Build Coastguard Worker torch.nonzero(t) 10854*da0073e9SAndroid Build Coastguard Worker t.nonzero() 10855*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 10856*da0073e9SAndroid Build Coastguard Worker 10857*da0073e9SAndroid Build Coastguard Worker def test_nonzero(self): 10858*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 10859*da0073e9SAndroid Build Coastguard Worker device = "mps" 10860*da0073e9SAndroid Build Coastguard Worker shapes = [ 10861*da0073e9SAndroid Build Coastguard Worker torch.Size((12,)), 10862*da0073e9SAndroid Build Coastguard Worker torch.Size((12, 1)), 10863*da0073e9SAndroid Build Coastguard Worker torch.Size((1, 12)), 10864*da0073e9SAndroid Build Coastguard Worker torch.Size((6, 2)), 10865*da0073e9SAndroid Build Coastguard Worker torch.Size((3, 2, 2)), 10866*da0073e9SAndroid Build Coastguard Worker torch.Size((5, 5, 5)), 10867*da0073e9SAndroid Build Coastguard Worker ] 10868*da0073e9SAndroid Build Coastguard Worker 10869*da0073e9SAndroid Build Coastguard Worker def gen_nontrivial_input(shape, dtype, device): 10870*da0073e9SAndroid Build Coastguard Worker if dtype != torch.bfloat16: 10871*da0073e9SAndroid Build Coastguard Worker return torch.randint(2, shape, device=device, dtype=dtype) 10872*da0073e9SAndroid Build Coastguard Worker else: 10873*da0073e9SAndroid Build Coastguard Worker # windows does not work for bfloat16 randing 10874*da0073e9SAndroid Build Coastguard Worker return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype) 10875*da0073e9SAndroid Build Coastguard Worker 10876*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 10877*da0073e9SAndroid Build Coastguard Worker tensor = gen_nontrivial_input(shape, dtype, device) 10878*da0073e9SAndroid Build Coastguard Worker dst1 = torch.nonzero(tensor, as_tuple=False) 10879*da0073e9SAndroid Build Coastguard Worker dst2 = tensor.nonzero(as_tuple=False) 10880*da0073e9SAndroid Build Coastguard Worker dst3 = torch.empty([], dtype=torch.long, device=device) 10881*da0073e9SAndroid Build Coastguard Worker dst3 = dst3.resize_(0) 10882*da0073e9SAndroid Build Coastguard Worker torch.nonzero(tensor, out=dst3) 10883*da0073e9SAndroid Build Coastguard Worker np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy() 10884*da0073e9SAndroid Build Coastguard Worker np_result = torch.from_numpy(np.stack(np_array.nonzero())).t() 10885*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0) 10886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0) 10887*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0) 10888*da0073e9SAndroid Build Coastguard Worker tup1 = torch.nonzero(tensor, as_tuple=True) 10889*da0073e9SAndroid Build Coastguard Worker tup2 = tensor.nonzero(as_tuple=True) 10890*da0073e9SAndroid Build Coastguard Worker tup1 = torch.stack(tup1).t().cpu() 10891*da0073e9SAndroid Build Coastguard Worker tup2 = torch.stack(tup2).t().cpu() 10892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tup1, np_result, atol=0, rtol=0) 10893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tup2, np_result, atol=0, rtol=0) 10894*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in self.supported_dtypes] 10895*da0073e9SAndroid Build Coastguard Worker 10896*da0073e9SAndroid Build Coastguard Worker def test_nonzero_astuple_out(self): 10897*da0073e9SAndroid Build Coastguard Worker device = "mps" 10898*da0073e9SAndroid Build Coastguard Worker t = torch.randn((3, 3, 3), device=device) 10899*da0073e9SAndroid Build Coastguard Worker out = torch.empty([], dtype=torch.long, device=device) 10900*da0073e9SAndroid Build Coastguard Worker out = out.resize_(0) 10901*da0073e9SAndroid Build Coastguard Worker 10902*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 10903*da0073e9SAndroid Build Coastguard Worker torch.nonzero(t, as_tuple=True, out=out) 10904*da0073e9SAndroid Build Coastguard Worker 10905*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out)) 10906*da0073e9SAndroid Build Coastguard Worker 10907*da0073e9SAndroid Build Coastguard Worker # Verifies that JIT script cannot handle the as_tuple kwarg 10908*da0073e9SAndroid Build Coastguard Worker # See Issue https://github.com/pytorch/pytorch/issues/45499. 10909*da0073e9SAndroid Build Coastguard Worker def _foo(t): 10910*da0073e9SAndroid Build Coastguard Worker tuple_result = torch.nonzero(t, as_tuple=True) 10911*da0073e9SAndroid Build Coastguard Worker nontuple_result = torch.nonzero(t, as_tuple=False) 10912*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(nontuple_result) 10913*da0073e9SAndroid Build Coastguard Worker torch.nonzero(t, as_tuple=False, out=out) 10914*da0073e9SAndroid Build Coastguard Worker return tuple_result, nontuple_result, out 10915*da0073e9SAndroid Build Coastguard Worker 10916*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 10917*da0073e9SAndroid Build Coastguard Worker scripted_foo = torch.jit.script(_foo) 10918*da0073e9SAndroid Build Coastguard Worker 10919*da0073e9SAndroid Build Coastguard Worker # Verifies that JIT tracing works fine 10920*da0073e9SAndroid Build Coastguard Worker traced_foo = torch.jit.trace(_foo, t) 10921*da0073e9SAndroid Build Coastguard Worker traced_tuple, traced_nontuple, traced_out = traced_foo(t) 10922*da0073e9SAndroid Build Coastguard Worker expected_tuple = torch.nonzero(t, as_tuple=True) 10923*da0073e9SAndroid Build Coastguard Worker expected_nontuple = torch.nonzero(t) 10924*da0073e9SAndroid Build Coastguard Worker 10925*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_tuple, expected_tuple) 10926*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_nontuple, expected_nontuple) 10927*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_out, expected_nontuple) 10928*da0073e9SAndroid Build Coastguard Worker 10929*da0073e9SAndroid Build Coastguard Worker def test_nonzero_discontiguous(self): 10930*da0073e9SAndroid Build Coastguard Worker device = "mps" 10931*da0073e9SAndroid Build Coastguard Worker shape = (4, 4) 10932*da0073e9SAndroid Build Coastguard Worker tensor = torch.randint(2, shape, device=device) 10933*da0073e9SAndroid Build Coastguard Worker tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor) 10934*da0073e9SAndroid Build Coastguard Worker dst1 = tensor.nonzero(as_tuple=False) 10935*da0073e9SAndroid Build Coastguard Worker dst2 = tensor_nc.nonzero(as_tuple=False) 10936*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst1, dst2, atol=0, rtol=0) 10937*da0073e9SAndroid Build Coastguard Worker dst3 = torch.empty_like(dst1) 10938*da0073e9SAndroid Build Coastguard Worker data_ptr = dst3.data_ptr() 10939*da0073e9SAndroid Build Coastguard Worker # expect dst3 storage to be reused 10940*da0073e9SAndroid Build Coastguard Worker torch.nonzero(tensor, out=dst3) 10941*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data_ptr, dst3.data_ptr()) 10942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst1, dst3, atol=0, rtol=0) 10943*da0073e9SAndroid Build Coastguard Worker # discontiguous out 10944*da0073e9SAndroid Build Coastguard Worker dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2] 10945*da0073e9SAndroid Build Coastguard Worker data_ptr = dst4.data_ptr() 10946*da0073e9SAndroid Build Coastguard Worker strides = dst4.stride() 10947*da0073e9SAndroid Build Coastguard Worker torch.nonzero(tensor, out=dst4) 10948*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data_ptr, dst4.data_ptr()) 10949*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst1, dst4, atol=0, rtol=0) 10950*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strides, dst4.stride()) 10951*da0073e9SAndroid Build Coastguard Worker 10952*da0073e9SAndroid Build Coastguard Worker def test_nonzero_non_diff(self): 10953*da0073e9SAndroid Build Coastguard Worker device = "mps" 10954*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, requires_grad=True, device=device) 10955*da0073e9SAndroid Build Coastguard Worker nz = x.nonzero() 10956*da0073e9SAndroid Build Coastguard Worker self.assertFalse(nz.requires_grad) 10957*da0073e9SAndroid Build Coastguard Worker 10958*da0073e9SAndroid Build Coastguard Worker def test_nonzero_multi_threading(self): 10959*da0073e9SAndroid Build Coastguard Worker # Test that MPS doesn't crash if nonzero called concurrently 10960*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/100285 10961*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 3, device="mps") 10962*da0073e9SAndroid Build Coastguard Worker t1 = threading.Thread(target=torch.nonzero, args=(x,)) 10963*da0073e9SAndroid Build Coastguard Worker t2 = threading.Thread(target=torch.nonzero, args=(x,)) 10964*da0073e9SAndroid Build Coastguard Worker t1.start() 10965*da0073e9SAndroid Build Coastguard Worker t2.start() 10966*da0073e9SAndroid Build Coastguard Worker 10967*da0073e9SAndroid Build Coastguard Worker def test_sliced_view_cast(self): 10968*da0073e9SAndroid Build Coastguard Worker # This used to crash on MacOS Sequoia 10969*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/137800 10970*da0073e9SAndroid Build Coastguard Worker x = torch.rand(16, 16, device='mps', dtype=torch.float16) 10971*da0073e9SAndroid Build Coastguard Worker y = x[:, 0:2].view(torch.float32) + 1 10972*da0073e9SAndroid Build Coastguard Worker 10973*da0073e9SAndroid Build Coastguard Worker def test_masked_select(self): 10974*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4) 10975*da0073e9SAndroid Build Coastguard Worker x_mps = x.to("mps") 10976*da0073e9SAndroid Build Coastguard Worker mask = x.ge(0.5) 10977*da0073e9SAndroid Build Coastguard Worker mask_mps = x_mps.ge(0.5) 10978*da0073e9SAndroid Build Coastguard Worker 10979*da0073e9SAndroid Build Coastguard Worker res = torch.masked_select(x, mask) 10980*da0073e9SAndroid Build Coastguard Worker res_mps = torch.masked_select(x_mps, mask_mps) 10981*da0073e9SAndroid Build Coastguard Worker 10982*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_mps) 10983*da0073e9SAndroid Build Coastguard Worker 10984*da0073e9SAndroid Build Coastguard Worker # examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm 10985*da0073e9SAndroid Build Coastguard Worker def test_indexing_get(self): 10986*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 10987*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype) 10988*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.detach().clone().to("mps") 10989*da0073e9SAndroid Build Coastguard Worker 10990*da0073e9SAndroid Build Coastguard Worker y_cpu = x_cpu[[0, 1, 2], [0, 1, 0]] 10991*da0073e9SAndroid Build Coastguard Worker y_mps = x_mps[[0, 1, 2], [0, 1, 0]] 10992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_cpu, y_mps, str(dtype)) 10993*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in self.supported_dtypes] 10994*da0073e9SAndroid Build Coastguard Worker 10995*da0073e9SAndroid Build Coastguard Worker def test_indexing_select_corners(self): 10996*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 10997*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype) 10998*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.detach().clone().to("mps") 10999*da0073e9SAndroid Build Coastguard Worker 11000*da0073e9SAndroid Build Coastguard Worker rows_cpu = torch.tensor([[0, 0], [3, 3]]) 11001*da0073e9SAndroid Build Coastguard Worker rows_mps = rows_cpu.detach().clone().to("mps") 11002*da0073e9SAndroid Build Coastguard Worker 11003*da0073e9SAndroid Build Coastguard Worker cols_cpu = torch.tensor([[0, 2], [0, 2]]) 11004*da0073e9SAndroid Build Coastguard Worker cols_mps = cols_cpu.detach().clone().to("mps") 11005*da0073e9SAndroid Build Coastguard Worker 11006*da0073e9SAndroid Build Coastguard Worker res_cpu = x_cpu[rows_cpu, cols_cpu] 11007*da0073e9SAndroid Build Coastguard Worker res_mps = x_mps[rows_mps, cols_mps] 11008*da0073e9SAndroid Build Coastguard Worker 11009*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps, str(dtype)) 11010*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in self.supported_dtypes] 11011*da0073e9SAndroid Build Coastguard Worker 11012*da0073e9SAndroid Build Coastguard Worker # FIXME: uint8 fails for this testcase, needs further debugging 11013*da0073e9SAndroid Build Coastguard Worker def test_slicing_using_advanced_index_for_column(self): 11014*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 11015*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype) 11016*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.detach().clone().to("mps") 11017*da0073e9SAndroid Build Coastguard Worker 11018*da0073e9SAndroid Build Coastguard Worker z_cpu = x_cpu[1:4, 1:3] 11019*da0073e9SAndroid Build Coastguard Worker z_mps = x_mps[1:4, 1:3] 11020*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z_cpu, z_mps, str(dtype)) 11021*da0073e9SAndroid Build Coastguard Worker 11022*da0073e9SAndroid Build Coastguard Worker # using advanced index for column 11023*da0073e9SAndroid Build Coastguard Worker y_cpu = x_cpu[1:4, [1, 2]] 11024*da0073e9SAndroid Build Coastguard Worker y_mps = x_mps[1:4, [1, 2]] 11025*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_cpu, y_mps, str(dtype)) 11026*da0073e9SAndroid Build Coastguard Worker # FIXME: use supported_dtypes once uint8 is fixed 11027*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]] 11028*da0073e9SAndroid Build Coastguard Worker 11029*da0073e9SAndroid Build Coastguard Worker def test_boolean_array_indexing(self): 11030*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 11031*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype) 11032*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.detach().clone().to("mps") 11033*da0073e9SAndroid Build Coastguard Worker 11034*da0073e9SAndroid Build Coastguard Worker res_cpu = x_cpu[x_cpu > 5] 11035*da0073e9SAndroid Build Coastguard Worker res_mps = x_mps[x_mps > 5] 11036*da0073e9SAndroid Build Coastguard Worker 11037*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_mps, str(dtype)) 11038*da0073e9SAndroid Build Coastguard Worker for dtype in self.supported_dtypes: 11039*da0073e9SAndroid Build Coastguard Worker # MPS support binary op with uint8 natively starting from macOS 13.0 11040*da0073e9SAndroid Build Coastguard Worker if product_version < 13.0 and dtype == torch.uint8: 11041*da0073e9SAndroid Build Coastguard Worker continue 11042*da0073e9SAndroid Build Coastguard Worker helper(dtype) 11043*da0073e9SAndroid Build Coastguard Worker 11044*da0073e9SAndroid Build Coastguard Worker def test_advanced_indexing_3D_get(self): 11045*da0073e9SAndroid Build Coastguard Worker def helper(x_cpu): 11046*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.detach().clone().to("mps") 11047*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu[[1, 2], 3, :], x_mps[[1, 2], 3, :]) 11048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu[[0, 2], :, :], x_mps[[0, 2], :, :]) 11049*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu[:, [1, 0], [1]], x_mps[:, [1, 0], [1]]) 11050*da0073e9SAndroid Build Coastguard Worker 11051*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4], 11052*da0073e9SAndroid Build Coastguard Worker [0.5, 0.6, 0.7, 0.8], 11053*da0073e9SAndroid Build Coastguard Worker [0.9, 1.0, 1.1, 1.2], 11054*da0073e9SAndroid Build Coastguard Worker [1.3, 1.4, 1.5, 1.6]], 11055*da0073e9SAndroid Build Coastguard Worker 11056*da0073e9SAndroid Build Coastguard Worker [[2.0, 2.1, 2.2, 2.3], 11057*da0073e9SAndroid Build Coastguard Worker [2.4, 2.5, 2.6, 2.7], 11058*da0073e9SAndroid Build Coastguard Worker [2.8, 2.9, 3.0, 3.1], 11059*da0073e9SAndroid Build Coastguard Worker [3.2, 3.3, 3.4, 3.5]], 11060*da0073e9SAndroid Build Coastguard Worker 11061*da0073e9SAndroid Build Coastguard Worker [[4.0, 4.1, 4.2, 4.3], 11062*da0073e9SAndroid Build Coastguard Worker [4.4, 4.5, 4.6, 4.7], 11063*da0073e9SAndroid Build Coastguard Worker [4.8, 4.9, 5.0, 5.1], 11064*da0073e9SAndroid Build Coastguard Worker [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32) 11065*da0073e9SAndroid Build Coastguard Worker helper(x_cpu) 11066*da0073e9SAndroid Build Coastguard Worker for idx in range(len(self.supported_np_dtypes)): 11067*da0073e9SAndroid Build Coastguard Worker # torch.randn / torch.rand don't work with all dtypes 11068*da0073e9SAndroid Build Coastguard Worker # Generate input data for all dtypes on Numpy them move to torch 11069*da0073e9SAndroid Build Coastguard Worker input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx]) 11070*da0073e9SAndroid Build Coastguard Worker inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx]) 11071*da0073e9SAndroid Build Coastguard Worker 11072*da0073e9SAndroid Build Coastguard Worker helper(inputCPU) 11073*da0073e9SAndroid Build Coastguard Worker 11074*da0073e9SAndroid Build Coastguard Worker def test_advanced_indexing_3D_put(self): 11075*da0073e9SAndroid Build Coastguard Worker def helper(x_cpu): 11076*da0073e9SAndroid Build Coastguard Worker dtype = x_cpu.dtype 11077*da0073e9SAndroid Build Coastguard Worker x_mps = x_cpu.detach().clone().to("mps") 11078*da0073e9SAndroid Build Coastguard Worker 11079*da0073e9SAndroid Build Coastguard Worker out_tensor_cpu = torch.tensor([88, 99], dtype=dtype, device="cpu") 11080*da0073e9SAndroid Build Coastguard Worker out_tensor_cpu_view = out_tensor_cpu[1:] 11081*da0073e9SAndroid Build Coastguard Worker 11082*da0073e9SAndroid Build Coastguard Worker out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps") 11083*da0073e9SAndroid Build Coastguard Worker out_tensor_mps_view = out_tensor_mps[1:] 11084*da0073e9SAndroid Build Coastguard Worker 11085*da0073e9SAndroid Build Coastguard Worker x_cpu[[1, 2], 3, :] = out_tensor_cpu_view 11086*da0073e9SAndroid Build Coastguard Worker x_mps[[1, 2], 3, :] = out_tensor_mps_view 11087*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_mps) 11088*da0073e9SAndroid Build Coastguard Worker 11089*da0073e9SAndroid Build Coastguard Worker x_cpu[[0, 2], :, :] = out_tensor_cpu_view 11090*da0073e9SAndroid Build Coastguard Worker x_mps[[0, 2], :, :] = out_tensor_mps_view 11091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_mps) 11092*da0073e9SAndroid Build Coastguard Worker 11093*da0073e9SAndroid Build Coastguard Worker x_cpu[:, [1, 0], [1]] = out_tensor_cpu_view 11094*da0073e9SAndroid Build Coastguard Worker x_mps[:, [1, 0], [1]] = out_tensor_mps_view 11095*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cpu, x_mps) 11096*da0073e9SAndroid Build Coastguard Worker 11097*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4], 11098*da0073e9SAndroid Build Coastguard Worker [0.5, 0.6, 0.7, 0.8], 11099*da0073e9SAndroid Build Coastguard Worker [0.9, 1.0, 1.1, 1.2], 11100*da0073e9SAndroid Build Coastguard Worker [1.3, 1.4, 1.5, 1.6]], 11101*da0073e9SAndroid Build Coastguard Worker 11102*da0073e9SAndroid Build Coastguard Worker [[2.0, 2.1, 2.2, 2.3], 11103*da0073e9SAndroid Build Coastguard Worker [2.4, 2.5, 2.6, 2.7], 11104*da0073e9SAndroid Build Coastguard Worker [2.8, 2.9, 3.0, 3.1], 11105*da0073e9SAndroid Build Coastguard Worker [3.2, 3.3, 3.4, 3.5]], 11106*da0073e9SAndroid Build Coastguard Worker 11107*da0073e9SAndroid Build Coastguard Worker [[4.0, 4.1, 4.2, 4.3], 11108*da0073e9SAndroid Build Coastguard Worker [4.4, 4.5, 4.6, 4.7], 11109*da0073e9SAndroid Build Coastguard Worker [4.8, 4.9, 5.0, 5.1], 11110*da0073e9SAndroid Build Coastguard Worker [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32) 11111*da0073e9SAndroid Build Coastguard Worker helper(x_cpu) 11112*da0073e9SAndroid Build Coastguard Worker for idx in range(len(self.supported_np_dtypes)): 11113*da0073e9SAndroid Build Coastguard Worker # torch.randn / torch.rand don't work with all dtypes 11114*da0073e9SAndroid Build Coastguard Worker # Generate input data for all dtypes on Numpy them move to torch 11115*da0073e9SAndroid Build Coastguard Worker input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx]) 11116*da0073e9SAndroid Build Coastguard Worker inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx]) 11117*da0073e9SAndroid Build Coastguard Worker 11118*da0073e9SAndroid Build Coastguard Worker helper(inputCPU) 11119*da0073e9SAndroid Build Coastguard Worker 11120*da0073e9SAndroid Build Coastguard Worker def test_index_put_with_view_indices(self): 11121*da0073e9SAndroid Build Coastguard Worker def helper(dtype): 11122*da0073e9SAndroid Build Coastguard Worker target_cpu = torch.zeros([5, 3], device="cpu", dtype=dtype) 11123*da0073e9SAndroid Build Coastguard Worker target_mps = torch.zeros([5, 3], device="mps", dtype=dtype) 11124*da0073e9SAndroid Build Coastguard Worker 11125*da0073e9SAndroid Build Coastguard Worker indices_cpu = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cpu") 11126*da0073e9SAndroid Build Coastguard Worker indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps") 11127*da0073e9SAndroid Build Coastguard Worker 11128*da0073e9SAndroid Build Coastguard Worker value_cpu = torch.ones(indices_cpu.shape[0], device="cpu", dtype=dtype) 11129*da0073e9SAndroid Build Coastguard Worker value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype) 11130*da0073e9SAndroid Build Coastguard Worker 11131*da0073e9SAndroid Build Coastguard Worker target_cpu.index_put_(tuple(indices_cpu.t()), value_cpu, accumulate=True) 11132*da0073e9SAndroid Build Coastguard Worker target_mps.index_put_(tuple(indices_mps.t()), value_mps, accumulate=True) 11133*da0073e9SAndroid Build Coastguard Worker 11134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target_cpu, target_mps) 11135*da0073e9SAndroid Build Coastguard Worker 11136*da0073e9SAndroid Build Coastguard Worker [helper(dtype) for dtype in [torch.int32, torch.float]] 11137*da0073e9SAndroid Build Coastguard Worker 11138*da0073e9SAndroid Build Coastguard Worker # tests from 'test_indexing.py' 11139*da0073e9SAndroid Build Coastguard Worker def test_advancedindex_big(self, device="mps"): 11140*da0073e9SAndroid Build Coastguard Worker reference = torch.arange(0, 123344, dtype=torch.int, device=device) 11141*da0073e9SAndroid Build Coastguard Worker 11142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ], 11143*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int)) 11144*da0073e9SAndroid Build Coastguard Worker 11145*da0073e9SAndroid Build Coastguard Worker def test_set_item_to_scalar_tensor(self, device="mps"): 11146*da0073e9SAndroid Build Coastguard Worker m = random.randint(1, 10) 11147*da0073e9SAndroid Build Coastguard Worker n = random.randint(1, 10) 11148*da0073e9SAndroid Build Coastguard Worker z = torch.randn([m, n], device=device) 11149*da0073e9SAndroid Build Coastguard Worker a = 1.0 11150*da0073e9SAndroid Build Coastguard Worker w = torch.tensor(a, requires_grad=True, device=device) 11151*da0073e9SAndroid Build Coastguard Worker z[:, 0] = w 11152*da0073e9SAndroid Build Coastguard Worker z.sum().backward() 11153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.grad, m * a) 11154*da0073e9SAndroid Build Coastguard Worker 11155*da0073e9SAndroid Build Coastguard Worker def test_single_int(self, device="mps"): 11156*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 11157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[4].shape, (7, 3)) 11158*da0073e9SAndroid Build Coastguard Worker 11159*da0073e9SAndroid Build Coastguard Worker def test_multiple_int(self, device="mps"): 11160*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 11161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[4].shape, (7, 3)) 11162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[4, :, 1].shape, (7,)) 11163*da0073e9SAndroid Build Coastguard Worker 11164*da0073e9SAndroid Build Coastguard Worker def test_none(self, device="mps"): 11165*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 11166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[None].shape, (1, 5, 7, 3)) 11167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[:, None].shape, (5, 1, 7, 3)) 11168*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3)) 11169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[..., None].shape, (5, 7, 3, 1)) 11170*da0073e9SAndroid Build Coastguard Worker 11171*da0073e9SAndroid Build Coastguard Worker def test_step(self, device="mps"): 11172*da0073e9SAndroid Build Coastguard Worker v = torch.arange(10, device=device) 11173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[::1], v) 11174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8]) 11175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[::3].tolist(), [0, 3, 6, 9]) 11176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[::11].tolist(), [0]) 11177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[1:6:2].tolist(), [1, 3, 5]) 11178*da0073e9SAndroid Build Coastguard Worker 11179*da0073e9SAndroid Build Coastguard Worker def test_step_assignment(self, device="mps"): 11180*da0073e9SAndroid Build Coastguard Worker v = torch.zeros(4, 4, device=device) 11181*da0073e9SAndroid Build Coastguard Worker v[0, 1::2] = torch.tensor([3., 4.], device=device) 11182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[0].tolist(), [0, 3, 0, 4]) 11183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[1:].sum(), 0) 11184*da0073e9SAndroid Build Coastguard Worker 11185*da0073e9SAndroid Build Coastguard Worker def test_bool_indices(self, device="mps"): 11186*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 11187*da0073e9SAndroid Build Coastguard Worker boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device) 11188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[boolIndices].shape, (3, 7, 3)) 11189*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]])) 11190*da0073e9SAndroid Build Coastguard Worker 11191*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([True, False, True], dtype=torch.bool, device=device) 11192*da0073e9SAndroid Build Coastguard Worker boolIndices = torch.tensor([True, False, False], dtype=torch.bool, device=device) 11193*da0073e9SAndroid Build Coastguard Worker uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device) 11194*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 11195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape) 11196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[boolIndices], v[uint8Indices]) 11197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device)) 11198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 2) 11199*da0073e9SAndroid Build Coastguard Worker 11200*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12") 11201*da0073e9SAndroid Build Coastguard Worker def test_bool_indices_accumulate(self, device="mps"): 11202*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device) 11203*da0073e9SAndroid Build Coastguard Worker mask = mask > 0 11204*da0073e9SAndroid Build Coastguard Worker y = torch.ones(size=(10, 10), device=device) 11205*da0073e9SAndroid Build Coastguard Worker y.index_put_((mask, ), y[mask], accumulate=True) 11206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, torch.ones(size=(10, 10), device=device)) 11207*da0073e9SAndroid Build Coastguard Worker 11208*da0073e9SAndroid Build Coastguard Worker def test_multiple_bool_indices(self, device="mps"): 11209*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 11210*da0073e9SAndroid Build Coastguard Worker # note: these broadcast together and are transposed to the first dim 11211*da0073e9SAndroid Build Coastguard Worker mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device) 11212*da0073e9SAndroid Build Coastguard Worker mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device) 11213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) 11214*da0073e9SAndroid Build Coastguard Worker 11215*da0073e9SAndroid Build Coastguard Worker def test_byte_mask(self, device="mps"): 11216*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 11217*da0073e9SAndroid Build Coastguard Worker mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) 11218*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 11219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[mask].shape, (3, 7, 3)) 11220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]])) 11221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 2) 11222*da0073e9SAndroid Build Coastguard Worker 11223*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([1.], device=device) 11224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[v == 0], torch.tensor([], device=device)) 11225*da0073e9SAndroid Build Coastguard Worker 11226*da0073e9SAndroid Build Coastguard Worker def test_byte_mask_accumulate(self, device="mps"): 11227*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device) 11228*da0073e9SAndroid Build Coastguard Worker y = torch.ones(size=(10, 10), device=device) 11229*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 11230*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 11231*da0073e9SAndroid Build Coastguard Worker y.index_put_((mask, ), y[mask], accumulate=True) 11232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, torch.ones(size=(10, 10), device=device)) 11233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 2) 11234*da0073e9SAndroid Build Coastguard Worker 11235*da0073e9SAndroid Build Coastguard Worker def test_index_put_accumulate_expanded_values(self, device="mps"): 11236*da0073e9SAndroid Build Coastguard Worker t = torch.zeros((5, 2)) 11237*da0073e9SAndroid Build Coastguard Worker t_dev = t.to(device) 11238*da0073e9SAndroid Build Coastguard Worker indices = [ 11239*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 1, 2, 3]), 11240*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, ]), 11241*da0073e9SAndroid Build Coastguard Worker ] 11242*da0073e9SAndroid Build Coastguard Worker indices_dev = [i.to(device) for i in indices] 11243*da0073e9SAndroid Build Coastguard Worker values0d = torch.tensor(1.0) 11244*da0073e9SAndroid Build Coastguard Worker values1d = torch.tensor([1.0, ]) 11245*da0073e9SAndroid Build Coastguard Worker 11246*da0073e9SAndroid Build Coastguard Worker out_mps = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True) 11247*da0073e9SAndroid Build Coastguard Worker out_cpu = t.index_put_(indices, values0d, accumulate=True) 11248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_mps.cpu(), out_cpu) 11249*da0073e9SAndroid Build Coastguard Worker 11250*da0073e9SAndroid Build Coastguard Worker out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) 11251*da0073e9SAndroid Build Coastguard Worker out_cpu = t.index_put_(indices, values1d, accumulate=True) 11252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_mps.cpu(), out_cpu) 11253*da0073e9SAndroid Build Coastguard Worker 11254*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(4, 3, 2) 11255*da0073e9SAndroid Build Coastguard Worker t_dev = t.to(device) 11256*da0073e9SAndroid Build Coastguard Worker 11257*da0073e9SAndroid Build Coastguard Worker indices = [ 11258*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, ]), 11259*da0073e9SAndroid Build Coastguard Worker torch.arange(3)[:, None], 11260*da0073e9SAndroid Build Coastguard Worker torch.arange(2)[None, :], 11261*da0073e9SAndroid Build Coastguard Worker ] 11262*da0073e9SAndroid Build Coastguard Worker indices_dev = [i.to(device) for i in indices] 11263*da0073e9SAndroid Build Coastguard Worker values1d = torch.tensor([-1.0, -2.0]) 11264*da0073e9SAndroid Build Coastguard Worker values2d = torch.tensor([[-1.0, -2.0], ]) 11265*da0073e9SAndroid Build Coastguard Worker 11266*da0073e9SAndroid Build Coastguard Worker out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) 11267*da0073e9SAndroid Build Coastguard Worker out_cpu = t.index_put_(indices, values1d, accumulate=True) 11268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_mps.cpu(), out_cpu) 11269*da0073e9SAndroid Build Coastguard Worker 11270*da0073e9SAndroid Build Coastguard Worker out_mps = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True) 11271*da0073e9SAndroid Build Coastguard Worker out_cpu = t.index_put_(indices, values2d, accumulate=True) 11272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_mps.cpu(), out_cpu) 11273*da0073e9SAndroid Build Coastguard Worker 11274*da0073e9SAndroid Build Coastguard Worker def test_index_put_accumulate_non_contiguous(self, device="mps"): 11275*da0073e9SAndroid Build Coastguard Worker t = torch.zeros((5, 2, 2)) 11276*da0073e9SAndroid Build Coastguard Worker t_dev = t.to(device) 11277*da0073e9SAndroid Build Coastguard Worker t1 = t_dev[:, 0, :] 11278*da0073e9SAndroid Build Coastguard Worker t2 = t[:, 0, :] 11279*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t1.is_contiguous()) 11280*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t2.is_contiguous()) 11281*da0073e9SAndroid Build Coastguard Worker 11282*da0073e9SAndroid Build Coastguard Worker indices = [torch.tensor([0, 1]), ] 11283*da0073e9SAndroid Build Coastguard Worker indices_dev = [i.to(device) for i in indices] 11284*da0073e9SAndroid Build Coastguard Worker value = torch.randn(2, 2) 11285*da0073e9SAndroid Build Coastguard Worker out_mps = t1.index_put_(indices_dev, value.to(device), accumulate=True) 11286*da0073e9SAndroid Build Coastguard Worker out_cpu = t2.index_put_(indices, value, accumulate=True) 11287*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t1.is_contiguous()) 11288*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t2.is_contiguous()) 11289*da0073e9SAndroid Build Coastguard Worker 11290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_mps.cpu(), out_cpu) 11291*da0073e9SAndroid Build Coastguard Worker 11292*da0073e9SAndroid Build Coastguard Worker def test_index_put_accumulate_with_optional_tensors(self, device="mps"): 11293*da0073e9SAndroid Build Coastguard Worker # TODO: replace with a better solution. 11294*da0073e9SAndroid Build Coastguard Worker # Currently, here using torchscript to put None into indices. 11295*da0073e9SAndroid Build Coastguard Worker # on C++ it gives indices as a list of 2 optional tensors: first is null and 11296*da0073e9SAndroid Build Coastguard Worker # the second is a valid tensor. 11297*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11298*da0073e9SAndroid Build Coastguard Worker def func(x, i, v): 11299*da0073e9SAndroid Build Coastguard Worker idx = [None, i] 11300*da0073e9SAndroid Build Coastguard Worker x.index_put_(idx, v, accumulate=True) 11301*da0073e9SAndroid Build Coastguard Worker return x 11302*da0073e9SAndroid Build Coastguard Worker 11303*da0073e9SAndroid Build Coastguard Worker n = 4 11304*da0073e9SAndroid Build Coastguard Worker t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2) 11305*da0073e9SAndroid Build Coastguard Worker t_dev = t.to(device) 11306*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor([1, 0]) 11307*da0073e9SAndroid Build Coastguard Worker indices_dev = indices.to(device) 11308*da0073e9SAndroid Build Coastguard Worker value0d = torch.tensor(10.0) 11309*da0073e9SAndroid Build Coastguard Worker value1d = torch.tensor([1.0, 2.0]) 11310*da0073e9SAndroid Build Coastguard Worker 11311*da0073e9SAndroid Build Coastguard Worker out_mps = func(t_dev, indices_dev, value0d.to("mps")) 11312*da0073e9SAndroid Build Coastguard Worker out_cpu = func(t, indices, value0d) 11313*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_mps.cpu(), out_cpu) 11314*da0073e9SAndroid Build Coastguard Worker 11315*da0073e9SAndroid Build Coastguard Worker out_mps = func(t_dev, indices_dev, value1d.to("mps")) 11316*da0073e9SAndroid Build Coastguard Worker out_cpu = func(t, indices, value1d) 11317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_mps.cpu(), out_cpu) 11318*da0073e9SAndroid Build Coastguard Worker 11319*da0073e9SAndroid Build Coastguard Worker def test_index_put_accumulate_duplicate_indices(self, device="mps"): 11320*da0073e9SAndroid Build Coastguard Worker for i in range(1, 128): 11321*da0073e9SAndroid Build Coastguard Worker # generate indices by random walk, this will create indices with 11322*da0073e9SAndroid Build Coastguard Worker # lots of duplicates interleaved with each other 11323*da0073e9SAndroid Build Coastguard Worker delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1) 11324*da0073e9SAndroid Build Coastguard Worker 11325*da0073e9SAndroid Build Coastguard Worker indices = delta.cumsum(0).long().to("mps") 11326*da0073e9SAndroid Build Coastguard Worker 11327*da0073e9SAndroid Build Coastguard Worker # abs for int64 is not supported on mps, fallback on 'cpu' to calculate it 11328*da0073e9SAndroid Build Coastguard Worker input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device) 11329*da0073e9SAndroid Build Coastguard Worker values = torch.randn(indices.size(0), device=device) 11330*da0073e9SAndroid Build Coastguard Worker output = input.index_put((indices,), values, accumulate=True) 11331*da0073e9SAndroid Build Coastguard Worker 11332*da0073e9SAndroid Build Coastguard Worker input_list = input.tolist() 11333*da0073e9SAndroid Build Coastguard Worker indices_list = indices.tolist() 11334*da0073e9SAndroid Build Coastguard Worker values_list = values.tolist() 11335*da0073e9SAndroid Build Coastguard Worker for i, v in zip(indices_list, values_list): 11336*da0073e9SAndroid Build Coastguard Worker input_list[i] += v 11337*da0073e9SAndroid Build Coastguard Worker 11338*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, input_list) 11339*da0073e9SAndroid Build Coastguard Worker 11340*da0073e9SAndroid Build Coastguard Worker def test_index_put_deterministic(self, device="mps"): 11341*da0073e9SAndroid Build Coastguard Worker def helper(dtype, accumulate, deterministic, num_tests=128): 11342*da0073e9SAndroid Build Coastguard Worker acc_expected = torch.tensor([233, 187, 360], device=device, dtype=dtype) 11343*da0073e9SAndroid Build Coastguard Worker non_acc_expected = torch.tensor([38, 37, 39], device=device, dtype=dtype) 11344*da0073e9SAndroid Build Coastguard Worker t_idx = torch.tensor( 11345*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 2, 2, 1, 0, 2, 1, 0, 1, 2, 1, 0, 2, 2, 2, 2, 2, 11346*da0073e9SAndroid Build Coastguard Worker 0, 0, 2, 1, 2, 1, 0, 0, 2, 0, 2, 1, 1, 2, 2, 0, 2, 1, 0, 2] 11347*da0073e9SAndroid Build Coastguard Worker ) 11348*da0073e9SAndroid Build Coastguard Worker for _ in range(num_tests): 11349*da0073e9SAndroid Build Coastguard Worker try: 11350*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(deterministic) 11351*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(3, dtype=dtype, device=device) 11352*da0073e9SAndroid Build Coastguard Worker t.index_put_((t_idx,), torch.arange(len(t_idx), device=device, dtype=dtype), accumulate=accumulate) 11353*da0073e9SAndroid Build Coastguard Worker if accumulate: 11354*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, acc_expected) 11355*da0073e9SAndroid Build Coastguard Worker else: 11356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, non_acc_expected) 11357*da0073e9SAndroid Build Coastguard Worker finally: 11358*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(False) 11359*da0073e9SAndroid Build Coastguard Worker 11360*da0073e9SAndroid Build Coastguard Worker for accumulate, deterministic in product((False, True), (False, True)): 11361*da0073e9SAndroid Build Coastguard Worker dtype = torch.float if accumulate else torch.long 11362*da0073e9SAndroid Build Coastguard Worker if not accumulate and not deterministic: 11363*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "Tensor-likes are not equal!"): 11364*da0073e9SAndroid Build Coastguard Worker helper(dtype, accumulate, deterministic) 11365*da0073e9SAndroid Build Coastguard Worker else: 11366*da0073e9SAndroid Build Coastguard Worker helper(dtype, accumulate, deterministic) 11367*da0073e9SAndroid Build Coastguard Worker 11368*da0073e9SAndroid Build Coastguard Worker def test_multiple_byte_mask(self, device="mps"): 11369*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 11370*da0073e9SAndroid Build Coastguard Worker # note: these broadcast together and are transposed to the first dim 11371*da0073e9SAndroid Build Coastguard Worker mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) 11372*da0073e9SAndroid Build Coastguard Worker mask2 = torch.ByteTensor([1, 1, 1]).to(device) 11373*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 11374*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 11375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) 11376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 2) 11377*da0073e9SAndroid Build Coastguard Worker 11378*da0073e9SAndroid Build Coastguard Worker def test_byte_mask2d(self, device="mps"): 11379*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 11380*da0073e9SAndroid Build Coastguard Worker c = torch.randn(5, 7, device=device) 11381*da0073e9SAndroid Build Coastguard Worker num_ones = (c > 0).sum() 11382*da0073e9SAndroid Build Coastguard Worker r = v[c > 0] 11383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.shape, (num_ones, 3)) 11384*da0073e9SAndroid Build Coastguard Worker 11385*da0073e9SAndroid Build Coastguard Worker def test_jit_indexing(self, device="mps"): 11386*da0073e9SAndroid Build Coastguard Worker def fn1(x): 11387*da0073e9SAndroid Build Coastguard Worker x[x < 50] = 1.0 11388*da0073e9SAndroid Build Coastguard Worker return x 11389*da0073e9SAndroid Build Coastguard Worker 11390*da0073e9SAndroid Build Coastguard Worker def fn2(x): 11391*da0073e9SAndroid Build Coastguard Worker x[0:50] = 1.0 11392*da0073e9SAndroid Build Coastguard Worker return x 11393*da0073e9SAndroid Build Coastguard Worker 11394*da0073e9SAndroid Build Coastguard Worker scripted_fn1 = torch.jit.script(fn1) 11395*da0073e9SAndroid Build Coastguard Worker scripted_fn2 = torch.jit.script(fn2) 11396*da0073e9SAndroid Build Coastguard Worker data = torch.arange(100, device=device, dtype=torch.float) 11397*da0073e9SAndroid Build Coastguard Worker out = scripted_fn1(data.detach().clone()) 11398*da0073e9SAndroid Build Coastguard Worker ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float) 11399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref) 11400*da0073e9SAndroid Build Coastguard Worker out = scripted_fn2(data.detach().clone()) 11401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref) 11402*da0073e9SAndroid Build Coastguard Worker 11403*da0073e9SAndroid Build Coastguard Worker def test_int_indices(self, device="mps"): 11404*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 7, 3, device=device) 11405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3)) 11406*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3)) 11407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3)) 11408*da0073e9SAndroid Build Coastguard Worker 11409*da0073e9SAndroid Build Coastguard Worker def test_index_put_src_datatype(self): 11410*da0073e9SAndroid Build Coastguard Worker def helper(device, dtype): 11411*da0073e9SAndroid Build Coastguard Worker src = torch.ones(3, 2, 4, device=device, dtype=dtype) 11412*da0073e9SAndroid Build Coastguard Worker vals = torch.ones(3, 2, 4, device=device, dtype=dtype) 11413*da0073e9SAndroid Build Coastguard Worker indices = (torch.tensor([0, 2, 1]),) 11414*da0073e9SAndroid Build Coastguard Worker res = src.index_put_(indices, vals, accumulate=True) 11415*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, src.shape) 11416*da0073e9SAndroid Build Coastguard Worker [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]] 11417*da0073e9SAndroid Build Coastguard Worker 11418*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12") 11419*da0073e9SAndroid Build Coastguard Worker def test_index_src_datatype(self): 11420*da0073e9SAndroid Build Coastguard Worker def helper(device, dtype): 11421*da0073e9SAndroid Build Coastguard Worker orig_dtype = dtype 11422*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bool: 11423*da0073e9SAndroid Build Coastguard Worker dtype = torch.uint8 11424*da0073e9SAndroid Build Coastguard Worker 11425*da0073e9SAndroid Build Coastguard Worker src = torch.ones(3, 2, 4, device=device, dtype=dtype) 11426*da0073e9SAndroid Build Coastguard Worker if orig_dtype is torch.bool: 11427*da0073e9SAndroid Build Coastguard Worker src = src == 1 11428*da0073e9SAndroid Build Coastguard Worker # test index 11429*da0073e9SAndroid Build Coastguard Worker res = src[[0, 2, 1], :, :] 11430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, src.shape) 11431*da0073e9SAndroid Build Coastguard Worker # test index_put, no accum 11432*da0073e9SAndroid Build Coastguard Worker src[[0, 2, 1], :, :] = res 11433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, src.shape) 11434*da0073e9SAndroid Build Coastguard Worker [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.bool]] 11435*da0073e9SAndroid Build Coastguard Worker 11436*da0073e9SAndroid Build Coastguard Worker def test_int_indices2d(self, device="mps"): 11437*da0073e9SAndroid Build Coastguard Worker # From the NumPy indexing example 11438*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 12, device=device).view(4, 3) 11439*da0073e9SAndroid Build Coastguard Worker rows = torch.tensor([[0, 0], [3, 3]], device=device) 11440*da0073e9SAndroid Build Coastguard Worker columns = torch.tensor([[0, 2], [0, 2]], device=device) 11441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]]) 11442*da0073e9SAndroid Build Coastguard Worker 11443*da0073e9SAndroid Build Coastguard Worker def test_int_indices_broadcast(self, device="mps"): 11444*da0073e9SAndroid Build Coastguard Worker # From the NumPy indexing example 11445*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 12, device=device).view(4, 3) 11446*da0073e9SAndroid Build Coastguard Worker rows = torch.tensor([0, 3], device=device) 11447*da0073e9SAndroid Build Coastguard Worker columns = torch.tensor([0, 2], device=device) 11448*da0073e9SAndroid Build Coastguard Worker result = x[rows[:, None], columns] 11449*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.tolist(), [[0, 2], [9, 11]]) 11450*da0073e9SAndroid Build Coastguard Worker 11451*da0073e9SAndroid Build Coastguard Worker def test_empty_index(self, device="mps"): 11452*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 12, device=device).view(4, 3) 11453*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor([], dtype=torch.long, device=device) 11454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[idx].numel(), 0) 11455*da0073e9SAndroid Build Coastguard Worker 11456*da0073e9SAndroid Build Coastguard Worker # empty assignment should have no effect but not throw an exception 11457*da0073e9SAndroid Build Coastguard Worker y = x.clone() 11458*da0073e9SAndroid Build Coastguard Worker y[idx] = -1 11459*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 11460*da0073e9SAndroid Build Coastguard Worker 11461*da0073e9SAndroid Build Coastguard Worker mask = torch.zeros(4, 3, device=device).bool() 11462*da0073e9SAndroid Build Coastguard Worker y[mask] = -1 11463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 11464*da0073e9SAndroid Build Coastguard Worker 11465*da0073e9SAndroid Build Coastguard Worker def test_empty_ndim_index(self, device="mps"): 11466*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, device=device) 11467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)]) 11468*da0073e9SAndroid Build Coastguard Worker 11469*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4, 5, device=device) 11470*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(2, 0, 6, 4, 5, device=device), 11471*da0073e9SAndroid Build Coastguard Worker x[:, torch.empty(0, 6, dtype=torch.int64, device=device)]) 11472*da0073e9SAndroid Build Coastguard Worker 11473*da0073e9SAndroid Build Coastguard Worker x = torch.empty(10, 0, device=device) 11474*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[[1, 2]].shape, (2, 0)) 11475*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[[], []].shape, (0,)) 11476*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, 'for dimension with size 0'): 11477*da0073e9SAndroid Build Coastguard Worker x[:, [0, 1]] 11478*da0073e9SAndroid Build Coastguard Worker 11479*da0073e9SAndroid Build Coastguard Worker def test_empty_ndim_index_bool(self, device="mps"): 11480*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, device=device) 11481*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)]) 11482*da0073e9SAndroid Build Coastguard Worker 11483*da0073e9SAndroid Build Coastguard Worker def test_empty_slice(self, device="mps"): 11484*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4, 5, device=device) 11485*da0073e9SAndroid Build Coastguard Worker y = x[:, :, :, 1] 11486*da0073e9SAndroid Build Coastguard Worker z = y[:, 1:1, :] 11487*da0073e9SAndroid Build Coastguard Worker self.assertEqual((2, 0, 4), z.shape) 11488*da0073e9SAndroid Build Coastguard Worker # this isn't technically necessary, but matches NumPy stride calculations. 11489*da0073e9SAndroid Build Coastguard Worker self.assertEqual((60, 20, 5), z.stride()) 11490*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.is_contiguous()) 11491*da0073e9SAndroid Build Coastguard Worker 11492*da0073e9SAndroid Build Coastguard Worker def test_index_getitem_copy_bools_slices(self, device="mps"): 11493*da0073e9SAndroid Build Coastguard Worker true = torch.tensor(1, dtype=torch.uint8, device=device) 11494*da0073e9SAndroid Build Coastguard Worker false = torch.tensor(0, dtype=torch.uint8, device=device) 11495*da0073e9SAndroid Build Coastguard Worker 11496*da0073e9SAndroid Build Coastguard Worker tensors = [torch.randn(2, 3, device=device), torch.tensor(3., device=device)] 11497*da0073e9SAndroid Build Coastguard Worker 11498*da0073e9SAndroid Build Coastguard Worker for a in tensors: 11499*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(a.data_ptr(), a[True].data_ptr()) 11500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, *a.shape), a[False]) 11501*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(a.data_ptr(), a[true].data_ptr()) 11502*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, *a.shape), a[false]) 11503*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.data_ptr(), a[None].data_ptr()) 11504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.data_ptr(), a[...].data_ptr()) 11505*da0073e9SAndroid Build Coastguard Worker 11506*da0073e9SAndroid Build Coastguard Worker def test_index_setitem_bools_slices(self, device="mps"): 11507*da0073e9SAndroid Build Coastguard Worker true = torch.tensor(1, dtype=torch.uint8, device=device) 11508*da0073e9SAndroid Build Coastguard Worker false = torch.tensor(0, dtype=torch.uint8, device=device) 11509*da0073e9SAndroid Build Coastguard Worker 11510*da0073e9SAndroid Build Coastguard Worker tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)] 11511*da0073e9SAndroid Build Coastguard Worker 11512*da0073e9SAndroid Build Coastguard Worker for a in tensors: 11513*da0073e9SAndroid Build Coastguard Worker # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s 11514*da0073e9SAndroid Build Coastguard Worker # (some of these ops already prefix a 1 to the size) 11515*da0073e9SAndroid Build Coastguard Worker neg_ones = torch.ones_like(a) * -1 11516*da0073e9SAndroid Build Coastguard Worker neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0) 11517*da0073e9SAndroid Build Coastguard Worker a[True] = neg_ones_expanded 11518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones) 11519*da0073e9SAndroid Build Coastguard Worker a[False] = 5 11520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones) 11521*da0073e9SAndroid Build Coastguard Worker a[true] = neg_ones_expanded * 2 11522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones * 2) 11523*da0073e9SAndroid Build Coastguard Worker a[false] = 5 11524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones * 2) 11525*da0073e9SAndroid Build Coastguard Worker a[None] = neg_ones_expanded * 3 11526*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones * 3) 11527*da0073e9SAndroid Build Coastguard Worker a[...] = neg_ones_expanded * 4 11528*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, neg_ones * 4) 11529*da0073e9SAndroid Build Coastguard Worker if a.dim() == 0: 11530*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 11531*da0073e9SAndroid Build Coastguard Worker a[:] = neg_ones_expanded * 5 11532*da0073e9SAndroid Build Coastguard Worker 11533*da0073e9SAndroid Build Coastguard Worker def test_index_scalar_with_bool_mask(self, device="mps"): 11534*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1, device=device) 11535*da0073e9SAndroid Build Coastguard Worker uintMask = torch.tensor(True, dtype=torch.uint8, device=device) 11536*da0073e9SAndroid Build Coastguard Worker boolMask = torch.tensor(True, dtype=torch.bool, device=device) 11537*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[uintMask], a[boolMask]) 11538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[uintMask].dtype, a[boolMask].dtype) 11539*da0073e9SAndroid Build Coastguard Worker 11540*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(True, dtype=torch.bool, device=device) 11541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[uintMask], a[boolMask]) 11542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[uintMask].dtype, a[boolMask].dtype) 11543*da0073e9SAndroid Build Coastguard Worker 11544*da0073e9SAndroid Build Coastguard Worker def test_setitem_expansion_error(self, device="mps"): 11545*da0073e9SAndroid Build Coastguard Worker true = torch.tensor(True, device=device) 11546*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, device=device) 11547*da0073e9SAndroid Build Coastguard Worker # check prefix with non-1s doesn't work 11548*da0073e9SAndroid Build Coastguard Worker a_expanded = a.expand(torch.Size([5, 1]) + a.size()) 11549*da0073e9SAndroid Build Coastguard Worker # NumPy: ValueError 11550*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 11551*da0073e9SAndroid Build Coastguard Worker a[True] = a_expanded 11552*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 11553*da0073e9SAndroid Build Coastguard Worker a[true] = a_expanded 11554*da0073e9SAndroid Build Coastguard Worker 11555*da0073e9SAndroid Build Coastguard Worker def test_getitem_scalars(self, device="mps"): 11556*da0073e9SAndroid Build Coastguard Worker zero = torch.tensor(0, dtype=torch.int64, device=device) 11557*da0073e9SAndroid Build Coastguard Worker one = torch.tensor(1, dtype=torch.int64, device=device) 11558*da0073e9SAndroid Build Coastguard Worker 11559*da0073e9SAndroid Build Coastguard Worker # non-scalar indexed with scalars 11560*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, device=device) 11561*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0], a[zero]) 11562*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0][1], a[zero][one]) 11563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, 1], a[zero, one]) 11564*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, one], a[zero, 1]) 11565*da0073e9SAndroid Build Coastguard Worker 11566*da0073e9SAndroid Build Coastguard Worker # indexing by a scalar should slice (not copy) 11567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr()) 11568*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr()) 11569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr()) 11570*da0073e9SAndroid Build Coastguard Worker 11571*da0073e9SAndroid Build Coastguard Worker # scalar indexed with scalar 11572*da0073e9SAndroid Build Coastguard Worker r = torch.randn((), device=device) 11573*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 11574*da0073e9SAndroid Build Coastguard Worker r[:] 11575*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 11576*da0073e9SAndroid Build Coastguard Worker r[zero] 11577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, r[...]) 11578*da0073e9SAndroid Build Coastguard Worker 11579*da0073e9SAndroid Build Coastguard Worker def test_setitem_scalars(self, device="mps"): 11580*da0073e9SAndroid Build Coastguard Worker zero = torch.tensor(0, dtype=torch.int64) 11581*da0073e9SAndroid Build Coastguard Worker 11582*da0073e9SAndroid Build Coastguard Worker # non-scalar indexed with scalars 11583*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, device=device) 11584*da0073e9SAndroid Build Coastguard Worker a_set_with_number = a.clone() 11585*da0073e9SAndroid Build Coastguard Worker a_set_with_scalar = a.clone() 11586*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, device=device) 11587*da0073e9SAndroid Build Coastguard Worker 11588*da0073e9SAndroid Build Coastguard Worker a_set_with_number[0] = b 11589*da0073e9SAndroid Build Coastguard Worker a_set_with_scalar[zero] = b 11590*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_set_with_number, a_set_with_scalar) 11591*da0073e9SAndroid Build Coastguard Worker a[1, zero] = 7.7 11592*da0073e9SAndroid Build Coastguard Worker self.assertEqual(7.7, a[1, 0]) 11593*da0073e9SAndroid Build Coastguard Worker 11594*da0073e9SAndroid Build Coastguard Worker # scalar indexed with scalars 11595*da0073e9SAndroid Build Coastguard Worker r = torch.randn((), device=device) 11596*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 11597*da0073e9SAndroid Build Coastguard Worker r[:] = 8.8 11598*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 11599*da0073e9SAndroid Build Coastguard Worker r[zero] = 8.8 11600*da0073e9SAndroid Build Coastguard Worker r[...] = 9.9 11601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(9.9, r) 11602*da0073e9SAndroid Build Coastguard Worker 11603*da0073e9SAndroid Build Coastguard Worker def test_basic_advanced_combined(self, device="mps"): 11604*da0073e9SAndroid Build Coastguard Worker # From the NumPy indexing example 11605*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 12, device=device).view(4, 3) 11606*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]]) 11607*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]]) 11608*da0073e9SAndroid Build Coastguard Worker 11609*da0073e9SAndroid Build Coastguard Worker # Check that it is a copy 11610*da0073e9SAndroid Build Coastguard Worker unmodified = x.clone() 11611*da0073e9SAndroid Build Coastguard Worker x[1:2, [1, 2]].zero_() 11612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, unmodified) 11613*da0073e9SAndroid Build Coastguard Worker 11614*da0073e9SAndroid Build Coastguard Worker # But assignment should modify the original 11615*da0073e9SAndroid Build Coastguard Worker unmodified = x.clone() 11616*da0073e9SAndroid Build Coastguard Worker x[1:2, [1, 2]] = 0 11617*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(x, unmodified) 11618*da0073e9SAndroid Build Coastguard Worker 11619*da0073e9SAndroid Build Coastguard Worker def test_int_assignment(self, device="mps"): 11620*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 4, device=device).view(2, 2) 11621*da0073e9SAndroid Build Coastguard Worker x[1] = 5 11622*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.tolist(), [[0, 1], [5, 5]]) 11623*da0073e9SAndroid Build Coastguard Worker 11624*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 4, device=device).view(2, 2) 11625*da0073e9SAndroid Build Coastguard Worker x[1] = torch.arange(5, 7, device=device) 11626*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.tolist(), [[0, 1], [5, 6]]) 11627*da0073e9SAndroid Build Coastguard Worker 11628*da0073e9SAndroid Build Coastguard Worker def test_byte_tensor_assignment(self, device="mps"): 11629*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0., 16, device=device).view(4, 4) 11630*da0073e9SAndroid Build Coastguard Worker b = torch.ByteTensor([True, False, True, False]).to(device) 11631*da0073e9SAndroid Build Coastguard Worker value = torch.tensor([3., 4., 5., 6.], device=device) 11632*da0073e9SAndroid Build Coastguard Worker 11633*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 11634*da0073e9SAndroid Build Coastguard Worker x[b] = value 11635*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 11636*da0073e9SAndroid Build Coastguard Worker 11637*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[0], value) 11638*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[1], torch.arange(4., 8, device=device)) 11639*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[2], value) 11640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[3], torch.arange(12., 16, device=device)) 11641*da0073e9SAndroid Build Coastguard Worker 11642*da0073e9SAndroid Build Coastguard Worker def test_variable_slicing(self, device="mps"): 11643*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 16, device=device).view(4, 4) 11644*da0073e9SAndroid Build Coastguard Worker indices = torch.IntTensor([0, 1]).to(device) 11645*da0073e9SAndroid Build Coastguard Worker i, j = indices 11646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[i:j], x[0:1]) 11647*da0073e9SAndroid Build Coastguard Worker 11648*da0073e9SAndroid Build Coastguard Worker def test_ellipsis_tensor(self, device="mps"): 11649*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 9, device=device).view(3, 3) 11650*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor([0, 2], device=device) 11651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[..., idx].tolist(), [[0, 2], 11652*da0073e9SAndroid Build Coastguard Worker [3, 5], 11653*da0073e9SAndroid Build Coastguard Worker [6, 8]]) 11654*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2], 11655*da0073e9SAndroid Build Coastguard Worker [6, 7, 8]]) 11656*da0073e9SAndroid Build Coastguard Worker 11657*da0073e9SAndroid Build Coastguard Worker def test_invalid_index(self, device="mps"): 11658*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 16, device=device).view(4, 4) 11659*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"]) 11660*da0073e9SAndroid Build Coastguard Worker 11661*da0073e9SAndroid Build Coastguard Worker def test_out_of_bound_index(self, device="mps"): 11662*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 100, device=device).view(2, 5, 10) 11663*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5]) 11664*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5]) 11665*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10', 11666*da0073e9SAndroid Build Coastguard Worker lambda: x[0, 1, 15]) 11667*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10', 11668*da0073e9SAndroid Build Coastguard Worker lambda: x[:, :, 12]) 11669*da0073e9SAndroid Build Coastguard Worker 11670*da0073e9SAndroid Build Coastguard Worker def test_zero_dim_index(self, device="mps"): 11671*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(10, device=device) 11672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.item()) 11673*da0073e9SAndroid Build Coastguard Worker 11674*da0073e9SAndroid Build Coastguard Worker def runner(): 11675*da0073e9SAndroid Build Coastguard Worker print(x[0]) 11676*da0073e9SAndroid Build Coastguard Worker return x[0] 11677*da0073e9SAndroid Build Coastguard Worker 11678*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, 'invalid index', runner) 11679*da0073e9SAndroid Build Coastguard Worker 11680*da0073e9SAndroid Build Coastguard Worker def test_cpu_indices(self, device="mps"): 11681*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor([0, 1]) 11682*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(2, device=device) 11683*da0073e9SAndroid Build Coastguard Worker x = torch.ones(10, device=device) 11684*da0073e9SAndroid Build Coastguard Worker x[idx] = b # index_put_ 11685*da0073e9SAndroid Build Coastguard Worker ref = torch.ones(10, device=device) 11686*da0073e9SAndroid Build Coastguard Worker ref[:2] = 0 11687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, ref, atol=0, rtol=0) 11688*da0073e9SAndroid Build Coastguard Worker out = x[idx] # index 11689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0) 11690*da0073e9SAndroid Build Coastguard Worker 11691*da0073e9SAndroid Build Coastguard Worker def test_nextafter(self, device="mps"): 11692*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float16, torch.float32]: 11693*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, -1, 0, 0, 2, -2], device=device, dtype=dtype) 11694*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([2, -2, -1, 1, -3, 3], device=device, dtype=dtype) 11695*da0073e9SAndroid Build Coastguard Worker na = torch.nextafter(x, y) 11696*da0073e9SAndroid Build Coastguard Worker na_cpu = torch.nextafter(x.cpu(), y.cpu()) 11697*da0073e9SAndroid Build Coastguard Worker na_ge_x_mps = na.cpu() > x.cpu() 11698*da0073e9SAndroid Build Coastguard Worker # greater is broken on MPS, see https://github.com/pytorch/pytorch/issues/125051 11699*da0073e9SAndroid Build Coastguard Worker na_ge_x_cpu = na_cpu > x.cpu() 11700*da0073e9SAndroid Build Coastguard Worker self.assertEqual(na_ge_x_mps, na_ge_x_cpu) 11701*da0073e9SAndroid Build Coastguard Worker 11702*da0073e9SAndroid Build Coastguard Worker 11703*da0073e9SAndroid Build Coastguard Workerclass TestRNNMPS(TestCaseMPS): 11704*da0073e9SAndroid Build Coastguard Worker def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False, 11705*da0073e9SAndroid Build Coastguard Worker seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False): 11706*da0073e9SAndroid Build Coastguard Worker rnn = nn.LSTM( 11707*da0073e9SAndroid Build Coastguard Worker input_size=input_size, 11708*da0073e9SAndroid Build Coastguard Worker hidden_size=hidden_size, 11709*da0073e9SAndroid Build Coastguard Worker num_layers=num_layers, 11710*da0073e9SAndroid Build Coastguard Worker bias=bias, 11711*da0073e9SAndroid Build Coastguard Worker bidirectional=bidirectional, 11712*da0073e9SAndroid Build Coastguard Worker batch_first=batch_first, 11713*da0073e9SAndroid Build Coastguard Worker device="cpu" 11714*da0073e9SAndroid Build Coastguard Worker ) 11715*da0073e9SAndroid Build Coastguard Worker bidirectional_mul = 2 if bidirectional else 1 11716*da0073e9SAndroid Build Coastguard Worker 11717*da0073e9SAndroid Build Coastguard Worker if batch_first: 11718*da0073e9SAndroid Build Coastguard Worker input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward) 11719*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, 11720*da0073e9SAndroid Build Coastguard Worker requires_grad=backward) 11721*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, 11722*da0073e9SAndroid Build Coastguard Worker requires_grad=backward) 11723*da0073e9SAndroid Build Coastguard Worker else: 11724*da0073e9SAndroid Build Coastguard Worker input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward) 11725*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, 11726*da0073e9SAndroid Build Coastguard Worker requires_grad=backward) 11727*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, 11728*da0073e9SAndroid Build Coastguard Worker requires_grad=backward) 11729*da0073e9SAndroid Build Coastguard Worker 11730*da0073e9SAndroid Build Coastguard Worker cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx)) 11731*da0073e9SAndroid Build Coastguard Worker 11732*da0073e9SAndroid Build Coastguard Worker rnn = rnn.to(device) 11733*da0073e9SAndroid Build Coastguard Worker input = input.to(device) 11734*da0073e9SAndroid Build Coastguard Worker hx = hx.to(device) 11735*da0073e9SAndroid Build Coastguard Worker cx = cx.to(device) 11736*da0073e9SAndroid Build Coastguard Worker output, (hn, cn) = rnn(input, (hx, cx)) 11737*da0073e9SAndroid Build Coastguard Worker 11738*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_output, output) 11739*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_hn, hn) 11740*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_cn, cn) 11741*da0073e9SAndroid Build Coastguard Worker 11742*da0073e9SAndroid Build Coastguard Worker def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True): 11743*da0073e9SAndroid Build Coastguard Worker rnn = rnn.to(device) 11744*da0073e9SAndroid Build Coastguard Worker inp, hx, cx = inp.to(device), hx.to(device), cx.to(device) 11745*da0073e9SAndroid Build Coastguard Worker 11746*da0073e9SAndroid Build Coastguard Worker output, (hx_out, cx_out) = rnn(inp, (hx, cx)) 11747*da0073e9SAndroid Build Coastguard Worker assert output_grad_presented or states_grad_presented, "At least some outputs must be used" 11748*da0073e9SAndroid Build Coastguard Worker 11749*da0073e9SAndroid Build Coastguard Worker f = 0 11750*da0073e9SAndroid Build Coastguard Worker if output_grad_presented: 11751*da0073e9SAndroid Build Coastguard Worker f = f + 3 * output.sum() 11752*da0073e9SAndroid Build Coastguard Worker if states_grad_presented: 11753*da0073e9SAndroid Build Coastguard Worker f = f + (hx_out * cx_out).sum() 11754*da0073e9SAndroid Build Coastguard Worker 11755*da0073e9SAndroid Build Coastguard Worker param_names, params = zip(*rnn.named_parameters()) 11756*da0073e9SAndroid Build Coastguard Worker param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True)) 11757*da0073e9SAndroid Build Coastguard Worker 11758*da0073e9SAndroid Build Coastguard Worker input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx]) 11759*da0073e9SAndroid Build Coastguard Worker return output, param_grads, input_grad, hx_grad, cx_grad 11760*da0073e9SAndroid Build Coastguard Worker 11761*da0073e9SAndroid Build Coastguard Worker if backward: 11762*da0073e9SAndroid Build Coastguard Worker grad_cases = [ 11763*da0073e9SAndroid Build Coastguard Worker dict(output_grad_presented=True, states_grad_presented=True), 11764*da0073e9SAndroid Build Coastguard Worker dict(output_grad_presented=False, states_grad_presented=True), 11765*da0073e9SAndroid Build Coastguard Worker dict(output_grad_presented=True, states_grad_presented=False), 11766*da0073e9SAndroid Build Coastguard Worker ] 11767*da0073e9SAndroid Build Coastguard Worker 11768*da0073e9SAndroid Build Coastguard Worker for grad_case in grad_cases: 11769*da0073e9SAndroid Build Coastguard Worker cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\ 11770*da0073e9SAndroid Build Coastguard Worker get_backward_results(rnn, "cpu", input, hx, cx, **grad_case) 11771*da0073e9SAndroid Build Coastguard Worker mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\ 11772*da0073e9SAndroid Build Coastguard Worker get_backward_results(rnn, device, input, hx, cx, **grad_case) 11773*da0073e9SAndroid Build Coastguard Worker 11774*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_hx_grad, mps_hx_grad) 11775*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_cx_grad, mps_cx_grad) 11776*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_output, mps_output) 11777*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_input_grad, mps_input_grad) 11778*da0073e9SAndroid Build Coastguard Worker for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad): 11779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_weight_grad, mps_weight_grad, 11780*da0073e9SAndroid Build Coastguard Worker f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}") 11781*da0073e9SAndroid Build Coastguard Worker 11782*da0073e9SAndroid Build Coastguard Worker LSTM_TEST_CASES = [ 11783*da0073e9SAndroid Build Coastguard Worker {}, # default 11784*da0073e9SAndroid Build Coastguard Worker dict(batch_first=True), 11785*da0073e9SAndroid Build Coastguard Worker dict(bias=False), 11786*da0073e9SAndroid Build Coastguard Worker dict(bidirectional=True), 11787*da0073e9SAndroid Build Coastguard Worker dict(batch_first=True, bias=False), 11788*da0073e9SAndroid Build Coastguard Worker dict(bidirectional=True, bias=False), 11789*da0073e9SAndroid Build Coastguard Worker dict(bidirectional=True, batch_first=True), 11790*da0073e9SAndroid Build Coastguard Worker dict(bidirectional=True, batch_first=True, bias=False) 11791*da0073e9SAndroid Build Coastguard Worker ] 11792*da0073e9SAndroid Build Coastguard Worker 11793*da0073e9SAndroid Build Coastguard Worker def test_lstm_forward(self, device="mps", dtype=torch.float32): 11794*da0073e9SAndroid Build Coastguard Worker for num_layers in [1, 2, 5]: 11795*da0073e9SAndroid Build Coastguard Worker for test_options in self.LSTM_TEST_CASES: 11796*da0073e9SAndroid Build Coastguard Worker self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options) 11797*da0073e9SAndroid Build Coastguard Worker 11798*da0073e9SAndroid Build Coastguard Worker def test_lstm_backward(self, device="mps", dtype=torch.float32): 11799*da0073e9SAndroid Build Coastguard Worker for num_layers in [1, 2, 5]: 11800*da0073e9SAndroid Build Coastguard Worker for test_options in self.LSTM_TEST_CASES: 11801*da0073e9SAndroid Build Coastguard Worker self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options) 11802*da0073e9SAndroid Build Coastguard Worker 11803*da0073e9SAndroid Build Coastguard Worker def test_RNN_cell_no_broadcasting(self): 11804*da0073e9SAndroid Build Coastguard Worker def test(cell_module, input, hx, input_size, hidden_size): 11805*da0073e9SAndroid Build Coastguard Worker cell = cell_module(input_size, hidden_size, device='mps') 11806*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: cell(input, hx)) 11807*da0073e9SAndroid Build Coastguard Worker 11808*da0073e9SAndroid Build Coastguard Worker def test_all(hidden_size, bad_hx, good_hx, input_size, input): 11809*da0073e9SAndroid Build Coastguard Worker test(nn.RNNCell, input, bad_hx, input_size, hidden_size) 11810*da0073e9SAndroid Build Coastguard Worker test(nn.GRUCell, input, bad_hx, input_size, hidden_size) 11811*da0073e9SAndroid Build Coastguard Worker test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size) 11812*da0073e9SAndroid Build Coastguard Worker test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size) 11813*da0073e9SAndroid Build Coastguard Worker 11814*da0073e9SAndroid Build Coastguard Worker hidden_size = 20 11815*da0073e9SAndroid Build Coastguard Worker input_size = 10 11816*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, input_size, device='mps') 11817*da0073e9SAndroid Build Coastguard Worker bad_hx = torch.randn(1, hidden_size, device='mps') 11818*da0073e9SAndroid Build Coastguard Worker good_hx = torch.randn(3, hidden_size, device='mps') 11819*da0073e9SAndroid Build Coastguard Worker 11820*da0073e9SAndroid Build Coastguard Worker # Test hidden/input batch size broadcasting 11821*da0073e9SAndroid Build Coastguard Worker test_all(hidden_size, bad_hx, good_hx, input_size, input) 11822*da0073e9SAndroid Build Coastguard Worker 11823*da0073e9SAndroid Build Coastguard Worker # Test hx's hidden_size vs module's hidden_size broadcasting 11824*da0073e9SAndroid Build Coastguard Worker bad_hx = torch.randn(3, 1) 11825*da0073e9SAndroid Build Coastguard Worker test_all(hidden_size, bad_hx, good_hx, input_size, input) 11826*da0073e9SAndroid Build Coastguard Worker 11827*da0073e9SAndroid Build Coastguard Worker # Test input's input_size vs module's input_size broadcasting 11828*da0073e9SAndroid Build Coastguard Worker bad_input = torch.randn(3, 1) 11829*da0073e9SAndroid Build Coastguard Worker test_all(hidden_size, good_hx, good_hx, input_size, bad_input) 11830*da0073e9SAndroid Build Coastguard Worker 11831*da0073e9SAndroid Build Coastguard Worker def test_LSTM_cell(self): 11832*da0073e9SAndroid Build Coastguard Worker # this is just a smoke test; these modules are implemented through 11833*da0073e9SAndroid Build Coastguard Worker # autograd so no Jacobian test is needed 11834*da0073e9SAndroid Build Coastguard Worker for bias in (True, False): 11835*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 10, device='mps') 11836*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(3, 20, device='mps') 11837*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(3, 20, device='mps') 11838*da0073e9SAndroid Build Coastguard Worker lstm = nn.LSTMCell(10, 20, bias=bias, device='mps') 11839*da0073e9SAndroid Build Coastguard Worker for _ in range(6): 11840*da0073e9SAndroid Build Coastguard Worker hx, cx = lstm(input, (hx, cx)) 11841*da0073e9SAndroid Build Coastguard Worker 11842*da0073e9SAndroid Build Coastguard Worker (hx + cx).sum().backward() 11843*da0073e9SAndroid Build Coastguard Worker 11844*da0073e9SAndroid Build Coastguard Worker def test_LSTM_cell_forward_input_size(self): 11845*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 11, device='mps') 11846*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(3, 20, device='mps') 11847*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(3, 20, device='mps') 11848*da0073e9SAndroid Build Coastguard Worker lstm = nn.LSTMCell(10, 20, device='mps') 11849*da0073e9SAndroid Build Coastguard Worker self.assertRaises(Exception, lambda: lstm(input, (hx, cx))) 11850*da0073e9SAndroid Build Coastguard Worker 11851*da0073e9SAndroid Build Coastguard Worker def test_LSTM_cell_forward_hidden_size(self): 11852*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 10, device='mps') 11853*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(3, 21, device='mps') 11854*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(3, 20, device='mps') 11855*da0073e9SAndroid Build Coastguard Worker lstm = nn.LSTMCell(10, 20, device='mps') 11856*da0073e9SAndroid Build Coastguard Worker self.assertRaises(Exception, lambda: lstm(input, (hx, cx))) 11857*da0073e9SAndroid Build Coastguard Worker self.assertRaises(Exception, lambda: lstm(input, (cx, hx))) 11858*da0073e9SAndroid Build Coastguard Worker 11859*da0073e9SAndroid Build Coastguard Worker 11860*da0073e9SAndroid Build Coastguard Workerclass TestFallbackWarning(TestCase): 11861*da0073e9SAndroid Build Coastguard Worker # TODO: Remove once test_testing.py is running on MPS devices 11862*da0073e9SAndroid Build Coastguard Worker def test_no_warning_on_import(self): 11863*da0073e9SAndroid Build Coastguard Worker out = subprocess.check_output( 11864*da0073e9SAndroid Build Coastguard Worker [sys.executable, "-W", "always", "-c", "import torch"], 11865*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 11866*da0073e9SAndroid Build Coastguard Worker # On Windows, opening the subprocess with the default CWD makes `import torch` 11867*da0073e9SAndroid Build Coastguard Worker # fail, so just set CWD to this script's directory 11868*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8") 11869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, "") 11870*da0073e9SAndroid Build Coastguard Worker 11871*da0073e9SAndroid Build Coastguard Worker def _get_not_implemented_op(self): 11872*da0073e9SAndroid Build Coastguard Worker # This can be changed once we actually implement 'lcm' 11873*da0073e9SAndroid Build Coastguard Worker # Should return fn, args, kwargs, string_version 11874*da0073e9SAndroid Build Coastguard Worker return (torch.lcm, 11875*da0073e9SAndroid Build Coastguard Worker [torch.tensor([1], device='mps'), torch.tensor([2], device='mps')], {}, 11876*da0073e9SAndroid Build Coastguard Worker "torch.lcm(torch.tensor([1], device='mps'), torch.tensor([2], device='mps'))") 11877*da0073e9SAndroid Build Coastguard Worker 11878*da0073e9SAndroid Build Coastguard Worker def test_error_on_not_implemented(self): 11879*da0073e9SAndroid Build Coastguard Worker fn, args, kwargs, _ = self._get_not_implemented_op() 11880*da0073e9SAndroid Build Coastguard Worker 11881*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "not currently implemented for the MPS device"): 11882*da0073e9SAndroid Build Coastguard Worker fn(*args, **kwargs) 11883*da0073e9SAndroid Build Coastguard Worker 11884*da0073e9SAndroid Build Coastguard Worker def test_warn_on_not_implemented_with_fallback(self): 11885*da0073e9SAndroid Build Coastguard Worker _, _, _, op = self._get_not_implemented_op() 11886*da0073e9SAndroid Build Coastguard Worker script = f""" 11887*da0073e9SAndroid Build Coastguard Workerimport os 11888*da0073e9SAndroid Build Coastguard Worker# MUST happen before pytorch's import 11889*da0073e9SAndroid Build Coastguard Workeros.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 11890*da0073e9SAndroid Build Coastguard Workerimport warnings 11891*da0073e9SAndroid Build Coastguard Worker 11892*da0073e9SAndroid Build Coastguard Workerwith warnings.catch_warnings(record=True) as w: 11893*da0073e9SAndroid Build Coastguard Worker import torch 11894*da0073e9SAndroid Build Coastguard Worker 11895*da0073e9SAndroid Build Coastguard Workerif len(w) > 0: 11896*da0073e9SAndroid Build Coastguard Worker print(w) 11897*da0073e9SAndroid Build Coastguard Worker exit(1) 11898*da0073e9SAndroid Build Coastguard Worker 11899*da0073e9SAndroid Build Coastguard Worker# This should run just fine and raise warning about perf 11900*da0073e9SAndroid Build Coastguard Workerwith warnings.catch_warnings(record=True) as w: 11901*da0073e9SAndroid Build Coastguard Worker {op} 11902*da0073e9SAndroid Build Coastguard Worker 11903*da0073e9SAndroid Build Coastguard Workerif len(w) != 1: 11904*da0073e9SAndroid Build Coastguard Worker print(w) 11905*da0073e9SAndroid Build Coastguard Worker exit(2) 11906*da0073e9SAndroid Build Coastguard Worker""" 11907*da0073e9SAndroid Build Coastguard Worker try: 11908*da0073e9SAndroid Build Coastguard Worker subprocess.check_output( 11909*da0073e9SAndroid Build Coastguard Worker [sys.executable, '-W', 'always', '-c', script], 11910*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 11911*da0073e9SAndroid Build Coastguard Worker # On Windows, opening the subprocess with the default CWD makes `import torch` 11912*da0073e9SAndroid Build Coastguard Worker # fail, so just set CWD to this script's directory 11913*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)),) 11914*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError as e: 11915*da0073e9SAndroid Build Coastguard Worker if e.returncode == 1: 11916*da0073e9SAndroid Build Coastguard Worker self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set." + 11917*da0073e9SAndroid Build Coastguard Worker e.output.decode("utf-8")) 11918*da0073e9SAndroid Build Coastguard Worker elif e.returncode == 2: 11919*da0073e9SAndroid Build Coastguard Worker self.assertTrue(False, "There wasn't exactly one warning when running not implemented op with " 11920*da0073e9SAndroid Build Coastguard Worker f"PYTORCH_ENABLE_MPS_FALLBACK set. {e.output}") 11921*da0073e9SAndroid Build Coastguard Worker else: 11922*da0073e9SAndroid Build Coastguard Worker self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set. " + 11923*da0073e9SAndroid Build Coastguard Worker e.output.decode("utf-8")) 11924*da0073e9SAndroid Build Coastguard Worker 11925*da0073e9SAndroid Build Coastguard Workerclass TestNoRegression(TestCase): 11926*da0073e9SAndroid Build Coastguard Worker def test_assert_close(self): 11927*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, device="mps") 11928*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(1, device="mps") 11929*da0073e9SAndroid Build Coastguard Worker inf = a / b 11930*da0073e9SAndroid Build Coastguard Worker nan = b / b 11931*da0073e9SAndroid Build Coastguard Worker 11932*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): 11933*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(a, inf) 11934*da0073e9SAndroid Build Coastguard Worker 11935*da0073e9SAndroid Build Coastguard Worker # TODO: The NaN test is failing when all the tests in test_mps are run 11936*da0073e9SAndroid Build Coastguard Worker # together but passes when run separately. There seems to be memory 11937*da0073e9SAndroid Build Coastguard Worker # corruption which needs to be fixed for this test to be enabled. 11938*da0073e9SAndroid Build Coastguard Worker # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): 11939*da0073e9SAndroid Build Coastguard Worker # torch.testing.assert_close(a, nan) 11940*da0073e9SAndroid Build Coastguard Worker 11941*da0073e9SAndroid Build Coastguard Worker def test_double_error(self): 11942*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"): 11943*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, dtype=torch.float64, device="mps") 11944*da0073e9SAndroid Build Coastguard Worker 11945*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, device="mps") 11946*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"): 11947*da0073e9SAndroid Build Coastguard Worker a = a.double() 11948*da0073e9SAndroid Build Coastguard Worker 11949*da0073e9SAndroid Build Coastguard Worker def test_legacy_constructor(self): 11950*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, device="mps") 11951*da0073e9SAndroid Build Coastguard Worker 11952*da0073e9SAndroid Build Coastguard Worker b = a.new(1) 11953*da0073e9SAndroid Build Coastguard Worker 11954*da0073e9SAndroid Build Coastguard Worker def test_serialization_map_location(self): 11955*da0073e9SAndroid Build Coastguard Worker 11956*da0073e9SAndroid Build Coastguard Worker # Ensures that cpu Tensor can be loaded on mps 11957*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 11958*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2) 11959*da0073e9SAndroid Build Coastguard Worker torch.save(x, f) 11960*da0073e9SAndroid Build Coastguard Worker 11961*da0073e9SAndroid Build Coastguard Worker f.seek(0) 11962*da0073e9SAndroid Build Coastguard Worker x2 = torch.load(f, map_location="mps") 11963*da0073e9SAndroid Build Coastguard Worker 11964*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x2) 11965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x2.device.type, "mps") 11966*da0073e9SAndroid Build Coastguard Worker 11967*da0073e9SAndroid Build Coastguard Worker # Ensures that mps Tensors can be loaded on mps 11968*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 11969*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, device="mps") 11970*da0073e9SAndroid Build Coastguard Worker torch.save(x, f) 11971*da0073e9SAndroid Build Coastguard Worker 11972*da0073e9SAndroid Build Coastguard Worker f.seek(0) 11973*da0073e9SAndroid Build Coastguard Worker x2 = torch.load(f) 11974*da0073e9SAndroid Build Coastguard Worker 11975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x2) 11976*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x2.device.type, "mps") 11977*da0073e9SAndroid Build Coastguard Worker 11978*da0073e9SAndroid Build Coastguard Worker # Ensures that mps Tensors can be loaded on cpu 11979*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 11980*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, device="mps") 11981*da0073e9SAndroid Build Coastguard Worker torch.save(x, f) 11982*da0073e9SAndroid Build Coastguard Worker 11983*da0073e9SAndroid Build Coastguard Worker f.seek(0) 11984*da0073e9SAndroid Build Coastguard Worker x2 = torch.load(f, map_location="cpu") 11985*da0073e9SAndroid Build Coastguard Worker 11986*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x2) 11987*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x2.device.type, "cpu") 11988*da0073e9SAndroid Build Coastguard Worker 11989*da0073e9SAndroid Build Coastguard Worker # Ensures that `mps:0` Tensors can be loaded on mps 11990*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 11991*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, device="mps:0") 11992*da0073e9SAndroid Build Coastguard Worker torch.save(x, f) 11993*da0073e9SAndroid Build Coastguard Worker 11994*da0073e9SAndroid Build Coastguard Worker f.seek(0) 11995*da0073e9SAndroid Build Coastguard Worker x2 = torch.load(f, map_location="mps:0") 11996*da0073e9SAndroid Build Coastguard Worker 11997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x2) 11998*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x2.device.type, "mps") 11999*da0073e9SAndroid Build Coastguard Worker 12000*da0073e9SAndroid Build Coastguard Worker 12001*da0073e9SAndroid Build Coastguard WorkerMPS_DTYPES = get_all_dtypes() 12002*da0073e9SAndroid Build Coastguard Workerfor t in [torch.double, torch.cdouble, torch.cfloat, torch.bfloat16]: 12003*da0073e9SAndroid Build Coastguard Worker del MPS_DTYPES[MPS_DTYPES.index(t)] 12004*da0073e9SAndroid Build Coastguard Worker 12005*da0073e9SAndroid Build Coastguard WorkerMPS_GRAD_DTYPES = [torch.float32, torch.float16] 12006*da0073e9SAndroid Build Coastguard Worker 12007*da0073e9SAndroid Build Coastguard Worker 12008*da0073e9SAndroid Build Coastguard Workerclass TestConsistency(TestCaseMPS): 12009*da0073e9SAndroid Build Coastguard Worker # TODO: This is only used while some ops are being added. 12010*da0073e9SAndroid Build Coastguard Worker # This list should contain all ops and dtypes eventually 12011*da0073e9SAndroid Build Coastguard Worker # This can be generated automatically in the `new_mps_allowlist.txt` file 12012*da0073e9SAndroid Build Coastguard Worker # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU` 12013*da0073e9SAndroid Build Coastguard Worker # You most likely do NOT want to modify this manually 12014*da0073e9SAndroid Build Coastguard Worker 12015*da0073e9SAndroid Build Coastguard Worker FP16_LOW_PRECISION_LIST = { 12016*da0073e9SAndroid Build Coastguard Worker 'add', 'sub', 'div', 'addcdiv', 12017*da0073e9SAndroid Build Coastguard Worker '__rdiv__', '__rmul__', 12018*da0073e9SAndroid Build Coastguard Worker 'nn.functional.huber_loss', 12019*da0073e9SAndroid Build Coastguard Worker 'true_divide', 'kron', 12020*da0073e9SAndroid Build Coastguard Worker 'gradient', 'var', 'std', 'std_mean', 'ldexp', 12021*da0073e9SAndroid Build Coastguard Worker 'linalg.vector_norm', 'lerp', 12022*da0073e9SAndroid Build Coastguard Worker 'addr', 'var_mean', 12023*da0073e9SAndroid Build Coastguard Worker 'var_mean_unbiased', 12024*da0073e9SAndroid Build Coastguard Worker 'acosh', 'asinh', 'asin', 12025*da0073e9SAndroid Build Coastguard Worker 'masked.std', 12026*da0073e9SAndroid Build Coastguard Worker 'nn.functional.normalize', 12027*da0073e9SAndroid Build Coastguard Worker 'nn.functional.triplet_margin_loss', 12028*da0073e9SAndroid Build Coastguard Worker 'nn.functional.triplet_margin_with_distance_loss', 12029*da0073e9SAndroid Build Coastguard Worker 'nn.functional.batch_norm', 12030*da0073e9SAndroid Build Coastguard Worker 'nn.functional.instance_norm', 12031*da0073e9SAndroid Build Coastguard Worker 'round', 'xlogy', 'addcmul', 12032*da0073e9SAndroid Build Coastguard Worker 'nn.functional.cross_entropy', 12033*da0073e9SAndroid Build Coastguard Worker 'nn.functional.binary_cross_entropy', 12034*da0073e9SAndroid Build Coastguard Worker 'nn.functional.nll_loss', 12035*da0073e9SAndroid Build Coastguard Worker 'nn.functional.max_pool2d', 12036*da0073e9SAndroid Build Coastguard Worker 'nn.functional.gelu', 12037*da0073e9SAndroid Build Coastguard Worker 'nn.functional.glu', 12038*da0073e9SAndroid Build Coastguard Worker '_native_batch_norm_legit', 12039*da0073e9SAndroid Build Coastguard Worker '_batch_norm_with_update', 12040*da0073e9SAndroid Build Coastguard Worker 'native_batch_norm', 12041*da0073e9SAndroid Build Coastguard Worker 'softmax', 12042*da0073e9SAndroid Build Coastguard Worker '_softmax_backward_data', 12043*da0073e9SAndroid Build Coastguard Worker 'log_softmax', 12044*da0073e9SAndroid Build Coastguard Worker 'masked.softmax', 12045*da0073e9SAndroid Build Coastguard Worker 'masked.log_softmax', 12046*da0073e9SAndroid Build Coastguard Worker 'masked.softmin', 12047*da0073e9SAndroid Build Coastguard Worker 'nn.functional.kl_div', 12048*da0073e9SAndroid Build Coastguard Worker 'nn.functional.softmin', 12049*da0073e9SAndroid Build Coastguard Worker 'cross', 'linalg.cross', 12050*da0073e9SAndroid Build Coastguard Worker 'prod', 'masked.prod', 12051*da0073e9SAndroid Build Coastguard Worker 'nextafter', 12052*da0073e9SAndroid Build Coastguard Worker 'native_layer_norm', 12053*da0073e9SAndroid Build Coastguard Worker 'nn.functional.layer_norm', 12054*da0073e9SAndroid Build Coastguard Worker 'nn.functional.interpolate', 12055*da0073e9SAndroid Build Coastguard Worker 'nn.functional.upsample_bilinear', 12056*da0073e9SAndroid Build Coastguard Worker 'nn.functional.upsample_nearest', 12057*da0073e9SAndroid Build Coastguard Worker 12058*da0073e9SAndroid Build Coastguard Worker # for macOS 12 12059*da0073e9SAndroid Build Coastguard Worker 'masked.normalize', 'masked.sum', 'masked.var', 12060*da0073e9SAndroid Build Coastguard Worker 'outer', 12061*da0073e9SAndroid Build Coastguard Worker 'sum_to_size', 'sum', 12062*da0073e9SAndroid Build Coastguard Worker 'mul', 12063*da0073e9SAndroid Build Coastguard Worker 'nansum', 'nanmean', 12064*da0073e9SAndroid Build Coastguard Worker 'norm', 12065*da0073e9SAndroid Build Coastguard Worker } 12066*da0073e9SAndroid Build Coastguard Worker 12067*da0073e9SAndroid Build Coastguard Worker FP32_LOW_PRECISION_LIST = { 12068*da0073e9SAndroid Build Coastguard Worker # conv2d and conv_transpose2d results have a very small 12069*da0073e9SAndroid Build Coastguard Worker # difference compared to CPU/CUDA, so we use lower precision on FP32 12070*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv2d', 12071*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv_transpose2d', 12072*da0073e9SAndroid Build Coastguard Worker 'matmul', '__rmatmul__', 12073*da0073e9SAndroid Build Coastguard Worker 'linalg.multi_dot', 12074*da0073e9SAndroid Build Coastguard Worker 'addbmm', 12075*da0073e9SAndroid Build Coastguard Worker } 12076*da0073e9SAndroid Build Coastguard Worker 12077*da0073e9SAndroid Build Coastguard Worker def _compute_tolerances(self, op, dtype): 12078*da0073e9SAndroid Build Coastguard Worker if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype in [torch.float32, torch.complex64]: 12079*da0073e9SAndroid Build Coastguard Worker return (1e-4, 3e-5) 12080*da0073e9SAndroid Build Coastguard Worker 12081*da0073e9SAndroid Build Coastguard Worker if op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16: 12082*da0073e9SAndroid Build Coastguard Worker return (1e-2, 1e-2) 12083*da0073e9SAndroid Build Coastguard Worker 12084*da0073e9SAndroid Build Coastguard Worker if op.name in ['nn.functional.conv_transpose1d', 12085*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv_transpose2d', 12086*da0073e9SAndroid Build Coastguard Worker 'nn.functional.conv_transpose3d', 12087*da0073e9SAndroid Build Coastguard Worker '__rmatmul__', 'addbmm', 'addmv', 12088*da0073e9SAndroid Build Coastguard Worker 'baddbmm', 'cov', 'matmul', 'mv'] and dtype == torch.float16: 12089*da0073e9SAndroid Build Coastguard Worker return (5e-2, 5e-2) 12090*da0073e9SAndroid Build Coastguard Worker if op.name == "masked.mean": 12091*da0073e9SAndroid Build Coastguard Worker return (7e-4, 2e-3) 12092*da0073e9SAndroid Build Coastguard Worker if op.name == "native_layer_norm": 12093*da0073e9SAndroid Build Coastguard Worker return (1e-4, 1.3e-5) 12094*da0073e9SAndroid Build Coastguard Worker if op.name in ["pow", "__rpow__"] and product_version < 13.3: 12095*da0073e9SAndroid Build Coastguard Worker # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721. 12096*da0073e9SAndroid Build Coastguard Worker # fixed in macOS 13.3+ 12097*da0073e9SAndroid Build Coastguard Worker return (1e-6, 2e-3 if dtype == torch.float16 else 4e-6) 12098*da0073e9SAndroid Build Coastguard Worker if op.name == "nn.functional.interpolate": 12099*da0073e9SAndroid Build Coastguard Worker return (1e-3, 1e-4) 12100*da0073e9SAndroid Build Coastguard Worker if op.name in ['fft.rfftn', 'fft.hfftn', 'fft.hfft2', 'fft.fft', 'fft.fftn', 'fft.rfft']: 12101*da0073e9SAndroid Build Coastguard Worker # TODO: Investigate why this is needed 12102*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/120237 12103*da0073e9SAndroid Build Coastguard Worker return (3e-5, 3e-5) 12104*da0073e9SAndroid Build Coastguard Worker return (None, None) 12105*da0073e9SAndroid Build Coastguard Worker 12106*da0073e9SAndroid Build Coastguard Worker # Used for accept mode only 12107*da0073e9SAndroid Build Coastguard Worker NEW_ALLOW_LIST = defaultdict(list) 12108*da0073e9SAndroid Build Coastguard Worker NEW_ALLOW_LIST_GRAD = defaultdict(list) 12109*da0073e9SAndroid Build Coastguard Worker 12110*da0073e9SAndroid Build Coastguard Worker @ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES + [torch.complex64]) 12111*da0073e9SAndroid Build Coastguard Worker def test_output_match(self, device, dtype, op): 12112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device, "cpu") 12113*da0073e9SAndroid Build Coastguard Worker 12114*da0073e9SAndroid Build Coastguard Worker def get_samples(): 12115*da0073e9SAndroid Build Coastguard Worker return op.sample_inputs( 12116*da0073e9SAndroid Build Coastguard Worker device, 12117*da0073e9SAndroid Build Coastguard Worker dtype, 12118*da0073e9SAndroid Build Coastguard Worker requires_grad=(dtype.is_floating_point or dtype.is_complex), 12119*da0073e9SAndroid Build Coastguard Worker # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails 12120*da0073e9SAndroid Build Coastguard Worker set_seed=False, 12121*da0073e9SAndroid Build Coastguard Worker ) 12122*da0073e9SAndroid Build Coastguard Worker cpu_samples = get_samples() 12123*da0073e9SAndroid Build Coastguard Worker 12124*da0073e9SAndroid Build Coastguard Worker for cpu_sample in cpu_samples: 12125*da0073e9SAndroid Build Coastguard Worker # 12126*da0073e9SAndroid Build Coastguard Worker # Forward check 12127*da0073e9SAndroid Build Coastguard Worker # 12128*da0073e9SAndroid Build Coastguard Worker mps_sample = cpu_sample.transform( 12129*da0073e9SAndroid Build Coastguard Worker lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x) 12130*da0073e9SAndroid Build Coastguard Worker 12131*da0073e9SAndroid Build Coastguard Worker cpu_args = [cpu_sample.input] + list(cpu_sample.args) 12132*da0073e9SAndroid Build Coastguard Worker cpu_kwargs = cpu_sample.kwargs 12133*da0073e9SAndroid Build Coastguard Worker mps_args = [mps_sample.input] + list(mps_sample.args) 12134*da0073e9SAndroid Build Coastguard Worker mps_kwargs = mps_sample.kwargs 12135*da0073e9SAndroid Build Coastguard Worker 12136*da0073e9SAndroid Build Coastguard Worker # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only 12137*da0073e9SAndroid Build Coastguard Worker if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor): 12138*da0073e9SAndroid Build Coastguard Worker mps_args[1] = cpu_args[1] 12139*da0073e9SAndroid Build Coastguard Worker 12140*da0073e9SAndroid Build Coastguard Worker cpu_out = op(*cpu_args, **cpu_kwargs) 12141*da0073e9SAndroid Build Coastguard Worker mps_out = op(*mps_args, **mps_kwargs) 12142*da0073e9SAndroid Build Coastguard Worker 12143*da0073e9SAndroid Build Coastguard Worker atol, rtol = self._compute_tolerances(op, dtype) 12144*da0073e9SAndroid Build Coastguard Worker if op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8: 12145*da0073e9SAndroid Build Coastguard Worker atol = 1.0 12146*da0073e9SAndroid Build Coastguard Worker rtol = 0.0 12147*da0073e9SAndroid Build Coastguard Worker 12148*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol) 12149*da0073e9SAndroid Build Coastguard Worker 12150*da0073e9SAndroid Build Coastguard Worker 12151*da0073e9SAndroid Build Coastguard Worker @ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES) 12152*da0073e9SAndroid Build Coastguard Worker def test_output_grad_match(self, device, dtype, op): 12153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device, "cpu") 12154*da0073e9SAndroid Build Coastguard Worker 12155*da0073e9SAndroid Build Coastguard Worker def get_samples(): 12156*da0073e9SAndroid Build Coastguard Worker return op.sample_inputs( 12157*da0073e9SAndroid Build Coastguard Worker device, 12158*da0073e9SAndroid Build Coastguard Worker dtype, 12159*da0073e9SAndroid Build Coastguard Worker requires_grad=(dtype.is_floating_point or dtype.is_complex), 12160*da0073e9SAndroid Build Coastguard Worker # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails 12161*da0073e9SAndroid Build Coastguard Worker set_seed=False, 12162*da0073e9SAndroid Build Coastguard Worker ) 12163*da0073e9SAndroid Build Coastguard Worker cpu_samples = get_samples() 12164*da0073e9SAndroid Build Coastguard Worker 12165*da0073e9SAndroid Build Coastguard Worker for cpu_sample in cpu_samples: 12166*da0073e9SAndroid Build Coastguard Worker # 12167*da0073e9SAndroid Build Coastguard Worker # Forward check 12168*da0073e9SAndroid Build Coastguard Worker # 12169*da0073e9SAndroid Build Coastguard Worker forward_failed = False 12170*da0073e9SAndroid Build Coastguard Worker mps_sample = cpu_sample.transform( 12171*da0073e9SAndroid Build Coastguard Worker lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x) 12172*da0073e9SAndroid Build Coastguard Worker 12173*da0073e9SAndroid Build Coastguard Worker cpu_args = [cpu_sample.input] + list(cpu_sample.args) 12174*da0073e9SAndroid Build Coastguard Worker cpu_kwargs = cpu_sample.kwargs 12175*da0073e9SAndroid Build Coastguard Worker mps_args = [mps_sample.input] + list(mps_sample.args) 12176*da0073e9SAndroid Build Coastguard Worker mps_kwargs = mps_sample.kwargs 12177*da0073e9SAndroid Build Coastguard Worker 12178*da0073e9SAndroid Build Coastguard Worker # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only 12179*da0073e9SAndroid Build Coastguard Worker if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor): 12180*da0073e9SAndroid Build Coastguard Worker mps_args[1] = cpu_args[1] 12181*da0073e9SAndroid Build Coastguard Worker 12182*da0073e9SAndroid Build Coastguard Worker cpu_out = op(*cpu_args, **cpu_kwargs) 12183*da0073e9SAndroid Build Coastguard Worker mps_out = op(*mps_args, **mps_kwargs) 12184*da0073e9SAndroid Build Coastguard Worker 12185*da0073e9SAndroid Build Coastguard Worker if op.name == "unique" and cpu_kwargs["sorted"] is False: 12186*da0073e9SAndroid Build Coastguard Worker continue 12187*da0073e9SAndroid Build Coastguard Worker 12188*da0073e9SAndroid Build Coastguard Worker atol, rtol = self._compute_tolerances(op, dtype) 12189*da0073e9SAndroid Build Coastguard Worker if op.name in ["renorm", "norm", "linalg.norm"] and dtype == torch.float16: 12190*da0073e9SAndroid Build Coastguard Worker atol = 7e-4 12191*da0073e9SAndroid Build Coastguard Worker rtol = 1.5e-3 12192*da0073e9SAndroid Build Coastguard Worker 12193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol) 12194*da0073e9SAndroid Build Coastguard Worker 12195*da0073e9SAndroid Build Coastguard Worker # 12196*da0073e9SAndroid Build Coastguard Worker # Backward check 12197*da0073e9SAndroid Build Coastguard Worker # 12198*da0073e9SAndroid Build Coastguard Worker if forward_failed: 12199*da0073e9SAndroid Build Coastguard Worker # We would've failed immediately anyway, but this error is clearer 12200*da0073e9SAndroid Build Coastguard Worker # We error instead of continuing so that all_backward_pass would not be True 12201*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Forward pass already failed") 12202*da0073e9SAndroid Build Coastguard Worker 12203*da0073e9SAndroid Build Coastguard Worker cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out) 12204*da0073e9SAndroid Build Coastguard Worker mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out) 12205*da0073e9SAndroid Build Coastguard Worker 12206*da0073e9SAndroid Build Coastguard Worker def req_grad(t): 12207*da0073e9SAndroid Build Coastguard Worker return isinstance(t, torch.Tensor) and t.requires_grad 12208*da0073e9SAndroid Build Coastguard Worker 12209*da0073e9SAndroid Build Coastguard Worker diff_cpu_out = tuple(t for t in cpu_out if req_grad(t)) 12210*da0073e9SAndroid Build Coastguard Worker diff_mps_out = tuple(t for t in mps_out if req_grad(t)) 12211*da0073e9SAndroid Build Coastguard Worker diff_cpu_arg = tuple(t for t in pytree.tree_leaves((cpu_args, cpu_kwargs)) if req_grad(t)) 12212*da0073e9SAndroid Build Coastguard Worker diff_mps_arg = tuple(t for t in pytree.tree_leaves((mps_args, mps_kwargs)) if req_grad(t)) 12213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(diff_cpu_out), len(diff_mps_out)) 12214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(diff_cpu_arg), len(diff_mps_arg)) 12215*da0073e9SAndroid Build Coastguard Worker 12216*da0073e9SAndroid Build Coastguard Worker if len(diff_cpu_out) == 0: 12217*da0073e9SAndroid Build Coastguard Worker continue 12218*da0073e9SAndroid Build Coastguard Worker # rand_like does not work with certain dtypes, so cast to double and cast back 12219*da0073e9SAndroid Build Coastguard Worker cpu_grad_outputs = tuple(torch.rand_like(t, dtype=torch.double).to(dtype=t.dtype) for t in diff_cpu_out) 12220*da0073e9SAndroid Build Coastguard Worker mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs) 12221*da0073e9SAndroid Build Coastguard Worker 12222*da0073e9SAndroid Build Coastguard Worker # Compare computed gradients with cpu given random grad_output vector 12223*da0073e9SAndroid Build Coastguard Worker # Sometimes when the derivative is 0, we just don't bother creating the graph 12224*da0073e9SAndroid Build Coastguard Worker # allow_unused is needed in those cases. 12225*da0073e9SAndroid Build Coastguard Worker cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) 12226*da0073e9SAndroid Build Coastguard Worker mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) 12227*da0073e9SAndroid Build Coastguard Worker 12228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) 12229*da0073e9SAndroid Build Coastguard Worker 12230*da0073e9SAndroid Build Coastguard Worker 12231*da0073e9SAndroid Build Coastguard Workerclass TestErrorInputs(TestCase): 12232*da0073e9SAndroid Build Coastguard Worker _ignore_not_implemented_error = True 12233*da0073e9SAndroid Build Coastguard Worker 12234*da0073e9SAndroid Build Coastguard Worker @ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none) 12235*da0073e9SAndroid Build Coastguard Worker def test_error_inputs(self, device, op): 12236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device, "mps:0") 12237*da0073e9SAndroid Build Coastguard Worker 12238*da0073e9SAndroid Build Coastguard Worker # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails 12239*da0073e9SAndroid Build Coastguard Worker mps_samples = op.error_inputs(device, set_seed=False) 12240*da0073e9SAndroid Build Coastguard Worker 12241*da0073e9SAndroid Build Coastguard Worker for mps_sample in mps_samples: 12242*da0073e9SAndroid Build Coastguard Worker mps_sample_input = mps_sample.sample_input 12243*da0073e9SAndroid Build Coastguard Worker error_type = mps_sample.error_type 12244*da0073e9SAndroid Build Coastguard Worker error_regex = mps_sample.error_regex 12245*da0073e9SAndroid Build Coastguard Worker 12246*da0073e9SAndroid Build Coastguard Worker mps_args = [mps_sample_input.input] + list(mps_sample_input.args) 12247*da0073e9SAndroid Build Coastguard Worker mps_kwargs = mps_sample_input.kwargs 12248*da0073e9SAndroid Build Coastguard Worker 12249*da0073e9SAndroid Build Coastguard Worker # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only 12250*da0073e9SAndroid Build Coastguard Worker if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)): 12251*da0073e9SAndroid Build Coastguard Worker mps_args[1] = mps_args[1].cpu() 12252*da0073e9SAndroid Build Coastguard Worker 12253*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(error_type, error_regex): 12254*da0073e9SAndroid Build Coastguard Worker op(*mps_args, **mps_kwargs) 12255*da0073e9SAndroid Build Coastguard Worker 12256*da0073e9SAndroid Build Coastguard Workerclass TestComplex(TestCase): 12257*da0073e9SAndroid Build Coastguard Worker def test_tensor_scalar_binops(self): 12258*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/119088 12259*da0073e9SAndroid Build Coastguard Worker def to_cpu(x): 12260*da0073e9SAndroid Build Coastguard Worker return x.cpu() if isinstance(x, torch.Tensor) else x 12261*da0073e9SAndroid Build Coastguard Worker 12262*da0073e9SAndroid Build Coastguard Worker # Allocate tensors on mps 12263*da0073e9SAndroid Build Coastguard Worker with torch.device("mps"): 12264*da0073e9SAndroid Build Coastguard Worker inputs = [torch.rand(2, dtype=dtype) for dtype in [torch.float, torch.half, torch.cfloat]] 12265*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all(x.device.type == "mps" for x in inputs)) 12266*da0073e9SAndroid Build Coastguard Worker # Add scalars 12267*da0073e9SAndroid Build Coastguard Worker inputs.extend([7, 3.14, 2 + 3j, torch.tensor(4 + 5j, dtype=torch.chalf)]) 12268*da0073e9SAndroid Build Coastguard Worker 12269*da0073e9SAndroid Build Coastguard Worker # Iterate over all permutations of types(int, float, complex, half) and ops (excluding div) 12270*da0073e9SAndroid Build Coastguard Worker for x, y in itertools.product(inputs, inputs): 12271*da0073e9SAndroid Build Coastguard Worker for op_name in ["__add__", "__sub__", "__mul__"]: 12272*da0073e9SAndroid Build Coastguard Worker x_cpu, y_cpu = map(to_cpu, (x, y)) 12273*da0073e9SAndroid Build Coastguard Worker res = getattr(x, op_name)(y) 12274*da0073e9SAndroid Build Coastguard Worker res_cpu = getattr(x_cpu, op_name)(y_cpu) 12275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(to_cpu(res), res_cpu, f"{op_name}({x}, {y}) produces different results {res} vs {res_cpu}") 12276*da0073e9SAndroid Build Coastguard Worker 12277*da0073e9SAndroid Build Coastguard Worker 12278*da0073e9SAndroid Build Coastguard Worker# Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS 12279*da0073e9SAndroid Build Coastguard Worker@skipIfSlowGradcheckEnv 12280*da0073e9SAndroid Build Coastguard Workerclass TestCommon(TestCase): 12281*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 12282*da0073e9SAndroid Build Coastguard Worker 12283*da0073e9SAndroid Build Coastguard Worker # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI 12284*da0073e9SAndroid Build Coastguard Worker @classmethod 12285*da0073e9SAndroid Build Coastguard Worker def tearDownClass(cls): 12286*da0073e9SAndroid Build Coastguard Worker super().tearDownClass() 12287*da0073e9SAndroid Build Coastguard Worker 12288*da0073e9SAndroid Build Coastguard Worker if IS_CI: 12289*da0073e9SAndroid Build Coastguard Worker err_msg = ( 12290*da0073e9SAndroid Build Coastguard Worker "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries." 12291*da0073e9SAndroid Build Coastguard Worker "This is OK for testing, but be sure to set the dtypes manually before landing your PR!" 12292*da0073e9SAndroid Build Coastguard Worker ) 12293*da0073e9SAndroid Build Coastguard Worker # Assure no opinfo entry has dynamic_dtypes 12294*da0073e9SAndroid Build Coastguard Worker filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db)) 12295*da0073e9SAndroid Build Coastguard Worker for op in filtered_ops: 12296*da0073e9SAndroid Build Coastguard Worker fmt_str = opinfo.utils.str_format_dynamic_dtype(op) 12297*da0073e9SAndroid Build Coastguard Worker err_msg += "\n" + fmt_str 12298*da0073e9SAndroid Build Coastguard Worker 12299*da0073e9SAndroid Build Coastguard Worker assert len(filtered_ops) == 0, err_msg 12300*da0073e9SAndroid Build Coastguard Worker 12301*da0073e9SAndroid Build Coastguard Worker # This is the MPS equivalent of `test_numpy_ref` from `test_ops.py`. It lives over here while 12302*da0073e9SAndroid Build Coastguard Worker # MPS still requires some fairly heavy special casing in the test framework. 12303*da0073e9SAndroid Build Coastguard Worker # When MPS becomes more consistent, this can probably be merged with that test using 12304*da0073e9SAndroid Build Coastguard Worker # `@dtypesIfMPS(torch.float32)`, but for now, the assertions themselves need to be loosened 12305*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 12306*da0073e9SAndroid Build Coastguard Worker # MPS only supports float32 12307*da0073e9SAndroid Build Coastguard Worker @ops(_ref_test_ops, allowed_dtypes=(torch.float32,)) 12308*da0073e9SAndroid Build Coastguard Worker def test_numpy_ref_mps(self, device, dtype, op): 12309*da0073e9SAndroid Build Coastguard Worker # Unlike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS 12310*da0073e9SAndroid Build Coastguard Worker # does not support float64 Tensors. 12311*da0073e9SAndroid Build Coastguard Worker # A few ops are currently broken on their reference inputs, but not their sample inputs. These should 12312*da0073e9SAndroid Build Coastguard Worker # get patched up and this workaround removed. 12313*da0073e9SAndroid Build Coastguard Worker broken_on_ref_inputs = op.name in ('where',) 12314*da0073e9SAndroid Build Coastguard Worker 12315*da0073e9SAndroid Build Coastguard Worker # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails 12316*da0073e9SAndroid Build Coastguard Worker inputs = ( 12317*da0073e9SAndroid Build Coastguard Worker op.reference_inputs(device, dtype, set_seed=False) if not broken_on_ref_inputs 12318*da0073e9SAndroid Build Coastguard Worker else op.sample_inputs(device, dtype, set_seed=False) 12319*da0073e9SAndroid Build Coastguard Worker ) 12320*da0073e9SAndroid Build Coastguard Worker for sample_input in inputs: 12321*da0073e9SAndroid Build Coastguard Worker self.compare_with_reference(op, op.ref, sample_input) 12322*da0073e9SAndroid Build Coastguard Worker 12323*da0073e9SAndroid Build Coastguard Worker @dtypes(*get_all_dtypes()) 12324*da0073e9SAndroid Build Coastguard Worker def test_tensor_creation(self, device, dtype): 12325*da0073e9SAndroid Build Coastguard Worker def ones(device): 12326*da0073e9SAndroid Build Coastguard Worker return torch.ones((2, 2), dtype=dtype, device=device) 12327*da0073e9SAndroid Build Coastguard Worker if dtype not in MPS_DTYPES + ([torch.bfloat16, torch.complex64] if product_version > 14.0 else [torch.complex64]): 12328*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 12329*da0073e9SAndroid Build Coastguard Worker ones(device) 12330*da0073e9SAndroid Build Coastguard Worker else: 12331*da0073e9SAndroid Build Coastguard Worker mps_tensor = ones(device) 12332*da0073e9SAndroid Build Coastguard Worker cpu_tensor = ones("cpu") 12333*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mps_tensor.cpu(), cpu_tensor) 12334*da0073e9SAndroid Build Coastguard Worker 12335*da0073e9SAndroid Build Coastguard Worker 12336*da0073e9SAndroid Build Coastguard Worker# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. 12337*da0073e9SAndroid Build Coastguard Worker# This requires mps to be properly registered in the device generic test framework which is not the 12338*da0073e9SAndroid Build Coastguard Worker# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342 12339*da0073e9SAndroid Build Coastguard Worker# to achieve this. 12340*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestConsistency, globals(), only_for="cpu") 12341*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps") 12342*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps") 12343*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps") 12344*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestMPS) 12345*da0073e9SAndroid Build Coastguard Worker 12346*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 12347*da0073e9SAndroid Build Coastguard Worker run_tests() 12348