xref: /aosp_15_r20/external/pytorch/test/test_mps.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: mps"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport io
4*da0073e9SAndroid Build Coastguard Workerimport platform
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerimport math
7*da0073e9SAndroid Build Coastguard Workerimport random
8*da0073e9SAndroid Build Coastguard Workerimport unittest
9*da0073e9SAndroid Build Coastguard Workerimport warnings
10*da0073e9SAndroid Build Coastguard Workerimport subprocess
11*da0073e9SAndroid Build Coastguard Workerimport tempfile
12*da0073e9SAndroid Build Coastguard Workerimport os
13*da0073e9SAndroid Build Coastguard Workerimport copy
14*da0073e9SAndroid Build Coastguard Workerimport gc
15*da0073e9SAndroid Build Coastguard Workerimport threading
16*da0073e9SAndroid Build Coastguard Workerimport torch
17*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
18*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
19*da0073e9SAndroid Build Coastguard Workerimport itertools
20*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
21*da0073e9SAndroid Build Coastguard Workerfrom torch import inf
22*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import Buffer, Parameter
23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import opinfo
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import \
25*da0073e9SAndroid Build Coastguard Worker    (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, IS_CI,
26*da0073e9SAndroid Build Coastguard Worker     NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests)
27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import get_all_dtypes, integral_types
29*da0073e9SAndroid Build Coastguard Workerimport torch.backends.mps
30*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import Uniform, Exponential
31*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
34*da0073e9SAndroid Build Coastguard Worker    op_db,
35*da0073e9SAndroid Build Coastguard Worker    DecorateInfo,
36*da0073e9SAndroid Build Coastguard Worker    UnaryUfuncInfo,
37*da0073e9SAndroid Build Coastguard Worker    ReductionOpInfo,
38*da0073e9SAndroid Build Coastguard Worker    SpectralFuncInfo,
39*da0073e9SAndroid Build Coastguard Worker    BinaryUfuncInfo,
40*da0073e9SAndroid Build Coastguard Worker)
41*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes
42*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import NNTestCase
43*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
44*da0073e9SAndroid Build Coastguard Workerimport numpy as np
45*da0073e9SAndroid Build Coastguard Workerimport torch
46*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
47*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
48*da0073e9SAndroid Build Coastguard Workerimport operator
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Workertest_consistency_op_db = copy.deepcopy(op_db)
51*da0073e9SAndroid Build Coastguard Workertest_error_inputs_op_db = copy.deepcopy(op_db)
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker# Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref`
54*da0073e9SAndroid Build Coastguard Worker_ref_test_ops = tuple(
55*da0073e9SAndroid Build Coastguard Worker    filter(
56*da0073e9SAndroid Build Coastguard Worker        lambda op: not isinstance(
57*da0073e9SAndroid Build Coastguard Worker            op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
58*da0073e9SAndroid Build Coastguard Worker        )
59*da0073e9SAndroid Build Coastguard Worker        and op.ref is not None,
60*da0073e9SAndroid Build Coastguard Worker        op_db,
61*da0073e9SAndroid Build Coastguard Worker    )
62*da0073e9SAndroid Build Coastguard Worker)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Workerdef xfailIf(condition):
65*da0073e9SAndroid Build Coastguard Worker    def wrapper(func):
66*da0073e9SAndroid Build Coastguard Worker        if condition:
67*da0073e9SAndroid Build Coastguard Worker            return unittest.expectedFailure(func)
68*da0073e9SAndroid Build Coastguard Worker        else:
69*da0073e9SAndroid Build Coastguard Worker            return func
70*da0073e9SAndroid Build Coastguard Worker    return wrapper
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Workerdef xfailIfMacOS14_4Plus(func):
73*da0073e9SAndroid Build Coastguard Worker    return unittest.expectedFailure(func) if product_version > 14.3 else func  # noqa: F821
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Workerdef mps_ops_grad_modifier(ops):
76*da0073e9SAndroid Build Coastguard Worker    XFAILLIST_GRAD = {
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        # precision issues
79*da0073e9SAndroid Build Coastguard Worker        'special.polygammaspecial_polygamma_n_0': [torch.float16],
80*da0073e9SAndroid Build Coastguard Worker        'polygammapolygamma_n_0': [torch.float16],
81*da0073e9SAndroid Build Coastguard Worker        'nn.functional.binary_cross_entropy': [torch.float16],
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker        # Unimplemented ops
84*da0073e9SAndroid Build Coastguard Worker        '__getitem__': [torch.float16],
85*da0073e9SAndroid Build Coastguard Worker        '_segment_reduce': [torch.float16, torch.float32],
86*da0073e9SAndroid Build Coastguard Worker        '_chunk_cat': [torch.float16, torch.float32],
87*da0073e9SAndroid Build Coastguard Worker        'unfold_copy': [torch.float16, torch.float32],  # unfold_backward is not implemented
88*da0073e9SAndroid Build Coastguard Worker        'unfold': [torch.float16, torch.float32],
89*da0073e9SAndroid Build Coastguard Worker        'sparse.mmreduce': [torch.float32],  # csr not supported
90*da0073e9SAndroid Build Coastguard Worker        'unique_consecutive': [torch.float16, torch.float32],
91*da0073e9SAndroid Build Coastguard Worker        'special_modified_bessel_i0': [torch.float16, torch.float32],
92*da0073e9SAndroid Build Coastguard Worker        'scalar_tensor': [torch.float16, torch.float32],
93*da0073e9SAndroid Build Coastguard Worker        'cdist': [torch.float32],
94*da0073e9SAndroid Build Coastguard Worker        'masked.scatter': [torch.float16, torch.float32],
95*da0073e9SAndroid Build Coastguard Worker        'index_fill': [torch.float16, torch.float32],  # missing `aten::_unique`.
96*da0073e9SAndroid Build Coastguard Worker        'linalg.lu_factor': [torch.float16, torch.float32],  # missing `aten::lu_unpack`.
97*da0073e9SAndroid Build Coastguard Worker        'aminmax': [torch.float32, torch.float16],
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker        # Correctness issues
100*da0073e9SAndroid Build Coastguard Worker        'atanh': [torch.float32],
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker        # Random output
103*da0073e9SAndroid Build Coastguard Worker        'exponential': [torch.float16, torch.float32],
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        # CPU errors
106*da0073e9SAndroid Build Coastguard Worker        # derivative for aten::nextafter is not implemented on CPU
107*da0073e9SAndroid Build Coastguard Worker        'nextafter': None,
108*da0073e9SAndroid Build Coastguard Worker        # derivative for aten::floor_divide is not implemented on CPU
109*da0073e9SAndroid Build Coastguard Worker        'floor_divide': [torch.float16, torch.float32],
110*da0073e9SAndroid Build Coastguard Worker        # derivative for aten::narrow_copy is not implemented on CPU
111*da0073e9SAndroid Build Coastguard Worker        'narrow_copy': [torch.float16, torch.float32],
112*da0073e9SAndroid Build Coastguard Worker        # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
113*da0073e9SAndroid Build Coastguard Worker        'histogramdd': [torch.float16, torch.float32],
114*da0073e9SAndroid Build Coastguard Worker        # derivative for aten::histogram is not implemented
115*da0073e9SAndroid Build Coastguard Worker        'histogram': [torch.float16, torch.float32],
116*da0073e9SAndroid Build Coastguard Worker        # 'bool' object is not iterable
117*da0073e9SAndroid Build Coastguard Worker        'allclose': [torch.float16, torch.float32],
118*da0073e9SAndroid Build Coastguard Worker        'equal': [torch.float16, torch.float32],
119*da0073e9SAndroid Build Coastguard Worker        # 'float' object is not iterable
120*da0073e9SAndroid Build Coastguard Worker        'item': [torch.float16, torch.float32],
121*da0073e9SAndroid Build Coastguard Worker        # "mse_backward_cpu_out" not implemented for 'Half'
122*da0073e9SAndroid Build Coastguard Worker        'nn.functional.mse_loss': [torch.float16],
123*da0073e9SAndroid Build Coastguard Worker        # "smooth_l1_backward_cpu_out" not implemented for 'Half'
124*da0073e9SAndroid Build Coastguard Worker        'nn.functional.smooth_l1_loss': [torch.float16],
125*da0073e9SAndroid Build Coastguard Worker        # cpu error: grad requires non-empty inputs
126*da0073e9SAndroid Build Coastguard Worker        'randn': [torch.float16, torch.float32],
127*da0073e9SAndroid Build Coastguard Worker        'signal.windows.bartlett': [torch.float32],
128*da0073e9SAndroid Build Coastguard Worker        'signal.windows.blackman': [torch.float32],
129*da0073e9SAndroid Build Coastguard Worker        'signal.windows.cosine': [torch.float32],
130*da0073e9SAndroid Build Coastguard Worker        'signal.windows.exponential': [torch.float32],
131*da0073e9SAndroid Build Coastguard Worker        'signal.windows.gaussian': [torch.float32],
132*da0073e9SAndroid Build Coastguard Worker        'signal.windows.general_cosine': [torch.float32],
133*da0073e9SAndroid Build Coastguard Worker        'signal.windows.general_hamming': [torch.float32],
134*da0073e9SAndroid Build Coastguard Worker        'signal.windows.hamming': [torch.float32],
135*da0073e9SAndroid Build Coastguard Worker        'signal.windows.hann': [torch.float32],
136*da0073e9SAndroid Build Coastguard Worker        'signal.windows.kaiser': [torch.float32],
137*da0073e9SAndroid Build Coastguard Worker        'signal.windows.nuttall': [torch.float32],
138*da0073e9SAndroid Build Coastguard Worker        'eye': [torch.float16, torch.float32],
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        # trunc_tensor not working properly for float16
141*da0073e9SAndroid Build Coastguard Worker        'divtrunc_rounding': [torch.float16],
142*da0073e9SAndroid Build Coastguard Worker        'fmod': [torch.float16],
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        # round not working properly for float16
145*da0073e9SAndroid Build Coastguard Worker        'round': [torch.float16],
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        # atomic operation in backward pass
148*da0073e9SAndroid Build Coastguard Worker        '_unsafe_masked_index': [torch.float16],
149*da0073e9SAndroid Build Coastguard Worker        '_unsafe_masked_index_put_accumulate': [torch.float16],
150*da0073e9SAndroid Build Coastguard Worker    }
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    MACOS_12_3_XFAILLIST_GRAD = {
153*da0073e9SAndroid Build Coastguard Worker        # Unsupported Border padding mode, forward pass success as fallback to cpu
154*da0073e9SAndroid Build Coastguard Worker        'grid_sampler_2d': [torch.float32],
155*da0073e9SAndroid Build Coastguard Worker        # Unimplemented
156*da0073e9SAndroid Build Coastguard Worker        'logaddexp2': [torch.float32],
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker    }
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    MACOS_BEFORE_13_3_XFAILLIST_GRAD = {
161*da0073e9SAndroid Build Coastguard Worker        # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
162*da0073e9SAndroid Build Coastguard Worker        'masked.softmin': [torch.float32, torch.float16],
163*da0073e9SAndroid Build Coastguard Worker        'masked.softmax': [torch.float32, torch.float16],
164*da0073e9SAndroid Build Coastguard Worker        'masked.log_softmax': [torch.float32, torch.float16],
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        # Unsupported Border padding mode, forward pass success as fallback to cpu
167*da0073e9SAndroid Build Coastguard Worker        'grid_sampler_2d': [torch.float32],
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker        # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
170*da0073e9SAndroid Build Coastguard Worker        # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
171*da0073e9SAndroid Build Coastguard Worker        # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
172*da0073e9SAndroid Build Coastguard Worker        # Running `msort` with stable `sort` passes.
173*da0073e9SAndroid Build Coastguard Worker        'msort': [torch.float16],
174*da0073e9SAndroid Build Coastguard Worker    }
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    SKIPLIST_GRAD = {
177*da0073e9SAndroid Build Coastguard Worker        'nn.functional.pairwise_distance': [torch.float16],
178*da0073e9SAndroid Build Coastguard Worker        # failed assertion `destination datatype must be fp32'
179*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv1d': [torch.float16],
180*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv2d': [torch.float16],
181*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv3d': [torch.float16],
182*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv_transpose1d': [torch.float16],
183*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv_transpose2d': [torch.float16],
184*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv_transpose3d': [torch.float16],
185*da0073e9SAndroid Build Coastguard Worker    }
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    MACOS_13_3_XFAILLIST_GRAD = {
188*da0073e9SAndroid Build Coastguard Worker        # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
189*da0073e9SAndroid Build Coastguard Worker        # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
190*da0073e9SAndroid Build Coastguard Worker        # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
191*da0073e9SAndroid Build Coastguard Worker        # Running `msort` with stable `sort` passes.
192*da0073e9SAndroid Build Coastguard Worker        'msort': [torch.float16],
193*da0073e9SAndroid Build Coastguard Worker    }
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    ON_MPS_XFAILLIST = {
196*da0073e9SAndroid Build Coastguard Worker        # Failures due to lack of implementation of downstream functions on MPS backend
197*da0073e9SAndroid Build Coastguard Worker        # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
198*da0073e9SAndroid Build Coastguard Worker        'linalg.matrix_rank': None,
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        # Exception: Caused by sample input at index 3 on MPS
201*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv3d': [torch.float32],
202*da0073e9SAndroid Build Coastguard Worker    }
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker    def addDecorator(op, d) -> None:
205*da0073e9SAndroid Build Coastguard Worker        op.decorators = list(op.decorators) if op.decorators is not None else []
206*da0073e9SAndroid Build Coastguard Worker        op.decorators.append(d)
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    for op in ops:
209*da0073e9SAndroid Build Coastguard Worker        key = op.name + op.variant_test_name
210*da0073e9SAndroid Build Coastguard Worker        if key in XFAILLIST_GRAD:
211*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
212*da0073e9SAndroid Build Coastguard Worker                         unittest.expectedFailure,
213*da0073e9SAndroid Build Coastguard Worker                         dtypes=XFAILLIST_GRAD[key]))
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker        if key in SKIPLIST_GRAD:
216*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
217*da0073e9SAndroid Build Coastguard Worker                         unittest.skip,
218*da0073e9SAndroid Build Coastguard Worker                         dtypes=SKIPLIST_GRAD[key]))
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker        if key in ON_MPS_XFAILLIST:
221*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
222*da0073e9SAndroid Build Coastguard Worker                         unittest.expectedFailure,
223*da0073e9SAndroid Build Coastguard Worker                         dtypes=ON_MPS_XFAILLIST[key]))
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
226*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
227*da0073e9SAndroid Build Coastguard Worker                         unittest.expectedFailure,
228*da0073e9SAndroid Build Coastguard Worker                         dtypes=MACOS_12_3_XFAILLIST_GRAD[key]))
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
231*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
232*da0073e9SAndroid Build Coastguard Worker                         unittest.expectedFailure,
233*da0073e9SAndroid Build Coastguard Worker                         dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key]))
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        if key in MACOS_13_3_XFAILLIST_GRAD and (product_version >= 13.3):
236*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
237*da0073e9SAndroid Build Coastguard Worker                         unittest.expectedFailure,
238*da0073e9SAndroid Build Coastguard Worker                         dtypes=MACOS_13_3_XFAILLIST_GRAD[key]))
239*da0073e9SAndroid Build Coastguard Worker        yield op
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Workerdef mps_ops_modifier(ops):
242*da0073e9SAndroid Build Coastguard Worker    # Supported complex OPS
243*da0073e9SAndroid Build Coastguard Worker    SUPPORTED_COMPLEX_OPS = {
244*da0073e9SAndroid Build Coastguard Worker        '__radd__',
245*da0073e9SAndroid Build Coastguard Worker        '__rmul__',
246*da0073e9SAndroid Build Coastguard Worker        '__getitem__',
247*da0073e9SAndroid Build Coastguard Worker        'abs',
248*da0073e9SAndroid Build Coastguard Worker        'add',
249*da0073e9SAndroid Build Coastguard Worker        'alias_copy',
250*da0073e9SAndroid Build Coastguard Worker        'argwhere',
251*da0073e9SAndroid Build Coastguard Worker        'atleast_1d',
252*da0073e9SAndroid Build Coastguard Worker        'atleast_2d',
253*da0073e9SAndroid Build Coastguard Worker        'atleast_3d',
254*da0073e9SAndroid Build Coastguard Worker        'as_strided',
255*da0073e9SAndroid Build Coastguard Worker        'as_strided_copy',
256*da0073e9SAndroid Build Coastguard Worker        'as_strided_scatter',
257*da0073e9SAndroid Build Coastguard Worker        'broadcast_tensors',
258*da0073e9SAndroid Build Coastguard Worker        'broadcast_to',
259*da0073e9SAndroid Build Coastguard Worker        'chalf',
260*da0073e9SAndroid Build Coastguard Worker        'cfloat',
261*da0073e9SAndroid Build Coastguard Worker        'chunk',
262*da0073e9SAndroid Build Coastguard Worker        'clone',
263*da0073e9SAndroid Build Coastguard Worker        'conj',
264*da0073e9SAndroid Build Coastguard Worker        'conj_physical',
265*da0073e9SAndroid Build Coastguard Worker        'contiguous',
266*da0073e9SAndroid Build Coastguard Worker        'diag',
267*da0073e9SAndroid Build Coastguard Worker        'diag_embed',
268*da0073e9SAndroid Build Coastguard Worker        'diagflat',
269*da0073e9SAndroid Build Coastguard Worker        'diagonal',
270*da0073e9SAndroid Build Coastguard Worker        'diagonal_copy',
271*da0073e9SAndroid Build Coastguard Worker        'diagonal_scatter',
272*da0073e9SAndroid Build Coastguard Worker        'dsplit',
273*da0073e9SAndroid Build Coastguard Worker        'empty',
274*da0073e9SAndroid Build Coastguard Worker        'empty_permuted',
275*da0073e9SAndroid Build Coastguard Worker        'empty_strided',
276*da0073e9SAndroid Build Coastguard Worker        'eye',
277*da0073e9SAndroid Build Coastguard Worker        'exp',
278*da0073e9SAndroid Build Coastguard Worker        'expand',
279*da0073e9SAndroid Build Coastguard Worker        'expand_as',
280*da0073e9SAndroid Build Coastguard Worker        'expand_copy',
281*da0073e9SAndroid Build Coastguard Worker        'flatten',
282*da0073e9SAndroid Build Coastguard Worker        'fill',
283*da0073e9SAndroid Build Coastguard Worker        'full',
284*da0073e9SAndroid Build Coastguard Worker        'H',
285*da0073e9SAndroid Build Coastguard Worker        'hsplit',
286*da0073e9SAndroid Build Coastguard Worker        'imag',
287*da0073e9SAndroid Build Coastguard Worker        'index_select',
288*da0073e9SAndroid Build Coastguard Worker        'isfinite',
289*da0073e9SAndroid Build Coastguard Worker        'isinf',
290*da0073e9SAndroid Build Coastguard Worker        'isreal',
291*da0073e9SAndroid Build Coastguard Worker        'item',
292*da0073e9SAndroid Build Coastguard Worker        'kron',
293*da0073e9SAndroid Build Coastguard Worker        'linalg.diagonal',
294*da0073e9SAndroid Build Coastguard Worker        'linalg.svd',
295*da0073e9SAndroid Build Coastguard Worker        'linspace',
296*da0073e9SAndroid Build Coastguard Worker        'logspace',
297*da0073e9SAndroid Build Coastguard Worker        'linspacetensor_overload',
298*da0073e9SAndroid Build Coastguard Worker        'logspacetensor_overload',
299*da0073e9SAndroid Build Coastguard Worker        'mH',
300*da0073e9SAndroid Build Coastguard Worker        'mT',
301*da0073e9SAndroid Build Coastguard Worker        'masked_scatter',
302*da0073e9SAndroid Build Coastguard Worker        'masked_select',
303*da0073e9SAndroid Build Coastguard Worker        'meshgridlist_of_tensors',
304*da0073e9SAndroid Build Coastguard Worker        'meshgridvariadic_tensors',
305*da0073e9SAndroid Build Coastguard Worker        'movedim',
306*da0073e9SAndroid Build Coastguard Worker        'mul',
307*da0073e9SAndroid Build Coastguard Worker        'narrow',
308*da0073e9SAndroid Build Coastguard Worker        'narrow_copy',
309*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv1d',
310*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv2d',
311*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv_transpose1d',
312*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv_transpose2d',
313*da0073e9SAndroid Build Coastguard Worker        'nn.functional.feature_alpha_dropoutwithout_train',
314*da0073e9SAndroid Build Coastguard Worker        'nn.functional.padcircular',
315*da0073e9SAndroid Build Coastguard Worker        'nn.functional.tanhshrink',
316*da0073e9SAndroid Build Coastguard Worker        'nn.functional.unfold',
317*da0073e9SAndroid Build Coastguard Worker        'nonzero',
318*da0073e9SAndroid Build Coastguard Worker        'ones',
319*da0073e9SAndroid Build Coastguard Worker        'outer',
320*da0073e9SAndroid Build Coastguard Worker        'permute',
321*da0073e9SAndroid Build Coastguard Worker        'positive',
322*da0073e9SAndroid Build Coastguard Worker        'randn',
323*da0073e9SAndroid Build Coastguard Worker        'ravel',
324*da0073e9SAndroid Build Coastguard Worker        'real',
325*da0073e9SAndroid Build Coastguard Worker        'repeat_interleave',
326*da0073e9SAndroid Build Coastguard Worker        'reshape_as',
327*da0073e9SAndroid Build Coastguard Worker        'reshape',
328*da0073e9SAndroid Build Coastguard Worker        'resolve_conj',
329*da0073e9SAndroid Build Coastguard Worker        'resolve_neg',
330*da0073e9SAndroid Build Coastguard Worker        'scalar_tensor',
331*da0073e9SAndroid Build Coastguard Worker        'select',
332*da0073e9SAndroid Build Coastguard Worker        'sgn',
333*da0073e9SAndroid Build Coastguard Worker        'slice',
334*da0073e9SAndroid Build Coastguard Worker        'split',
335*da0073e9SAndroid Build Coastguard Worker        'split_with_sizes',
336*da0073e9SAndroid Build Coastguard Worker        'split_with_sizes_copy',
337*da0073e9SAndroid Build Coastguard Worker        'splitlist_args',
338*da0073e9SAndroid Build Coastguard Worker        'squeeze',
339*da0073e9SAndroid Build Coastguard Worker        'squeezemultiple',
340*da0073e9SAndroid Build Coastguard Worker        'sub',
341*da0073e9SAndroid Build Coastguard Worker        'svd',
342*da0073e9SAndroid Build Coastguard Worker        't',
343*da0073e9SAndroid Build Coastguard Worker        't_copy',
344*da0073e9SAndroid Build Coastguard Worker        'tanh',
345*da0073e9SAndroid Build Coastguard Worker        'tensor_split',
346*da0073e9SAndroid Build Coastguard Worker        'transpose',
347*da0073e9SAndroid Build Coastguard Worker        'T',
348*da0073e9SAndroid Build Coastguard Worker        'unbind',
349*da0073e9SAndroid Build Coastguard Worker        'unflatten',
350*da0073e9SAndroid Build Coastguard Worker        'unfold',
351*da0073e9SAndroid Build Coastguard Worker        'unfold_copy',
352*da0073e9SAndroid Build Coastguard Worker        'unsafe_chunk',
353*da0073e9SAndroid Build Coastguard Worker        'unsafe_split',
354*da0073e9SAndroid Build Coastguard Worker        'unsqueeze',
355*da0073e9SAndroid Build Coastguard Worker        'unsqueeze_copy',
356*da0073e9SAndroid Build Coastguard Worker        'view_as',
357*da0073e9SAndroid Build Coastguard Worker        'view_as_real',
358*da0073e9SAndroid Build Coastguard Worker        'view',
359*da0073e9SAndroid Build Coastguard Worker        'view_copy',
360*da0073e9SAndroid Build Coastguard Worker        'vsplit',
361*da0073e9SAndroid Build Coastguard Worker        'zero_',
362*da0073e9SAndroid Build Coastguard Worker        'zeros',
363*da0073e9SAndroid Build Coastguard Worker    }
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker    AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = {
366*da0073e9SAndroid Build Coastguard Worker        '__rdiv__',
367*da0073e9SAndroid Build Coastguard Worker        '__rmatmul__',
368*da0073e9SAndroid Build Coastguard Worker        '_chunk_cat',
369*da0073e9SAndroid Build Coastguard Worker        '_unsafe_masked_index',
370*da0073e9SAndroid Build Coastguard Worker        'acos',
371*da0073e9SAndroid Build Coastguard Worker        'acosh',
372*da0073e9SAndroid Build Coastguard Worker        'all',
373*da0073e9SAndroid Build Coastguard Worker        'allclose',
374*da0073e9SAndroid Build Coastguard Worker        'any',
375*da0073e9SAndroid Build Coastguard Worker        'addcdiv',
376*da0073e9SAndroid Build Coastguard Worker        'addcmul',
377*da0073e9SAndroid Build Coastguard Worker        'addmmdecomposed',
378*da0073e9SAndroid Build Coastguard Worker        'addmv',
379*da0073e9SAndroid Build Coastguard Worker        'asin',
380*da0073e9SAndroid Build Coastguard Worker        'atan',
381*da0073e9SAndroid Build Coastguard Worker        'atanh',
382*da0073e9SAndroid Build Coastguard Worker        'bfloat16',
383*da0073e9SAndroid Build Coastguard Worker        'bmm',
384*da0073e9SAndroid Build Coastguard Worker        'bool',
385*da0073e9SAndroid Build Coastguard Worker        'cartesian_prod',
386*da0073e9SAndroid Build Coastguard Worker        'cat',
387*da0073e9SAndroid Build Coastguard Worker        'char',
388*da0073e9SAndroid Build Coastguard Worker        'column_stack',
389*da0073e9SAndroid Build Coastguard Worker        'combinations',
390*da0073e9SAndroid Build Coastguard Worker        'corrcoef',
391*da0073e9SAndroid Build Coastguard Worker        'constant_pad_nd',
392*da0073e9SAndroid Build Coastguard Worker        'cos',
393*da0073e9SAndroid Build Coastguard Worker        'cosh',
394*da0073e9SAndroid Build Coastguard Worker        'count_nonzero',
395*da0073e9SAndroid Build Coastguard Worker        'diff',
396*da0073e9SAndroid Build Coastguard Worker        'div',
397*da0073e9SAndroid Build Coastguard Worker        'divno_rounding_mode',
398*da0073e9SAndroid Build Coastguard Worker        'dot',
399*da0073e9SAndroid Build Coastguard Worker        'dstack',
400*da0073e9SAndroid Build Coastguard Worker        'einsum',
401*da0073e9SAndroid Build Coastguard Worker        'eq',
402*da0073e9SAndroid Build Coastguard Worker        'equal',
403*da0073e9SAndroid Build Coastguard Worker        'exp2',
404*da0073e9SAndroid Build Coastguard Worker        'expm1',
405*da0073e9SAndroid Build Coastguard Worker        'fft.fft',
406*da0073e9SAndroid Build Coastguard Worker        'fft.fft2',
407*da0073e9SAndroid Build Coastguard Worker        'fft.fftn',
408*da0073e9SAndroid Build Coastguard Worker        'fft.fftshift',
409*da0073e9SAndroid Build Coastguard Worker        'fft.ifft',
410*da0073e9SAndroid Build Coastguard Worker        'fft.ifft2',
411*da0073e9SAndroid Build Coastguard Worker        'fft.ifftn',
412*da0073e9SAndroid Build Coastguard Worker        'fft.ifftshift',
413*da0073e9SAndroid Build Coastguard Worker        'fft.irfftn',
414*da0073e9SAndroid Build Coastguard Worker        'fft.irfft2',
415*da0073e9SAndroid Build Coastguard Worker        'fft.irfft',
416*da0073e9SAndroid Build Coastguard Worker        'fft.hfftn',
417*da0073e9SAndroid Build Coastguard Worker        'fft.hfft2',
418*da0073e9SAndroid Build Coastguard Worker        'fft.hfft',
419*da0073e9SAndroid Build Coastguard Worker        'flip',
420*da0073e9SAndroid Build Coastguard Worker        'fliplr',
421*da0073e9SAndroid Build Coastguard Worker        'flipud',
422*da0073e9SAndroid Build Coastguard Worker        'float',
423*da0073e9SAndroid Build Coastguard Worker        'gradient',
424*da0073e9SAndroid Build Coastguard Worker        'half',
425*da0073e9SAndroid Build Coastguard Worker        'hstack',
426*da0073e9SAndroid Build Coastguard Worker        'inner',
427*da0073e9SAndroid Build Coastguard Worker        'int',
428*da0073e9SAndroid Build Coastguard Worker        'isclose',
429*da0073e9SAndroid Build Coastguard Worker        'isnan',
430*da0073e9SAndroid Build Coastguard Worker        'ldexp',
431*da0073e9SAndroid Build Coastguard Worker        'linalg.multi_dot',
432*da0073e9SAndroid Build Coastguard Worker        'linalg.pinv',
433*da0073e9SAndroid Build Coastguard Worker        'log10',
434*da0073e9SAndroid Build Coastguard Worker        'log1p',
435*da0073e9SAndroid Build Coastguard Worker        'log2',
436*da0073e9SAndroid Build Coastguard Worker        'log',
437*da0073e9SAndroid Build Coastguard Worker        'logical_and',
438*da0073e9SAndroid Build Coastguard Worker        'logical_not',
439*da0073e9SAndroid Build Coastguard Worker        'logical_or',
440*da0073e9SAndroid Build Coastguard Worker        'logical_xor',
441*da0073e9SAndroid Build Coastguard Worker        'logsumexp',
442*da0073e9SAndroid Build Coastguard Worker        'long',
443*da0073e9SAndroid Build Coastguard Worker        'masked_fill',
444*da0073e9SAndroid Build Coastguard Worker        'masked.mean',
445*da0073e9SAndroid Build Coastguard Worker        'masked.prod',
446*da0073e9SAndroid Build Coastguard Worker        'masked.std',
447*da0073e9SAndroid Build Coastguard Worker        'masked.sum',
448*da0073e9SAndroid Build Coastguard Worker        'masked.var',
449*da0073e9SAndroid Build Coastguard Worker        'masked.logsumexp',
450*da0073e9SAndroid Build Coastguard Worker        'matmul',
451*da0073e9SAndroid Build Coastguard Worker        'mean',
452*da0073e9SAndroid Build Coastguard Worker        'mm',
453*da0073e9SAndroid Build Coastguard Worker        'mv',
454*da0073e9SAndroid Build Coastguard Worker        'ne',
455*da0073e9SAndroid Build Coastguard Worker        'neg',
456*da0073e9SAndroid Build Coastguard Worker        'nn.functional.padconstant',
457*da0073e9SAndroid Build Coastguard Worker        'nn.functional.padreflect',
458*da0073e9SAndroid Build Coastguard Worker        'nn.functional.padreplicate',
459*da0073e9SAndroid Build Coastguard Worker        'nn.functional.pixel_shuffle',
460*da0073e9SAndroid Build Coastguard Worker        'nn.functional.pixel_unshuffle',
461*da0073e9SAndroid Build Coastguard Worker        'nn.functional.rms_norm',
462*da0073e9SAndroid Build Coastguard Worker        'nn.functional.softsign',
463*da0073e9SAndroid Build Coastguard Worker        'pinverse',
464*da0073e9SAndroid Build Coastguard Worker        'prod',
465*da0073e9SAndroid Build Coastguard Worker        'reciprocal',
466*da0073e9SAndroid Build Coastguard Worker        'roll',
467*da0073e9SAndroid Build Coastguard Worker        'rot90',
468*da0073e9SAndroid Build Coastguard Worker        'rsqrt',
469*da0073e9SAndroid Build Coastguard Worker        'short',
470*da0073e9SAndroid Build Coastguard Worker        'sigmoid',
471*da0073e9SAndroid Build Coastguard Worker        'sin',
472*da0073e9SAndroid Build Coastguard Worker        'sinh',
473*da0073e9SAndroid Build Coastguard Worker        'sqrt',
474*da0073e9SAndroid Build Coastguard Worker        'square',
475*da0073e9SAndroid Build Coastguard Worker        'stack',
476*da0073e9SAndroid Build Coastguard Worker        'stft',
477*da0073e9SAndroid Build Coastguard Worker        'sum',
478*da0073e9SAndroid Build Coastguard Worker        'sum_to_size',
479*da0073e9SAndroid Build Coastguard Worker        'tan',
480*da0073e9SAndroid Build Coastguard Worker        'tensordot',
481*da0073e9SAndroid Build Coastguard Worker        'trace',
482*da0073e9SAndroid Build Coastguard Worker        'trapz',
483*da0073e9SAndroid Build Coastguard Worker        'trapezoid',
484*da0073e9SAndroid Build Coastguard Worker        'tril',
485*da0073e9SAndroid Build Coastguard Worker        'triu',
486*da0073e9SAndroid Build Coastguard Worker        'true_divide',
487*da0073e9SAndroid Build Coastguard Worker        'vstack',
488*da0073e9SAndroid Build Coastguard Worker        'where',
489*da0073e9SAndroid Build Coastguard Worker        'byte',
490*da0073e9SAndroid Build Coastguard Worker    }
491*da0073e9SAndroid Build Coastguard Worker    # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758
492*da0073e9SAndroid Build Coastguard Worker    MACOS_12_3_XFAILLIST = {
493*da0073e9SAndroid Build Coastguard Worker        # Top 60
494*da0073e9SAndroid Build Coastguard Worker        # expected failures
495*da0073e9SAndroid Build Coastguard Worker        # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
496*da0073e9SAndroid Build Coastguard Worker        # fixed in macOS 13.3. Currently error is not raised.
497*da0073e9SAndroid Build Coastguard Worker        'pow': [torch.int16, torch.int64, torch.uint8, torch.int8],
498*da0073e9SAndroid Build Coastguard Worker        # expected failures
499*da0073e9SAndroid Build Coastguard Worker        '__rpow__': [torch.uint8, torch.int8],
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker        # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
502*da0073e9SAndroid Build Coastguard Worker        'cdist': [torch.float32],
503*da0073e9SAndroid Build Coastguard Worker        'tan': [torch.uint8, torch.float32],
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker        # Data type support starts from macOS 13
506*da0073e9SAndroid Build Coastguard Worker        'nn.functional.avg_pool1d': [torch.int64],
507*da0073e9SAndroid Build Coastguard Worker        'nn.functional.avg_pool2d': [torch.int64],
508*da0073e9SAndroid Build Coastguard Worker        'nn.functional.local_response_norm': [torch.int64],
509*da0073e9SAndroid Build Coastguard Worker        '__radd__': [torch.uint8],
510*da0073e9SAndroid Build Coastguard Worker        '__rdiv__': [torch.uint8],
511*da0073e9SAndroid Build Coastguard Worker        '__rmul__': [torch.uint8],
512*da0073e9SAndroid Build Coastguard Worker        'abs': [torch.uint8],
513*da0073e9SAndroid Build Coastguard Worker        'acos': [torch.uint8],
514*da0073e9SAndroid Build Coastguard Worker        'acosh': [torch.uint8],
515*da0073e9SAndroid Build Coastguard Worker        'add': [torch.uint8],
516*da0073e9SAndroid Build Coastguard Worker        'asin': [torch.uint8],
517*da0073e9SAndroid Build Coastguard Worker        'asinh': [torch.uint8],
518*da0073e9SAndroid Build Coastguard Worker        'atan': [torch.uint8],
519*da0073e9SAndroid Build Coastguard Worker        'atanh': [torch.uint8],
520*da0073e9SAndroid Build Coastguard Worker        'ceil': [torch.uint8],
521*da0073e9SAndroid Build Coastguard Worker        'corrcoef': [torch.uint8],
522*da0073e9SAndroid Build Coastguard Worker        'cos': [torch.uint8],
523*da0073e9SAndroid Build Coastguard Worker        'cosh': [torch.uint8],
524*da0073e9SAndroid Build Coastguard Worker        'cov': [torch.uint8],
525*da0073e9SAndroid Build Coastguard Worker        'cumulative_trapezoid': [torch.uint8],
526*da0073e9SAndroid Build Coastguard Worker        'deg2rad': [torch.uint8],
527*da0073e9SAndroid Build Coastguard Worker        'diff': [torch.uint8],
528*da0073e9SAndroid Build Coastguard Worker        'eq': [torch.uint8],
529*da0073e9SAndroid Build Coastguard Worker        'equal': [torch.uint8],
530*da0073e9SAndroid Build Coastguard Worker        'erf': [torch.uint8],
531*da0073e9SAndroid Build Coastguard Worker        'exp2': [torch.uint8],
532*da0073e9SAndroid Build Coastguard Worker        'exp': [torch.uint8],
533*da0073e9SAndroid Build Coastguard Worker        'expm1': [torch.uint8],
534*da0073e9SAndroid Build Coastguard Worker        'floor': [torch.uint8],
535*da0073e9SAndroid Build Coastguard Worker        'fmax': [torch.uint8],
536*da0073e9SAndroid Build Coastguard Worker        'fmin': [torch.uint8],
537*da0073e9SAndroid Build Coastguard Worker        'fmod': [torch.uint8],
538*da0073e9SAndroid Build Coastguard Worker        'ge': [torch.uint8],
539*da0073e9SAndroid Build Coastguard Worker        'gt': [torch.uint8],
540*da0073e9SAndroid Build Coastguard Worker        'isclose': [torch.uint8],
541*da0073e9SAndroid Build Coastguard Worker        'isnan': [torch.uint8],
542*da0073e9SAndroid Build Coastguard Worker        'kron': [torch.uint8],
543*da0073e9SAndroid Build Coastguard Worker        'le': [torch.uint8],
544*da0073e9SAndroid Build Coastguard Worker        'log10': [torch.uint8],
545*da0073e9SAndroid Build Coastguard Worker        'log1p': [torch.uint8],
546*da0073e9SAndroid Build Coastguard Worker        'log2': [torch.uint8],
547*da0073e9SAndroid Build Coastguard Worker        'log': [torch.uint8],
548*da0073e9SAndroid Build Coastguard Worker        'logical_and': [torch.uint8],
549*da0073e9SAndroid Build Coastguard Worker        'logical_or': [torch.uint8],
550*da0073e9SAndroid Build Coastguard Worker        'logical_xor': [torch.uint8],
551*da0073e9SAndroid Build Coastguard Worker        'logit': [torch.uint8],
552*da0073e9SAndroid Build Coastguard Worker        'lt': [torch.uint8],
553*da0073e9SAndroid Build Coastguard Worker        'masked.mean': [torch.uint8],
554*da0073e9SAndroid Build Coastguard Worker        'masked.std': [torch.uint8],
555*da0073e9SAndroid Build Coastguard Worker        'masked.var': [torch.uint8],
556*da0073e9SAndroid Build Coastguard Worker        'maximum': [torch.uint8],
557*da0073e9SAndroid Build Coastguard Worker        'minimum': [torch.uint8],
558*da0073e9SAndroid Build Coastguard Worker        'mul': [torch.uint8],
559*da0073e9SAndroid Build Coastguard Worker        'ne': [torch.uint8],
560*da0073e9SAndroid Build Coastguard Worker        'neg': [torch.uint8],
561*da0073e9SAndroid Build Coastguard Worker        'nn.functional.cosine_embedding_loss': [torch.uint8],
562*da0073e9SAndroid Build Coastguard Worker        'nn.functional.margin_ranking_loss': [torch.uint8],
563*da0073e9SAndroid Build Coastguard Worker        'nn.functional.poisson_nll_loss': [torch.uint8],
564*da0073e9SAndroid Build Coastguard Worker        'nn.functional.softsign': [torch.uint8],
565*da0073e9SAndroid Build Coastguard Worker        'nn.functional.tanhshrink': [torch.uint8],
566*da0073e9SAndroid Build Coastguard Worker        'nn.functional.triplet_margin_loss': [torch.uint8],
567*da0073e9SAndroid Build Coastguard Worker        'nn.functional.triplet_margin_with_distance_loss': [torch.uint8],
568*da0073e9SAndroid Build Coastguard Worker        'nn.functional.pairwise_distance': [torch.uint8],
569*da0073e9SAndroid Build Coastguard Worker        'outer': [torch.uint8],
570*da0073e9SAndroid Build Coastguard Worker        'rad2deg': [torch.uint8],
571*da0073e9SAndroid Build Coastguard Worker        'reciprocal': [torch.uint8],
572*da0073e9SAndroid Build Coastguard Worker        'remainder': [torch.uint8],
573*da0073e9SAndroid Build Coastguard Worker        'round': [torch.uint8],
574*da0073e9SAndroid Build Coastguard Worker        'rsqrt': [torch.uint8],
575*da0073e9SAndroid Build Coastguard Worker        'sigmoid': [torch.uint8],
576*da0073e9SAndroid Build Coastguard Worker        'sign': [torch.uint8],
577*da0073e9SAndroid Build Coastguard Worker        'signbit': [torch.uint8],
578*da0073e9SAndroid Build Coastguard Worker        'sin': [torch.uint8],
579*da0073e9SAndroid Build Coastguard Worker        'sinh': [torch.uint8],
580*da0073e9SAndroid Build Coastguard Worker        'special.ndtr': [torch.uint8],
581*da0073e9SAndroid Build Coastguard Worker        'sqrt': [torch.uint8],
582*da0073e9SAndroid Build Coastguard Worker        'sub': [torch.uint8],
583*da0073e9SAndroid Build Coastguard Worker        'trapezoid': [torch.uint8],
584*da0073e9SAndroid Build Coastguard Worker        'trapz': [torch.uint8],
585*da0073e9SAndroid Build Coastguard Worker        'true_divide': [torch.uint8],
586*da0073e9SAndroid Build Coastguard Worker        'trunc': [torch.uint8],
587*da0073e9SAndroid Build Coastguard Worker        'xlogy': [torch.uint8],
588*da0073e9SAndroid Build Coastguard Worker        'minbinary': [torch.uint8],
589*da0073e9SAndroid Build Coastguard Worker        'maxbinary': [torch.uint8],
590*da0073e9SAndroid Build Coastguard Worker        'divtrunc_rounding': [torch.uint8],
591*da0073e9SAndroid Build Coastguard Worker        'divfloor_rounding': [torch.uint8],
592*da0073e9SAndroid Build Coastguard Worker        'divno_rounding_mode': [torch.uint8],
593*da0073e9SAndroid Build Coastguard Worker        'floor_divide': [torch.uint8],
594*da0073e9SAndroid Build Coastguard Worker        'ldexp': [torch.uint8],
595*da0073e9SAndroid Build Coastguard Worker        # square internally calls into power, and will type cast to int64, which supports starting from macOS 13
596*da0073e9SAndroid Build Coastguard Worker        'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker        # cpu not giving nan for x/0.0
599*da0073e9SAndroid Build Coastguard Worker        'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker        # inconsistency errors between cpu and mps, max seen atol is 2
602*da0073e9SAndroid Build Coastguard Worker        'nn.functional.interpolatebilinear': [torch.uint8],
603*da0073e9SAndroid Build Coastguard Worker    }
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker    MACOS_BEFORE_13_3_XFAILLIST = {
606*da0073e9SAndroid Build Coastguard Worker        # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
607*da0073e9SAndroid Build Coastguard Worker        'tan': [torch.float32],
608*da0073e9SAndroid Build Coastguard Worker        'cdist': [torch.float32],
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker        # CPU Error: cpu not giving nan for x/0.0
611*da0073e9SAndroid Build Coastguard Worker        'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker        # test blow pass on macOS 12 as it falls back to cpu
614*da0073e9SAndroid Build Coastguard Worker        # Argsort case using duplicate indices (undefined behaviour):
615*da0073e9SAndroid Build Coastguard Worker        #  - CPU output: tensor([2546, 6917, 3181,  ..., 7128, 5133,   30], devuce='cpu')
616*da0073e9SAndroid Build Coastguard Worker        #  - MPS output: tensor([2546, 6917, 3181,  ..., 7128,   30, 5133], device='mps:0')
617*da0073e9SAndroid Build Coastguard Worker        # Elements from index 30 and 5133 are both equal.
618*da0073e9SAndroid Build Coastguard Worker        # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
619*da0073e9SAndroid Build Coastguard Worker        'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
620*da0073e9SAndroid Build Coastguard Worker        # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
621*da0073e9SAndroid Build Coastguard Worker        # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
622*da0073e9SAndroid Build Coastguard Worker        'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
623*da0073e9SAndroid Build Coastguard Worker        # Unsupported dtypes
624*da0073e9SAndroid Build Coastguard Worker        'cumsum': [torch.int64],
625*da0073e9SAndroid Build Coastguard Worker        'cumprod': [torch.int64],
626*da0073e9SAndroid Build Coastguard Worker        'cumulative_trapezoid': [torch.int64],
627*da0073e9SAndroid Build Coastguard Worker        'masked.cumsum': [torch.int64],
628*da0073e9SAndroid Build Coastguard Worker        'masked.cumprod': [torch.int64],
629*da0073e9SAndroid Build Coastguard Worker        'linalg.vander': [torch.int64],
630*da0073e9SAndroid Build Coastguard Worker    }
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker    MACOS_AFTER_13_1_XFAILLIST = {
633*da0073e9SAndroid Build Coastguard Worker        # before macOS 13.2 it falls back to cpu and pass the forward pass
634*da0073e9SAndroid Build Coastguard Worker        'grid_sampler_2d': [torch.float32],  # Unsupported Border padding mode
635*da0073e9SAndroid Build Coastguard Worker        # inconsistency errors between cpu and mps, max seen atol is 2
636*da0073e9SAndroid Build Coastguard Worker        'nn.functional.interpolatebilinear': [torch.uint8],
637*da0073e9SAndroid Build Coastguard Worker    }
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker    MACOS_13_3_XFAILLIST = {
640*da0073e9SAndroid Build Coastguard Worker        # Failure due to precision issue for fp16
641*da0073e9SAndroid Build Coastguard Worker        # on both cpu and mps there are test cases that might produce inf result
642*da0073e9SAndroid Build Coastguard Worker        # 'nn.functional.pairwise_distance': [torch.float16],
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Worker        # test blow pass on macOS 12 as it falls back to cpu
645*da0073e9SAndroid Build Coastguard Worker        # Argsort case using duplicate indices (undefined behaviour):
646*da0073e9SAndroid Build Coastguard Worker        #  - CPU output: tensor([2546, 6917, 3181,  ..., 7128, 5133,   30], devuce='cpu')
647*da0073e9SAndroid Build Coastguard Worker        #  - MPS output: tensor([2546, 6917, 3181,  ..., 7128,   30, 5133], device='mps:0')
648*da0073e9SAndroid Build Coastguard Worker        # Elements from index 30 and 5133 are both equal.
649*da0073e9SAndroid Build Coastguard Worker        # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
650*da0073e9SAndroid Build Coastguard Worker        'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
651*da0073e9SAndroid Build Coastguard Worker        # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
652*da0073e9SAndroid Build Coastguard Worker        # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
653*da0073e9SAndroid Build Coastguard Worker        'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
654*da0073e9SAndroid Build Coastguard Worker    }
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker    MACOS_BEFORE_14_4_XFAILLIST = {
657*da0073e9SAndroid Build Coastguard Worker        # These ops work fine in 14.4 but fail in 14.2 or 13.x
658*da0073e9SAndroid Build Coastguard Worker        'fft.hfft2': [torch.complex64],
659*da0073e9SAndroid Build Coastguard Worker    }
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker    # Those ops are not expected to work
662*da0073e9SAndroid Build Coastguard Worker    UNIMPLEMENTED_XFAILLIST = {
663*da0073e9SAndroid Build Coastguard Worker        # Failures due to lack of op implementation on MPS backend
664*da0073e9SAndroid Build Coastguard Worker        'login': None,
665*da0073e9SAndroid Build Coastguard Worker        'linalg.eig': None,
666*da0073e9SAndroid Build Coastguard Worker        'linalg.eigvals': None,
667*da0073e9SAndroid Build Coastguard Worker        'put': None,
668*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv_transpose3d': None,
669*da0073e9SAndroid Build Coastguard Worker        'rounddecimals_neg_3': None,
670*da0073e9SAndroid Build Coastguard Worker        'rounddecimals_3': None,
671*da0073e9SAndroid Build Coastguard Worker        'rounddecimals_0': None,
672*da0073e9SAndroid Build Coastguard Worker        '__rsub__': None,
673*da0073e9SAndroid Build Coastguard Worker        'angle': None,
674*da0073e9SAndroid Build Coastguard Worker        'cauchy_': None,
675*da0073e9SAndroid Build Coastguard Worker        'cauchy': None,
676*da0073e9SAndroid Build Coastguard Worker        'cholesky': None,
677*da0073e9SAndroid Build Coastguard Worker        'cholesky_inverse': None,
678*da0073e9SAndroid Build Coastguard Worker        'cholesky_solve': None,
679*da0073e9SAndroid Build Coastguard Worker        'cummax': None,
680*da0073e9SAndroid Build Coastguard Worker        'cummin': None,
681*da0073e9SAndroid Build Coastguard Worker        'erfc': None,
682*da0073e9SAndroid Build Coastguard Worker        'frexp': None,
683*da0073e9SAndroid Build Coastguard Worker        'gcd': None,
684*da0073e9SAndroid Build Coastguard Worker        'geqrf': None,
685*da0073e9SAndroid Build Coastguard Worker        'nn.functional.grid_sample': None,  # Unsupported Border padding mode
686*da0073e9SAndroid Build Coastguard Worker        'heaviside': None,
687*da0073e9SAndroid Build Coastguard Worker        'i0': None,
688*da0073e9SAndroid Build Coastguard Worker        'igamma': None,
689*da0073e9SAndroid Build Coastguard Worker        'igammac': None,
690*da0073e9SAndroid Build Coastguard Worker        'index_copy': None,
691*da0073e9SAndroid Build Coastguard Worker        'index_reduceprod': None,
692*da0073e9SAndroid Build Coastguard Worker        'index_reducemean': None,
693*da0073e9SAndroid Build Coastguard Worker        'index_reduceamax': None,
694*da0073e9SAndroid Build Coastguard Worker        'index_reduceamin': None,
695*da0073e9SAndroid Build Coastguard Worker        'isneginf': None,
696*da0073e9SAndroid Build Coastguard Worker        'isposinf': None,
697*da0073e9SAndroid Build Coastguard Worker        'kthvalue': None,
698*da0073e9SAndroid Build Coastguard Worker        'lcm': None,
699*da0073e9SAndroid Build Coastguard Worker        'linalg.cholesky': None,
700*da0073e9SAndroid Build Coastguard Worker        'linalg.cholesky_ex': None,
701*da0073e9SAndroid Build Coastguard Worker        'linalg.cond': None,
702*da0073e9SAndroid Build Coastguard Worker        'linalg.detsingular': None,
703*da0073e9SAndroid Build Coastguard Worker        'linalg.det': None,
704*da0073e9SAndroid Build Coastguard Worker        'linalg.eigh': None,
705*da0073e9SAndroid Build Coastguard Worker        'linalg.eigvalsh': None,
706*da0073e9SAndroid Build Coastguard Worker        'linalg.householder_product': None,
707*da0073e9SAndroid Build Coastguard Worker        'linalg.ldl_factor': None,
708*da0073e9SAndroid Build Coastguard Worker        'linalg.ldl_factor_ex': None,
709*da0073e9SAndroid Build Coastguard Worker        'linalg.ldl_solve': None,
710*da0073e9SAndroid Build Coastguard Worker        'linalg.lstsq': None,
711*da0073e9SAndroid Build Coastguard Worker        'linalg.lstsqgrad_oriented': None,
712*da0073e9SAndroid Build Coastguard Worker        'linalg.lu': None,
713*da0073e9SAndroid Build Coastguard Worker        'linalg.lu_factor_ex': None,
714*da0073e9SAndroid Build Coastguard Worker        'linalg.lu_solve': None,
715*da0073e9SAndroid Build Coastguard Worker        'linalg.matrix_norm': [torch.float32],
716*da0073e9SAndroid Build Coastguard Worker        'linalg.norm': [torch.float32],
717*da0073e9SAndroid Build Coastguard Worker        'linalg.normsubgradients_at_zero': [torch.float32],
718*da0073e9SAndroid Build Coastguard Worker        'linalg.qr': None,
719*da0073e9SAndroid Build Coastguard Worker        'linalg.slogdet': None,
720*da0073e9SAndroid Build Coastguard Worker        'linalg.solve': None,
721*da0073e9SAndroid Build Coastguard Worker        'linalg.solve_ex': None,
722*da0073e9SAndroid Build Coastguard Worker        'linalg.svdvals': None,
723*da0073e9SAndroid Build Coastguard Worker        'linalg.tensorsolve': None,
724*da0073e9SAndroid Build Coastguard Worker        'linalg.vecdot': None,
725*da0073e9SAndroid Build Coastguard Worker        'logcumsumexp': None,
726*da0073e9SAndroid Build Coastguard Worker        'logdet': None,
727*da0073e9SAndroid Build Coastguard Worker        'lu': None,
728*da0073e9SAndroid Build Coastguard Worker        'lu_solve': None,
729*da0073e9SAndroid Build Coastguard Worker        'lu_unpack': None,
730*da0073e9SAndroid Build Coastguard Worker        'masked.median': None,
731*da0073e9SAndroid Build Coastguard Worker        'matrix_exp': None,
732*da0073e9SAndroid Build Coastguard Worker        'mode': None,
733*da0073e9SAndroid Build Coastguard Worker        'nanmedian': None,
734*da0073e9SAndroid Build Coastguard Worker        'native_dropout_backward': None,
735*da0073e9SAndroid Build Coastguard Worker        'normnuc': None,
736*da0073e9SAndroid Build Coastguard Worker        'nn.functional.fractional_max_pool2d': None,
737*da0073e9SAndroid Build Coastguard Worker        'nn.functional.fractional_max_pool3d': None,
738*da0073e9SAndroid Build Coastguard Worker        'nn.functional.adaptive_avg_pool3d': None,
739*da0073e9SAndroid Build Coastguard Worker        'nn.functional.adaptive_max_pool3d': None,
740*da0073e9SAndroid Build Coastguard Worker        'nn.functional.interpolatearea': None,
741*da0073e9SAndroid Build Coastguard Worker        'nn.functional.interpolatebicubic': None,
742*da0073e9SAndroid Build Coastguard Worker        'nn.functional.interpolatetrilinear': None,
743*da0073e9SAndroid Build Coastguard Worker        'nn.functional.max_unpool1dgrad': None,
744*da0073e9SAndroid Build Coastguard Worker        'nn.functional.max_unpool2dgrad': None,
745*da0073e9SAndroid Build Coastguard Worker        'nn.functional.max_unpool3dgrad': None,
746*da0073e9SAndroid Build Coastguard Worker        'nn.functional.avg_pool3d': None,
747*da0073e9SAndroid Build Coastguard Worker        'nn.functional.ctc_loss': None,
748*da0073e9SAndroid Build Coastguard Worker        'nn.functional.embedding_bag': None,
749*da0073e9SAndroid Build Coastguard Worker        'nn.functional.hardshrink': None,
750*da0073e9SAndroid Build Coastguard Worker        'nn.functional.max_pool3d': None,
751*da0073e9SAndroid Build Coastguard Worker        'nn.functional.max_unpool1d': None,
752*da0073e9SAndroid Build Coastguard Worker        'nn.functional.max_unpool2d': None,
753*da0073e9SAndroid Build Coastguard Worker        'nn.functional.max_unpool3d': None,
754*da0073e9SAndroid Build Coastguard Worker        'nn.functional.multi_margin_loss': None,
755*da0073e9SAndroid Build Coastguard Worker        'nn.functional.multilabel_margin_loss': None,
756*da0073e9SAndroid Build Coastguard Worker        'nn.functional.pdist': None,
757*da0073e9SAndroid Build Coastguard Worker        'nn.functional.rrelu': None,
758*da0073e9SAndroid Build Coastguard Worker        'nn.functional.norm': None,
759*da0073e9SAndroid Build Coastguard Worker        'ormqr': None,
760*da0073e9SAndroid Build Coastguard Worker        'pca_lowrank': None,
761*da0073e9SAndroid Build Coastguard Worker        'qr': None,
762*da0073e9SAndroid Build Coastguard Worker        'rsub': None,
763*da0073e9SAndroid Build Coastguard Worker        'scatter_reduceamax': None,
764*da0073e9SAndroid Build Coastguard Worker        'scatter_reduceamin': None,
765*da0073e9SAndroid Build Coastguard Worker        'scatter_reducemin': None,
766*da0073e9SAndroid Build Coastguard Worker        'scatter_reducemean': None,
767*da0073e9SAndroid Build Coastguard Worker        'scatter_reduceprod': None,
768*da0073e9SAndroid Build Coastguard Worker        'scatter_reducesum': None,
769*da0073e9SAndroid Build Coastguard Worker        'segment_reduce': None,
770*da0073e9SAndroid Build Coastguard Worker        '_segment.reduce': None,
771*da0073e9SAndroid Build Coastguard Worker        'segment.reduce': None,
772*da0073e9SAndroid Build Coastguard Worker        'segment_reduce_offsets': None,
773*da0073e9SAndroid Build Coastguard Worker        '_segment_reduce_offsets': None,
774*da0073e9SAndroid Build Coastguard Worker        '_segment_reduce_lengths': None,
775*da0073e9SAndroid Build Coastguard Worker        '_segment_reducelengths': None,
776*da0073e9SAndroid Build Coastguard Worker        '_segment_reduceoffsets': None,
777*da0073e9SAndroid Build Coastguard Worker        'sinc': None,
778*da0073e9SAndroid Build Coastguard Worker        'sparse.mm': None,
779*da0073e9SAndroid Build Coastguard Worker        'sparse.mmreduce': None,
780*da0073e9SAndroid Build Coastguard Worker        'special.airy_ai': None,
781*da0073e9SAndroid Build Coastguard Worker        'special.bessel_j0': None,
782*da0073e9SAndroid Build Coastguard Worker        'special.bessel_j1': None,
783*da0073e9SAndroid Build Coastguard Worker        'special.bessel_y0': None,
784*da0073e9SAndroid Build Coastguard Worker        'special.bessel_y1': None,
785*da0073e9SAndroid Build Coastguard Worker        'special.chebyshev_polynomial_t': None,
786*da0073e9SAndroid Build Coastguard Worker        'special.chebyshev_polynomial_u': None,
787*da0073e9SAndroid Build Coastguard Worker        'special.entr': None,
788*da0073e9SAndroid Build Coastguard Worker        'special.erfcx': None,
789*da0073e9SAndroid Build Coastguard Worker        'special.hermite_polynomial_h': None,
790*da0073e9SAndroid Build Coastguard Worker        'special.hermite_polynomial_he': None,
791*da0073e9SAndroid Build Coastguard Worker        'special.i0e': None,
792*da0073e9SAndroid Build Coastguard Worker        'special.i1': None,
793*da0073e9SAndroid Build Coastguard Worker        'special.i1e': None,
794*da0073e9SAndroid Build Coastguard Worker        'special.laguerre_polynomial_l': None,
795*da0073e9SAndroid Build Coastguard Worker        'special.log_ndtr': None,
796*da0073e9SAndroid Build Coastguard Worker        'special.modified_bessel_i0': None,
797*da0073e9SAndroid Build Coastguard Worker        'special.modified_bessel_i1': None,
798*da0073e9SAndroid Build Coastguard Worker        'special.modified_bessel_k0': None,
799*da0073e9SAndroid Build Coastguard Worker        'special.modified_bessel_k1': None,
800*da0073e9SAndroid Build Coastguard Worker        'special.ndtri': None,
801*da0073e9SAndroid Build Coastguard Worker        'special.scaled_modified_bessel_k0': None,
802*da0073e9SAndroid Build Coastguard Worker        'special.scaled_modified_bessel_k1': None,
803*da0073e9SAndroid Build Coastguard Worker        'special.spherical_bessel_j0': None,
804*da0073e9SAndroid Build Coastguard Worker        'special.xlog1py': None,
805*da0073e9SAndroid Build Coastguard Worker        'special.zeta': None,
806*da0073e9SAndroid Build Coastguard Worker        'svd_lowrank': None,
807*da0073e9SAndroid Build Coastguard Worker        'symeig': None,
808*da0073e9SAndroid Build Coastguard Worker        'take': None,
809*da0073e9SAndroid Build Coastguard Worker        'to': None,
810*da0073e9SAndroid Build Coastguard Worker        'to_sparse': None,
811*da0073e9SAndroid Build Coastguard Worker        'unique': None,
812*da0073e9SAndroid Build Coastguard Worker        'vdot': None,
813*da0073e9SAndroid Build Coastguard Worker        'segment_reduce_': None,
814*da0073e9SAndroid Build Coastguard Worker        '_upsample_bilinear2d_aa': None,
815*da0073e9SAndroid Build Coastguard Worker        'geometric' : None,
816*da0073e9SAndroid Build Coastguard Worker        'geometric_': None,
817*da0073e9SAndroid Build Coastguard Worker        'log_normal_': None,
818*da0073e9SAndroid Build Coastguard Worker        'log_normal': None,
819*da0073e9SAndroid Build Coastguard Worker        'cdouble': None,
820*da0073e9SAndroid Build Coastguard Worker        'double': None,
821*da0073e9SAndroid Build Coastguard Worker        'nn.functional.softminwith_dtype': None,
822*da0073e9SAndroid Build Coastguard Worker        'log_softmaxwith_dtype': None,
823*da0073e9SAndroid Build Coastguard Worker        'softmaxwith_dtype': None,
824*da0073e9SAndroid Build Coastguard Worker        'float_power': None,
825*da0073e9SAndroid Build Coastguard Worker        'full_like': None,
826*da0073e9SAndroid Build Coastguard Worker        'linalg.matrix_rankhermitian': None,
827*da0073e9SAndroid Build Coastguard Worker        'linalg.pinvhermitian': None,
828*da0073e9SAndroid Build Coastguard Worker        'nonzero_static': None,
829*da0073e9SAndroid Build Coastguard Worker
830*da0073e9SAndroid Build Coastguard Worker        # MPS: input sizes must be divisible by output sizes
831*da0073e9SAndroid Build Coastguard Worker        'nn.functional.adaptive_avg_pool1d': None,
832*da0073e9SAndroid Build Coastguard Worker        'nn.functional.adaptive_avg_pool2d': None,
833*da0073e9SAndroid Build Coastguard Worker
834*da0073e9SAndroid Build Coastguard Worker        # Unsupported dtypes
835*da0073e9SAndroid Build Coastguard Worker        # bmm is not supported for integral types
836*da0073e9SAndroid Build Coastguard Worker        'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
837*da0073e9SAndroid Build Coastguard Worker        'ones_like': None,
838*da0073e9SAndroid Build Coastguard Worker        'zeros_like': None,
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Worker        # Convolution for integral types is not supported on MPS
841*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv1d': [torch.int64],
842*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv2d': [torch.int64],
843*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv3d': [torch.int64],
844*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv_transpose1d': [torch.int64],
845*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv_transpose2d': [torch.int64],
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker        # Unsupported dtypes
848*da0073e9SAndroid Build Coastguard Worker        'dot': [torch.int64],
849*da0073e9SAndroid Build Coastguard Worker        'histc': [torch.float16],
850*da0073e9SAndroid Build Coastguard Worker        'index_add': [torch.int64],
851*da0073e9SAndroid Build Coastguard Worker        'log1p': [torch.int64],
852*da0073e9SAndroid Build Coastguard Worker        'sigmoid': [torch.int64],
853*da0073e9SAndroid Build Coastguard Worker        'atan2': [torch.int64],
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker        # GEMM on MPS is not supported for integral types
856*da0073e9SAndroid Build Coastguard Worker        'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
857*da0073e9SAndroid Build Coastguard Worker        '__rmatmul__': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
858*da0073e9SAndroid Build Coastguard Worker        'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
859*da0073e9SAndroid Build Coastguard Worker        'addbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
860*da0073e9SAndroid Build Coastguard Worker        'addmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
861*da0073e9SAndroid Build Coastguard Worker        'addmv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
862*da0073e9SAndroid Build Coastguard Worker        'baddbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
863*da0073e9SAndroid Build Coastguard Worker        'mm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
864*da0073e9SAndroid Build Coastguard Worker        'bmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
865*da0073e9SAndroid Build Coastguard Worker        'einsum': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
866*da0073e9SAndroid Build Coastguard Worker        'inner': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
867*da0073e9SAndroid Build Coastguard Worker        'linalg.multi_dot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
868*da0073e9SAndroid Build Coastguard Worker        'matmul': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
869*da0073e9SAndroid Build Coastguard Worker        'mat': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
870*da0073e9SAndroid Build Coastguard Worker        'mv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
871*da0073e9SAndroid Build Coastguard Worker        'tensordot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
872*da0073e9SAndroid Build Coastguard Worker        'unravel_index': [torch.int32, torch.int64],
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker        # new_zeros/new_ones: Cannot convert a MPS Tensor to float64 dtype as
875*da0073e9SAndroid Build Coastguard Worker        # the MPS framework doesn't support float64
876*da0073e9SAndroid Build Coastguard Worker        'new_zeros': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
877*da0073e9SAndroid Build Coastguard Worker        'new_ones': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
878*da0073e9SAndroid Build Coastguard Worker        'new_full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
879*da0073e9SAndroid Build Coastguard Worker        # returned output on CPU is float64
880*da0073e9SAndroid Build Coastguard Worker        'bincount': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker        # trunc_tensor not working properly for float16
883*da0073e9SAndroid Build Coastguard Worker        'divtrunc_rounding': [torch.float16],
884*da0073e9SAndroid Build Coastguard Worker        'fmod': [torch.float16],
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Worker        # round not working properly for float16
887*da0073e9SAndroid Build Coastguard Worker        'round': [torch.float16],
888*da0073e9SAndroid Build Coastguard Worker
889*da0073e9SAndroid Build Coastguard Worker        # atomic operations not supported
890*da0073e9SAndroid Build Coastguard Worker        '_unsafe_masked_index_put_accumulate': [torch.bool, torch.int8, torch.uint8, torch.float16, torch.int16, torch.int64],
891*da0073e9SAndroid Build Coastguard Worker    }
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker    if product_version < 14.0:
894*da0073e9SAndroid Build Coastguard Worker        # FFT and BFloat16 support was added in MacOS 14
895*da0073e9SAndroid Build Coastguard Worker        UNIMPLEMENTED_XFAILLIST.update({
896*da0073e9SAndroid Build Coastguard Worker            'bfloat16': None,
897*da0073e9SAndroid Build Coastguard Worker            'fft.fft': None,
898*da0073e9SAndroid Build Coastguard Worker            'fft.fft2': None,
899*da0073e9SAndroid Build Coastguard Worker            'fft.fftn': None,
900*da0073e9SAndroid Build Coastguard Worker            'fft.hfft': None,
901*da0073e9SAndroid Build Coastguard Worker            'fft.hfft2': None,
902*da0073e9SAndroid Build Coastguard Worker            'fft.hfftn': None,
903*da0073e9SAndroid Build Coastguard Worker            'fft.ifft': None,
904*da0073e9SAndroid Build Coastguard Worker            'fft.ifft2': None,
905*da0073e9SAndroid Build Coastguard Worker            'fft.ifftn': None,
906*da0073e9SAndroid Build Coastguard Worker            'fft.ihfft': None,
907*da0073e9SAndroid Build Coastguard Worker            'fft.ihfft2': None,
908*da0073e9SAndroid Build Coastguard Worker            'fft.ihfftn': None,
909*da0073e9SAndroid Build Coastguard Worker            'fft.irfft': None,
910*da0073e9SAndroid Build Coastguard Worker            'fft.irfft2': None,
911*da0073e9SAndroid Build Coastguard Worker            'fft.irfftn': None,
912*da0073e9SAndroid Build Coastguard Worker            'fft.rfft': None,
913*da0073e9SAndroid Build Coastguard Worker            'fft.rfft2': None,
914*da0073e9SAndroid Build Coastguard Worker            'fft.rfftn': None,
915*da0073e9SAndroid Build Coastguard Worker            'stft': None,
916*da0073e9SAndroid Build Coastguard Worker            # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers,
917*da0073e9SAndroid Build Coastguard Worker            # not reproducible in later OS. Added assert to op if used in < 14.0
918*da0073e9SAndroid Build Coastguard Worker            'isin': [torch.int64, torch.int32, torch.int16, torch.uint8, torch.int8],
919*da0073e9SAndroid Build Coastguard Worker            'nn.functional.max_pool2d': [torch.uint8],
920*da0073e9SAndroid Build Coastguard Worker        })
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker    if product_version < 15.0:
923*da0073e9SAndroid Build Coastguard Worker        UNIMPLEMENTED_XFAILLIST.update({
924*da0073e9SAndroid Build Coastguard Worker            'quantile': None,
925*da0073e9SAndroid Build Coastguard Worker            'nanquantile': None,
926*da0073e9SAndroid Build Coastguard Worker        })
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker    UNDEFINED_XFAILLIST = {
929*da0073e9SAndroid Build Coastguard Worker        # Top 60 operators
930*da0073e9SAndroid Build Coastguard Worker        # topk fails with duplicate indices
931*da0073e9SAndroid Build Coastguard Worker        'topk': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker        # Failures due to random output that they generate using
934*da0073e9SAndroid Build Coastguard Worker        # Philox engine causing mismatch with CPU results
935*da0073e9SAndroid Build Coastguard Worker        'multinomial': [torch.float16, torch.float32],  # random results
936*da0073e9SAndroid Build Coastguard Worker        'uniform': [torch.float16, torch.float32],
937*da0073e9SAndroid Build Coastguard Worker        'rand_like': [torch.float16, torch.float32],
938*da0073e9SAndroid Build Coastguard Worker        'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
939*da0073e9SAndroid Build Coastguard Worker        'randn_like': [torch.float16, torch.float32],
940*da0073e9SAndroid Build Coastguard Worker        'bernoulli': [torch.float16, torch.float32],
941*da0073e9SAndroid Build Coastguard Worker        'exponential': [torch.float16, torch.float32],
942*da0073e9SAndroid Build Coastguard Worker        'nn.functional.feature_alpha_dropoutwith_train': [torch.float16, torch.float32],
943*da0073e9SAndroid Build Coastguard Worker        'normal': [torch.float16, torch.float32, torch.float16, torch.float32],
944*da0073e9SAndroid Build Coastguard Worker        'normalin_place': [torch.float16, torch.float32],
945*da0073e9SAndroid Build Coastguard Worker        'normalnumber_mean': [torch.float16, torch.float32],
946*da0073e9SAndroid Build Coastguard Worker        'nn.functional.alpha_dropout': [torch.float16, torch.float32],
947*da0073e9SAndroid Build Coastguard Worker        'nn.functional.dropout': [torch.float16, torch.float32],
948*da0073e9SAndroid Build Coastguard Worker        'nn.functional.dropout2d': [torch.float16, torch.float32],
949*da0073e9SAndroid Build Coastguard Worker        'nn.functional.dropout3d': [torch.float16, torch.float32],
950*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/111479
951*da0073e9SAndroid Build Coastguard Worker        'nn.functional.multi_head_attention_forward': [torch.float32, torch.float16],
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker        # duplicate indices are used in the testcase - undefined behaviour
954*da0073e9SAndroid Build Coastguard Worker        'index_put': None,
955*da0073e9SAndroid Build Coastguard Worker        # zero to negative integer powers are undefined
956*da0073e9SAndroid Build Coastguard Worker        '__rpow__': [torch.int8, torch.int16, torch.int32, torch.int64],
957*da0073e9SAndroid Build Coastguard Worker        'resize_': [torch.float16, torch.float32],
958*da0073e9SAndroid Build Coastguard Worker        'resize_as_': [torch.float16, torch.float32],
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker        # CPU Errors:
961*da0073e9SAndroid Build Coastguard Worker        'addr': [torch.bool, torch.int16, torch.int32,
962*da0073e9SAndroid Build Coastguard Worker                 torch.int64, torch.uint8, torch.int8],  # "addmv_impl_cpu" not implemented for 'Half'
963*da0073e9SAndroid Build Coastguard Worker        'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
964*da0073e9SAndroid Build Coastguard Worker                                    torch.int32, torch.int64, torch.uint8, torch.int8],  # cpu result off, showing random values
965*da0073e9SAndroid Build Coastguard Worker        'as_strided_partial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
966*da0073e9SAndroid Build Coastguard Worker                                     torch.int32, torch.int64, torch.uint8, torch.int8],  # cpu result off, showing random values
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker        # random results
969*da0073e9SAndroid Build Coastguard Worker        # mps vs cpu:
970*da0073e9SAndroid Build Coastguard Worker        # Mismatched elements: 40 / 96 (41.7%)
971*da0073e9SAndroid Build Coastguard Worker        # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
972*da0073e9SAndroid Build Coastguard Worker        # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
973*da0073e9SAndroid Build Coastguard Worker        # cuda(2.0.0.dev20230301+cu117) vs cpu:
974*da0073e9SAndroid Build Coastguard Worker        # Mismatched elements: 56 / 96 (58.3%)
975*da0073e9SAndroid Build Coastguard Worker        # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
976*da0073e9SAndroid Build Coastguard Worker        # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
977*da0073e9SAndroid Build Coastguard Worker        'nn.functional.scaled_dot_product_attention': [torch.float32, torch.float16],
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        # float output for float16 input on MPS
980*da0073e9SAndroid Build Coastguard Worker        'logit': [torch.float16],
981*da0073e9SAndroid Build Coastguard Worker    }
982*da0073e9SAndroid Build Coastguard Worker
983*da0073e9SAndroid Build Coastguard Worker    ON_MPS_XFAILLIST = {
984*da0073e9SAndroid Build Coastguard Worker        # Failures due to lack of implementation of downstream functions on MPS backend
985*da0073e9SAndroid Build Coastguard Worker        # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
986*da0073e9SAndroid Build Coastguard Worker        'linalg.matrix_rank': None,
987*da0073e9SAndroid Build Coastguard Worker    }
988*da0073e9SAndroid Build Coastguard Worker
989*da0073e9SAndroid Build Coastguard Worker    EMPTY_OPS_SKIPLIST = {
990*da0073e9SAndroid Build Coastguard Worker        # Fill tensors with uninitialized data, causing mismatch with CPU.
991*da0073e9SAndroid Build Coastguard Worker        # They occasionally match, thus skipping them.
992*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/100175
993*da0073e9SAndroid Build Coastguard Worker        'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
994*da0073e9SAndroid Build Coastguard Worker        'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16,
995*da0073e9SAndroid Build Coastguard Worker                              torch.int32, torch.int64, torch.uint8, torch.int8],
996*da0073e9SAndroid Build Coastguard Worker        'empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
997*da0073e9SAndroid Build Coastguard Worker        # CPU: empty is returning all 0's and there is a mismatch with MPS
998*da0073e9SAndroid Build Coastguard Worker        # allocation (MacOS 13). According to
999*da0073e9SAndroid Build Coastguard Worker        # https://pytorch.org/docs/2.0/generated/torch.empty.html
1000*da0073e9SAndroid Build Coastguard Worker        'empty': [torch.bool, torch.float16, torch.float32, torch.int16,
1001*da0073e9SAndroid Build Coastguard Worker                  torch.int32, torch.int64, torch.uint8, torch.int8],
1002*da0073e9SAndroid Build Coastguard Worker        'empty_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
1003*da0073e9SAndroid Build Coastguard Worker        'empty_permuted': [torch.bool, torch.float16, torch.float32, torch.int16,
1004*da0073e9SAndroid Build Coastguard Worker                           torch.int32, torch.int64, torch.uint8, torch.int8],
1005*da0073e9SAndroid Build Coastguard Worker    }
1006*da0073e9SAndroid Build Coastguard Worker
1007*da0073e9SAndroid Build Coastguard Worker    SKIPLIST = {
1008*da0073e9SAndroid Build Coastguard Worker        # Unsupported
1009*da0073e9SAndroid Build Coastguard Worker        # input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible
1010*da0073e9SAndroid Build Coastguard Worker        'nn.functional.avg_pool2d': [torch.float16],
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker        # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16
1013*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv3d': None,
1014*da0073e9SAndroid Build Coastguard Worker    }
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker    def addDecorator(op, d) -> None:
1017*da0073e9SAndroid Build Coastguard Worker        op.decorators = list(op.decorators) if op.decorators is not None else []
1018*da0073e9SAndroid Build Coastguard Worker        op.decorators.append(d)
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker    for op in ops:
1021*da0073e9SAndroid Build Coastguard Worker        key = op.name + op.variant_test_name
1022*da0073e9SAndroid Build Coastguard Worker        if key in EMPTY_OPS_SKIPLIST:
1023*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
1024*da0073e9SAndroid Build Coastguard Worker                         unittest.skip("Skipping empty ops."),
1025*da0073e9SAndroid Build Coastguard Worker                         dtypes=EMPTY_OPS_SKIPLIST[key]))
1026*da0073e9SAndroid Build Coastguard Worker        if key in SKIPLIST:
1027*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key]))
1028*da0073e9SAndroid Build Coastguard Worker        for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST, ON_MPS_XFAILLIST]:
1029*da0073e9SAndroid Build Coastguard Worker            if key in xfaillist:
1030*da0073e9SAndroid Build Coastguard Worker                addDecorator(op, DecorateInfo(
1031*da0073e9SAndroid Build Coastguard Worker                             unittest.expectedFailure,
1032*da0073e9SAndroid Build Coastguard Worker                             dtypes=xfaillist[key]))
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker        if key in MACOS_BEFORE_14_4_XFAILLIST and (product_version < 14.4):
1035*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
1036*da0073e9SAndroid Build Coastguard Worker                         unittest.expectedFailure,
1037*da0073e9SAndroid Build Coastguard Worker                         dtypes=MACOS_BEFORE_14_4_XFAILLIST[key]))
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker        if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
1040*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
1041*da0073e9SAndroid Build Coastguard Worker                         unittest.expectedFailure,
1042*da0073e9SAndroid Build Coastguard Worker                         dtypes=MACOS_BEFORE_13_3_XFAILLIST[key]))
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker        if key in MACOS_AFTER_13_1_XFAILLIST and torch.backends.mps.is_macos13_or_newer(2):
1045*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
1046*da0073e9SAndroid Build Coastguard Worker                         unittest.expectedFailure,
1047*da0073e9SAndroid Build Coastguard Worker                         dtypes=MACOS_AFTER_13_1_XFAILLIST[key]))
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker        if key in MACOS_13_3_XFAILLIST and (product_version >= 13.3):
1050*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
1051*da0073e9SAndroid Build Coastguard Worker                         unittest.expectedFailure,
1052*da0073e9SAndroid Build Coastguard Worker                         dtypes=MACOS_13_3_XFAILLIST[key]))
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker        if key in MACOS_12_3_XFAILLIST and (not torch.backends.mps.is_macos13_or_newer()):
1055*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(
1056*da0073e9SAndroid Build Coastguard Worker                         unittest.expectedFailure,
1057*da0073e9SAndroid Build Coastguard Worker                         dtypes=MACOS_12_3_XFAILLIST[key]))
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker        # If ops is not supported for complex types, expect it to fail
1060*da0073e9SAndroid Build Coastguard Worker        if key not in SUPPORTED_COMPLEX_OPS and (key not in AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS or product_version < 14.0):
1061*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(unittest.expectedFailure, dtypes=[torch.complex32, torch.complex64]))
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Worker        yield op
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Workerdef mps_ops_error_inputs_modifier(ops):
1066*da0073e9SAndroid Build Coastguard Worker    # Error input samples do not take a dtype argument.
1067*da0073e9SAndroid Build Coastguard Worker    XFAILLIST = {
1068*da0073e9SAndroid Build Coastguard Worker        # Exceptions are not raised
1069*da0073e9SAndroid Build Coastguard Worker        '__rmod__',
1070*da0073e9SAndroid Build Coastguard Worker        '__rsub__',
1071*da0073e9SAndroid Build Coastguard Worker        '__rpow__',
1072*da0073e9SAndroid Build Coastguard Worker        'bernoulli',
1073*da0073e9SAndroid Build Coastguard Worker        'clamp_max',
1074*da0073e9SAndroid Build Coastguard Worker        'clamp_min',
1075*da0073e9SAndroid Build Coastguard Worker        'masked_scatter',
1076*da0073e9SAndroid Build Coastguard Worker
1077*da0073e9SAndroid Build Coastguard Worker        # unsupported float64 dtype
1078*da0073e9SAndroid Build Coastguard Worker        'cat',
1079*da0073e9SAndroid Build Coastguard Worker        'complex',
1080*da0073e9SAndroid Build Coastguard Worker        'multinomial',
1081*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv1d',
1082*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv2d',
1083*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv3d',
1084*da0073e9SAndroid Build Coastguard Worker        'gather',
1085*da0073e9SAndroid Build Coastguard Worker        'scatter',
1086*da0073e9SAndroid Build Coastguard Worker        'scatter_add',
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker        # unsupported complex dtypes
1089*da0073e9SAndroid Build Coastguard Worker        'masked_fill',
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Worker        # MPS does not support tensor dimensions > 16
1092*da0073e9SAndroid Build Coastguard Worker        'amax',
1093*da0073e9SAndroid Build Coastguard Worker        'amin',
1094*da0073e9SAndroid Build Coastguard Worker        'aminmax',
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker        # memory overlapping checks
1097*da0073e9SAndroid Build Coastguard Worker        'index_select',
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Worker        # unimplemented
1100*da0073e9SAndroid Build Coastguard Worker        'logcumsumexp',
1101*da0073e9SAndroid Build Coastguard Worker    }
1102*da0073e9SAndroid Build Coastguard Worker
1103*da0073e9SAndroid Build Coastguard Worker    def addDecorator(op, d) -> None:
1104*da0073e9SAndroid Build Coastguard Worker        op.decorators = list(op.decorators) if op.decorators is not None else []
1105*da0073e9SAndroid Build Coastguard Worker        op.decorators.append(d)
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker    for op in ops:
1108*da0073e9SAndroid Build Coastguard Worker        if op.error_inputs_func is None:
1109*da0073e9SAndroid Build Coastguard Worker            continue
1110*da0073e9SAndroid Build Coastguard Worker        key = op.name + op.variant_test_name
1111*da0073e9SAndroid Build Coastguard Worker        if key in XFAILLIST:
1112*da0073e9SAndroid Build Coastguard Worker            addDecorator(op, DecorateInfo(unittest.expectedFailure))
1113*da0073e9SAndroid Build Coastguard Worker        yield op
1114*da0073e9SAndroid Build Coastguard Worker
1115*da0073e9SAndroid Build Coastguard Worker# Same logic as test_cuda.py
1116*da0073e9SAndroid Build Coastguard Workerif not torch.backends.mps.is_available():
1117*da0073e9SAndroid Build Coastguard Worker    print('MPS not available, skipping tests', file=sys.stderr)
1118*da0073e9SAndroid Build Coastguard Worker    TestCase = NoTest  # noqa: F811
1119*da0073e9SAndroid Build Coastguard Worker    NNTestCase = NoTest  # noqa: F811
1120*da0073e9SAndroid Build Coastguard Worker
1121*da0073e9SAndroid Build Coastguard Workerproduct_version = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
1122*da0073e9SAndroid Build Coastguard Workertotal_memory = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"]))
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Worker# Determine whether to enable MPS memory leak check (uses same code as CUDA).
1125*da0073e9SAndroid Build Coastguard WorkerTEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
1126*da0073e9SAndroid Build Coastguard Worker
1127*da0073e9SAndroid Build Coastguard Workerdef skipMPSMemoryLeakCheckIf(condition):
1128*da0073e9SAndroid Build Coastguard Worker    def dec(fn):
1129*da0073e9SAndroid Build Coastguard Worker        if getattr(fn, '_do_mps_memory_leak_check', True):
1130*da0073e9SAndroid Build Coastguard Worker            fn._do_mps_memory_leak_check = not condition
1131*da0073e9SAndroid Build Coastguard Worker        return fn
1132*da0073e9SAndroid Build Coastguard Worker    return dec
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Workerclass MpsMemoryLeakCheck:
1135*da0073e9SAndroid Build Coastguard Worker    def __init__(self, testcase, name=None):
1136*da0073e9SAndroid Build Coastguard Worker        self.name = testcase.id() if name is None else name
1137*da0073e9SAndroid Build Coastguard Worker        self.testcase = testcase
1138*da0073e9SAndroid Build Coastguard Worker
1139*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
1140*da0073e9SAndroid Build Coastguard Worker        # Performs a gc if required (required if any memory is held)
1141*da0073e9SAndroid Build Coastguard Worker        caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1142*da0073e9SAndroid Build Coastguard Worker        if caching_allocator_mem_allocated > 0:
1143*da0073e9SAndroid Build Coastguard Worker            gc.collect()
1144*da0073e9SAndroid Build Coastguard Worker            torch.mps.empty_cache()
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker        # Acquires caching allocator and driver statistics before the test is run
1147*da0073e9SAndroid Build Coastguard Worker        self.caching_allocator_before = torch.mps.current_allocated_memory()
1148*da0073e9SAndroid Build Coastguard Worker        self.driver_before = torch.mps.driver_allocated_memory()
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, exec_type, exec_value, traceback):
1151*da0073e9SAndroid Build Coastguard Worker        # Don't check for leaks if an exception was thrown
1152*da0073e9SAndroid Build Coastguard Worker        if exec_type is not None:
1153*da0073e9SAndroid Build Coastguard Worker            return
1154*da0073e9SAndroid Build Coastguard Worker        # Compares caching allocator before/after statistics
1155*da0073e9SAndroid Build Coastguard Worker        # An increase in allocated memory is a discrepancy indicating a possible memory leak
1156*da0073e9SAndroid Build Coastguard Worker        discrepancy_detected = False
1157*da0073e9SAndroid Build Coastguard Worker        caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1158*da0073e9SAndroid Build Coastguard Worker        if caching_allocator_mem_allocated > self.caching_allocator_before:
1159*da0073e9SAndroid Build Coastguard Worker            discrepancy_detected = True
1160*da0073e9SAndroid Build Coastguard Worker
1161*da0073e9SAndroid Build Coastguard Worker        # Short-circuits if no discrepancy detected
1162*da0073e9SAndroid Build Coastguard Worker        if not discrepancy_detected:
1163*da0073e9SAndroid Build Coastguard Worker            return
1164*da0073e9SAndroid Build Coastguard Worker        # Validates the discrepancy persists after garbage collection and
1165*da0073e9SAndroid Build Coastguard Worker        # is confirmed by the driver API
1166*da0073e9SAndroid Build Coastguard Worker        gc.collect()
1167*da0073e9SAndroid Build Coastguard Worker        torch.mps.empty_cache()
1168*da0073e9SAndroid Build Coastguard Worker
1169*da0073e9SAndroid Build Coastguard Worker        discrepancy_detected = True
1170*da0073e9SAndroid Build Coastguard Worker        # Query memory multiple items to ensure leak was not transient
1171*da0073e9SAndroid Build Coastguard Worker        for n in range(3):
1172*da0073e9SAndroid Build Coastguard Worker            caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1173*da0073e9SAndroid Build Coastguard Worker            driver_mem_allocated = torch.mps.driver_allocated_memory()
1174*da0073e9SAndroid Build Coastguard Worker
1175*da0073e9SAndroid Build Coastguard Worker            caching_allocator_discrepancy = False
1176*da0073e9SAndroid Build Coastguard Worker            driver_discrepancy = False
1177*da0073e9SAndroid Build Coastguard Worker
1178*da0073e9SAndroid Build Coastguard Worker            if caching_allocator_mem_allocated > self.caching_allocator_before:
1179*da0073e9SAndroid Build Coastguard Worker                caching_allocator_discrepancy = True
1180*da0073e9SAndroid Build Coastguard Worker
1181*da0073e9SAndroid Build Coastguard Worker            if driver_mem_allocated > self.driver_before:
1182*da0073e9SAndroid Build Coastguard Worker                driver_discrepancy = True
1183*da0073e9SAndroid Build Coastguard Worker
1184*da0073e9SAndroid Build Coastguard Worker            if not (caching_allocator_discrepancy or driver_discrepancy):
1185*da0073e9SAndroid Build Coastguard Worker                # Leak was false positive, exit loop
1186*da0073e9SAndroid Build Coastguard Worker                discrepancy_detected = False
1187*da0073e9SAndroid Build Coastguard Worker                break
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker        if caching_allocator_discrepancy and not driver_discrepancy:
1190*da0073e9SAndroid Build Coastguard Worker            # Just raises a warning if the leak is not validated by the driver API
1191*da0073e9SAndroid Build Coastguard Worker            msg = ("MPS caching allocator reports a memory leak not "
1192*da0073e9SAndroid Build Coastguard Worker                   f"verified by the driver API in {self.name}! "
1193*da0073e9SAndroid Build Coastguard Worker                   f"Caching allocator allocated memory was {self.caching_allocator_before} "
1194*da0073e9SAndroid Build Coastguard Worker                   f"and is now reported as {caching_allocator_mem_allocated}. "
1195*da0073e9SAndroid Build Coastguard Worker                   f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
1196*da0073e9SAndroid Build Coastguard Worker            warnings.warn(msg)
1197*da0073e9SAndroid Build Coastguard Worker        elif caching_allocator_discrepancy and driver_discrepancy:
1198*da0073e9SAndroid Build Coastguard Worker            # A caching allocator discrepancy validated by the driver API is a failure
1199*da0073e9SAndroid Build Coastguard Worker            msg = (f"MPS driver API confirmed a leak in {self.name}! "
1200*da0073e9SAndroid Build Coastguard Worker                   f"Caching allocator allocated memory was {self.caching_allocator_before} "
1201*da0073e9SAndroid Build Coastguard Worker                   f"and is now reported as {caching_allocator_mem_allocated}. "
1202*da0073e9SAndroid Build Coastguard Worker                   f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(msg)
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Workerclass TestAutocastMPS(TestCase):
1207*da0073e9SAndroid Build Coastguard Worker
1208*da0073e9SAndroid Build Coastguard Worker    def test_matmul_autocast(self):
1209*da0073e9SAndroid Build Coastguard Worker        autocast_tensor_A = torch.rand((8, 8), device="mps")
1210*da0073e9SAndroid Build Coastguard Worker        autocast_tensor_B = torch.rand((8, 8), device="mps")
1211*da0073e9SAndroid Build Coastguard Worker        tensor_A = autocast_tensor_A.clone().detach()
1212*da0073e9SAndroid Build Coastguard Worker        tensor_B = autocast_tensor_B.clone().detach()
1213*da0073e9SAndroid Build Coastguard Worker        autocast_output_tensor = torch.empty(8, 8)
1214*da0073e9SAndroid Build Coastguard Worker        output_tensor = autocast_output_tensor.clone().detach()
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker        with torch.autocast(device_type="mps"):
1217*da0073e9SAndroid Build Coastguard Worker            autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B)
1218*da0073e9SAndroid Build Coastguard Worker            autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor)
1219*da0073e9SAndroid Build Coastguard Worker
1220*da0073e9SAndroid Build Coastguard Worker        output_tensor = torch.mm(tensor_A, tensor_B)
1221*da0073e9SAndroid Build Coastguard Worker        output_tensor = torch.mm(tensor_A, output_tensor)
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16")
1224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(autocast_output_tensor,
1225*da0073e9SAndroid Build Coastguard Worker                         output_tensor.to(torch.float16),
1226*da0073e9SAndroid Build Coastguard Worker                         f"Autocast & non-autocast tensors did not match, \
1227*da0073e9SAndroid Build Coastguard Worker                         got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}")
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker# Expand TestCase class with Memory Leak Detection on MPS device
1230*da0073e9SAndroid Build Coastguard Workerclass TestCaseMPS(TestCase):
1231*da0073e9SAndroid Build Coastguard Worker    _do_mps_memory_leak_check = True
1232*da0073e9SAndroid Build Coastguard Worker
1233*da0073e9SAndroid Build Coastguard Worker    def __init__(self, method_name='runTest'):
1234*da0073e9SAndroid Build Coastguard Worker        super().__init__(method_name)
1235*da0073e9SAndroid Build Coastguard Worker        test_method = getattr(self, method_name, None)
1236*da0073e9SAndroid Build Coastguard Worker        if test_method is not None:
1237*da0073e9SAndroid Build Coastguard Worker            # Wraps the tested method if we should do MPS memory check.
1238*da0073e9SAndroid Build Coastguard Worker            if TEST_MPS_MEM_LEAK_CHECK:
1239*da0073e9SAndroid Build Coastguard Worker                if self._do_mps_memory_leak_check:
1240*da0073e9SAndroid Build Coastguard Worker                    self.wrap_with_mps_policy(method_name, self.assertLeaksNoMpsTensors)
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker    def assertLeaksNoMpsTensors(self, name=None):
1243*da0073e9SAndroid Build Coastguard Worker        name = self.id() if name is None else name
1244*da0073e9SAndroid Build Coastguard Worker        return MpsMemoryLeakCheck(self, name)
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker    def wrap_with_mps_policy(self, method_name, policy):
1247*da0073e9SAndroid Build Coastguard Worker        test_method = getattr(self, method_name)
1248*da0073e9SAndroid Build Coastguard Worker        setattr(self, method_name, super().wrap_method_with_policy(test_method, policy))
1249*da0073e9SAndroid Build Coastguard Worker
1250*da0073e9SAndroid Build Coastguard Worker    # checks for leaks even if TEST_MPS_MEM_LEAK_CHECK is 0
1251*da0073e9SAndroid Build Coastguard Worker    def wrap_with_mps_memory_check(self, method):
1252*da0073e9SAndroid Build Coastguard Worker        return super().wrap_method_with_policy(method, self.assertLeaksNoMpsTensors)
1253*da0073e9SAndroid Build Coastguard Worker
1254*da0073e9SAndroid Build Coastguard Workerclass TestMemoryLeak(TestCaseMPS):
1255*da0073e9SAndroid Build Coastguard Worker    def test_mps_memory_leak_detection(self):
1256*da0073e9SAndroid Build Coastguard Worker        l = []
1257*da0073e9SAndroid Build Coastguard Worker
1258*da0073e9SAndroid Build Coastguard Worker        @self.wrap_with_mps_memory_check
1259*da0073e9SAndroid Build Coastguard Worker        def no_leak():
1260*da0073e9SAndroid Build Coastguard Worker            pass
1261*da0073e9SAndroid Build Coastguard Worker
1262*da0073e9SAndroid Build Coastguard Worker        # Trigger an intentional memory leak
1263*da0073e9SAndroid Build Coastguard Worker        @self.wrap_with_mps_memory_check
1264*da0073e9SAndroid Build Coastguard Worker        def leak_gpu0():
1265*da0073e9SAndroid Build Coastguard Worker            # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
1266*da0073e9SAndroid Build Coastguard Worker            l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps")))
1267*da0073e9SAndroid Build Coastguard Worker
1268*da0073e9SAndroid Build Coastguard Worker        no_leak()
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker        # check if a runtime error for memory leak was emitted which would
1271*da0073e9SAndroid Build Coastguard Worker        # confirm whether memory leak detection worked successfully or not.
1272*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
1273*da0073e9SAndroid Build Coastguard Worker            leak_gpu0()
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker    def test_copy_cast_no_leak(self):
1276*da0073e9SAndroid Build Coastguard Worker
1277*da0073e9SAndroid Build Coastguard Worker        def step(x):
1278*da0073e9SAndroid Build Coastguard Worker            x = x.to(device='cpu', dtype=torch.float32)
1279*da0073e9SAndroid Build Coastguard Worker            x = x.to(device='mps', dtype=torch.float16)
1280*da0073e9SAndroid Build Coastguard Worker
1281*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(128, 128, device='mps', dtype=torch.float16)
1282*da0073e9SAndroid Build Coastguard Worker        # Warm up / prebuild MPS shaders (otherwise check fails on 13.2)
1283*da0073e9SAndroid Build Coastguard Worker        step(a)
1284*da0073e9SAndroid Build Coastguard Worker        torch.mps.empty_cache()
1285*da0073e9SAndroid Build Coastguard Worker        driver_before = torch.mps.driver_allocated_memory()
1286*da0073e9SAndroid Build Coastguard Worker        step(a)
1287*da0073e9SAndroid Build Coastguard Worker        torch.mps.empty_cache()
1288*da0073e9SAndroid Build Coastguard Worker        driver_after = torch.mps.driver_allocated_memory()
1289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(driver_before, driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory")
1290*da0073e9SAndroid Build Coastguard Worker
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Workerclass TestPixelShuffle(TestCaseMPS):
1293*da0073e9SAndroid Build Coastguard Worker    def test_pixel_shuffle_unshuffle(self):
1294*da0073e9SAndroid Build Coastguard Worker        def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
1295*da0073e9SAndroid Build Coastguard Worker                                                 upscale_factor=None, is_contiguous=True):
1296*da0073e9SAndroid Build Coastguard Worker
1297*da0073e9SAndroid Build Coastguard Worker            def generate_input():
1298*da0073e9SAndroid Build Coastguard Worker                # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
1299*da0073e9SAndroid Build Coastguard Worker                channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
1300*da0073e9SAndroid Build Coastguard Worker                height = random.randint(5, 10)
1301*da0073e9SAndroid Build Coastguard Worker                width = random.randint(5, 10)
1302*da0073e9SAndroid Build Coastguard Worker
1303*da0073e9SAndroid Build Coastguard Worker                if num_input_dims == 1:
1304*da0073e9SAndroid Build Coastguard Worker                    input = torch.rand(channels, requires_grad=True, device='mps')
1305*da0073e9SAndroid Build Coastguard Worker                    assert is_contiguous
1306*da0073e9SAndroid Build Coastguard Worker                elif num_input_dims == 2:
1307*da0073e9SAndroid Build Coastguard Worker                    input = torch.rand(width, height, requires_grad=True, device='mps').T
1308*da0073e9SAndroid Build Coastguard Worker                    if is_contiguous:
1309*da0073e9SAndroid Build Coastguard Worker                        input = input.contiguous()
1310*da0073e9SAndroid Build Coastguard Worker                else:
1311*da0073e9SAndroid Build Coastguard Worker                    batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
1312*da0073e9SAndroid Build Coastguard Worker                    input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps')
1313*da0073e9SAndroid Build Coastguard Worker                    input = input.transpose(-1, -2)
1314*da0073e9SAndroid Build Coastguard Worker                    if is_contiguous:
1315*da0073e9SAndroid Build Coastguard Worker                        input = input.contiguous()
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker                if not is_contiguous and len(input.reshape(-1)) > 0:
1318*da0073e9SAndroid Build Coastguard Worker                    assert not input.is_contiguous()
1319*da0073e9SAndroid Build Coastguard Worker
1320*da0073e9SAndroid Build Coastguard Worker                input = input.detach().clone()
1321*da0073e9SAndroid Build Coastguard Worker                input.requires_grad = True
1322*da0073e9SAndroid Build Coastguard Worker                return input
1323*da0073e9SAndroid Build Coastguard Worker
1324*da0073e9SAndroid Build Coastguard Worker            # Function to imperatively ensure pixels are shuffled to the correct locations.
1325*da0073e9SAndroid Build Coastguard Worker            # Used to validate the batch operations in pixel_shuffle.
1326*da0073e9SAndroid Build Coastguard Worker            def _verify_pixel_shuffle(input, output, upscale_factor):
1327*da0073e9SAndroid Build Coastguard Worker                for c in range(output.size(-3)):
1328*da0073e9SAndroid Build Coastguard Worker                    for h in range(output.size(-2)):
1329*da0073e9SAndroid Build Coastguard Worker                        for w in range(output.size(-1)):
1330*da0073e9SAndroid Build Coastguard Worker                            height_idx = h // upscale_factor
1331*da0073e9SAndroid Build Coastguard Worker                            weight_idx = w // upscale_factor
1332*da0073e9SAndroid Build Coastguard Worker                            channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
1333*da0073e9SAndroid Build Coastguard Worker                                          (c * upscale_factor ** 2)
1334*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
1335*da0073e9SAndroid Build Coastguard Worker
1336*da0073e9SAndroid Build Coastguard Worker            upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
1337*da0073e9SAndroid Build Coastguard Worker            input = generate_input()
1338*da0073e9SAndroid Build Coastguard Worker
1339*da0073e9SAndroid Build Coastguard Worker            ps = nn.PixelShuffle(upscale_factor)
1340*da0073e9SAndroid Build Coastguard Worker            pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
1341*da0073e9SAndroid Build Coastguard Worker
1342*da0073e9SAndroid Build Coastguard Worker            if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
1343*da0073e9SAndroid Build Coastguard Worker                output = ps(input)
1344*da0073e9SAndroid Build Coastguard Worker                _verify_pixel_shuffle(input, output, upscale_factor)
1345*da0073e9SAndroid Build Coastguard Worker                output.backward(output.data)
1346*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input.data, input.grad.data)
1347*da0073e9SAndroid Build Coastguard Worker
1348*da0073e9SAndroid Build Coastguard Worker                # Ensure unshuffle properly inverts shuffle.
1349*da0073e9SAndroid Build Coastguard Worker                unshuffle_output = pus(output)
1350*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input, unshuffle_output)
1351*da0073e9SAndroid Build Coastguard Worker            else:
1352*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: ps(input))
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Worker        def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
1355*da0073e9SAndroid Build Coastguard Worker                                                    downscale_factor=None):
1356*da0073e9SAndroid Build Coastguard Worker            downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
1357*da0073e9SAndroid Build Coastguard Worker            channels = random.randint(1, 4)
1358*da0073e9SAndroid Build Coastguard Worker            # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
1359*da0073e9SAndroid Build Coastguard Worker            height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
1360*da0073e9SAndroid Build Coastguard Worker            # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
1361*da0073e9SAndroid Build Coastguard Worker            width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
1362*da0073e9SAndroid Build Coastguard Worker
1363*da0073e9SAndroid Build Coastguard Worker            if num_input_dims == 1:
1364*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(channels, requires_grad=True, device='mps')
1365*da0073e9SAndroid Build Coastguard Worker            elif num_input_dims == 2:
1366*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(height, width, requires_grad=True, device='mps')
1367*da0073e9SAndroid Build Coastguard Worker            else:
1368*da0073e9SAndroid Build Coastguard Worker                batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
1369*da0073e9SAndroid Build Coastguard Worker                input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps')
1370*da0073e9SAndroid Build Coastguard Worker
1371*da0073e9SAndroid Build Coastguard Worker            pus = nn.PixelUnshuffle(downscale_factor)
1372*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: pus(input))
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker        def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
1375*da0073e9SAndroid Build Coastguard Worker            # For 1D - 2D, this is an error case.
1376*da0073e9SAndroid Build Coastguard Worker            # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
1377*da0073e9SAndroid Build Coastguard Worker            is_contiguous_check = [True, False] if num_input_dims > 1 else [True]
1378*da0073e9SAndroid Build Coastguard Worker            for is_contiguous in is_contiguous_check:
1379*da0073e9SAndroid Build Coastguard Worker                _test_pixel_shuffle_unshuffle_helper(
1380*da0073e9SAndroid Build Coastguard Worker                    num_input_dims=num_input_dims, is_contiguous=is_contiguous
1381*da0073e9SAndroid Build Coastguard Worker                )
1382*da0073e9SAndroid Build Coastguard Worker                _test_pixel_shuffle_unshuffle_helper(
1383*da0073e9SAndroid Build Coastguard Worker                    num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous
1384*da0073e9SAndroid Build Coastguard Worker                )
1385*da0073e9SAndroid Build Coastguard Worker                _test_pixel_shuffle_unshuffle_helper(
1386*da0073e9SAndroid Build Coastguard Worker                    num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous
1387*da0073e9SAndroid Build Coastguard Worker                )
1388*da0073e9SAndroid Build Coastguard Worker                _test_pixel_shuffle_unshuffle_helper(
1389*da0073e9SAndroid Build Coastguard Worker                    num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous
1390*da0073e9SAndroid Build Coastguard Worker                )
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker                # Error cases for pixel_unshuffle.
1393*da0073e9SAndroid Build Coastguard Worker            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
1394*da0073e9SAndroid Build Coastguard Worker            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
1395*da0073e9SAndroid Build Coastguard Worker            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
1396*da0073e9SAndroid Build Coastguard Worker            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker        def test_pixel_shuffle_unshuffle_1D():
1399*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
1400*da0073e9SAndroid Build Coastguard Worker
1401*da0073e9SAndroid Build Coastguard Worker        def test_pixel_shuffle_unshuffle_2D():
1402*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
1403*da0073e9SAndroid Build Coastguard Worker
1404*da0073e9SAndroid Build Coastguard Worker        def test_pixel_shuffle_unshuffle_3D():
1405*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker        def test_pixel_shuffle_unshuffle_4D():
1408*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
1409*da0073e9SAndroid Build Coastguard Worker
1410*da0073e9SAndroid Build Coastguard Worker        def test_pixel_shuffle_unshuffle_5D():
1411*da0073e9SAndroid Build Coastguard Worker            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker        test_pixel_shuffle_unshuffle_1D()
1414*da0073e9SAndroid Build Coastguard Worker        test_pixel_shuffle_unshuffle_2D()
1415*da0073e9SAndroid Build Coastguard Worker        test_pixel_shuffle_unshuffle_3D()
1416*da0073e9SAndroid Build Coastguard Worker        test_pixel_shuffle_unshuffle_4D()
1417*da0073e9SAndroid Build Coastguard Worker        test_pixel_shuffle_unshuffle_5D()
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Workerclass MPSReluTest(TestCaseMPS):
1420*da0073e9SAndroid Build Coastguard Worker    def _npRelu(self, np_features):
1421*da0073e9SAndroid Build Coastguard Worker        return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)
1422*da0073e9SAndroid Build Coastguard Worker
1423*da0073e9SAndroid Build Coastguard Worker    def testNpRelu(self):
1424*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
1425*da0073e9SAndroid Build Coastguard Worker            np.array([[0., 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]),
1426*da0073e9SAndroid Build Coastguard Worker            self._npRelu(
1427*da0073e9SAndroid Build Coastguard Worker                np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
1428*da0073e9SAndroid Build Coastguard Worker                                                         0.9]])))
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Worker    def _testRelu(self, np_features, device):
1431*da0073e9SAndroid Build Coastguard Worker        np_relu = self._npRelu(np_features)
1432*da0073e9SAndroid Build Coastguard Worker        # Convert the numpy array to a PyTorch Tensor,
1433*da0073e9SAndroid Build Coastguard Worker        # and move the Tensor to the CPU/GPU based on the "device" parameter
1434*da0073e9SAndroid Build Coastguard Worker        py_tensor = torch.from_numpy(np_features).to(device)
1435*da0073e9SAndroid Build Coastguard Worker        py_relu = torch.nn.ReLU(inplace=False)(py_tensor)
1436*da0073e9SAndroid Build Coastguard Worker        py_relu_cpu = py_relu.to("cpu")
1437*da0073e9SAndroid Build Coastguard Worker
1438*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np_relu, py_relu_cpu)
1439*da0073e9SAndroid Build Coastguard Worker
1440*da0073e9SAndroid Build Coastguard Worker    def _testReluInPlace(self, np_features, device):
1441*da0073e9SAndroid Build Coastguard Worker        np_relu = self._npRelu(np_features)
1442*da0073e9SAndroid Build Coastguard Worker        # Convert the numpy array to a PyTorch Tensor,
1443*da0073e9SAndroid Build Coastguard Worker        # and move the Tensor to the CPU/GPU based on the "device" parameter
1444*da0073e9SAndroid Build Coastguard Worker        py_tensor = torch.from_numpy(np_features).to(device)
1445*da0073e9SAndroid Build Coastguard Worker        py_relu = torch.nn.ReLU(inplace=True)(py_tensor)
1446*da0073e9SAndroid Build Coastguard Worker        py_relu_cpu = py_relu.to("cpu")
1447*da0073e9SAndroid Build Coastguard Worker
1448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np_relu, py_relu_cpu)
1449*da0073e9SAndroid Build Coastguard Worker        # Inplace Relu modifies the initial input and it should match the output of Relu
1450*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np_relu, py_tensor.to("cpu"))
1451*da0073e9SAndroid Build Coastguard Worker
1452*da0073e9SAndroid Build Coastguard Worker    def testNumbersCPU(self):
1453*da0073e9SAndroid Build Coastguard Worker        for t in [np.int32]:
1454*da0073e9SAndroid Build Coastguard Worker            # Force execution on CPU even if a GPU kernel is available for the type.
1455*da0073e9SAndroid Build Coastguard Worker            self._testRelu(
1456*da0073e9SAndroid Build Coastguard Worker                np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1457*da0073e9SAndroid Build Coastguard Worker                device="cpu")
1458*da0073e9SAndroid Build Coastguard Worker            self._testReluInPlace(
1459*da0073e9SAndroid Build Coastguard Worker                np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1460*da0073e9SAndroid Build Coastguard Worker                device="cpu")
1461*da0073e9SAndroid Build Coastguard Worker
1462*da0073e9SAndroid Build Coastguard Worker    def testNumbersGPU(self):
1463*da0073e9SAndroid Build Coastguard Worker        for t in [np.float16, np.float32]:
1464*da0073e9SAndroid Build Coastguard Worker            self._testRelu(
1465*da0073e9SAndroid Build Coastguard Worker                np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1466*da0073e9SAndroid Build Coastguard Worker                device="mps")
1467*da0073e9SAndroid Build Coastguard Worker            self._testReluInPlace(
1468*da0073e9SAndroid Build Coastguard Worker                np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1469*da0073e9SAndroid Build Coastguard Worker                device="mps")
1470*da0073e9SAndroid Build Coastguard Worker            self._testRelu(np.array([]).astype(t), device="mps")
1471*da0073e9SAndroid Build Coastguard Worker            self._testReluInPlace(np.array([]).astype(t), device="mps")
1472*da0073e9SAndroid Build Coastguard Worker
1473*da0073e9SAndroid Build Coastguard Workerclass MatmulTest(TestCaseMPS):
1474*da0073e9SAndroid Build Coastguard Worker    def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
1475*da0073e9SAndroid Build Coastguard Worker        if expand_tensor_1_shape:
1476*da0073e9SAndroid Build Coastguard Worker            tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
1477*da0073e9SAndroid Build Coastguard Worker        else:
1478*da0073e9SAndroid Build Coastguard Worker            tensor1_mps = torch.randn(shape_tensor_1, device="mps")
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker        if expand_tensor_2_shape:
1481*da0073e9SAndroid Build Coastguard Worker            tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape)
1482*da0073e9SAndroid Build Coastguard Worker        else:
1483*da0073e9SAndroid Build Coastguard Worker            tensor2_mps = torch.randn(shape_tensor_2, device="mps")
1484*da0073e9SAndroid Build Coastguard Worker
1485*da0073e9SAndroid Build Coastguard Worker        tensor1_cpu = tensor1_mps.to("cpu")
1486*da0073e9SAndroid Build Coastguard Worker        tensor2_cpu = tensor2_mps.to("cpu")
1487*da0073e9SAndroid Build Coastguard Worker
1488*da0073e9SAndroid Build Coastguard Worker        matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
1489*da0073e9SAndroid Build Coastguard Worker        matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
1490*da0073e9SAndroid Build Coastguard Worker
1491*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(matmul_cpu, matmul_mps.to("cpu"))
1492*da0073e9SAndroid Build Coastguard Worker
1493*da0073e9SAndroid Build Coastguard Worker    def test_vector_x_vector(self):
1494*da0073e9SAndroid Build Coastguard Worker        # uses `dot`
1495*da0073e9SAndroid Build Coastguard Worker        self._helper(3, 3)
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker    def test_matrix_x_vector(self):
1498*da0073e9SAndroid Build Coastguard Worker        # uses `addmv`
1499*da0073e9SAndroid Build Coastguard Worker        self._helper((3, 4), 4)
1500*da0073e9SAndroid Build Coastguard Worker
1501*da0073e9SAndroid Build Coastguard Worker    def test_batched_matrix_x_broadcasted_vector(self):
1502*da0073e9SAndroid Build Coastguard Worker        self._helper((10, 3, 4), 4)
1503*da0073e9SAndroid Build Coastguard Worker
1504*da0073e9SAndroid Build Coastguard Worker    def test_batched_matrix_x_batched_matrix(self):
1505*da0073e9SAndroid Build Coastguard Worker        # uses `bmm.out`
1506*da0073e9SAndroid Build Coastguard Worker        self._helper((10, 3, 4), (10, 4, 5))
1507*da0073e9SAndroid Build Coastguard Worker
1508*da0073e9SAndroid Build Coastguard Worker    def test_batched_matrix_x_broadcasted_matrix(self):
1509*da0073e9SAndroid Build Coastguard Worker        self._helper((10, 3, 4), (4, 5))
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Workerclass MPSLeakyReluTest(TestCaseMPS):
1513*da0073e9SAndroid Build Coastguard Worker    def _npLeakyRelu(self, np_features, negative_slope=0.1):
1514*da0073e9SAndroid Build Coastguard Worker        return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype)
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker    def testNpLeakyRelu(self):
1517*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
1518*da0073e9SAndroid Build Coastguard Worker            np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
1519*da0073e9SAndroid Build Coastguard Worker                      [0.1, -0.03, 0.5, -0.07, 0.9]]),
1520*da0073e9SAndroid Build Coastguard Worker            self._npLeakyRelu(
1521*da0073e9SAndroid Build Coastguard Worker                np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
1522*da0073e9SAndroid Build Coastguard Worker                                                         0.9]]),
1523*da0073e9SAndroid Build Coastguard Worker                negative_slope=0.1))
1524*da0073e9SAndroid Build Coastguard Worker
1525*da0073e9SAndroid Build Coastguard Worker    def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous):
1526*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
1527*da0073e9SAndroid Build Coastguard Worker        mps_x = cpu_x.detach().clone().to('mps')
1528*da0073e9SAndroid Build Coastguard Worker
1529*da0073e9SAndroid Build Coastguard Worker        if not contiguous and not (0 in shape or len(shape) < 2):
1530*da0073e9SAndroid Build Coastguard Worker            # Tranposing will make the tensor non-contiguous
1531*da0073e9SAndroid Build Coastguard Worker            cpu_x = cpu_x.transpose(0, 1)
1532*da0073e9SAndroid Build Coastguard Worker            mps_x = mps_x.transpose(0, 1)
1533*da0073e9SAndroid Build Coastguard Worker            assert not mps_x.is_contiguous()
1534*da0073e9SAndroid Build Coastguard Worker
1535*da0073e9SAndroid Build Coastguard Worker        cpu_x.requires_grad_()
1536*da0073e9SAndroid Build Coastguard Worker        mps_x.requires_grad_()
1537*da0073e9SAndroid Build Coastguard Worker
1538*da0073e9SAndroid Build Coastguard Worker        relu_op = torch.nn.LeakyReLU(negative_slope)
1539*da0073e9SAndroid Build Coastguard Worker
1540*da0073e9SAndroid Build Coastguard Worker        cpu_leaky_relu = relu_op(cpu_x)
1541*da0073e9SAndroid Build Coastguard Worker        mps_leaky_relu = relu_op(mps_x)
1542*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
1543*da0073e9SAndroid Build Coastguard Worker
1544*da0073e9SAndroid Build Coastguard Worker        # test backward pass
1545*da0073e9SAndroid Build Coastguard Worker
1546*da0073e9SAndroid Build Coastguard Worker        cpu_grad = torch.ones_like(cpu_leaky_relu)
1547*da0073e9SAndroid Build Coastguard Worker        mps_grad = cpu_grad.to('mps')
1548*da0073e9SAndroid Build Coastguard Worker
1549*da0073e9SAndroid Build Coastguard Worker        mps_leaky_relu.backward(gradient=mps_grad)
1550*da0073e9SAndroid Build Coastguard Worker        cpu_leaky_relu.backward(gradient=cpu_grad)
1551*da0073e9SAndroid Build Coastguard Worker
1552*da0073e9SAndroid Build Coastguard Worker        assert cpu_x.grad is not None  # Check that the grad is well-populated
1553*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x.grad, mps_x.grad)
1554*da0073e9SAndroid Build Coastguard Worker
1555*da0073e9SAndroid Build Coastguard Worker    def testNumbersCPU(self):
1556*da0073e9SAndroid Build Coastguard Worker        for t in [torch.float, torch.half]:
1557*da0073e9SAndroid Build Coastguard Worker            for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
1558*da0073e9SAndroid Build Coastguard Worker                for contiguous in [True, False]:
1559*da0073e9SAndroid Build Coastguard Worker                    self._testLeakyRelu(shape,
1560*da0073e9SAndroid Build Coastguard Worker                                        dtype=t,
1561*da0073e9SAndroid Build Coastguard Worker                                        negative_slope=0.2,
1562*da0073e9SAndroid Build Coastguard Worker                                        contiguous=contiguous)
1563*da0073e9SAndroid Build Coastguard Worker
1564*da0073e9SAndroid Build Coastguard Workerclass TestAvgPool(TestCaseMPS):
1565*da0073e9SAndroid Build Coastguard Worker    def _sum_pool2d(self, x, kernel_size):
1566*da0073e9SAndroid Build Coastguard Worker        windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size)
1567*da0073e9SAndroid Build Coastguard Worker        return torch.sum(windows, dim=1)
1568*da0073e9SAndroid Build Coastguard Worker
1569*da0073e9SAndroid Build Coastguard Worker    def _sum_pool3d(self, x, kernel_size):
1570*da0073e9SAndroid Build Coastguard Worker        # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum
1571*da0073e9SAndroid Build Coastguard Worker        h = kernel_size[0]
1572*da0073e9SAndroid Build Coastguard Worker        splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h]
1573*da0073e9SAndroid Build Coastguard Worker        # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times
1574*da0073e9SAndroid Build Coastguard Worker        splited_x = [self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) for t in splited_x]
1575*da0073e9SAndroid Build Coastguard Worker        joined_x = torch.cat(splited_x)
1576*da0073e9SAndroid Build Coastguard Worker        return joined_x.view(1, joined_x.numel())
1577*da0073e9SAndroid Build Coastguard Worker
1578*da0073e9SAndroid Build Coastguard Worker    def _avg_pool2d(self, x, kernel_size):
1579*da0073e9SAndroid Build Coastguard Worker        size = reduce(operator.mul, kernel_size)  # noqa: F821
1580*da0073e9SAndroid Build Coastguard Worker        return self._sum_pool2d(x, kernel_size) / size
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker    def _avg_pool3d(self, x, kernel_size):
1583*da0073e9SAndroid Build Coastguard Worker        size = reduce(operator.mul, kernel_size)  # noqa: F821
1584*da0073e9SAndroid Build Coastguard Worker        return self._sum_pool3d(x, kernel_size) / size
1585*da0073e9SAndroid Build Coastguard Worker
1586*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool2d_with_zero_divisor(self):
1587*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, "divisor must be not zero",
1588*da0073e9SAndroid Build Coastguard Worker                               lambda: F.avg_pool2d(torch.zeros(3, 3, 3), (2, 2), divisor_override=0))
1589*da0073e9SAndroid Build Coastguard Worker
1590*da0073e9SAndroid Build Coastguard Worker    def test_doubletensor_avg_pool2d_with_divisor(self):
1591*da0073e9SAndroid Build Coastguard Worker        n, m = 3, 3
1592*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(1, 1, n, m)
1593*da0073e9SAndroid Build Coastguard Worker        for i in range(1, n + 1):
1594*da0073e9SAndroid Build Coastguard Worker            for j in range(1, m + 1):
1595*da0073e9SAndroid Build Coastguard Worker                for divisor in [1, 7, i * j]:
1596*da0073e9SAndroid Build Coastguard Worker                    actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor)
1597*da0073e9SAndroid Build Coastguard Worker                    actual = actual.view(1, actual.numel())
1598*da0073e9SAndroid Build Coastguard Worker                    expected = self._sum_pool2d(input, (i, j)) / divisor
1599*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual, expected, rtol=0, atol=1e-5)
1600*da0073e9SAndroid Build Coastguard Worker
1601*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool2d_ceil_mode(self):
1602*da0073e9SAndroid Build Coastguard Worker        # Regression test for gh-36977
1603*da0073e9SAndroid Build Coastguard Worker        x = 10 * torch.randn((1, 16, 4, 4))
1604*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.avg_pool2d(
1605*da0073e9SAndroid Build Coastguard Worker            x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
1606*da0073e9SAndroid Build Coastguard Worker            padding=(0, 1), stride=2)
1607*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.isnan(y).any())
1608*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.avg_pool2d(
1609*da0073e9SAndroid Build Coastguard Worker            x.to('mps'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
1610*da0073e9SAndroid Build Coastguard Worker            padding=(0, 1), stride=2)
1611*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.isnan(y).any())
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker
1614*da0073e9SAndroid Build Coastguard Workerclass TestMPS(TestCaseMPS):
1615*da0073e9SAndroid Build Coastguard Worker    def test_exp(self, device="mps", dtype=torch.float):
1616*da0073e9SAndroid Build Coastguard Worker        for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
1617*da0073e9SAndroid Build Coastguard Worker            b = torch.arange(18, dtype=dtype, device=device) / 3 * math.pi
1618*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(v, dtype=dtype, device="mps") * b
1619*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch.exp, np.exp, a)
1620*da0073e9SAndroid Build Coastguard Worker
1621*da0073e9SAndroid Build Coastguard Worker    def test_conv_raises_error(self, device='mps', dtype=torch.float):
1622*da0073e9SAndroid Build Coastguard Worker        conv = nn.Conv1d(1, 65537, 3, padding=1).to('mps')
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker        x = torch.ones([1, 1, 3])
1625*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(NotImplementedError):
1626*da0073e9SAndroid Build Coastguard Worker            y = conv(x.to("mps"))
1627*da0073e9SAndroid Build Coastguard Worker
1628*da0073e9SAndroid Build Coastguard Worker    def test_triu_inf(self, device="mps", dtype=torch.float):
1629*da0073e9SAndroid Build Coastguard Worker        for diag in [-1, 0, 1]:
1630*da0073e9SAndroid Build Coastguard Worker            mask = torch.full((3, 6, 6), float("-inf"))
1631*da0073e9SAndroid Build Coastguard Worker            mask_mps = mask.clone().detach().to('mps')
1632*da0073e9SAndroid Build Coastguard Worker            cpu_ref = torch.triu(mask, diagonal=diag)
1633*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.triu(mask_mps, diagonal=diag)
1634*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_ref, mps_out)
1635*da0073e9SAndroid Build Coastguard Worker
1636*da0073e9SAndroid Build Coastguard Worker    def test_exp1(self, device="mps", dtype=torch.float):
1637*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype)
1638*da0073e9SAndroid Build Coastguard Worker        output = torch.exp(input)
1639*da0073e9SAndroid Build Coastguard Worker        output_cpu = torch.exp(input.cpu())
1640*da0073e9SAndroid Build Coastguard Worker        # If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with
1641*da0073e9SAndroid Build Coastguard Worker        # Mismatched elements: 3 / 4 (75.0%)
1642*da0073e9SAndroid Build Coastguard Worker        # Greatest absolute difference: 1.1920928955078125e-07 at index (3,) (up to 1e-08 allowed)
1643*da0073e9SAndroid Build Coastguard Worker        # Greatest relative difference: 1.0786502002702036e-07 at index (3,) (up to 1e-08 allowed)
1644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, output_cpu, atol=1e-8, rtol=1e-8)
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker    def test_exp_strided_output(self):
1647*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((256, 10), device='mps')
1648*da0073e9SAndroid Build Coastguard Worker        x_cpu = x.to("cpu")
1649*da0073e9SAndroid Build Coastguard Worker
1650*da0073e9SAndroid Build Coastguard Worker        x = x.permute(1, 0)
1651*da0073e9SAndroid Build Coastguard Worker        x_cpu = x_cpu.permute(1, 0)
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker        res = x.exp()
1654*da0073e9SAndroid Build Coastguard Worker        res_cpu = x_cpu.exp()
1655*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_cpu)
1656*da0073e9SAndroid Build Coastguard Worker
1657*da0073e9SAndroid Build Coastguard Worker    def _testLeakyRelu(self, np_features, negative_slope, device):
1658*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.from_numpy(np_features).requires_grad_()
1659*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
1660*da0073e9SAndroid Build Coastguard Worker        relu_op = torch.nn.LeakyReLU(negative_slope)
1661*da0073e9SAndroid Build Coastguard Worker
1662*da0073e9SAndroid Build Coastguard Worker        cpu_leaky_relu = relu_op(cpu_x)
1663*da0073e9SAndroid Build Coastguard Worker        mps_leaky_relu = relu_op(mps_x)
1664*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
1665*da0073e9SAndroid Build Coastguard Worker
1666*da0073e9SAndroid Build Coastguard Worker        # test backward pass
1667*da0073e9SAndroid Build Coastguard Worker        cpu_grad = torch.ones_like(cpu_leaky_relu)
1668*da0073e9SAndroid Build Coastguard Worker        mps_grad = cpu_grad.to('mps')
1669*da0073e9SAndroid Build Coastguard Worker        cpu_leaky_relu.backward(gradient=cpu_grad)
1670*da0073e9SAndroid Build Coastguard Worker        mps_leaky_relu.backward(gradient=mps_grad)
1671*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
1672*da0073e9SAndroid Build Coastguard Worker
1673*da0073e9SAndroid Build Coastguard Worker    def testNumbersGPU(self):
1674*da0073e9SAndroid Build Coastguard Worker        for t in [np.float32]:
1675*da0073e9SAndroid Build Coastguard Worker            self._testLeakyRelu(
1676*da0073e9SAndroid Build Coastguard Worker                np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1677*da0073e9SAndroid Build Coastguard Worker                negative_slope=0.1,
1678*da0073e9SAndroid Build Coastguard Worker                device="mps")
1679*da0073e9SAndroid Build Coastguard Worker
1680*da0073e9SAndroid Build Coastguard Worker    def test_fill(self):
1681*da0073e9SAndroid Build Coastguard Worker
1682*da0073e9SAndroid Build Coastguard Worker        def helper(val, shape, dtype):
1683*da0073e9SAndroid Build Coastguard Worker            tensor = torch.zeros(shape, device='mps', dtype=dtype)
1684*da0073e9SAndroid Build Coastguard Worker            tensor_mps = tensor.fill_(val)
1685*da0073e9SAndroid Build Coastguard Worker
1686*da0073e9SAndroid Build Coastguard Worker            tensor_0 = torch.zeros(shape, device='cpu', dtype=dtype)
1687*da0073e9SAndroid Build Coastguard Worker            tensor_cpu = tensor_0.fill_(val)
1688*da0073e9SAndroid Build Coastguard Worker
1689*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor_mps, tensor_cpu)
1690*da0073e9SAndroid Build Coastguard Worker
1691*da0073e9SAndroid Build Coastguard Worker        helper(0, [1024], torch.float32)
1692*da0073e9SAndroid Build Coastguard Worker        helper(0.2, [2, 3], torch.float32)
1693*da0073e9SAndroid Build Coastguard Worker        helper(0.2 + 0.5j, [2, 3], torch.complex64)
1694*da0073e9SAndroid Build Coastguard Worker
1695*da0073e9SAndroid Build Coastguard Worker    def test_fill_storage_offset(self):
1696*da0073e9SAndroid Build Coastguard Worker        shape = [2, 10]
1697*da0073e9SAndroid Build Coastguard Worker        val = 0.2
1698*da0073e9SAndroid Build Coastguard Worker        tensor = torch.ones(shape, device="mps")
1699*da0073e9SAndroid Build Coastguard Worker        tensor_mps = tensor[:][1].fill_(val)
1700*da0073e9SAndroid Build Coastguard Worker        tensor_0 = torch.ones(shape, device="cpu")
1701*da0073e9SAndroid Build Coastguard Worker        tensor_cpu = tensor_0[:][1].fill_(val)
1702*da0073e9SAndroid Build Coastguard Worker
1703*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor_mps, tensor_cpu)
1704*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor, tensor_0)
1705*da0073e9SAndroid Build Coastguard Worker
1706*da0073e9SAndroid Build Coastguard Worker        shape = [1, 10]
1707*da0073e9SAndroid Build Coastguard Worker        val = 0.0
1708*da0073e9SAndroid Build Coastguard Worker        tensor = torch.ones(shape, device="mps")
1709*da0073e9SAndroid Build Coastguard Worker        val_tensor_mps = torch.tensor(val, device="mps")
1710*da0073e9SAndroid Build Coastguard Worker        tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
1711*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/114692
1712*da0073e9SAndroid Build Coastguard Worker        tensor[:, 5].fill_(val_tensor_mps)
1713*da0073e9SAndroid Build Coastguard Worker        tensor_0 = torch.ones(shape, device="cpu")
1714*da0073e9SAndroid Build Coastguard Worker        val_tensor_cpu = torch.tensor(val, device="cpu")
1715*da0073e9SAndroid Build Coastguard Worker        tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
1716*da0073e9SAndroid Build Coastguard Worker        tensor_0[:, 5].fill_(val_tensor_cpu)
1717*da0073e9SAndroid Build Coastguard Worker
1718*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor_mps.to(device="cpu"), tensor_cpu)
1719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.to(device="cpu"), tensor_0)
1720*da0073e9SAndroid Build Coastguard Worker
1721*da0073e9SAndroid Build Coastguard Worker    def test_cdist_large(self, device="mps"):
1722*da0073e9SAndroid Build Coastguard Worker        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1723*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(100, 10, device=device)
1724*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(100, 10, device=device)
1725*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1726*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
1727*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
1728*da0073e9SAndroid Build Coastguard Worker
1729*da0073e9SAndroid Build Coastguard Worker    def test_cdist_large_batch(self, device="mps"):
1730*da0073e9SAndroid Build Coastguard Worker        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1731*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(4, 3, 100, 10, device=device)
1732*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(4, 3, 100, 10, device=device)
1733*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1734*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
1735*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
1736*da0073e9SAndroid Build Coastguard Worker
1737*da0073e9SAndroid Build Coastguard Worker    def test_cdist_non_contiguous(self, device="mps"):
1738*da0073e9SAndroid Build Coastguard Worker        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1739*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(5, 7, device=device).mT
1740*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(5, 3, device=device).mT
1741*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1742*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
1743*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.is_contiguous())
1744*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(y.is_contiguous())
1745*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(7, 5, device=device)
1748*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(5, 3, device=device).t()
1749*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1750*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
1751*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.is_contiguous())
1752*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(y.is_contiguous())
1753*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(5, 7, device=device).t()
1756*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(3, 5, device=device)
1757*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1758*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
1759*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.is_contiguous())
1760*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(y.is_contiguous())
1761*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
1762*da0073e9SAndroid Build Coastguard Worker
1763*da0073e9SAndroid Build Coastguard Worker    def test_cdist_non_contiguous_batch(self, device="mps"):
1764*da0073e9SAndroid Build Coastguard Worker        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1765*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(4, 3, 2, 5, 7, device=device).mT
1766*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(4, 3, 2, 5, 3, device=device).mT
1767*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1768*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
1769*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.is_contiguous())
1770*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(y.is_contiguous())
1771*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
1772*da0073e9SAndroid Build Coastguard Worker
1773*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(7, 2, 7, 5, device=device)
1774*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(7, 2, 5, 3, device=device).mT
1775*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1776*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
1777*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.is_contiguous())
1778*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(y.is_contiguous())
1779*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
1780*da0073e9SAndroid Build Coastguard Worker
1781*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(4, 5, 7, device=device).mT
1782*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(4, 3, 5, device=device)
1783*da0073e9SAndroid Build Coastguard Worker            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1784*da0073e9SAndroid Build Coastguard Worker            expected = self._brute_cdist(x, y, p=2)
1785*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.is_contiguous())
1786*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(y.is_contiguous())
1787*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
1788*da0073e9SAndroid Build Coastguard Worker
1789*da0073e9SAndroid Build Coastguard Worker    def test_cdist_euclidean_large(self, device="mps"):
1790*da0073e9SAndroid Build Coastguard Worker        def _test_euclidean_large_cdist(sizex, sizey=None):
1791*da0073e9SAndroid Build Coastguard Worker            if sizey is None:
1792*da0073e9SAndroid Build Coastguard Worker                sizey = sizex
1793*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(sizex, device=device, dtype=torch.float)
1794*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(sizey, device=device, dtype=torch.float)
1795*da0073e9SAndroid Build Coastguard Worker            eps = 1e-6
1796*da0073e9SAndroid Build Coastguard Worker            # to avoid extremum
1797*da0073e9SAndroid Build Coastguard Worker            x = x - (((x - y) < eps).float() * 2 * eps)
1798*da0073e9SAndroid Build Coastguard Worker            x.requires_grad = True
1799*da0073e9SAndroid Build Coastguard Worker            y.requires_grad = True
1800*da0073e9SAndroid Build Coastguard Worker            dist = torch.cdist(x, y, p=2)
1801*da0073e9SAndroid Build Coastguard Worker            # Do a backward pass to check that it is valid for large
1802*da0073e9SAndroid Build Coastguard Worker            # matrices
1803*da0073e9SAndroid Build Coastguard Worker            loss = dist.sum()
1804*da0073e9SAndroid Build Coastguard Worker            loss.backward()
1805*da0073e9SAndroid Build Coastguard Worker
1806*da0073e9SAndroid Build Coastguard Worker        _test_euclidean_large_cdist((2000, 5))
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker    def test_cdist_same_inputs(self, device="mps"):
1809*da0073e9SAndroid Build Coastguard Worker        # Test to detect issues in cdist gradient calculation
1810*da0073e9SAndroid Build Coastguard Worker        # When the distances are 0
1811*da0073e9SAndroid Build Coastguard Worker        sizex = (1, 27, 32)
1812*da0073e9SAndroid Build Coastguard Worker        for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
1813*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(sizex, device=device, dtype=torch.float)
1814*da0073e9SAndroid Build Coastguard Worker            dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
1815*da0073e9SAndroid Build Coastguard Worker            y = x.clone()
1816*da0073e9SAndroid Build Coastguard Worker            eps = 1e-6
1817*da0073e9SAndroid Build Coastguard Worker            x.requires_grad = True
1818*da0073e9SAndroid Build Coastguard Worker            d = torch.cdist(x, y)
1819*da0073e9SAndroid Build Coastguard Worker            d.backward(dist_grad)
1820*da0073e9SAndroid Build Coastguard Worker            # Check that the backward passs does not contain invalid
1821*da0073e9SAndroid Build Coastguard Worker            # values such as nan or inf
1822*da0073e9SAndroid Build Coastguard Worker            assert torch.isfinite(x.grad).all()
1823*da0073e9SAndroid Build Coastguard Worker
1824*da0073e9SAndroid Build Coastguard Worker
1825*da0073e9SAndroid Build Coastguard Worker    def _brute_cdist(self, x, y, p=2):
1826*da0073e9SAndroid Build Coastguard Worker        r1 = x.shape[-2]
1827*da0073e9SAndroid Build Coastguard Worker        r2 = y.shape[-2]
1828*da0073e9SAndroid Build Coastguard Worker        if r1 == 0 or r2 == 0:
1829*da0073e9SAndroid Build Coastguard Worker            return torch.empty(r1, r2, device=x.device)
1830*da0073e9SAndroid Build Coastguard Worker        return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
1831*da0073e9SAndroid Build Coastguard Worker
1832*da0073e9SAndroid Build Coastguard Worker    def test_cdist_norm(self, device="mps"):
1833*da0073e9SAndroid Build Coastguard Worker        for r1 in [3, 4]:
1834*da0073e9SAndroid Build Coastguard Worker            for m in [2, 3]:
1835*da0073e9SAndroid Build Coastguard Worker                for r2 in [4, 6]:
1836*da0073e9SAndroid Build Coastguard Worker                    for p in [0, 1, 1.5, 2.5, float('inf')]:
1837*da0073e9SAndroid Build Coastguard Worker                        x = torch.randn(r1, m, device=device)
1838*da0073e9SAndroid Build Coastguard Worker                        y = torch.randn(r2, m, device=device)
1839*da0073e9SAndroid Build Coastguard Worker                        if p == 2:
1840*da0073e9SAndroid Build Coastguard Worker                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1841*da0073e9SAndroid Build Coastguard Worker                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
1842*da0073e9SAndroid Build Coastguard Worker                                expected = self._brute_cdist(x, y, p=2)
1843*da0073e9SAndroid Build Coastguard Worker                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
1844*da0073e9SAndroid Build Coastguard Worker                        else:
1845*da0073e9SAndroid Build Coastguard Worker                            actual = torch.cdist(x, y, p=p)
1846*da0073e9SAndroid Build Coastguard Worker                            expected = self._brute_cdist(x, y, p=p)
1847*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(expected, actual)
1848*da0073e9SAndroid Build Coastguard Worker
1849*da0073e9SAndroid Build Coastguard Worker    def test_cdist_norm_batch(self, device="mps"):
1850*da0073e9SAndroid Build Coastguard Worker        for r1 in [3, 4]:
1851*da0073e9SAndroid Build Coastguard Worker            for m in [2, 3]:
1852*da0073e9SAndroid Build Coastguard Worker                for r2 in [4, 6]:
1853*da0073e9SAndroid Build Coastguard Worker                    for p in [0, 3, 1.5, 2.5, float('inf')]:
1854*da0073e9SAndroid Build Coastguard Worker                        x = torch.randn(2, 3, 6, r1, m, device=device)
1855*da0073e9SAndroid Build Coastguard Worker                        y = torch.randn(2, 3, 6, r2, m, device=device)
1856*da0073e9SAndroid Build Coastguard Worker                        if p == 2:
1857*da0073e9SAndroid Build Coastguard Worker                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1858*da0073e9SAndroid Build Coastguard Worker                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
1859*da0073e9SAndroid Build Coastguard Worker                                expected = self._brute_cdist(x, y, p=2)
1860*da0073e9SAndroid Build Coastguard Worker                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
1861*da0073e9SAndroid Build Coastguard Worker                        else:
1862*da0073e9SAndroid Build Coastguard Worker                            actual = torch.cdist(x, y, p=p)
1863*da0073e9SAndroid Build Coastguard Worker                            expected = self._brute_cdist(x, y, p=p)
1864*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(expected, actual)
1865*da0073e9SAndroid Build Coastguard Worker
1866*da0073e9SAndroid Build Coastguard Worker    def test_mm(self):
1867*da0073e9SAndroid Build Coastguard Worker        B = torch.ones(5, 6).to("mps")
1868*da0073e9SAndroid Build Coastguard Worker        C = torch.ones(6, 5).to("mps")
1869*da0073e9SAndroid Build Coastguard Worker        D = torch.mm(B, C).cpu()
1870*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(D, torch.full((5, 5), 6.0))
1871*da0073e9SAndroid Build Coastguard Worker
1872*da0073e9SAndroid Build Coastguard Worker    def test_linalg_cross(self):
1873*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
1874*da0073e9SAndroid Build Coastguard Worker            device = "mps"
1875*da0073e9SAndroid Build Coastguard Worker            if dtype is torch.int32 or dtype is torch.int64:
1876*da0073e9SAndroid Build Coastguard Worker                x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
1877*da0073e9SAndroid Build Coastguard Worker                y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
1878*da0073e9SAndroid Build Coastguard Worker            else:
1879*da0073e9SAndroid Build Coastguard Worker                x = torch.rand(100, 3, 100, dtype=dtype, device=device)
1880*da0073e9SAndroid Build Coastguard Worker                y = torch.rand(100, 3, 100, dtype=dtype, device=device)
1881*da0073e9SAndroid Build Coastguard Worker            x_cpu = x.to("cpu")
1882*da0073e9SAndroid Build Coastguard Worker            y_cpu = y.to("cpu")
1883*da0073e9SAndroid Build Coastguard Worker            res1 = torch.linalg.cross(x, y, dim=1)
1884*da0073e9SAndroid Build Coastguard Worker            res2 = torch.tensor((), dtype=dtype, device=device)
1885*da0073e9SAndroid Build Coastguard Worker            res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
1886*da0073e9SAndroid Build Coastguard Worker            res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
1887*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cross(x, y, dim=1, out=res2)
1888*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
1889*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
1890*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res1_cpu)
1891*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res2, res2_cpu)
1892*da0073e9SAndroid Build Coastguard Worker
1893*da0073e9SAndroid Build Coastguard Worker            # test for broadcastable inputs
1894*da0073e9SAndroid Build Coastguard Worker            if dtype is torch.int32 or dtype is torch.int64:
1895*da0073e9SAndroid Build Coastguard Worker                x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device)
1896*da0073e9SAndroid Build Coastguard Worker                y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device)
1897*da0073e9SAndroid Build Coastguard Worker            else:
1898*da0073e9SAndroid Build Coastguard Worker                x = torch.rand(1, 3, 2, dtype=dtype, device=device)
1899*da0073e9SAndroid Build Coastguard Worker                y = torch.rand(4, 3, 1, dtype=dtype, device=device)
1900*da0073e9SAndroid Build Coastguard Worker            x_cpu = x.to("cpu")
1901*da0073e9SAndroid Build Coastguard Worker            y_cpu = y.to("cpu")
1902*da0073e9SAndroid Build Coastguard Worker            res1 = torch.linalg.cross(x, y, dim=1)
1903*da0073e9SAndroid Build Coastguard Worker            res2 = torch.tensor((), dtype=dtype, device=device)
1904*da0073e9SAndroid Build Coastguard Worker            res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
1905*da0073e9SAndroid Build Coastguard Worker            res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
1906*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cross(x, y, dim=1, out=res2)
1907*da0073e9SAndroid Build Coastguard Worker            torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
1908*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
1909*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res1_cpu)
1910*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res2, res2_cpu)
1911*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]]
1912*da0073e9SAndroid Build Coastguard Worker
1913*da0073e9SAndroid Build Coastguard Worker    def test_cross(self):
1914*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, 3, device="mps")
1915*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(4, 3, device="mps")
1916*da0073e9SAndroid Build Coastguard Worker        a_cpu = a.to("cpu")
1917*da0073e9SAndroid Build Coastguard Worker        b_cpu = b.to("cpu")
1918*da0073e9SAndroid Build Coastguard Worker        res = torch.cross(a, b, dim=1)
1919*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.cross(a_cpu, b_cpu, dim=1)
1920*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_cpu)
1921*da0073e9SAndroid Build Coastguard Worker
1922*da0073e9SAndroid Build Coastguard Worker    def test_addmm(self):
1923*da0073e9SAndroid Build Coastguard Worker        A = torch.ones(5, 5).to("mps")
1924*da0073e9SAndroid Build Coastguard Worker        B = torch.ones(5, 6).to("mps")
1925*da0073e9SAndroid Build Coastguard Worker        C = torch.ones(6, 5).to("mps")
1926*da0073e9SAndroid Build Coastguard Worker        D = torch.addmm(A, B, C).to("cpu")
1927*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(D, torch.full((5, 5), 7.0))
1928*da0073e9SAndroid Build Coastguard Worker
1929*da0073e9SAndroid Build Coastguard Worker    def test_bmm(self):
1930*da0073e9SAndroid Build Coastguard Worker        batch1_cpu = torch.randn(10, 3, 4)
1931*da0073e9SAndroid Build Coastguard Worker        batch2_cpu = torch.randn(10, 4, 5)
1932*da0073e9SAndroid Build Coastguard Worker
1933*da0073e9SAndroid Build Coastguard Worker        batch1_mps = batch1_cpu.detach().clone().to("mps")
1934*da0073e9SAndroid Build Coastguard Worker        batch2_mps = batch2_cpu.detach().clone().to("mps")
1935*da0073e9SAndroid Build Coastguard Worker
1936*da0073e9SAndroid Build Coastguard Worker        output_cpu = torch.bmm(batch1_cpu, batch2_cpu)
1937*da0073e9SAndroid Build Coastguard Worker        output_mps = torch.bmm(batch1_mps, batch2_mps)
1938*da0073e9SAndroid Build Coastguard Worker
1939*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu, output_mps)
1940*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu.size(), output_mps.size())
1941*da0073e9SAndroid Build Coastguard Worker
1942*da0073e9SAndroid Build Coastguard Worker    @xfailIf(product_version < 15.0)
1943*da0073e9SAndroid Build Coastguard Worker    @parametrize("dtype", [torch.float16, torch.bfloat16])
1944*da0073e9SAndroid Build Coastguard Worker    def test_large_bmm(self, dtype):
1945*da0073e9SAndroid Build Coastguard Worker        batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps')
1946*da0073e9SAndroid Build Coastguard Worker        batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps')
1947*da0073e9SAndroid Build Coastguard Worker        output_cpu = torch.bmm(batch1.cpu(), batch2.cpu())
1948*da0073e9SAndroid Build Coastguard Worker        output_mps = torch.bmm(batch1, batch2)
1949*da0073e9SAndroid Build Coastguard Worker
1950*da0073e9SAndroid Build Coastguard Worker        # Using the low precision comparison for FP16
1951*da0073e9SAndroid Build Coastguard Worker        tol = 1e-2 if dtype == torch.float16 else None
1952*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol)
1953*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu.size(), output_mps.size())
1954*da0073e9SAndroid Build Coastguard Worker
1955*da0073e9SAndroid Build Coastguard Worker
1956*da0073e9SAndroid Build Coastguard Worker    def test_addr(self):
1957*da0073e9SAndroid Build Coastguard Worker        A = torch.ones(5, 10).to("mps")
1958*da0073e9SAndroid Build Coastguard Worker        B = torch.ones(5).to("mps")
1959*da0073e9SAndroid Build Coastguard Worker        C = torch.ones(10).to("mps")
1960*da0073e9SAndroid Build Coastguard Worker        D = torch.addr(A, B, C).to("cpu")
1961*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(D, torch.full((5, 10), 2.0))
1962*da0073e9SAndroid Build Coastguard Worker
1963*da0073e9SAndroid Build Coastguard Worker    def test_trace(self):
1964*da0073e9SAndroid Build Coastguard Worker        M_cpu = torch.randn(3, 3)
1965*da0073e9SAndroid Build Coastguard Worker        M_mps = M_cpu.detach().clone().to("mps")
1966*da0073e9SAndroid Build Coastguard Worker
1967*da0073e9SAndroid Build Coastguard Worker        output_cpu = torch.trace(M_cpu)
1968*da0073e9SAndroid Build Coastguard Worker        output_mps = torch.trace(M_mps)
1969*da0073e9SAndroid Build Coastguard Worker
1970*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu, output_mps)
1971*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu.size(), output_mps.size())
1972*da0073e9SAndroid Build Coastguard Worker
1973*da0073e9SAndroid Build Coastguard Worker    def test_addbmm(self):
1974*da0073e9SAndroid Build Coastguard Worker        M_cpu = torch.randn(3, 5)
1975*da0073e9SAndroid Build Coastguard Worker        batch1_cpu = torch.randn(10, 3, 4)
1976*da0073e9SAndroid Build Coastguard Worker        batch2_cpu = torch.randn(10, 4, 5)
1977*da0073e9SAndroid Build Coastguard Worker
1978*da0073e9SAndroid Build Coastguard Worker        M_mps = M_cpu.detach().clone().to("mps")
1979*da0073e9SAndroid Build Coastguard Worker        batch1_mps = batch1_cpu.detach().clone().to("mps")
1980*da0073e9SAndroid Build Coastguard Worker        batch2_mps = batch2_cpu.detach().clone().to("mps")
1981*da0073e9SAndroid Build Coastguard Worker
1982*da0073e9SAndroid Build Coastguard Worker        output_cpu = torch.addbmm(M_cpu, batch1_cpu, batch2_cpu)
1983*da0073e9SAndroid Build Coastguard Worker        output_mps = torch.addbmm(M_mps, batch1_mps, batch2_mps)
1984*da0073e9SAndroid Build Coastguard Worker
1985*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu, output_mps)
1986*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu.size(), output_mps.size())
1987*da0073e9SAndroid Build Coastguard Worker
1988*da0073e9SAndroid Build Coastguard Worker    def test_baddbmm(self):
1989*da0073e9SAndroid Build Coastguard Worker        def helper(input_shape, batch1_shape, batch2_shape):
1990*da0073e9SAndroid Build Coastguard Worker            M_cpu = torch.randn(input_shape)
1991*da0073e9SAndroid Build Coastguard Worker            batch1_cpu = torch.randn(batch1_shape)
1992*da0073e9SAndroid Build Coastguard Worker            batch2_cpu = torch.randn(batch2_shape)
1993*da0073e9SAndroid Build Coastguard Worker            alpha = 1.2
1994*da0073e9SAndroid Build Coastguard Worker            beta = 0.8
1995*da0073e9SAndroid Build Coastguard Worker
1996*da0073e9SAndroid Build Coastguard Worker            M_mps = M_cpu.detach().clone().to("mps")
1997*da0073e9SAndroid Build Coastguard Worker            batch1_mps = batch1_cpu.detach().clone().to("mps")
1998*da0073e9SAndroid Build Coastguard Worker            batch2_mps = batch2_cpu.detach().clone().to("mps")
1999*da0073e9SAndroid Build Coastguard Worker
2000*da0073e9SAndroid Build Coastguard Worker            output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha)
2001*da0073e9SAndroid Build Coastguard Worker            output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha)
2002*da0073e9SAndroid Build Coastguard Worker
2003*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_cpu, output_mps)
2004*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_cpu.size(), output_mps.size())
2005*da0073e9SAndroid Build Coastguard Worker
2006*da0073e9SAndroid Build Coastguard Worker        helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
2007*da0073e9SAndroid Build Coastguard Worker        helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
2008*da0073e9SAndroid Build Coastguard Worker        helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77))
2009*da0073e9SAndroid Build Coastguard Worker
2010*da0073e9SAndroid Build Coastguard Worker    def test_local_scalar_dense_mps(self):
2011*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.randn(1)
2012*da0073e9SAndroid Build Coastguard Worker        y_mps = x_cpu.to("mps")
2013*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(x_cpu.item(), y_mps.item())
2014*da0073e9SAndroid Build Coastguard Worker
2015*da0073e9SAndroid Build Coastguard Worker    def test_linear_1d_weight(self):
2016*da0073e9SAndroid Build Coastguard Worker        device = 'cpu'
2017*da0073e9SAndroid Build Coastguard Worker        projected = torch.rand([8]).to(device)
2018*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([1, 2, 2, 8]).to(device)
2019*da0073e9SAndroid Build Coastguard Worker        x_mps = x.to('mps')
2020*da0073e9SAndroid Build Coastguard Worker        projected_mps = projected.to('mps')
2021*da0073e9SAndroid Build Coastguard Worker        linear = F.linear(x, projected)
2022*da0073e9SAndroid Build Coastguard Worker        linear_mps = F.linear(x_mps, projected_mps)
2023*da0073e9SAndroid Build Coastguard Worker
2024*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(linear, linear_mps)
2025*da0073e9SAndroid Build Coastguard Worker
2026*da0073e9SAndroid Build Coastguard Worker        projected = torch.rand([1, 8]).to(device)
2027*da0073e9SAndroid Build Coastguard Worker        x = torch.rand([1, 2, 2, 8]).to(device)
2028*da0073e9SAndroid Build Coastguard Worker        x_mps = x.to('mps')
2029*da0073e9SAndroid Build Coastguard Worker        projected_mps = projected.to('mps')
2030*da0073e9SAndroid Build Coastguard Worker        linear = F.linear(x, projected)
2031*da0073e9SAndroid Build Coastguard Worker        linear_mps = F.linear(x_mps, projected_mps)
2032*da0073e9SAndroid Build Coastguard Worker
2033*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(linear, linear_mps)
2034*da0073e9SAndroid Build Coastguard Worker
2035*da0073e9SAndroid Build Coastguard Worker    def test_linear_bias(self):
2036*da0073e9SAndroid Build Coastguard Worker        def helper(bias_shape):
2037*da0073e9SAndroid Build Coastguard Worker            device = "cpu"
2038*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 2, 2, 64, device=device)
2039*da0073e9SAndroid Build Coastguard Worker            linear = torch.nn.Linear(64, 4, device=device)
2040*da0073e9SAndroid Build Coastguard Worker            linear.bias = torch.nn.Parameter(torch.randn(bias_shape, dtype=torch.float32, device=device))
2041*da0073e9SAndroid Build Coastguard Worker            y = linear(x)
2042*da0073e9SAndroid Build Coastguard Worker            device = "mps"
2043*da0073e9SAndroid Build Coastguard Worker            x_mps = x.to(device)
2044*da0073e9SAndroid Build Coastguard Worker            linear.to(device)
2045*da0073e9SAndroid Build Coastguard Worker            y_mps = linear(x_mps)
2046*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, y_mps)
2047*da0073e9SAndroid Build Coastguard Worker
2048*da0073e9SAndroid Build Coastguard Worker        helper(())
2049*da0073e9SAndroid Build Coastguard Worker        helper((2, 4))
2050*da0073e9SAndroid Build Coastguard Worker
2051*da0073e9SAndroid Build Coastguard Worker    def test_linear_errors(self):
2052*da0073e9SAndroid Build Coastguard Worker        # Mixed CPU<->MPS tensors
2053*da0073e9SAndroid Build Coastguard Worker        size = (3, 3)
2054*da0073e9SAndroid Build Coastguard Worker
2055*da0073e9SAndroid Build Coastguard Worker        # Unsupported dtypes
2056*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not support linear for non-float weights"):
2057*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.linear(torch.rand(size, device='mps'),
2058*da0073e9SAndroid Build Coastguard Worker                                       torch.randint(-10, 10, size, dtype=torch.int8, device='mps'))
2059*da0073e9SAndroid Build Coastguard Worker
2060*da0073e9SAndroid Build Coastguard Worker        # Weigths on wrong device
2061*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"):
2062*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.linear(torch.rand(size, device='mps'),
2063*da0073e9SAndroid Build Coastguard Worker                                       torch.rand(size, device='cpu'))
2064*da0073e9SAndroid Build Coastguard Worker
2065*da0073e9SAndroid Build Coastguard Worker        # Input on wrong device
2066*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"):
2067*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.linear(torch.rand(size, device='cpu'),
2068*da0073e9SAndroid Build Coastguard Worker                                       torch.rand(size, device='mps'))
2069*da0073e9SAndroid Build Coastguard Worker
2070*da0073e9SAndroid Build Coastguard Worker    def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False):
2071*da0073e9SAndroid Build Coastguard Worker        cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias)
2072*da0073e9SAndroid Build Coastguard Worker        mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)
2073*da0073e9SAndroid Build Coastguard Worker
2074*da0073e9SAndroid Build Coastguard Worker        # Use the same weights and bias as the ones from the cpu
2075*da0073e9SAndroid Build Coastguard Worker        mps_linear.weight.data = cpu_linear.weight.data.detach().clone().to("mps")
2076*da0073e9SAndroid Build Coastguard Worker
2077*da0073e9SAndroid Build Coastguard Worker        if bias:
2078*da0073e9SAndroid Build Coastguard Worker            mps_linear.bias.data = cpu_linear.bias.data.detach().clone().to("mps")
2079*da0073e9SAndroid Build Coastguard Worker
2080*da0073e9SAndroid Build Coastguard Worker        linear_mps_input = torch.randn(shape).to('mps')
2081*da0073e9SAndroid Build Coastguard Worker        linear_cpu_input = linear_mps_input.detach().clone().to('cpu')
2082*da0073e9SAndroid Build Coastguard Worker
2083*da0073e9SAndroid Build Coastguard Worker        if backward_pass:
2084*da0073e9SAndroid Build Coastguard Worker            linear_mps_input = linear_mps_input.requires_grad_()
2085*da0073e9SAndroid Build Coastguard Worker            linear_cpu_input = linear_cpu_input.requires_grad_()
2086*da0073e9SAndroid Build Coastguard Worker
2087*da0073e9SAndroid Build Coastguard Worker        linear_cpu_output = cpu_linear(linear_cpu_input)
2088*da0073e9SAndroid Build Coastguard Worker        linear_mps_output = mps_linear(linear_mps_input)
2089*da0073e9SAndroid Build Coastguard Worker
2090*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(linear_cpu_output, linear_mps_output.to('cpu'))
2091*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(linear_cpu_output.size(), linear_mps_output.size())
2092*da0073e9SAndroid Build Coastguard Worker
2093*da0073e9SAndroid Build Coastguard Worker        if backward_pass:
2094*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.rand_like(linear_cpu_output, requires_grad=True)
2095*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.detach().to('mps').requires_grad_()
2096*da0073e9SAndroid Build Coastguard Worker
2097*da0073e9SAndroid Build Coastguard Worker            linear_cpu_output.backward(gradient=cpu_grad, create_graph=True)
2098*da0073e9SAndroid Build Coastguard Worker            linear_mps_output.backward(gradient=grad, create_graph=True)
2099*da0073e9SAndroid Build Coastguard Worker
2100*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(linear_cpu_input.grad.size(), linear_mps_input.grad.size())
2101*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2102*da0073e9SAndroid Build Coastguard Worker
2103*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_linear.weight.grad.size(), mps_linear.weight.grad.size())
2104*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2105*da0073e9SAndroid Build Coastguard Worker            if bias:
2106*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cpu_linear.bias.grad.size(), mps_linear.bias.grad.size())
2107*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2108*da0073e9SAndroid Build Coastguard Worker
2109*da0073e9SAndroid Build Coastguard Worker            # test gradgrad
2110*da0073e9SAndroid Build Coastguard Worker            x_grad_out = torch.rand_like(linear_cpu_input)
2111*da0073e9SAndroid Build Coastguard Worker            x_grad_out_mps = x_grad_out.to("mps")
2112*da0073e9SAndroid Build Coastguard Worker            w_grad_out = torch.rand_like(cpu_linear.weight)
2113*da0073e9SAndroid Build Coastguard Worker            w_grad_out_mps = w_grad_out.to("mps")
2114*da0073e9SAndroid Build Coastguard Worker
2115*da0073e9SAndroid Build Coastguard Worker            linear_cpu_input.grad.detach().zero_()
2116*da0073e9SAndroid Build Coastguard Worker            linear_mps_input.grad.detach().zero_()
2117*da0073e9SAndroid Build Coastguard Worker            cpu_linear.weight.grad.detach().zero_()
2118*da0073e9SAndroid Build Coastguard Worker            mps_linear.weight.grad.detach().zero_()
2119*da0073e9SAndroid Build Coastguard Worker            if bias:
2120*da0073e9SAndroid Build Coastguard Worker                b_grad_out = torch.rand_like(cpu_linear.bias)
2121*da0073e9SAndroid Build Coastguard Worker                b_grad_out_mps = b_grad_out.to("mps")
2122*da0073e9SAndroid Build Coastguard Worker                cpu_linear.bias.grad.detach().zero_()
2123*da0073e9SAndroid Build Coastguard Worker                mps_linear.bias.grad.detach().zero_()
2124*da0073e9SAndroid Build Coastguard Worker
2125*da0073e9SAndroid Build Coastguard Worker            linear_cpu_input.grad.backward(x_grad_out, retain_graph=True)
2126*da0073e9SAndroid Build Coastguard Worker            linear_mps_input.grad.backward(x_grad_out_mps, retain_graph=True)
2127*da0073e9SAndroid Build Coastguard Worker            cpu_linear.weight.grad.backward(w_grad_out, retain_graph=True)
2128*da0073e9SAndroid Build Coastguard Worker            mps_linear.weight.grad.backward(w_grad_out_mps, retain_graph=True)
2129*da0073e9SAndroid Build Coastguard Worker            if bias:
2130*da0073e9SAndroid Build Coastguard Worker                cpu_linear.bias.grad.backward(b_grad_out, retain_graph=True)
2131*da0073e9SAndroid Build Coastguard Worker                mps_linear.bias.grad.backward(b_grad_out_mps, retain_graph=True)
2132*da0073e9SAndroid Build Coastguard Worker
2133*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_grad.grad, grad.grad)
2134*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad)
2135*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad)
2136*da0073e9SAndroid Build Coastguard Worker            if bias:
2137*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad)
2138*da0073e9SAndroid Build Coastguard Worker
2139*da0073e9SAndroid Build Coastguard Worker    def test_linear1D(self):
2140*da0073e9SAndroid Build Coastguard Worker        self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=False)
2141*da0073e9SAndroid Build Coastguard Worker
2142*da0073e9SAndroid Build Coastguard Worker    def test_linear1D_backward(self):
2143*da0073e9SAndroid Build Coastguard Worker        self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=True)
2144*da0073e9SAndroid Build Coastguard Worker
2145*da0073e9SAndroid Build Coastguard Worker    def test_linear2D(self):
2146*da0073e9SAndroid Build Coastguard Worker        self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=False)
2147*da0073e9SAndroid Build Coastguard Worker
2148*da0073e9SAndroid Build Coastguard Worker    def test_linear2D_backward(self):
2149*da0073e9SAndroid Build Coastguard Worker        self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=True)
2150*da0073e9SAndroid Build Coastguard Worker
2151*da0073e9SAndroid Build Coastguard Worker    def test_linear2D_no_bias(self):
2152*da0073e9SAndroid Build Coastguard Worker        self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=False)
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker    def test_linear2D_no_bias_backward(self):
2155*da0073e9SAndroid Build Coastguard Worker        self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=True)
2156*da0073e9SAndroid Build Coastguard Worker
2157*da0073e9SAndroid Build Coastguard Worker    def test_linear3D(self):
2158*da0073e9SAndroid Build Coastguard Worker        self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
2159*da0073e9SAndroid Build Coastguard Worker
2160*da0073e9SAndroid Build Coastguard Worker    def test_linear3D_backward(self):
2161*da0073e9SAndroid Build Coastguard Worker        self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
2162*da0073e9SAndroid Build Coastguard Worker
2163*da0073e9SAndroid Build Coastguard Worker    def test_linear3D_no_bias(self):
2164*da0073e9SAndroid Build Coastguard Worker        self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
2165*da0073e9SAndroid Build Coastguard Worker
2166*da0073e9SAndroid Build Coastguard Worker    def test_linear3D_no_bias_backward(self):
2167*da0073e9SAndroid Build Coastguard Worker        self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
2168*da0073e9SAndroid Build Coastguard Worker
2169*da0073e9SAndroid Build Coastguard Worker    def test_uniform(self):
2170*da0073e9SAndroid Build Coastguard Worker        low = torch.zeros(5, 5, requires_grad=True)
2171*da0073e9SAndroid Build Coastguard Worker        high = (torch.ones(5, 5) * 3).requires_grad_()
2172*da0073e9SAndroid Build Coastguard Worker        low_1d = torch.zeros(1, requires_grad=True)
2173*da0073e9SAndroid Build Coastguard Worker        high_1d = (torch.ones(1) * 3).requires_grad_()
2174*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
2175*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
2176*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
2177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
2178*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
2179*da0073e9SAndroid Build Coastguard Worker
2180*da0073e9SAndroid Build Coastguard Worker        # Check log_prob computation when value outside range
2181*da0073e9SAndroid Build Coastguard Worker        uniform = Uniform(low_1d, high_1d, validate_args=False)
2182*da0073e9SAndroid Build Coastguard Worker        above_high = torch.tensor([4.0])
2183*da0073e9SAndroid Build Coastguard Worker        below_low = torch.tensor([-1.0])
2184*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.log_prob(above_high).item(), -inf)
2185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.log_prob(below_low).item(), -inf)
2186*da0073e9SAndroid Build Coastguard Worker
2187*da0073e9SAndroid Build Coastguard Worker        # check cdf computation when value outside range
2188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.cdf(below_low).item(), 0)
2189*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(uniform.cdf(above_high).item(), 1)
2190*da0073e9SAndroid Build Coastguard Worker
2191*da0073e9SAndroid Build Coastguard Worker        state = torch.get_rng_state()
2192*da0073e9SAndroid Build Coastguard Worker        rand = low.new(low.size()).uniform_()
2193*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(state)
2194*da0073e9SAndroid Build Coastguard Worker        u = Uniform(low, high).rsample()
2195*da0073e9SAndroid Build Coastguard Worker        u.backward(torch.ones_like(u))
2196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(low.grad, 1 - rand)
2197*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(high.grad, rand)
2198*da0073e9SAndroid Build Coastguard Worker        low.grad.zero_()
2199*da0073e9SAndroid Build Coastguard Worker        high.grad.zero_()
2200*da0073e9SAndroid Build Coastguard Worker
2201*da0073e9SAndroid Build Coastguard Worker    def test_randperm(self, device="mps"):
2202*da0073e9SAndroid Build Coastguard Worker        rng_device = None
2203*da0073e9SAndroid Build Coastguard Worker        for n in (5, 100, 50000, 100000):
2204*da0073e9SAndroid Build Coastguard Worker            for dtype in (torch.long, torch.half, torch.float):
2205*da0073e9SAndroid Build Coastguard Worker                if n > 2049 and dtype == torch.half:  # Large n for torch.half will raise an exception, do not test here.
2206*da0073e9SAndroid Build Coastguard Worker                    continue
2207*da0073e9SAndroid Build Coastguard Worker                if n > 256 and dtype == torch.bfloat16:
2208*da0073e9SAndroid Build Coastguard Worker                    continue
2209*da0073e9SAndroid Build Coastguard Worker                with torch.random.fork_rng(devices=rng_device):
2210*da0073e9SAndroid Build Coastguard Worker                    res1 = torch.randperm(n, dtype=dtype, device=device)
2211*da0073e9SAndroid Build Coastguard Worker                res2 = torch.empty(0, dtype=dtype, device=device)
2212*da0073e9SAndroid Build Coastguard Worker                torch.randperm(n, out=res2, dtype=dtype, device=device)
2213*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device))
2214*da0073e9SAndroid Build Coastguard Worker
2215*da0073e9SAndroid Build Coastguard Worker        # Default type is long
2216*da0073e9SAndroid Build Coastguard Worker        for n in (100, 10000):
2217*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.randperm(n, device=device).dtype, torch.long)
2218*da0073e9SAndroid Build Coastguard Worker
2219*da0073e9SAndroid Build Coastguard Worker        # randperm of 0 elements is an empty tensor
2220*da0073e9SAndroid Build Coastguard Worker        res1 = torch.randperm(0)
2221*da0073e9SAndroid Build Coastguard Worker        res2 = torch.tensor(5, dtype=dtype, device=device)
2222*da0073e9SAndroid Build Coastguard Worker        torch.randperm(0, out=res2)
2223*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1.numel(), 0)
2224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res2.numel(), 0)
2225*da0073e9SAndroid Build Coastguard Worker
2226*da0073e9SAndroid Build Coastguard Worker        # Test non-contiguous tensors
2227*da0073e9SAndroid Build Coastguard Worker        for n in (4, 5, 6, 10, 20):
2228*da0073e9SAndroid Build Coastguard Worker            non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t()
2229*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(non_contiguous_tensor.is_contiguous())
2230*da0073e9SAndroid Build Coastguard Worker            with torch.random.fork_rng(devices=rng_device):
2231*da0073e9SAndroid Build Coastguard Worker                res = torch.randperm(n, dtype=torch.long, device=device)
2232*da0073e9SAndroid Build Coastguard Worker            torch.randperm(n, out=non_contiguous_tensor)
2233*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device))
2234*da0073e9SAndroid Build Coastguard Worker
2235*da0073e9SAndroid Build Coastguard Worker    # Test forward maxpool2d
2236*da0073e9SAndroid Build Coastguard Worker    def test_max_pool2d(self):
2237*da0073e9SAndroid Build Coastguard Worker        def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):
2238*da0073e9SAndroid Build Coastguard Worker
2239*da0073e9SAndroid Build Coastguard Worker            cpu_x = None
2240*da0073e9SAndroid Build Coastguard Worker            if (test_ties):
2241*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True)
2242*da0073e9SAndroid Build Coastguard Worker            else:
2243*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
2244*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
2245*da0073e9SAndroid Build Coastguard Worker
2246*da0073e9SAndroid Build Coastguard Worker            pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation,
2247*da0073e9SAndroid Build Coastguard Worker                                      ceil_mode=ceil_mode, return_indices=return_indices)
2248*da0073e9SAndroid Build Coastguard Worker
2249*da0073e9SAndroid Build Coastguard Worker            if (return_indices is False):
2250*da0073e9SAndroid Build Coastguard Worker                y = pool(x)
2251*da0073e9SAndroid Build Coastguard Worker                ref_y = pool(cpu_x)
2252*da0073e9SAndroid Build Coastguard Worker
2253*da0073e9SAndroid Build Coastguard Worker                cpu_grad = torch.ones_like(ref_y)
2254*da0073e9SAndroid Build Coastguard Worker                grad = cpu_grad.to('mps')
2255*da0073e9SAndroid Build Coastguard Worker
2256*da0073e9SAndroid Build Coastguard Worker                y.backward(gradient=grad)
2257*da0073e9SAndroid Build Coastguard Worker                ref_y.backward(gradient=cpu_grad)
2258*da0073e9SAndroid Build Coastguard Worker
2259*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y, ref_y)
2260*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.grad, cpu_x.grad)
2261*da0073e9SAndroid Build Coastguard Worker            else:
2262*da0073e9SAndroid Build Coastguard Worker                y, idx = pool(x)
2263*da0073e9SAndroid Build Coastguard Worker                ref_y, ref_idx = pool(cpu_x)
2264*da0073e9SAndroid Build Coastguard Worker
2265*da0073e9SAndroid Build Coastguard Worker                cpu_grad = torch.ones_like(ref_y)
2266*da0073e9SAndroid Build Coastguard Worker                grad = cpu_grad.to('mps')
2267*da0073e9SAndroid Build Coastguard Worker
2268*da0073e9SAndroid Build Coastguard Worker                y.backward(gradient=grad)
2269*da0073e9SAndroid Build Coastguard Worker                ref_y.backward(gradient=cpu_grad)
2270*da0073e9SAndroid Build Coastguard Worker
2271*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y, ref_y)
2272*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(idx, ref_idx)
2273*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.grad, cpu_x.grad)
2274*da0073e9SAndroid Build Coastguard Worker
2275*da0073e9SAndroid Build Coastguard Worker        # Test with no batch dimension
2276*da0073e9SAndroid Build Coastguard Worker        helper((8, 4, 4), ks=2)
2277*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 4), ks=2)
2278*da0073e9SAndroid Build Coastguard Worker        helper((1, 1000, 32, 32), ks=4)
2279*da0073e9SAndroid Build Coastguard Worker        helper((1, 1000, 1, 4), ks=(1, 4))  # test for max_pool1d
2280*da0073e9SAndroid Build Coastguard Worker        # Test padding
2281*da0073e9SAndroid Build Coastguard Worker        helper((1, 1000, 32, 32), ks=4, padding=1)
2282*da0073e9SAndroid Build Coastguard Worker        helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1))  # test for max_pool1d
2283*da0073e9SAndroid Build Coastguard Worker        # Test dilation
2284*da0073e9SAndroid Build Coastguard Worker        helper((1, 1000, 32, 32), ks=4, dilation=2)
2285*da0073e9SAndroid Build Coastguard Worker        helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2))  # test for max_pool1d
2286*da0073e9SAndroid Build Coastguard Worker        # Test ceil mode
2287*da0073e9SAndroid Build Coastguard Worker        helper((1, 1000, 32, 32), ks=4, ceil_mode=True)
2288*da0073e9SAndroid Build Coastguard Worker        helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True)  # test for max_pool1d
2289*da0073e9SAndroid Build Coastguard Worker
2290*da0073e9SAndroid Build Coastguard Worker        # Test return indices
2291*da0073e9SAndroid Build Coastguard Worker        for test_ties in [False, True]:
2292*da0073e9SAndroid Build Coastguard Worker            # Test with no batch dimension
2293*da0073e9SAndroid Build Coastguard Worker            helper((8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
2294*da0073e9SAndroid Build Coastguard Worker            helper((2, 8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
2295*da0073e9SAndroid Build Coastguard Worker            helper((1, 1000, 32, 32), ks=4, return_indices=True, test_ties=test_ties)
2296*da0073e9SAndroid Build Coastguard Worker            helper((1, 1000, 1, 4), ks=(1, 4), return_indices=True, test_ties=test_ties)  # test for max_pool1d
2297*da0073e9SAndroid Build Coastguard Worker            # Test padding
2298*da0073e9SAndroid Build Coastguard Worker            helper((1, 1000, 32, 32), ks=4, padding=1, return_indices=True, test_ties=test_ties)
2299*da0073e9SAndroid Build Coastguard Worker            helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1),
2300*da0073e9SAndroid Build Coastguard Worker                   return_indices=True, test_ties=test_ties)  # test for max_pool1d
2301*da0073e9SAndroid Build Coastguard Worker            # Test dilation
2302*da0073e9SAndroid Build Coastguard Worker            helper((1, 1000, 32, 32), ks=4, dilation=2, return_indices=True, test_ties=test_ties)
2303*da0073e9SAndroid Build Coastguard Worker            helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2),
2304*da0073e9SAndroid Build Coastguard Worker                   return_indices=True, test_ties=test_ties)  # test for max_pool1d
2305*da0073e9SAndroid Build Coastguard Worker            # Test ceil mode
2306*da0073e9SAndroid Build Coastguard Worker            helper((1, 1000, 32, 32), ks=4, ceil_mode=True, return_indices=True, test_ties=test_ties)
2307*da0073e9SAndroid Build Coastguard Worker            helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True,
2308*da0073e9SAndroid Build Coastguard Worker                   return_indices=True, test_ties=test_ties)  # test for max_pool1d
2309*da0073e9SAndroid Build Coastguard Worker
2310*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_avg_pool2d_output_size_one(self):
2311*da0073e9SAndroid Build Coastguard Worker        def helper(size, memory_format):
2312*da0073e9SAndroid Build Coastguard Worker            x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True)
2313*da0073e9SAndroid Build Coastguard Worker            if memory_format == 'non_contiguous':
2314*da0073e9SAndroid Build Coastguard Worker                x = x[::2, ::2, ::2, ::2]
2315*da0073e9SAndroid Build Coastguard Worker            else:
2316*da0073e9SAndroid Build Coastguard Worker                x = x.to(memory_format=memory_format)
2317*da0073e9SAndroid Build Coastguard Worker
2318*da0073e9SAndroid Build Coastguard Worker            net = torch.nn.AdaptiveAvgPool2d((1, 1))
2319*da0073e9SAndroid Build Coastguard Worker            out = net(x)
2320*da0073e9SAndroid Build Coastguard Worker            ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
2321*da0073e9SAndroid Build Coastguard Worker
2322*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()    # make sure it doesn't crash
2323*da0073e9SAndroid Build Coastguard Worker
2324*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
2325*da0073e9SAndroid Build Coastguard Worker            if memory_format == torch.channels_last:
2326*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
2327*da0073e9SAndroid Build Coastguard Worker                c = out.size(1)
2328*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.stride(), [c, 1, c, c])
2329*da0073e9SAndroid Build Coastguard Worker            else:
2330*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out.is_contiguous())
2331*da0073e9SAndroid Build Coastguard Worker                c = out.size(1)
2332*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.stride(), [c, 1, 1, 1])
2333*da0073e9SAndroid Build Coastguard Worker
2334*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 6, 6), torch.contiguous_format)
2335*da0073e9SAndroid Build Coastguard Worker
2336*da0073e9SAndroid Build Coastguard Worker    def test_masked_scatter(self):
2337*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
2338*da0073e9SAndroid Build Coastguard Worker            x_mps = torch.randn(shape, device="mps")
2339*da0073e9SAndroid Build Coastguard Worker            x_cpu = x_mps.detach().clone().cpu()
2340*da0073e9SAndroid Build Coastguard Worker
2341*da0073e9SAndroid Build Coastguard Worker            mask_mps = torch.rand(shape, device="mps") < 0.6
2342*da0073e9SAndroid Build Coastguard Worker            mask_cpu = mask_mps.detach().clone().cpu()
2343*da0073e9SAndroid Build Coastguard Worker
2344*da0073e9SAndroid Build Coastguard Worker            y_mps = torch.randn(shape, device="mps")
2345*da0073e9SAndroid Build Coastguard Worker            y_cpu = y_mps.detach().clone().cpu()
2346*da0073e9SAndroid Build Coastguard Worker
2347*da0073e9SAndroid Build Coastguard Worker            y_mps.masked_scatter_(mask_mps, x_mps)
2348*da0073e9SAndroid Build Coastguard Worker            y_cpu.masked_scatter_(mask_cpu, x_cpu)
2349*da0073e9SAndroid Build Coastguard Worker
2350*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_mps, y_cpu)
2351*da0073e9SAndroid Build Coastguard Worker        helper([2, 5])
2352*da0073e9SAndroid Build Coastguard Worker        helper([10, 10])
2353*da0073e9SAndroid Build Coastguard Worker        helper([5, 10, 3])
2354*da0073e9SAndroid Build Coastguard Worker        helper([10, 5, 10, 3])
2355*da0073e9SAndroid Build Coastguard Worker        helper([10, 5, 10, 3, 20])
2356*da0073e9SAndroid Build Coastguard Worker
2357*da0073e9SAndroid Build Coastguard Worker    def test_masked_fill(self):
2358*da0073e9SAndroid Build Coastguard Worker        device = "mps"
2359*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float32
2360*da0073e9SAndroid Build Coastguard Worker        mask_dtype = torch.bool
2361*da0073e9SAndroid Build Coastguard Worker        num_dest = 10
2362*da0073e9SAndroid Build Coastguard Worker
2363*da0073e9SAndroid Build Coastguard Worker        dst = torch.zeros(num_dest, dtype=dtype, device=device)
2364*da0073e9SAndroid Build Coastguard Worker        mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
2365*da0073e9SAndroid Build Coastguard Worker        val = random.random()
2366*da0073e9SAndroid Build Coastguard Worker        dst2 = torch.zeros(num_dest, dtype=dtype)
2367*da0073e9SAndroid Build Coastguard Worker        mask_cpu = mask.to("cpu")
2368*da0073e9SAndroid Build Coastguard Worker
2369*da0073e9SAndroid Build Coastguard Worker        dst.masked_fill_(mask, val)
2370*da0073e9SAndroid Build Coastguard Worker        for i in range(num_dest):
2371*da0073e9SAndroid Build Coastguard Worker            if mask_cpu[i]:
2372*da0073e9SAndroid Build Coastguard Worker                dst2[i] = val
2373*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
2374*da0073e9SAndroid Build Coastguard Worker
2375*da0073e9SAndroid Build Coastguard Worker    def test_masked_fill__non_contiguous(self):
2376*da0073e9SAndroid Build Coastguard Worker        shape = (3, 5)
2377*da0073e9SAndroid Build Coastguard Worker
2378*da0073e9SAndroid Build Coastguard Worker        x_mps = torch.randn(shape, device="mps")
2379*da0073e9SAndroid Build Coastguard Worker        x_cpu = x_mps.detach().clone().cpu()
2380*da0073e9SAndroid Build Coastguard Worker        mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool)
2381*da0073e9SAndroid Build Coastguard Worker        mask_cpu = mask_mps.detach().clone().cpu()
2382*da0073e9SAndroid Build Coastguard Worker
2383*da0073e9SAndroid Build Coastguard Worker        x_mps_strided = x_mps.T
2384*da0073e9SAndroid Build Coastguard Worker        x_cpu_strided = x_cpu.T
2385*da0073e9SAndroid Build Coastguard Worker
2386*da0073e9SAndroid Build Coastguard Worker        x_mps_strided.masked_fill_(mask_mps.T, float("-inf"))
2387*da0073e9SAndroid Build Coastguard Worker        x_cpu_strided.masked_fill_(mask_cpu.T, float("-inf"))
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_mps_strided, x_cpu_strided)
2390*da0073e9SAndroid Build Coastguard Worker        self.assertFalse((x_mps_strided == float("-inf")).any())
2391*da0073e9SAndroid Build Coastguard Worker
2392*da0073e9SAndroid Build Coastguard Worker    def test_nhwc_operation(self):
2393*da0073e9SAndroid Build Coastguard Worker        def helper(shape, channels_last=False):
2394*da0073e9SAndroid Build Coastguard Worker            import numpy as np
2395*da0073e9SAndroid Build Coastguard Worker            np.random.seed(332)
2396*da0073e9SAndroid Build Coastguard Worker            arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2397*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
2398*da0073e9SAndroid Build Coastguard Worker            if (channels_last):
2399*da0073e9SAndroid Build Coastguard Worker                cpu_x = cpu_x.to(memory_format=torch.channels_last)
2400*da0073e9SAndroid Build Coastguard Worker                cpu_x.retain_grad()
2401*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
2402*da0073e9SAndroid Build Coastguard Worker
2403*da0073e9SAndroid Build Coastguard Worker            # This passes
2404*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, cpu_x)
2405*da0073e9SAndroid Build Coastguard Worker
2406*da0073e9SAndroid Build Coastguard Worker        helper((2, 2, 2, 2), True)
2407*da0073e9SAndroid Build Coastguard Worker
2408*da0073e9SAndroid Build Coastguard Worker    # Test forward batch norm
2409*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm(self):
2410*da0073e9SAndroid Build Coastguard Worker        def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last=False,
2411*da0073e9SAndroid Build Coastguard Worker                   track_running_stats=True, test_module=False):
2412*da0073e9SAndroid Build Coastguard Worker
2413*da0073e9SAndroid Build Coastguard Worker            import numpy as np
2414*da0073e9SAndroid Build Coastguard Worker            np.random.seed(332)
2415*da0073e9SAndroid Build Coastguard Worker            arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2416*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
2417*da0073e9SAndroid Build Coastguard Worker            if (channels_last):
2418*da0073e9SAndroid Build Coastguard Worker                cpu_x = cpu_x.to(memory_format=torch.channels_last)
2419*da0073e9SAndroid Build Coastguard Worker                cpu_x.retain_grad()
2420*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
2421*da0073e9SAndroid Build Coastguard Worker
2422*da0073e9SAndroid Build Coastguard Worker            mean_shape = [shape[1]]
2423*da0073e9SAndroid Build Coastguard Worker            cpu_running_mean = None
2424*da0073e9SAndroid Build Coastguard Worker            cpu_running_var = None
2425*da0073e9SAndroid Build Coastguard Worker            running_mean = None
2426*da0073e9SAndroid Build Coastguard Worker            running_var = None
2427*da0073e9SAndroid Build Coastguard Worker            if (track_running_stats):
2428*da0073e9SAndroid Build Coastguard Worker                mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
2429*da0073e9SAndroid Build Coastguard Worker                cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
2430*da0073e9SAndroid Build Coastguard Worker                var_arr = 32 * np.random.random_sample(size=mean_shape)
2431*da0073e9SAndroid Build Coastguard Worker                cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
2432*da0073e9SAndroid Build Coastguard Worker                running_mean = cpu_running_mean.detach().clone().to('mps')
2433*da0073e9SAndroid Build Coastguard Worker                running_var = cpu_running_var.detach().clone().to('mps')
2434*da0073e9SAndroid Build Coastguard Worker
2435*da0073e9SAndroid Build Coastguard Worker            weight = None
2436*da0073e9SAndroid Build Coastguard Worker            cpu_weight = None
2437*da0073e9SAndroid Build Coastguard Worker            bias = None
2438*da0073e9SAndroid Build Coastguard Worker            cpu_bias = None
2439*da0073e9SAndroid Build Coastguard Worker            if (wts):
2440*da0073e9SAndroid Build Coastguard Worker                cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2441*da0073e9SAndroid Build Coastguard Worker                weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2442*da0073e9SAndroid Build Coastguard Worker                cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2443*da0073e9SAndroid Build Coastguard Worker                bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2444*da0073e9SAndroid Build Coastguard Worker
2445*da0073e9SAndroid Build Coastguard Worker            y = None
2446*da0073e9SAndroid Build Coastguard Worker            ref_y = None
2447*da0073e9SAndroid Build Coastguard Worker
2448*da0073e9SAndroid Build Coastguard Worker            if (not test_module):
2449*da0073e9SAndroid Build Coastguard Worker                y = torch.nn.functional.batch_norm(x, running_mean, running_var,
2450*da0073e9SAndroid Build Coastguard Worker                                                   weight=weight,
2451*da0073e9SAndroid Build Coastguard Worker                                                   bias=bias,
2452*da0073e9SAndroid Build Coastguard Worker                                                   training=training,
2453*da0073e9SAndroid Build Coastguard Worker                                                   momentum=momentum, eps=eps)
2454*da0073e9SAndroid Build Coastguard Worker                ref_y = torch.nn.functional.batch_norm(cpu_x, cpu_running_mean, cpu_running_var,
2455*da0073e9SAndroid Build Coastguard Worker                                                       weight=cpu_weight,
2456*da0073e9SAndroid Build Coastguard Worker                                                       bias=cpu_bias,
2457*da0073e9SAndroid Build Coastguard Worker                                                       training=training,
2458*da0073e9SAndroid Build Coastguard Worker                                                       momentum=momentum, eps=eps)
2459*da0073e9SAndroid Build Coastguard Worker
2460*da0073e9SAndroid Build Coastguard Worker            else:
2461*da0073e9SAndroid Build Coastguard Worker
2462*da0073e9SAndroid Build Coastguard Worker                batchnorm_op = None
2463*da0073e9SAndroid Build Coastguard Worker                mps_batchnorm_op = None
2464*da0073e9SAndroid Build Coastguard Worker
2465*da0073e9SAndroid Build Coastguard Worker                if (len(shape) == 3):
2466*da0073e9SAndroid Build Coastguard Worker                    batchnorm_op = torch.nn.BatchNorm1d(shape[1],
2467*da0073e9SAndroid Build Coastguard Worker                                                        eps=eps,
2468*da0073e9SAndroid Build Coastguard Worker                                                        momentum=momentum,
2469*da0073e9SAndroid Build Coastguard Worker                                                        affine=wts,
2470*da0073e9SAndroid Build Coastguard Worker                                                        track_running_stats=track_running_stats,
2471*da0073e9SAndroid Build Coastguard Worker                                                        device='cpu')
2472*da0073e9SAndroid Build Coastguard Worker                    mps_batchnorm_op = torch.nn.BatchNorm1d(shape[1],
2473*da0073e9SAndroid Build Coastguard Worker                                                            eps=eps,
2474*da0073e9SAndroid Build Coastguard Worker                                                            momentum=momentum,
2475*da0073e9SAndroid Build Coastguard Worker                                                            affine=wts,
2476*da0073e9SAndroid Build Coastguard Worker                                                            track_running_stats=track_running_stats,
2477*da0073e9SAndroid Build Coastguard Worker                                                            device='mps')
2478*da0073e9SAndroid Build Coastguard Worker                elif (len(shape) == 4):
2479*da0073e9SAndroid Build Coastguard Worker                    batchnorm_op = torch.nn.BatchNorm2d(shape[1],
2480*da0073e9SAndroid Build Coastguard Worker                                                        eps=eps,
2481*da0073e9SAndroid Build Coastguard Worker                                                        momentum=momentum,
2482*da0073e9SAndroid Build Coastguard Worker                                                        affine=wts,
2483*da0073e9SAndroid Build Coastguard Worker                                                        track_running_stats=track_running_stats,
2484*da0073e9SAndroid Build Coastguard Worker                                                        device='cpu')
2485*da0073e9SAndroid Build Coastguard Worker                    mps_batchnorm_op = torch.nn.BatchNorm2d(shape[1],
2486*da0073e9SAndroid Build Coastguard Worker                                                            eps=eps,
2487*da0073e9SAndroid Build Coastguard Worker                                                            momentum=momentum,
2488*da0073e9SAndroid Build Coastguard Worker                                                            affine=wts,
2489*da0073e9SAndroid Build Coastguard Worker                                                            track_running_stats=track_running_stats,
2490*da0073e9SAndroid Build Coastguard Worker                                                            device='mps')
2491*da0073e9SAndroid Build Coastguard Worker                elif (len(shape) == 5):
2492*da0073e9SAndroid Build Coastguard Worker                    batchnorm_op = torch.nn.BatchNorm3d(shape[1],
2493*da0073e9SAndroid Build Coastguard Worker                                                        eps=eps,
2494*da0073e9SAndroid Build Coastguard Worker                                                        momentum=momentum,
2495*da0073e9SAndroid Build Coastguard Worker                                                        affine=wts,
2496*da0073e9SAndroid Build Coastguard Worker                                                        track_running_stats=track_running_stats,
2497*da0073e9SAndroid Build Coastguard Worker                                                        device='cpu')
2498*da0073e9SAndroid Build Coastguard Worker                    mps_batchnorm_op = torch.nn.BatchNorm3d(shape[1],
2499*da0073e9SAndroid Build Coastguard Worker                                                            eps=eps,
2500*da0073e9SAndroid Build Coastguard Worker                                                            momentum=momentum,
2501*da0073e9SAndroid Build Coastguard Worker                                                            affine=wts,
2502*da0073e9SAndroid Build Coastguard Worker                                                            track_running_stats=track_running_stats,
2503*da0073e9SAndroid Build Coastguard Worker                                                            device='mps')
2504*da0073e9SAndroid Build Coastguard Worker
2505*da0073e9SAndroid Build Coastguard Worker                if (track_running_stats):
2506*da0073e9SAndroid Build Coastguard Worker                    batchnorm_op.running_mean = cpu_running_mean
2507*da0073e9SAndroid Build Coastguard Worker                    batchnorm_op.running_var = cpu_running_var
2508*da0073e9SAndroid Build Coastguard Worker                    mps_batchnorm_op.running_mean = running_mean
2509*da0073e9SAndroid Build Coastguard Worker                    mps_batchnorm_op.running_var = running_var
2510*da0073e9SAndroid Build Coastguard Worker                if (wts):
2511*da0073e9SAndroid Build Coastguard Worker                    batchnorm_op.weight = torch.nn.Parameter(cpu_weight)
2512*da0073e9SAndroid Build Coastguard Worker                    batchnorm_op.bias = torch.nn.Parameter(cpu_bias)
2513*da0073e9SAndroid Build Coastguard Worker                    mps_batchnorm_op.weight = torch.nn.Parameter(weight)
2514*da0073e9SAndroid Build Coastguard Worker                    mps_batchnorm_op.bias = torch.nn.Parameter(bias)
2515*da0073e9SAndroid Build Coastguard Worker
2516*da0073e9SAndroid Build Coastguard Worker                ref_y = batchnorm_op(cpu_x)
2517*da0073e9SAndroid Build Coastguard Worker                y = mps_batchnorm_op(x)
2518*da0073e9SAndroid Build Coastguard Worker
2519*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y)
2520*da0073e9SAndroid Build Coastguard Worker            if (not test_module):
2521*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(running_mean, cpu_running_mean)
2522*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(running_var, cpu_running_var)
2523*da0073e9SAndroid Build Coastguard Worker            else:
2524*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mps_batchnorm_op.running_mean, batchnorm_op.running_mean)
2525*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mps_batchnorm_op.running_var, batchnorm_op.running_var)
2526*da0073e9SAndroid Build Coastguard Worker
2527*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(ref_y.shape)
2528*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
2529*da0073e9SAndroid Build Coastguard Worker            ref_y.backward(gradient=cpu_grad)
2530*da0073e9SAndroid Build Coastguard Worker            y.backward(gradient=grad)
2531*da0073e9SAndroid Build Coastguard Worker
2532*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
2533*da0073e9SAndroid Build Coastguard Worker            if (wts):
2534*da0073e9SAndroid Build Coastguard Worker                if (not test_module):
2535*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(weight.grad, cpu_weight.grad)
2536*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(bias.grad, cpu_bias.grad)
2537*da0073e9SAndroid Build Coastguard Worker                else:
2538*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mps_batchnorm_op.weight.grad, batchnorm_op.weight.grad)
2539*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mps_batchnorm_op.bias.grad, batchnorm_op.bias.grad)
2540*da0073e9SAndroid Build Coastguard Worker
2541*da0073e9SAndroid Build Coastguard Worker        for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
2542*da0073e9SAndroid Build Coastguard Worker            for test_module in [False, True]:
2543*da0073e9SAndroid Build Coastguard Worker                for track_running_stats in [True, False]:
2544*da0073e9SAndroid Build Coastguard Worker                    for channels_last in [False]:
2545*da0073e9SAndroid Build Coastguard Worker                        if (channels_last and len(shape) != 4):
2546*da0073e9SAndroid Build Coastguard Worker                            continue
2547*da0073e9SAndroid Build Coastguard Worker                        # Running stats must be tracked in eval mode
2548*da0073e9SAndroid Build Coastguard Worker                        if (track_running_stats):
2549*da0073e9SAndroid Build Coastguard Worker                            helper(shape, eps=0, momentum=1, channels_last=channels_last,
2550*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2551*da0073e9SAndroid Build Coastguard Worker                            helper(shape, channels_last=channels_last,
2552*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2553*da0073e9SAndroid Build Coastguard Worker                            helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last,
2554*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2555*da0073e9SAndroid Build Coastguard Worker                            helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last,
2556*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2557*da0073e9SAndroid Build Coastguard Worker                            helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last,
2558*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2559*da0073e9SAndroid Build Coastguard Worker                            helper(shape, eps=3, momentum=0.67, wts=True, training=False, channels_last=channels_last,
2560*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2561*da0073e9SAndroid Build Coastguard Worker                        helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last,
2562*da0073e9SAndroid Build Coastguard Worker                               track_running_stats=track_running_stats, test_module=test_module)
2563*da0073e9SAndroid Build Coastguard Worker                        helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last,
2564*da0073e9SAndroid Build Coastguard Worker                               track_running_stats=track_running_stats, test_module=test_module)
2565*da0073e9SAndroid Build Coastguard Worker                        helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last,
2566*da0073e9SAndroid Build Coastguard Worker                               track_running_stats=track_running_stats, test_module=test_module)
2567*da0073e9SAndroid Build Coastguard Worker                        helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
2568*da0073e9SAndroid Build Coastguard Worker                               track_running_stats=track_running_stats, test_module=test_module)
2569*da0073e9SAndroid Build Coastguard Worker
2570*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_backward(self):
2571*da0073e9SAndroid Build Coastguard Worker        inputs = torch.rand(1, 8, 4, 4, device="mps", requires_grad=True)
2572*da0073e9SAndroid Build Coastguard Worker        x = torch.nn.BatchNorm2d(8).to("mps")
2573*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.BatchNorm2d(8).to("mps")
2574*da0073e9SAndroid Build Coastguard Worker        y.weight.requires_grad = False
2575*da0073e9SAndroid Build Coastguard Worker        y.bias.requires_grad = False
2576*da0073e9SAndroid Build Coastguard Worker        outputs = y(x(inputs))
2577*da0073e9SAndroid Build Coastguard Worker        # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
2578*da0073e9SAndroid Build Coastguard Worker        outputs.sum().backward()
2579*da0073e9SAndroid Build Coastguard Worker
2580*da0073e9SAndroid Build Coastguard Worker    # Regression test for https://github.com/pytorch/pytorch/issues/133520
2581*da0073e9SAndroid Build Coastguard Worker    def test_batch_norm_slices(self):
2582*da0073e9SAndroid Build Coastguard Worker        bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
2583*da0073e9SAndroid Build Coastguard Worker        bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
2584*da0073e9SAndroid Build Coastguard Worker
2585*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
2586*da0073e9SAndroid Build Coastguard Worker        x_mps = x_cpu.to('mps')
2587*da0073e9SAndroid Build Coastguard Worker
2588*da0073e9SAndroid Build Coastguard Worker        res_cpu = bn_cpu(x_cpu[5:])
2589*da0073e9SAndroid Build Coastguard Worker        res_mps = bn_mps(x_mps[5:])
2590*da0073e9SAndroid Build Coastguard Worker
2591*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_cpu, res_mps)
2592*da0073e9SAndroid Build Coastguard Worker
2593*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm_backward(self):
2594*da0073e9SAndroid Build Coastguard Worker        inputs = torch.rand(4, 4, device="mps", requires_grad=True)
2595*da0073e9SAndroid Build Coastguard Worker        x = torch.nn.LayerNorm(4).to("mps")
2596*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.LayerNorm(4).to("mps")
2597*da0073e9SAndroid Build Coastguard Worker        y.weight.requires_grad = False
2598*da0073e9SAndroid Build Coastguard Worker        y.bias.requires_grad = False
2599*da0073e9SAndroid Build Coastguard Worker        outputs = y(x(inputs))
2600*da0073e9SAndroid Build Coastguard Worker        # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
2601*da0073e9SAndroid Build Coastguard Worker        outputs.sum().backward()
2602*da0073e9SAndroid Build Coastguard Worker
2603*da0073e9SAndroid Build Coastguard Worker    def test_norm(self):
2604*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(9, dtype=torch.float, device="mps") - 4
2605*da0073e9SAndroid Build Coastguard Worker        b = a.reshape((3, 3))
2606*da0073e9SAndroid Build Coastguard Worker
2607*da0073e9SAndroid Build Coastguard Worker        a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4
2608*da0073e9SAndroid Build Coastguard Worker        b_cpu = a_cpu.reshape((3, 3))
2609*da0073e9SAndroid Build Coastguard Worker
2610*da0073e9SAndroid Build Coastguard Worker        res = torch.norm(a)
2611*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.norm(a_cpu)
2612*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_cpu)
2613*da0073e9SAndroid Build Coastguard Worker
2614*da0073e9SAndroid Build Coastguard Worker        res = torch.norm(b)
2615*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.norm(b_cpu)
2616*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_cpu)
2617*da0073e9SAndroid Build Coastguard Worker
2618*da0073e9SAndroid Build Coastguard Worker        res = torch.norm(a, float('inf'))
2619*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.norm(a_cpu, float('inf'))
2620*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_cpu)
2621*da0073e9SAndroid Build Coastguard Worker
2622*da0073e9SAndroid Build Coastguard Worker        res = torch.norm(b, float('inf'))
2623*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.norm(b_cpu, float('inf'))
2624*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_cpu)
2625*da0073e9SAndroid Build Coastguard Worker
2626*da0073e9SAndroid Build Coastguard Worker        c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
2627*da0073e9SAndroid Build Coastguard Worker        c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu")
2628*da0073e9SAndroid Build Coastguard Worker
2629*da0073e9SAndroid Build Coastguard Worker        res = torch.norm(c, dim=0)
2630*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.norm(c_cpu, dim=0)
2631*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_cpu)
2632*da0073e9SAndroid Build Coastguard Worker
2633*da0073e9SAndroid Build Coastguard Worker        res = torch.norm(c, dim=1)
2634*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.norm(c_cpu, dim=1)
2635*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_cpu)
2636*da0073e9SAndroid Build Coastguard Worker
2637*da0073e9SAndroid Build Coastguard Worker        res = torch.norm(c, p=1, dim=1)
2638*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.norm(c_cpu, p=1, dim=1)
2639*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_cpu)
2640*da0073e9SAndroid Build Coastguard Worker
2641*da0073e9SAndroid Build Coastguard Worker        d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
2642*da0073e9SAndroid Build Coastguard Worker        d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2)
2643*da0073e9SAndroid Build Coastguard Worker
2644*da0073e9SAndroid Build Coastguard Worker        res = torch.norm(d, dim=(1, 2))
2645*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.norm(d_cpu, dim=(1, 2))
2646*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_cpu)
2647*da0073e9SAndroid Build Coastguard Worker
2648*da0073e9SAndroid Build Coastguard Worker        res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
2649*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :])
2650*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_cpu)
2651*da0073e9SAndroid Build Coastguard Worker
2652*da0073e9SAndroid Build Coastguard Worker    def test_linalg_vector_norm(self):
2653*da0073e9SAndroid Build Coastguard Worker        x_mps = torch.tensor([0, 0, 0, 2, 3], dtype=torch.float, device="mps")
2654*da0073e9SAndroid Build Coastguard Worker        x_cpu = x_mps.detach().clone().cpu()
2655*da0073e9SAndroid Build Coastguard Worker
2656*da0073e9SAndroid Build Coastguard Worker        res_mps = torch.linalg.vector_norm(x_mps, ord=0)
2657*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.linalg.vector_norm(x_cpu, ord=0)
2658*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_mps, res_cpu)
2659*da0073e9SAndroid Build Coastguard Worker
2660*da0073e9SAndroid Build Coastguard Worker        a_mps = torch.arange(27, dtype=torch.float, device="mps") - 4
2661*da0073e9SAndroid Build Coastguard Worker        a_cpu = torch.arange(27, dtype=torch.float, device="cpu") - 4
2662*da0073e9SAndroid Build Coastguard Worker
2663*da0073e9SAndroid Build Coastguard Worker        B_mps = a_mps.reshape(3, 3, 3)
2664*da0073e9SAndroid Build Coastguard Worker        B_cpu = a_cpu.reshape(3, 3, 3)
2665*da0073e9SAndroid Build Coastguard Worker
2666*da0073e9SAndroid Build Coastguard Worker        res_mps = torch.linalg.vector_norm(a_mps, ord=3.5)
2667*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.linalg.vector_norm(a_cpu, ord=3.5)
2668*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_mps, res_cpu)
2669*da0073e9SAndroid Build Coastguard Worker
2670*da0073e9SAndroid Build Coastguard Worker        res_mps = torch.linalg.vector_norm(B_mps, ord=3.5)
2671*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5)
2672*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_mps, res_cpu)
2673*da0073e9SAndroid Build Coastguard Worker
2674*da0073e9SAndroid Build Coastguard Worker        for dim in range(0, B_mps.dim()):
2675*da0073e9SAndroid Build Coastguard Worker            res_mps = torch.linalg.vector_norm(B_mps, ord=3.5, dim=dim)
2676*da0073e9SAndroid Build Coastguard Worker            res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5, dim=dim)
2677*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_mps, res_cpu)
2678*da0073e9SAndroid Build Coastguard Worker
2679*da0073e9SAndroid Build Coastguard Worker
2680*da0073e9SAndroid Build Coastguard Worker    def test_layer_norm(self):
2681*da0073e9SAndroid Build Coastguard Worker        # TODO: Test non-contiguous
2682*da0073e9SAndroid Build Coastguard Worker        def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32):
2683*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
2684*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
2685*da0073e9SAndroid Build Coastguard Worker
2686*da0073e9SAndroid Build Coastguard Worker            cpu_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='cpu', dtype=dtype)
2687*da0073e9SAndroid Build Coastguard Worker            mps_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='mps', dtype=dtype)
2688*da0073e9SAndroid Build Coastguard Worker            cpu_wt = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
2689*da0073e9SAndroid Build Coastguard Worker            wt = cpu_wt.detach().clone().to('mps').requires_grad_()
2690*da0073e9SAndroid Build Coastguard Worker            cpu_bias = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
2691*da0073e9SAndroid Build Coastguard Worker            bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2692*da0073e9SAndroid Build Coastguard Worker
2693*da0073e9SAndroid Build Coastguard Worker            if (elementwise_affine):
2694*da0073e9SAndroid Build Coastguard Worker                cpu_op.weight = torch.nn.Parameter(cpu_wt)
2695*da0073e9SAndroid Build Coastguard Worker                mps_op.weight = torch.nn.Parameter(wt)
2696*da0073e9SAndroid Build Coastguard Worker                cpu_op.bias = torch.nn.Parameter(cpu_bias)
2697*da0073e9SAndroid Build Coastguard Worker                mps_op.bias = torch.nn.Parameter(bias)
2698*da0073e9SAndroid Build Coastguard Worker
2699*da0073e9SAndroid Build Coastguard Worker            cpu_result = cpu_op(cpu_x)
2700*da0073e9SAndroid Build Coastguard Worker            result = mps_op(x)
2701*da0073e9SAndroid Build Coastguard Worker
2702*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(cpu_result.shape)
2703*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
2704*da0073e9SAndroid Build Coastguard Worker
2705*da0073e9SAndroid Build Coastguard Worker            cpu_result.backward(cpu_grad)
2706*da0073e9SAndroid Build Coastguard Worker            result.backward(grad)
2707*da0073e9SAndroid Build Coastguard Worker
2708*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, cpu_result)
2709*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
2710*da0073e9SAndroid Build Coastguard Worker            if (elementwise_affine):
2711*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mps_op.weight.grad, cpu_op.weight.grad)
2712*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mps_op.bias.grad, cpu_op.bias.grad)
2713*da0073e9SAndroid Build Coastguard Worker
2714*da0073e9SAndroid Build Coastguard Worker        for elementwise_affine in [True, False]:
2715*da0073e9SAndroid Build Coastguard Worker            helper((2, 2, 2, 2), (2, 2), elementwise_affine=elementwise_affine)
2716*da0073e9SAndroid Build Coastguard Worker            helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine)
2717*da0073e9SAndroid Build Coastguard Worker            helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine)
2718*da0073e9SAndroid Build Coastguard Worker
2719*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/96113
2720*da0073e9SAndroid Build Coastguard Worker        torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16))
2721*da0073e9SAndroid Build Coastguard Worker
2722*da0073e9SAndroid Build Coastguard Worker    @xfailIf(product_version < 14.0)
2723*da0073e9SAndroid Build Coastguard Worker    def test_ifft(self):
2724*da0073e9SAndroid Build Coastguard Worker        # See: https://github.com/pytorch/pytorch/issues/124096
2725*da0073e9SAndroid Build Coastguard Worker        device = torch.device("mps")
2726*da0073e9SAndroid Build Coastguard Worker
2727*da0073e9SAndroid Build Coastguard Worker        N = 64
2728*da0073e9SAndroid Build Coastguard Worker        signal = torch.rand(N, device=device)
2729*da0073e9SAndroid Build Coastguard Worker        fft_result = torch.fft.rfft(signal)
2730*da0073e9SAndroid Build Coastguard Worker        ifft_result = torch.fft.irfft(fft_result, n=signal.shape[0])
2731*da0073e9SAndroid Build Coastguard Worker
2732*da0073e9SAndroid Build Coastguard Worker        # Expecting the inverted to yield the original signal
2733*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ifft_result, signal)
2734*da0073e9SAndroid Build Coastguard Worker
2735*da0073e9SAndroid Build Coastguard Worker    # Regression test for https://github.com/pytorch/pytorch/issues/135223
2736*da0073e9SAndroid Build Coastguard Worker    def test_fftfreq(self):
2737*da0073e9SAndroid Build Coastguard Worker        freq_cpu = torch.fft.fftfreq(10**4, device='cpu')
2738*da0073e9SAndroid Build Coastguard Worker        freq_mps = torch.fft.fftfreq(10**4, device='mps')
2739*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(freq_cpu, freq_mps)
2740*da0073e9SAndroid Build Coastguard Worker
2741*da0073e9SAndroid Build Coastguard Worker    def test_instance_norm(self):
2742*da0073e9SAndroid Build Coastguard Worker        def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):
2743*da0073e9SAndroid Build Coastguard Worker
2744*da0073e9SAndroid Build Coastguard Worker            import numpy as np
2745*da0073e9SAndroid Build Coastguard Worker            np.random.seed(332)
2746*da0073e9SAndroid Build Coastguard Worker            arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2747*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
2748*da0073e9SAndroid Build Coastguard Worker            if (channels_last):
2749*da0073e9SAndroid Build Coastguard Worker                cpu_x = cpu_x.to(memory_format=torch.channels_last)
2750*da0073e9SAndroid Build Coastguard Worker                cpu_x.retain_grad()
2751*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
2752*da0073e9SAndroid Build Coastguard Worker
2753*da0073e9SAndroid Build Coastguard Worker            mean_shape = [shape[1]]
2754*da0073e9SAndroid Build Coastguard Worker            cpu_running_mean = None
2755*da0073e9SAndroid Build Coastguard Worker            cpu_running_var = None
2756*da0073e9SAndroid Build Coastguard Worker            running_mean = None
2757*da0073e9SAndroid Build Coastguard Worker            running_var = None
2758*da0073e9SAndroid Build Coastguard Worker            if (track_running_stats):
2759*da0073e9SAndroid Build Coastguard Worker                mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
2760*da0073e9SAndroid Build Coastguard Worker                cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
2761*da0073e9SAndroid Build Coastguard Worker                var_arr = 32 * np.random.random_sample(size=mean_shape)
2762*da0073e9SAndroid Build Coastguard Worker                cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
2763*da0073e9SAndroid Build Coastguard Worker                running_mean = cpu_running_mean.detach().clone().to('mps')
2764*da0073e9SAndroid Build Coastguard Worker                running_var = cpu_running_var.detach().clone().to('mps')
2765*da0073e9SAndroid Build Coastguard Worker
2766*da0073e9SAndroid Build Coastguard Worker            weight = None
2767*da0073e9SAndroid Build Coastguard Worker            cpu_weight = None
2768*da0073e9SAndroid Build Coastguard Worker            bias = None
2769*da0073e9SAndroid Build Coastguard Worker            cpu_bias = None
2770*da0073e9SAndroid Build Coastguard Worker            if (wts):
2771*da0073e9SAndroid Build Coastguard Worker                cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2772*da0073e9SAndroid Build Coastguard Worker                weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2773*da0073e9SAndroid Build Coastguard Worker                cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2774*da0073e9SAndroid Build Coastguard Worker                bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2775*da0073e9SAndroid Build Coastguard Worker
2776*da0073e9SAndroid Build Coastguard Worker            y = None
2777*da0073e9SAndroid Build Coastguard Worker            ref_y = None
2778*da0073e9SAndroid Build Coastguard Worker
2779*da0073e9SAndroid Build Coastguard Worker            if (not test_module):
2780*da0073e9SAndroid Build Coastguard Worker                ref_y = torch.nn.functional.instance_norm(cpu_x, cpu_running_mean, cpu_running_var,
2781*da0073e9SAndroid Build Coastguard Worker                                                          weight=cpu_weight,
2782*da0073e9SAndroid Build Coastguard Worker                                                          bias=cpu_bias,
2783*da0073e9SAndroid Build Coastguard Worker                                                          momentum=momentum, eps=eps)
2784*da0073e9SAndroid Build Coastguard Worker                y = torch.nn.functional.instance_norm(x, running_mean, running_var,
2785*da0073e9SAndroid Build Coastguard Worker                                                      weight=weight,
2786*da0073e9SAndroid Build Coastguard Worker                                                      bias=bias,
2787*da0073e9SAndroid Build Coastguard Worker                                                      momentum=momentum, eps=eps)
2788*da0073e9SAndroid Build Coastguard Worker
2789*da0073e9SAndroid Build Coastguard Worker            else:
2790*da0073e9SAndroid Build Coastguard Worker
2791*da0073e9SAndroid Build Coastguard Worker                instancenorm_op = None
2792*da0073e9SAndroid Build Coastguard Worker                mps_instancenorm_op = None
2793*da0073e9SAndroid Build Coastguard Worker
2794*da0073e9SAndroid Build Coastguard Worker                if (len(shape) == 3):
2795*da0073e9SAndroid Build Coastguard Worker                    instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
2796*da0073e9SAndroid Build Coastguard Worker                                                              eps=eps,
2797*da0073e9SAndroid Build Coastguard Worker                                                              momentum=momentum,
2798*da0073e9SAndroid Build Coastguard Worker                                                              affine=wts,
2799*da0073e9SAndroid Build Coastguard Worker                                                              track_running_stats=track_running_stats,
2800*da0073e9SAndroid Build Coastguard Worker                                                              device='cpu')
2801*da0073e9SAndroid Build Coastguard Worker                    mps_instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
2802*da0073e9SAndroid Build Coastguard Worker                                                                  eps=eps,
2803*da0073e9SAndroid Build Coastguard Worker                                                                  momentum=momentum,
2804*da0073e9SAndroid Build Coastguard Worker                                                                  affine=wts,
2805*da0073e9SAndroid Build Coastguard Worker                                                                  track_running_stats=track_running_stats,
2806*da0073e9SAndroid Build Coastguard Worker                                                                  device='mps')
2807*da0073e9SAndroid Build Coastguard Worker                elif (len(shape) == 4):
2808*da0073e9SAndroid Build Coastguard Worker                    instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
2809*da0073e9SAndroid Build Coastguard Worker                                                              eps=eps,
2810*da0073e9SAndroid Build Coastguard Worker                                                              momentum=momentum,
2811*da0073e9SAndroid Build Coastguard Worker                                                              affine=wts,
2812*da0073e9SAndroid Build Coastguard Worker                                                              track_running_stats=track_running_stats,
2813*da0073e9SAndroid Build Coastguard Worker                                                              device='cpu')
2814*da0073e9SAndroid Build Coastguard Worker                    mps_instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
2815*da0073e9SAndroid Build Coastguard Worker                                                                  eps=eps,
2816*da0073e9SAndroid Build Coastguard Worker                                                                  momentum=momentum,
2817*da0073e9SAndroid Build Coastguard Worker                                                                  affine=wts,
2818*da0073e9SAndroid Build Coastguard Worker                                                                  track_running_stats=track_running_stats,
2819*da0073e9SAndroid Build Coastguard Worker                                                                  device='mps')
2820*da0073e9SAndroid Build Coastguard Worker                elif (len(shape) == 5):
2821*da0073e9SAndroid Build Coastguard Worker                    instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
2822*da0073e9SAndroid Build Coastguard Worker                                                              eps=eps,
2823*da0073e9SAndroid Build Coastguard Worker                                                              momentum=momentum,
2824*da0073e9SAndroid Build Coastguard Worker                                                              affine=wts,
2825*da0073e9SAndroid Build Coastguard Worker                                                              track_running_stats=track_running_stats,
2826*da0073e9SAndroid Build Coastguard Worker                                                              device='cpu')
2827*da0073e9SAndroid Build Coastguard Worker                    mps_instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
2828*da0073e9SAndroid Build Coastguard Worker                                                                  eps=eps,
2829*da0073e9SAndroid Build Coastguard Worker                                                                  momentum=momentum,
2830*da0073e9SAndroid Build Coastguard Worker                                                                  affine=wts,
2831*da0073e9SAndroid Build Coastguard Worker                                                                  track_running_stats=track_running_stats,
2832*da0073e9SAndroid Build Coastguard Worker                                                                  device='mps')
2833*da0073e9SAndroid Build Coastguard Worker
2834*da0073e9SAndroid Build Coastguard Worker                if (track_running_stats):
2835*da0073e9SAndroid Build Coastguard Worker                    instancenorm_op.running_mean = cpu_running_mean
2836*da0073e9SAndroid Build Coastguard Worker                    instancenorm_op.running_var = cpu_running_var
2837*da0073e9SAndroid Build Coastguard Worker                    mps_instancenorm_op.running_mean = running_mean
2838*da0073e9SAndroid Build Coastguard Worker                    mps_instancenorm_op.running_var = running_var
2839*da0073e9SAndroid Build Coastguard Worker                if (wts):
2840*da0073e9SAndroid Build Coastguard Worker                    instancenorm_op.weight = torch.nn.Parameter(cpu_weight)
2841*da0073e9SAndroid Build Coastguard Worker                    instancenorm_op.bias = torch.nn.Parameter(cpu_bias)
2842*da0073e9SAndroid Build Coastguard Worker                    mps_instancenorm_op.weight = torch.nn.Parameter(weight)
2843*da0073e9SAndroid Build Coastguard Worker                    mps_instancenorm_op.bias = torch.nn.Parameter(bias)
2844*da0073e9SAndroid Build Coastguard Worker
2845*da0073e9SAndroid Build Coastguard Worker                ref_y = instancenorm_op(cpu_x)
2846*da0073e9SAndroid Build Coastguard Worker                y = mps_instancenorm_op(x)
2847*da0073e9SAndroid Build Coastguard Worker
2848*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y)
2849*da0073e9SAndroid Build Coastguard Worker            if (not test_module):
2850*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(running_mean, cpu_running_mean)
2851*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(running_var, cpu_running_var)
2852*da0073e9SAndroid Build Coastguard Worker            else:
2853*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mps_instancenorm_op.running_mean, instancenorm_op.running_mean)
2854*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mps_instancenorm_op.running_var, instancenorm_op.running_var)
2855*da0073e9SAndroid Build Coastguard Worker
2856*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(ref_y.shape)
2857*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
2858*da0073e9SAndroid Build Coastguard Worker            ref_y.backward(gradient=cpu_grad)
2859*da0073e9SAndroid Build Coastguard Worker            y.backward(gradient=grad)
2860*da0073e9SAndroid Build Coastguard Worker
2861*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
2862*da0073e9SAndroid Build Coastguard Worker            if (wts):
2863*da0073e9SAndroid Build Coastguard Worker                if (not test_module):
2864*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(weight.grad, cpu_weight.grad)
2865*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(bias.grad, cpu_bias.grad)
2866*da0073e9SAndroid Build Coastguard Worker                else:
2867*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mps_instancenorm_op.weight.grad, instancenorm_op.weight.grad)
2868*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mps_instancenorm_op.bias.grad, instancenorm_op.bias.grad)
2869*da0073e9SAndroid Build Coastguard Worker
2870*da0073e9SAndroid Build Coastguard Worker        for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
2871*da0073e9SAndroid Build Coastguard Worker            for test_module in [False, True]:
2872*da0073e9SAndroid Build Coastguard Worker                for track_running_stats in [True, False]:
2873*da0073e9SAndroid Build Coastguard Worker                    for channels_last in [False]:
2874*da0073e9SAndroid Build Coastguard Worker                        if (channels_last and len(shape) != 4):
2875*da0073e9SAndroid Build Coastguard Worker                            continue
2876*da0073e9SAndroid Build Coastguard Worker                        # Running stats must be tracked in eval mode
2877*da0073e9SAndroid Build Coastguard Worker                        if (track_running_stats):
2878*da0073e9SAndroid Build Coastguard Worker                            helper(shape, eps=0, momentum=1, channels_last=channels_last,
2879*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2880*da0073e9SAndroid Build Coastguard Worker                            helper(shape, channels_last=channels_last,
2881*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2882*da0073e9SAndroid Build Coastguard Worker                            helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
2883*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2884*da0073e9SAndroid Build Coastguard Worker                            helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
2885*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2886*da0073e9SAndroid Build Coastguard Worker                            helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
2887*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2888*da0073e9SAndroid Build Coastguard Worker                            helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
2889*da0073e9SAndroid Build Coastguard Worker                                   track_running_stats=track_running_stats, test_module=test_module)
2890*da0073e9SAndroid Build Coastguard Worker                        helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
2891*da0073e9SAndroid Build Coastguard Worker                               track_running_stats=track_running_stats, test_module=test_module)
2892*da0073e9SAndroid Build Coastguard Worker                        helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
2893*da0073e9SAndroid Build Coastguard Worker                               track_running_stats=track_running_stats, test_module=test_module)
2894*da0073e9SAndroid Build Coastguard Worker                        helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
2895*da0073e9SAndroid Build Coastguard Worker                               track_running_stats=track_running_stats, test_module=test_module)
2896*da0073e9SAndroid Build Coastguard Worker                        helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
2897*da0073e9SAndroid Build Coastguard Worker                               track_running_stats=track_running_stats, test_module=test_module)
2898*da0073e9SAndroid Build Coastguard Worker
2899*da0073e9SAndroid Build Coastguard Worker    def test_weight_norm(self):
2900*da0073e9SAndroid Build Coastguard Worker        def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim):
2901*da0073e9SAndroid Build Coastguard Worker            cpu_norm = torch.nn.utils.parametrizations.weight_norm(cpu_model, dim=dim)
2902*da0073e9SAndroid Build Coastguard Worker            norm = torch.nn.utils.parametrizations.weight_norm(model, dim=dim)
2903*da0073e9SAndroid Build Coastguard Worker
2904*da0073e9SAndroid Build Coastguard Worker            cpu_out = cpu_norm(cpu_x)
2905*da0073e9SAndroid Build Coastguard Worker            out = norm(x)
2906*da0073e9SAndroid Build Coastguard Worker
2907*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_out, out)
2908*da0073e9SAndroid Build Coastguard Worker
2909*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(cpu_out.shape)
2910*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
2911*da0073e9SAndroid Build Coastguard Worker            cpu_out.backward(gradient=cpu_grad)
2912*da0073e9SAndroid Build Coastguard Worker            out.backward(gradient=grad)
2913*da0073e9SAndroid Build Coastguard Worker
2914*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_model.parametrizations.weight.original0.grad, model.parametrizations.weight.original0.grad)
2915*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_model.parametrizations.weight.original1.grad, model.parametrizations.weight.original1.grad)
2916*da0073e9SAndroid Build Coastguard Worker
2917*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
2918*da0073e9SAndroid Build Coastguard Worker
2919*da0073e9SAndroid Build Coastguard Worker        def helper(dim, layer='linear', dtype=torch.float32):
2920*da0073e9SAndroid Build Coastguard Worker            # linear layer
2921*da0073e9SAndroid Build Coastguard Worker            if layer == 'linear':
2922*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn((2, 5), device='cpu', dtype=dtype, requires_grad=True)
2923*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps').requires_grad_()
2924*da0073e9SAndroid Build Coastguard Worker
2925*da0073e9SAndroid Build Coastguard Worker                cpu_weight = torch.randn(10, 5, device='cpu', dtype=dtype, requires_grad=True)
2926*da0073e9SAndroid Build Coastguard Worker                weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2927*da0073e9SAndroid Build Coastguard Worker
2928*da0073e9SAndroid Build Coastguard Worker                cpu_bias = torch.randn(10, device='cpu', dtype=dtype, requires_grad=True)
2929*da0073e9SAndroid Build Coastguard Worker                bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2930*da0073e9SAndroid Build Coastguard Worker
2931*da0073e9SAndroid Build Coastguard Worker                cpu_linear = torch.nn.Linear(5, 10, device='cpu')
2932*da0073e9SAndroid Build Coastguard Worker                linear = torch.nn.Linear(5, 10, device='mps')
2933*da0073e9SAndroid Build Coastguard Worker
2934*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
2935*da0073e9SAndroid Build Coastguard Worker                    cpu_linear.weight.copy_(cpu_weight)
2936*da0073e9SAndroid Build Coastguard Worker                    cpu_linear.bias.copy_(cpu_bias)
2937*da0073e9SAndroid Build Coastguard Worker                    linear.weight.copy_(weight)
2938*da0073e9SAndroid Build Coastguard Worker                    linear.bias.copy_(bias)
2939*da0073e9SAndroid Build Coastguard Worker                validate_weight_norm_equality(linear, cpu_linear, x, cpu_x, dim)
2940*da0073e9SAndroid Build Coastguard Worker
2941*da0073e9SAndroid Build Coastguard Worker            # conv layer
2942*da0073e9SAndroid Build Coastguard Worker            if layer == 'conv':
2943*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn((3, 5, 5), device='cpu', dtype=dtype, requires_grad=True)
2944*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps').requires_grad_()
2945*da0073e9SAndroid Build Coastguard Worker
2946*da0073e9SAndroid Build Coastguard Worker                cpu_conv = torch.nn.Conv2d(3, 3, 3, device='cpu')
2947*da0073e9SAndroid Build Coastguard Worker                conv = torch.nn.Conv2d(3, 3, 3, device='mps')
2948*da0073e9SAndroid Build Coastguard Worker
2949*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
2950*da0073e9SAndroid Build Coastguard Worker                    conv.weight.copy_(cpu_conv.weight)
2951*da0073e9SAndroid Build Coastguard Worker                    conv.bias.copy_(cpu_conv.bias)
2952*da0073e9SAndroid Build Coastguard Worker
2953*da0073e9SAndroid Build Coastguard Worker                validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
2954*da0073e9SAndroid Build Coastguard Worker
2955*da0073e9SAndroid Build Coastguard Worker            # conv3d layer
2956*da0073e9SAndroid Build Coastguard Worker            if layer == 'conv3d':
2957*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn((3, 5, 5, 4), device='cpu', dtype=dtype, requires_grad=True)
2958*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps').requires_grad_()
2959*da0073e9SAndroid Build Coastguard Worker
2960*da0073e9SAndroid Build Coastguard Worker                cpu_conv = torch.nn.Conv3d(3, 3, 3, device='cpu')
2961*da0073e9SAndroid Build Coastguard Worker                conv = torch.nn.Conv3d(3, 3, 3, device='mps')
2962*da0073e9SAndroid Build Coastguard Worker
2963*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
2964*da0073e9SAndroid Build Coastguard Worker                    conv.weight.copy_(cpu_conv.weight)
2965*da0073e9SAndroid Build Coastguard Worker                    conv.bias.copy_(cpu_conv.bias)
2966*da0073e9SAndroid Build Coastguard Worker
2967*da0073e9SAndroid Build Coastguard Worker                validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
2968*da0073e9SAndroid Build Coastguard Worker
2969*da0073e9SAndroid Build Coastguard Worker        helper(0, layer='linear')
2970*da0073e9SAndroid Build Coastguard Worker        helper(1, layer='linear')
2971*da0073e9SAndroid Build Coastguard Worker        helper(-1, layer='linear')
2972*da0073e9SAndroid Build Coastguard Worker
2973*da0073e9SAndroid Build Coastguard Worker        helper(0, layer='conv')
2974*da0073e9SAndroid Build Coastguard Worker        helper(1, layer='conv')
2975*da0073e9SAndroid Build Coastguard Worker        helper(2, layer='conv')
2976*da0073e9SAndroid Build Coastguard Worker        helper(3, layer='conv')
2977*da0073e9SAndroid Build Coastguard Worker        helper(-1, layer='conv')
2978*da0073e9SAndroid Build Coastguard Worker
2979*da0073e9SAndroid Build Coastguard Worker        if product_version >= 13.2:
2980*da0073e9SAndroid Build Coastguard Worker            # Conv3d is only available from MacOS 13 onwards
2981*da0073e9SAndroid Build Coastguard Worker            helper(0, layer='conv3d')
2982*da0073e9SAndroid Build Coastguard Worker            helper(1, layer='conv3d')
2983*da0073e9SAndroid Build Coastguard Worker            helper(2, layer='conv3d')
2984*da0073e9SAndroid Build Coastguard Worker            helper(3, layer='conv3d')
2985*da0073e9SAndroid Build Coastguard Worker            helper(4, layer='conv3d')
2986*da0073e9SAndroid Build Coastguard Worker            helper(-1, layer='conv3d')
2987*da0073e9SAndroid Build Coastguard Worker
2988*da0073e9SAndroid Build Coastguard Worker    # Test conv2d
2989*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_unit(self):
2990*da0073e9SAndroid Build Coastguard Worker        def helper(input_shape, wt_shape,
2991*da0073e9SAndroid Build Coastguard Worker                   stride=1, padding=0,
2992*da0073e9SAndroid Build Coastguard Worker                   dilation=1, groups=1,
2993*da0073e9SAndroid Build Coastguard Worker                   bias_shape=None):
2994*da0073e9SAndroid Build Coastguard Worker
2995*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
2996*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
2997*da0073e9SAndroid Build Coastguard Worker
2998*da0073e9SAndroid Build Coastguard Worker            cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
2999*da0073e9SAndroid Build Coastguard Worker            wt = cpu_wt.detach().clone().to('mps').requires_grad_()
3000*da0073e9SAndroid Build Coastguard Worker
3001*da0073e9SAndroid Build Coastguard Worker            cpu_bias = None
3002*da0073e9SAndroid Build Coastguard Worker            bias = None
3003*da0073e9SAndroid Build Coastguard Worker
3004*da0073e9SAndroid Build Coastguard Worker            if (bias_shape is not None):
3005*da0073e9SAndroid Build Coastguard Worker                cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
3006*da0073e9SAndroid Build Coastguard Worker                bias = cpu_bias.detach().clone().to('mps').requires_grad_()
3007*da0073e9SAndroid Build Coastguard Worker
3008*da0073e9SAndroid Build Coastguard Worker            y = torch.nn.functional.conv2d(x, wt, bias=bias, stride=stride,
3009*da0073e9SAndroid Build Coastguard Worker                                           padding=padding, dilation=dilation, groups=groups)
3010*da0073e9SAndroid Build Coastguard Worker            ref_y = torch.nn.functional.conv2d(cpu_x, cpu_wt, bias=cpu_bias, stride=stride,
3011*da0073e9SAndroid Build Coastguard Worker                                               padding=padding, dilation=dilation, groups=groups)
3012*da0073e9SAndroid Build Coastguard Worker
3013*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.ones_like(ref_y)
3014*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
3015*da0073e9SAndroid Build Coastguard Worker
3016*da0073e9SAndroid Build Coastguard Worker            y.backward(gradient=grad)
3017*da0073e9SAndroid Build Coastguard Worker            ref_y.backward(gradient=cpu_grad)
3018*da0073e9SAndroid Build Coastguard Worker
3019*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
3020*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
3021*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
3022*da0073e9SAndroid Build Coastguard Worker            if (bias_shape is not None):
3023*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(bias.grad, cpu_bias.grad, atol=8e-04, rtol=10.4e-05)
3024*da0073e9SAndroid Build Coastguard Worker
3025*da0073e9SAndroid Build Coastguard Worker        N = 1
3026*da0073e9SAndroid Build Coastguard Worker        C_in = 3
3027*da0073e9SAndroid Build Coastguard Worker        C_out = 64
3028*da0073e9SAndroid Build Coastguard Worker        H = 64
3029*da0073e9SAndroid Build Coastguard Worker        W = 64
3030*da0073e9SAndroid Build Coastguard Worker        kH = 4
3031*da0073e9SAndroid Build Coastguard Worker        kW = 4
3032*da0073e9SAndroid Build Coastguard Worker        stride = 2
3033*da0073e9SAndroid Build Coastguard Worker        padding = 1
3034*da0073e9SAndroid Build Coastguard Worker
3035*da0073e9SAndroid Build Coastguard Worker        helper((N, C_in, H, W), (C_out, C_in, kH, kW), stride=stride, padding=padding)
3036*da0073e9SAndroid Build Coastguard Worker
3037*da0073e9SAndroid Build Coastguard Worker        N = 4
3038*da0073e9SAndroid Build Coastguard Worker        C_in = 16
3039*da0073e9SAndroid Build Coastguard Worker        H = 32
3040*da0073e9SAndroid Build Coastguard Worker        W = 32
3041*da0073e9SAndroid Build Coastguard Worker
3042*da0073e9SAndroid Build Coastguard Worker        C_out = 8
3043*da0073e9SAndroid Build Coastguard Worker        kH = 3
3044*da0073e9SAndroid Build Coastguard Worker        kW = 3
3045*da0073e9SAndroid Build Coastguard Worker
3046*da0073e9SAndroid Build Coastguard Worker        for groups in [1, 2, 4]:
3047*da0073e9SAndroid Build Coastguard Worker            helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
3048*da0073e9SAndroid Build Coastguard Worker            helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
3049*da0073e9SAndroid Build Coastguard Worker
3050*da0073e9SAndroid Build Coastguard Worker            helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
3051*da0073e9SAndroid Build Coastguard Worker            helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
3052*da0073e9SAndroid Build Coastguard Worker
3053*da0073e9SAndroid Build Coastguard Worker            helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
3054*da0073e9SAndroid Build Coastguard Worker            helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
3055*da0073e9SAndroid Build Coastguard Worker
3056*da0073e9SAndroid Build Coastguard Worker            helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
3057*da0073e9SAndroid Build Coastguard Worker                   kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
3058*da0073e9SAndroid Build Coastguard Worker            helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
3059*da0073e9SAndroid Build Coastguard Worker                   kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
3060*da0073e9SAndroid Build Coastguard Worker
3061*da0073e9SAndroid Build Coastguard Worker    # Test conv transpose 2d
3062*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose2d(self):
3063*da0073e9SAndroid Build Coastguard Worker        def helper(input_shape, wt_shape,
3064*da0073e9SAndroid Build Coastguard Worker                   stride=1, padding=0,
3065*da0073e9SAndroid Build Coastguard Worker                   output_padding=0,
3066*da0073e9SAndroid Build Coastguard Worker                   dilation=1, groups=1,
3067*da0073e9SAndroid Build Coastguard Worker                   bias_shape=None):
3068*da0073e9SAndroid Build Coastguard Worker
3069*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
3070*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
3071*da0073e9SAndroid Build Coastguard Worker
3072*da0073e9SAndroid Build Coastguard Worker            cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
3073*da0073e9SAndroid Build Coastguard Worker            wt = cpu_wt.detach().clone().to('mps').requires_grad_()
3074*da0073e9SAndroid Build Coastguard Worker
3075*da0073e9SAndroid Build Coastguard Worker            cpu_bias = None
3076*da0073e9SAndroid Build Coastguard Worker            bias = None
3077*da0073e9SAndroid Build Coastguard Worker
3078*da0073e9SAndroid Build Coastguard Worker            if (bias_shape is not None):
3079*da0073e9SAndroid Build Coastguard Worker                cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
3080*da0073e9SAndroid Build Coastguard Worker                bias = cpu_bias.detach().clone().to('mps').requires_grad_()
3081*da0073e9SAndroid Build Coastguard Worker
3082*da0073e9SAndroid Build Coastguard Worker            y = torch.nn.functional.conv_transpose2d(
3083*da0073e9SAndroid Build Coastguard Worker                x, wt, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
3084*da0073e9SAndroid Build Coastguard Worker            ref_y = torch.nn.functional.conv_transpose2d(
3085*da0073e9SAndroid Build Coastguard Worker                cpu_x, cpu_wt, bias=cpu_bias, stride=stride, padding=padding,
3086*da0073e9SAndroid Build Coastguard Worker                output_padding=output_padding, groups=groups, dilation=dilation)
3087*da0073e9SAndroid Build Coastguard Worker
3088*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(ref_y.shape)
3089*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
3090*da0073e9SAndroid Build Coastguard Worker
3091*da0073e9SAndroid Build Coastguard Worker            y.backward(gradient=grad)
3092*da0073e9SAndroid Build Coastguard Worker            ref_y.backward(gradient=cpu_grad)
3093*da0073e9SAndroid Build Coastguard Worker
3094*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
3095*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
3096*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
3097*da0073e9SAndroid Build Coastguard Worker
3098*da0073e9SAndroid Build Coastguard Worker            # if (bias_shape is not None):
3099*da0073e9SAndroid Build Coastguard Worker            #  print(cpu_bias.grad)
3100*da0073e9SAndroid Build Coastguard Worker            #  print(bias.grad.to('cpu'))
3101*da0073e9SAndroid Build Coastguard Worker            #  self.assertEqual(bias.grad, cpu_bias.grad)
3102*da0073e9SAndroid Build Coastguard Worker
3103*da0073e9SAndroid Build Coastguard Worker        N = 4
3104*da0073e9SAndroid Build Coastguard Worker        C_in = 2
3105*da0073e9SAndroid Build Coastguard Worker        H = 32
3106*da0073e9SAndroid Build Coastguard Worker        W = 32
3107*da0073e9SAndroid Build Coastguard Worker
3108*da0073e9SAndroid Build Coastguard Worker        C_out = 8
3109*da0073e9SAndroid Build Coastguard Worker        groups = 1
3110*da0073e9SAndroid Build Coastguard Worker        kH = 3
3111*da0073e9SAndroid Build Coastguard Worker        kW = 3
3112*da0073e9SAndroid Build Coastguard Worker
3113*da0073e9SAndroid Build Coastguard Worker        for stride in [1, 2, 3]:
3114*da0073e9SAndroid Build Coastguard Worker            for padding in [0, 1, 2]:
3115*da0073e9SAndroid Build Coastguard Worker                for output_padding in [0, 1, 2]:
3116*da0073e9SAndroid Build Coastguard Worker                    for dilation in [1, 2]:
3117*da0073e9SAndroid Build Coastguard Worker                        if (output_padding >= stride or output_padding >= dilation):
3118*da0073e9SAndroid Build Coastguard Worker                            continue
3119*da0073e9SAndroid Build Coastguard Worker                        helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
3120*da0073e9SAndroid Build Coastguard Worker                               padding=padding, output_padding=output_padding, dilation=dilation)
3121*da0073e9SAndroid Build Coastguard Worker                        helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
3122*da0073e9SAndroid Build Coastguard Worker                               padding=padding, output_padding=output_padding, dilation=dilation)
3123*da0073e9SAndroid Build Coastguard Worker
3124*da0073e9SAndroid Build Coastguard Worker                        helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
3125*da0073e9SAndroid Build Coastguard Worker                               padding=padding, output_padding=output_padding, dilation=dilation)
3126*da0073e9SAndroid Build Coastguard Worker                        helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
3127*da0073e9SAndroid Build Coastguard Worker                               padding=padding, output_padding=output_padding, dilation=dilation)
3128*da0073e9SAndroid Build Coastguard Worker
3129*da0073e9SAndroid Build Coastguard Worker    # Test sigmoid
3130*da0073e9SAndroid Build Coastguard Worker    def test_sigmoid(self):
3131*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
3132*da0073e9SAndroid Build Coastguard Worker
3133*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3134*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
3135*da0073e9SAndroid Build Coastguard Worker
3136*da0073e9SAndroid Build Coastguard Worker            sigmoid_op = torch.nn.Sigmoid()
3137*da0073e9SAndroid Build Coastguard Worker
3138*da0073e9SAndroid Build Coastguard Worker            y = sigmoid_op(x)
3139*da0073e9SAndroid Build Coastguard Worker            ref_y = sigmoid_op(cpu_x)
3140*da0073e9SAndroid Build Coastguard Worker
3141*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.ones_like(ref_y)
3142*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
3143*da0073e9SAndroid Build Coastguard Worker
3144*da0073e9SAndroid Build Coastguard Worker            y.backward(gradient=grad)
3145*da0073e9SAndroid Build Coastguard Worker            ref_y.backward(gradient=cpu_grad)
3146*da0073e9SAndroid Build Coastguard Worker
3147*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y)
3148*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
3149*da0073e9SAndroid Build Coastguard Worker
3150*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
3151*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4))
3152*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
3153*da0073e9SAndroid Build Coastguard Worker
3154*da0073e9SAndroid Build Coastguard Worker    # Test tanh
3155*da0073e9SAndroid Build Coastguard Worker    def test_tanh(self):
3156*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
3157*da0073e9SAndroid Build Coastguard Worker
3158*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3159*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
3160*da0073e9SAndroid Build Coastguard Worker
3161*da0073e9SAndroid Build Coastguard Worker            tanh_op = torch.nn.Tanh()
3162*da0073e9SAndroid Build Coastguard Worker
3163*da0073e9SAndroid Build Coastguard Worker            y = tanh_op(x)
3164*da0073e9SAndroid Build Coastguard Worker            ref_y = tanh_op(cpu_x)
3165*da0073e9SAndroid Build Coastguard Worker
3166*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.ones_like(ref_y)
3167*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
3168*da0073e9SAndroid Build Coastguard Worker
3169*da0073e9SAndroid Build Coastguard Worker            y.backward(gradient=grad)
3170*da0073e9SAndroid Build Coastguard Worker            ref_y.backward(gradient=cpu_grad)
3171*da0073e9SAndroid Build Coastguard Worker
3172*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y)
3173*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
3174*da0073e9SAndroid Build Coastguard Worker
3175*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
3176*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4))
3177*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
3178*da0073e9SAndroid Build Coastguard Worker
3179*da0073e9SAndroid Build Coastguard Worker    def test_threshold(self):
3180*da0073e9SAndroid Build Coastguard Worker        def helper(threshold, value, num_elems, inplace=False, requires_grad=True):
3181*da0073e9SAndroid Build Coastguard Worker            m = nn.Threshold(threshold=threshold, value=value, inplace=inplace)
3182*da0073e9SAndroid Build Coastguard Worker
3183*da0073e9SAndroid Build Coastguard Worker            input_cpu = torch.randn(num_elems, requires_grad=requires_grad, dtype=torch.float)
3184*da0073e9SAndroid Build Coastguard Worker            input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
3185*da0073e9SAndroid Build Coastguard Worker
3186*da0073e9SAndroid Build Coastguard Worker            output_cpu = m(input_cpu)
3187*da0073e9SAndroid Build Coastguard Worker            output_mps = m(input_mps)
3188*da0073e9SAndroid Build Coastguard Worker
3189*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.ones_like(output_cpu)
3190*da0073e9SAndroid Build Coastguard Worker            mps_grad = cpu_grad.to('mps')
3191*da0073e9SAndroid Build Coastguard Worker
3192*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_cpu, output_mps)
3193*da0073e9SAndroid Build Coastguard Worker
3194*da0073e9SAndroid Build Coastguard Worker            if requires_grad:
3195*da0073e9SAndroid Build Coastguard Worker                output_cpu.backward(gradient=cpu_grad)
3196*da0073e9SAndroid Build Coastguard Worker                output_mps.backward(gradient=mps_grad)
3197*da0073e9SAndroid Build Coastguard Worker
3198*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input_cpu.grad, input_mps.grad)
3199*da0073e9SAndroid Build Coastguard Worker
3200*da0073e9SAndroid Build Coastguard Worker        helper(threshold=0.1, value=20, num_elems=2)
3201*da0073e9SAndroid Build Coastguard Worker        helper(threshold=-0.1, value=10, num_elems=10)
3202*da0073e9SAndroid Build Coastguard Worker        helper(threshold=0.5, value=-15, num_elems=100)
3203*da0073e9SAndroid Build Coastguard Worker        helper(threshold=1, value=10, num_elems=100, inplace=True, requires_grad=False)
3204*da0073e9SAndroid Build Coastguard Worker
3205*da0073e9SAndroid Build Coastguard Worker    # Test pow
3206*da0073e9SAndroid Build Coastguard Worker    def test_pow(self):
3207*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
3208*da0073e9SAndroid Build Coastguard Worker            # aten::pow.Tensor_Tensor
3209*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3210*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
3211*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3212*da0073e9SAndroid Build Coastguard Worker            y = cpu_y.detach().clone().to('mps')
3213*da0073e9SAndroid Build Coastguard Worker            z = torch.pow(x, y)
3214*da0073e9SAndroid Build Coastguard Worker            ref_z = torch.pow(cpu_x, cpu_y)
3215*da0073e9SAndroid Build Coastguard Worker
3216*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z, ref_z)
3217*da0073e9SAndroid Build Coastguard Worker
3218*da0073e9SAndroid Build Coastguard Worker            # aten::pow.Tensor_Scalar
3219*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3220*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
3221*da0073e9SAndroid Build Coastguard Worker            exp = random.random()
3222*da0073e9SAndroid Build Coastguard Worker            z = torch.pow(x, exp)
3223*da0073e9SAndroid Build Coastguard Worker            ref_z = torch.pow(cpu_x, exp)
3224*da0073e9SAndroid Build Coastguard Worker
3225*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z, ref_z)
3226*da0073e9SAndroid Build Coastguard Worker
3227*da0073e9SAndroid Build Coastguard Worker            # aten::pow.Scalar
3228*da0073e9SAndroid Build Coastguard Worker            x = random.random()
3229*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3230*da0073e9SAndroid Build Coastguard Worker            y = cpu_y.detach().clone().to('mps')
3231*da0073e9SAndroid Build Coastguard Worker            z = torch.pow(x, y)
3232*da0073e9SAndroid Build Coastguard Worker            ref_z = torch.pow(x, cpu_y)
3233*da0073e9SAndroid Build Coastguard Worker
3234*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z, ref_z)
3235*da0073e9SAndroid Build Coastguard Worker
3236*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
3237*da0073e9SAndroid Build Coastguard Worker
3238*da0073e9SAndroid Build Coastguard Worker    # Test addcmul
3239*da0073e9SAndroid Build Coastguard Worker    def test_addcmul(self):
3240*da0073e9SAndroid Build Coastguard Worker        def helper(shape, value, xtype=torch.float32, ytype=None, ztype=None):
3241*da0073e9SAndroid Build Coastguard Worker            def rand_helper(dtype):
3242*da0073e9SAndroid Build Coastguard Worker                if dtype.is_floating_point:
3243*da0073e9SAndroid Build Coastguard Worker                    return torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
3244*da0073e9SAndroid Build Coastguard Worker                return torch.randint(10, shape, dtype=dtype, device='cpu', requires_grad=False)
3245*da0073e9SAndroid Build Coastguard Worker
3246*da0073e9SAndroid Build Coastguard Worker            cpu_x = rand_helper(xtype)
3247*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
3248*da0073e9SAndroid Build Coastguard Worker
3249*da0073e9SAndroid Build Coastguard Worker            cpu_y = rand_helper(ytype if ytype is not None else xtype)
3250*da0073e9SAndroid Build Coastguard Worker            y = cpu_y.detach().clone().to('mps')
3251*da0073e9SAndroid Build Coastguard Worker
3252*da0073e9SAndroid Build Coastguard Worker            cpu_z = rand_helper(ztype if ztype is not None else xtype)
3253*da0073e9SAndroid Build Coastguard Worker            z = cpu_z.detach().clone().to('mps')
3254*da0073e9SAndroid Build Coastguard Worker
3255*da0073e9SAndroid Build Coastguard Worker            y = torch.addcmul(x, y, z, value=value)
3256*da0073e9SAndroid Build Coastguard Worker            ref_y = torch.addcmul(cpu_x, cpu_y, cpu_z, value=value)
3257*da0073e9SAndroid Build Coastguard Worker
3258*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y)
3259*da0073e9SAndroid Build Coastguard Worker
3260*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5), 0.1)
3261*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 0.1)
3262*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5), 0.2)
3263*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 0.2)
3264*da0073e9SAndroid Build Coastguard Worker        # Integral types
3265*da0073e9SAndroid Build Coastguard Worker        helper((2, 2), 1.0, xtype=torch.int32)
3266*da0073e9SAndroid Build Coastguard Worker        helper((2, 2), 2.0, xtype=torch.int16)
3267*da0073e9SAndroid Build Coastguard Worker
3268*da0073e9SAndroid Build Coastguard Worker        # Mixed types
3269*da0073e9SAndroid Build Coastguard Worker        helper((2, 2), 1.0, xtype=torch.float16, ytype=torch.float32)
3270*da0073e9SAndroid Build Coastguard Worker        helper((3, 2), 1.0, ytype=torch.float16)
3271*da0073e9SAndroid Build Coastguard Worker        helper((2, 3), 1.0, ztype=torch.float16)
3272*da0073e9SAndroid Build Coastguard Worker        helper((2, 2), 1.0, xtype=torch.int32, ytype=torch.int16, ztype=torch.uint8)
3273*da0073e9SAndroid Build Coastguard Worker        helper((2, 2), 1.0, ytype=torch.int16, ztype=torch.uint8)
3274*da0073e9SAndroid Build Coastguard Worker
3275*da0073e9SAndroid Build Coastguard Worker    # Test addcdiv
3276*da0073e9SAndroid Build Coastguard Worker    def test_addcdiv(self):
3277*da0073e9SAndroid Build Coastguard Worker        def helper(shape, value):
3278*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3279*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3280*da0073e9SAndroid Build Coastguard Worker            # clamp to avoid division by 0
3281*da0073e9SAndroid Build Coastguard Worker            cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False).clamp_min_(0.1)
3282*da0073e9SAndroid Build Coastguard Worker            cpu_out = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3283*da0073e9SAndroid Build Coastguard Worker
3284*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
3285*da0073e9SAndroid Build Coastguard Worker            mps_y = cpu_y.detach().clone().to('mps')
3286*da0073e9SAndroid Build Coastguard Worker            mps_z = cpu_z.detach().clone().to('mps')
3287*da0073e9SAndroid Build Coastguard Worker            mps_out = cpu_out.detach().clone().to('mps')
3288*da0073e9SAndroid Build Coastguard Worker
3289*da0073e9SAndroid Build Coastguard Worker            result_div_mps = torch.addcdiv(mps_x, mps_y, mps_z, value=value)
3290*da0073e9SAndroid Build Coastguard Worker            result_div_cpu = torch.addcdiv(cpu_x, cpu_y, cpu_z, value=value)
3291*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_div_mps, result_div_cpu)
3292*da0073e9SAndroid Build Coastguard Worker            # test .out variant
3293*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.addcdiv(mps_x, mps_y, mps_z, out=mps_out, value=value), result_div_cpu)
3294*da0073e9SAndroid Build Coastguard Worker
3295*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5), 0.1)
3296*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 0.2)
3297*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5), 1.0)  # value of 1 should be ignored internally
3298*da0073e9SAndroid Build Coastguard Worker
3299*da0073e9SAndroid Build Coastguard Worker    def test_addcdiv_transpose(self):
3300*da0073e9SAndroid Build Coastguard Worker        # Regression test for issue https://github.com/pytorch/pytorch/issues/118115
3301*da0073e9SAndroid Build Coastguard Worker        # Testing continuity of all input tensors
3302*da0073e9SAndroid Build Coastguard Worker
3303*da0073e9SAndroid Build Coastguard Worker        def helper(shape, value):
3304*da0073e9SAndroid Build Coastguard Worker            shape_t = shape[::-1]
3305*da0073e9SAndroid Build Coastguard Worker            for i in range(2):
3306*da0073e9SAndroid Build Coastguard Worker                for j in range(2):
3307*da0073e9SAndroid Build Coastguard Worker                    for k in range(2):
3308*da0073e9SAndroid Build Coastguard Worker                        x = torch.rand(shape, device="cpu") if i == 0 else torch.rand(shape_t, device="cpu").t()
3309*da0073e9SAndroid Build Coastguard Worker                        y = torch.rand(shape, device="cpu") if j == 0 else torch.rand(shape_t, device="cpu").t()
3310*da0073e9SAndroid Build Coastguard Worker                        z = torch.rand(shape, device="cpu") if k == 0 else torch.rand(shape_t, device="cpu").t()
3311*da0073e9SAndroid Build Coastguard Worker
3312*da0073e9SAndroid Build Coastguard Worker                        x_mps = x.detach().clone().to(device="mps")
3313*da0073e9SAndroid Build Coastguard Worker                        y_mps = y.detach().clone().to(device="mps")
3314*da0073e9SAndroid Build Coastguard Worker                        z_mps = z.detach().clone().to(device="mps")
3315*da0073e9SAndroid Build Coastguard Worker
3316*da0073e9SAndroid Build Coastguard Worker                        result_cpu = x.addcdiv_(y, z, value=value)
3317*da0073e9SAndroid Build Coastguard Worker                        result_mps = x_mps.addcdiv(y_mps, z_mps, value=value)
3318*da0073e9SAndroid Build Coastguard Worker                        result_mps_out = result_cpu.detach().clone().to('mps')
3319*da0073e9SAndroid Build Coastguard Worker                        torch.addcdiv(x_mps, y_mps, z_mps, out=result_mps_out, value=value)
3320*da0073e9SAndroid Build Coastguard Worker
3321*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(result_cpu, result_mps)
3322*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(result_cpu, result_mps_out)
3323*da0073e9SAndroid Build Coastguard Worker
3324*da0073e9SAndroid Build Coastguard Worker        helper((2, 3), 1.0)
3325*da0073e9SAndroid Build Coastguard Worker        helper((2, 3), 0.2)
3326*da0073e9SAndroid Build Coastguard Worker        helper((100, 300), 1.0)
3327*da0073e9SAndroid Build Coastguard Worker        helper((100, 300), 0.2)
3328*da0073e9SAndroid Build Coastguard Worker
3329*da0073e9SAndroid Build Coastguard Worker    def test_buffer_size_match(self):
3330*da0073e9SAndroid Build Coastguard Worker        # this test shouldn't cause any crash
3331*da0073e9SAndroid Build Coastguard Worker        size = 16
3332*da0073e9SAndroid Build Coastguard Worker        cpu_A = torch.rand(size, device='cpu')
3333*da0073e9SAndroid Build Coastguard Worker        cpu_F = torch.rand(size, size, size, device='cpu')
3334*da0073e9SAndroid Build Coastguard Worker
3335*da0073e9SAndroid Build Coastguard Worker        mps_A = cpu_A.to('mps')
3336*da0073e9SAndroid Build Coastguard Worker        mps_F = cpu_F.to('mps')
3337*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_A @ cpu_F, mps_A @ mps_F)
3338*da0073e9SAndroid Build Coastguard Worker
3339*da0073e9SAndroid Build Coastguard Worker    def test_transpose_inplace(self):
3340*da0073e9SAndroid Build Coastguard Worker        values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
3341*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values, device='cpu')
3342*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.tensor(values, device='mps')
3343*da0073e9SAndroid Build Coastguard Worker
3344*da0073e9SAndroid Build Coastguard Worker        cpu_x.transpose_(0, 1)
3345*da0073e9SAndroid Build Coastguard Worker        mps_x.transpose_(0, 1)
3346*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x, mps_x.to('cpu'))
3347*da0073e9SAndroid Build Coastguard Worker
3348*da0073e9SAndroid Build Coastguard Worker    def test_expand_cpu_to_mps_copy(self):
3349*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/78642
3350*da0073e9SAndroid Build Coastguard Worker
3351*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(1).expand([10]).to("mps")
3352*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.tensor(1).expand([10])
3353*da0073e9SAndroid Build Coastguard Worker
3354*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_cpu, x.cpu())
3355*da0073e9SAndroid Build Coastguard Worker
3356*da0073e9SAndroid Build Coastguard Worker    def test_cpu_to_strided_mps_copy(self):
3357*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/86975
3358*da0073e9SAndroid Build Coastguard Worker
3359*da0073e9SAndroid Build Coastguard Worker        a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
3360*da0073e9SAndroid Build Coastguard Worker        b1 = torch.Tensor([-1, -1])
3361*da0073e9SAndroid Build Coastguard Worker        a1[1:, 1] = b1
3362*da0073e9SAndroid Build Coastguard Worker
3363*da0073e9SAndroid Build Coastguard Worker        a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
3364*da0073e9SAndroid Build Coastguard Worker        b2 = torch.Tensor([-1, -1]).to(torch.device("mps"))
3365*da0073e9SAndroid Build Coastguard Worker        a2[1:, 1] = b2
3366*da0073e9SAndroid Build Coastguard Worker
3367*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a1, a2)
3368*da0073e9SAndroid Build Coastguard Worker
3369*da0073e9SAndroid Build Coastguard Worker    def test_view_slice_reshape(self):
3370*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([1, 4, 4], device="mps")
3371*da0073e9SAndroid Build Coastguard Worker        y = x[0, :1, 1:]
3372*da0073e9SAndroid Build Coastguard Worker
3373*da0073e9SAndroid Build Coastguard Worker        x_cpu = x.to("cpu")
3374*da0073e9SAndroid Build Coastguard Worker        y_cpu = x_cpu[0, :1, 1:]
3375*da0073e9SAndroid Build Coastguard Worker
3376*da0073e9SAndroid Build Coastguard Worker        r = y + 1
3377*da0073e9SAndroid Build Coastguard Worker        r_cpu = y_cpu + 1
3378*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, r_cpu)
3379*da0073e9SAndroid Build Coastguard Worker
3380*da0073e9SAndroid Build Coastguard Worker    def test_slice_reshape(self):
3381*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps")
3382*da0073e9SAndroid Build Coastguard Worker        x_cpu = x.detach().clone().to("cpu")
3383*da0073e9SAndroid Build Coastguard Worker
3384*da0073e9SAndroid Build Coastguard Worker        x = x[:, 3:].view(2, 3, 4, 1)
3385*da0073e9SAndroid Build Coastguard Worker        x_cpu = x_cpu[:, 3:].view(2, 3, 4, 1)
3386*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x_cpu)
3387*da0073e9SAndroid Build Coastguard Worker
3388*da0073e9SAndroid Build Coastguard Worker        x = x + 2
3389*da0073e9SAndroid Build Coastguard Worker        x_cpu = x_cpu + 2
3390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x_cpu)
3391*da0073e9SAndroid Build Coastguard Worker
3392*da0073e9SAndroid Build Coastguard Worker    def test_reshape_storage_offset(self):
3393*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/95883
3394*da0073e9SAndroid Build Coastguard Worker        B = 4
3395*da0073e9SAndroid Build Coastguard Worker        T = 1
3396*da0073e9SAndroid Build Coastguard Worker
3397*da0073e9SAndroid Build Coastguard Worker        lin_cpu = nn.Linear(10, 256)
3398*da0073e9SAndroid Build Coastguard Worker        lin_mps = nn.Linear(10, 256, device="mps")
3399*da0073e9SAndroid Build Coastguard Worker
3400*da0073e9SAndroid Build Coastguard Worker        # Use the same weights and bias as the ones from the cpu
3401*da0073e9SAndroid Build Coastguard Worker        lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
3402*da0073e9SAndroid Build Coastguard Worker        lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
3403*da0073e9SAndroid Build Coastguard Worker
3404*da0073e9SAndroid Build Coastguard Worker        x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
3405*da0073e9SAndroid Build Coastguard Worker        x_cpu = x_mps.detach().clone().cpu().requires_grad_()
3406*da0073e9SAndroid Build Coastguard Worker        x_mps = lin_mps(x_mps)
3407*da0073e9SAndroid Build Coastguard Worker        x_cpu = lin_cpu(x_cpu)
3408*da0073e9SAndroid Build Coastguard Worker
3409*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_mps.shape, (B, T, 256))
3410*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_cpu.shape, (B, T, 256))
3411*da0073e9SAndroid Build Coastguard Worker
3412*da0073e9SAndroid Build Coastguard Worker        cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
3413*da0073e9SAndroid Build Coastguard Worker        cls_token_cpu = cls_token_mps.detach().clone().cpu()
3414*da0073e9SAndroid Build Coastguard Worker        x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
3415*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
3416*da0073e9SAndroid Build Coastguard Worker
3417*da0073e9SAndroid Build Coastguard Worker        x_mps = x_mps.transpose(0, 1)
3418*da0073e9SAndroid Build Coastguard Worker        x_cpu = x_cpu.transpose(0, 1)
3419*da0073e9SAndroid Build Coastguard Worker
3420*da0073e9SAndroid Build Coastguard Worker        target_mps = torch.rand_like(x_mps)
3421*da0073e9SAndroid Build Coastguard Worker        target_cpu = target_mps.detach().clone().cpu()
3422*da0073e9SAndroid Build Coastguard Worker        loss_mps = F.mse_loss(x_mps, target_mps)
3423*da0073e9SAndroid Build Coastguard Worker        loss_cpu = F.mse_loss(x_cpu, target_cpu)
3424*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loss_mps, loss_cpu)
3425*da0073e9SAndroid Build Coastguard Worker
3426*da0073e9SAndroid Build Coastguard Worker        loss_mps.backward()
3427*da0073e9SAndroid Build Coastguard Worker        loss_cpu.backward()
3428*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_mps.grad, x_cpu.grad)
3429*da0073e9SAndroid Build Coastguard Worker
3430*da0073e9SAndroid Build Coastguard Worker    def test_stack_storage_offset(self):
3431*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/87856
3432*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.tensor([[1, 2]])
3433*da0073e9SAndroid Build Coastguard Worker        x_mps = x_cpu.detach().clone().to("mps")
3434*da0073e9SAndroid Build Coastguard Worker
3435*da0073e9SAndroid Build Coastguard Worker        y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
3436*da0073e9SAndroid Build Coastguard Worker        y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
3437*da0073e9SAndroid Build Coastguard Worker
3438*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_cpu, y_mps)
3439*da0073e9SAndroid Build Coastguard Worker
3440*da0073e9SAndroid Build Coastguard Worker        t_mps = torch.tensor([1, 2, 3, 4], device="mps")
3441*da0073e9SAndroid Build Coastguard Worker        t_cpu = t_mps.detach().cpu().detach()
3442*da0073e9SAndroid Build Coastguard Worker
3443*da0073e9SAndroid Build Coastguard Worker        x_mps = t_mps[2:]
3444*da0073e9SAndroid Build Coastguard Worker        y_mps = t_mps[:2]
3445*da0073e9SAndroid Build Coastguard Worker
3446*da0073e9SAndroid Build Coastguard Worker        x_cpu = t_cpu[2:]
3447*da0073e9SAndroid Build Coastguard Worker        y_cpu = t_cpu[:2]
3448*da0073e9SAndroid Build Coastguard Worker
3449*da0073e9SAndroid Build Coastguard Worker        res_mps = torch.stack((y_mps, x_mps), dim=-1)
3450*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
3451*da0073e9SAndroid Build Coastguard Worker
3452*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_mps, res_cpu)
3453*da0073e9SAndroid Build Coastguard Worker
3454*da0073e9SAndroid Build Coastguard Worker    def test_unsafe_chunk(self):
3455*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/91065
3456*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(5, dtype=torch.float32, device="cpu")
3457*da0073e9SAndroid Build Coastguard Worker        ret = a.unsafe_chunk(4, 0)
3458*da0073e9SAndroid Build Coastguard Worker        y = ret[0] * ret[2]
3459*da0073e9SAndroid Build Coastguard Worker        a_mps = a.to("mps")
3460*da0073e9SAndroid Build Coastguard Worker        ret_mps = a_mps.unsafe_chunk(4, 0)
3461*da0073e9SAndroid Build Coastguard Worker        y_mps = ret_mps[0] * ret_mps[2]
3462*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_mps)
3463*da0073e9SAndroid Build Coastguard Worker
3464*da0073e9SAndroid Build Coastguard Worker    def test_slice_casting(self):
3465*da0073e9SAndroid Build Coastguard Worker        # generate random binary numbers
3466*da0073e9SAndroid Build Coastguard Worker        cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)
3467*da0073e9SAndroid Build Coastguard Worker        mps_in = cpu_in.detach().clone().to("mps")
3468*da0073e9SAndroid Build Coastguard Worker        # check copy_cast(unit8 -> bool) on tensors with storage offset
3469*da0073e9SAndroid Build Coastguard Worker        cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool)
3470*da0073e9SAndroid Build Coastguard Worker        mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool)
3471*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_out, mps_out)
3472*da0073e9SAndroid Build Coastguard Worker
3473*da0073e9SAndroid Build Coastguard Worker    def test_slice_reshape_contg_view(self):
3474*da0073e9SAndroid Build Coastguard Worker        import torch
3475*da0073e9SAndroid Build Coastguard Worker
3476*da0073e9SAndroid Build Coastguard Worker        x_mps = torch.randn(1, 4800, 2, device="mps")
3477*da0073e9SAndroid Build Coastguard Worker        x_cpu = x_mps.detach().clone().cpu()
3478*da0073e9SAndroid Build Coastguard Worker
3479*da0073e9SAndroid Build Coastguard Worker        r_mps = x_mps + 2
3480*da0073e9SAndroid Build Coastguard Worker        r_cpu = x_cpu + 2
3481*da0073e9SAndroid Build Coastguard Worker
3482*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r_mps, r_cpu)
3483*da0073e9SAndroid Build Coastguard Worker
3484*da0073e9SAndroid Build Coastguard Worker    def test_contiguous_slice_2d(self):
3485*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
3486*da0073e9SAndroid Build Coastguard Worker            for i in range(0, shape[0]):
3487*da0073e9SAndroid Build Coastguard Worker                for j in range(0, shape[1]):
3488*da0073e9SAndroid Build Coastguard Worker                    t_mps = torch.randn(shape, device="mps")
3489*da0073e9SAndroid Build Coastguard Worker                    t_cpu = t_mps.detach().clone().cpu()
3490*da0073e9SAndroid Build Coastguard Worker
3491*da0073e9SAndroid Build Coastguard Worker                    y_mps = t_mps[i:, :j]
3492*da0073e9SAndroid Build Coastguard Worker                    y_cpu = t_cpu[i:, :j]
3493*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_mps + 1, y_cpu + 1)
3494*da0073e9SAndroid Build Coastguard Worker
3495*da0073e9SAndroid Build Coastguard Worker                    y_mps = t_mps[i:, j]
3496*da0073e9SAndroid Build Coastguard Worker                    y_cpu = t_cpu[i:, j]
3497*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_mps + 1, y_cpu + 1)
3498*da0073e9SAndroid Build Coastguard Worker
3499*da0073e9SAndroid Build Coastguard Worker                    y_mps = t_mps[i, :j]
3500*da0073e9SAndroid Build Coastguard Worker                    y_cpu = t_cpu[i, :j]
3501*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_mps + 1, y_cpu + 1)
3502*da0073e9SAndroid Build Coastguard Worker
3503*da0073e9SAndroid Build Coastguard Worker                    y_mps = t_mps[:i, :j]
3504*da0073e9SAndroid Build Coastguard Worker                    y_cpu = t_cpu[:i, :j]
3505*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_mps + 1, y_cpu + 1)
3506*da0073e9SAndroid Build Coastguard Worker
3507*da0073e9SAndroid Build Coastguard Worker                    y_mps = t_mps[:i, j]
3508*da0073e9SAndroid Build Coastguard Worker                    y_cpu = t_cpu[:i, j]
3509*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_mps + 1, y_cpu + 1)
3510*da0073e9SAndroid Build Coastguard Worker
3511*da0073e9SAndroid Build Coastguard Worker                    y_mps = t_mps[:i, j:]
3512*da0073e9SAndroid Build Coastguard Worker                    y_cpu = t_cpu[:i, j:]
3513*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_mps + 1, y_cpu + 1)
3514*da0073e9SAndroid Build Coastguard Worker
3515*da0073e9SAndroid Build Coastguard Worker        l = []
3516*da0073e9SAndroid Build Coastguard Worker        for N in range(1, 3):
3517*da0073e9SAndroid Build Coastguard Worker            l.append(N)
3518*da0073e9SAndroid Build Coastguard Worker            for C in range(1, 3):
3519*da0073e9SAndroid Build Coastguard Worker                l.append(C)
3520*da0073e9SAndroid Build Coastguard Worker                helper(l)
3521*da0073e9SAndroid Build Coastguard Worker                for D in range(1, 3):
3522*da0073e9SAndroid Build Coastguard Worker                    l.append(D)
3523*da0073e9SAndroid Build Coastguard Worker                    helper(l)
3524*da0073e9SAndroid Build Coastguard Worker                    for H in range(1, 3):
3525*da0073e9SAndroid Build Coastguard Worker                        l.append(H)
3526*da0073e9SAndroid Build Coastguard Worker                        helper(l)
3527*da0073e9SAndroid Build Coastguard Worker                        for W in range(1, 3):
3528*da0073e9SAndroid Build Coastguard Worker                            l.append(W)
3529*da0073e9SAndroid Build Coastguard Worker                            helper(l)
3530*da0073e9SAndroid Build Coastguard Worker                            l.pop()
3531*da0073e9SAndroid Build Coastguard Worker                        l.pop()
3532*da0073e9SAndroid Build Coastguard Worker                    l.pop()
3533*da0073e9SAndroid Build Coastguard Worker                l.pop()
3534*da0073e9SAndroid Build Coastguard Worker            l.pop()
3535*da0073e9SAndroid Build Coastguard Worker
3536*da0073e9SAndroid Build Coastguard Worker        helper([9, 15, 4])
3537*da0073e9SAndroid Build Coastguard Worker        helper([9, 3, 2])
3538*da0073e9SAndroid Build Coastguard Worker        helper([3, 4, 18, 22])
3539*da0073e9SAndroid Build Coastguard Worker        helper([3, 4, 18, 22, 150])
3540*da0073e9SAndroid Build Coastguard Worker
3541*da0073e9SAndroid Build Coastguard Worker    def test_contiguous_slice_3d(self):
3542*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 3, device="mps")
3543*da0073e9SAndroid Build Coastguard Worker        x_cpu = x.detach().clone().cpu()
3544*da0073e9SAndroid Build Coastguard Worker        x = x[:1]
3545*da0073e9SAndroid Build Coastguard Worker        x_cpu = x_cpu[:1]
3546*da0073e9SAndroid Build Coastguard Worker        out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2]
3547*da0073e9SAndroid Build Coastguard Worker        out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2]
3548*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out_cpu)
3549*da0073e9SAndroid Build Coastguard Worker
3550*da0073e9SAndroid Build Coastguard Worker    def test_view_slice(self):
3551*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/83995
3552*da0073e9SAndroid Build Coastguard Worker        NUM_SAMPLES = 60
3553*da0073e9SAndroid Build Coastguard Worker        s = (0, 1)
3554*da0073e9SAndroid Build Coastguard Worker
3555*da0073e9SAndroid Build Coastguard Worker        X = torch.rand(8000, 3, dtype=torch.float32, device='cpu')
3556*da0073e9SAndroid Build Coastguard Worker        X_mps = X.detach().clone().to("cpu")
3557*da0073e9SAndroid Build Coastguard Worker
3558*da0073e9SAndroid Build Coastguard Worker        idx = torch.randint(0, X.shape[0], (1,)).repeat(len(s))
3559*da0073e9SAndroid Build Coastguard Worker        pts = torch.randint(0, X.shape[0], (NUM_SAMPLES, X.shape[1]))
3560*da0073e9SAndroid Build Coastguard Worker        idx_mps = idx.to("mps")
3561*da0073e9SAndroid Build Coastguard Worker        pts_mps = pts.to("mps")
3562*da0073e9SAndroid Build Coastguard Worker        pts[:, s] = idx
3563*da0073e9SAndroid Build Coastguard Worker        pts_mps[:, s] = idx_mps
3564*da0073e9SAndroid Build Coastguard Worker
3565*da0073e9SAndroid Build Coastguard Worker        actual_pts = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float)
3566*da0073e9SAndroid Build Coastguard Worker        actual_pts_mps = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float, device="mps")
3567*da0073e9SAndroid Build Coastguard Worker
3568*da0073e9SAndroid Build Coastguard Worker        for i in range(NUM_SAMPLES):
3569*da0073e9SAndroid Build Coastguard Worker            for j in range(X.shape[1]):
3570*da0073e9SAndroid Build Coastguard Worker                actual_pts_mps[i, j] = X_mps[pts_mps[i, j], j]
3571*da0073e9SAndroid Build Coastguard Worker                actual_pts[i, j] = X[pts[i, j], j]
3572*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_pts[i, j], actual_pts_mps[i, j])
3573*da0073e9SAndroid Build Coastguard Worker
3574*da0073e9SAndroid Build Coastguard Worker    def test_slice_scatter(self):
3575*da0073e9SAndroid Build Coastguard Worker        shape = (4, 4)
3576*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randint(10, shape, device="mps")
3577*da0073e9SAndroid Build Coastguard Worker        tensor_before = tensor.clone()
3578*da0073e9SAndroid Build Coastguard Worker        torch.empty(shape[0], shape[1] * 2, device="mps")[:, ::2].copy_(tensor)
3579*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(tensor, tensor_before)
3580*da0073e9SAndroid Build Coastguard Worker
3581*da0073e9SAndroid Build Coastguard Worker    def test_slice(self):
3582*da0073e9SAndroid Build Coastguard Worker        values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
3583*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values, device='cpu')
3584*da0073e9SAndroid Build Coastguard Worker        mps_x = (torch.tensor(values, device='mps', dtype=torch.float))
3585*da0073e9SAndroid Build Coastguard Worker
3586*da0073e9SAndroid Build Coastguard Worker        cpu_slice1 = cpu_x[:2, :]
3587*da0073e9SAndroid Build Coastguard Worker        mps_slice1 = mps_x[:2, :]
3588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_slice1, mps_slice1)
3589*da0073e9SAndroid Build Coastguard Worker
3590*da0073e9SAndroid Build Coastguard Worker        cpu_slice2 = cpu_x[:, :1]
3591*da0073e9SAndroid Build Coastguard Worker        mps_slice2 = mps_x[:, :1]
3592*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_slice2, mps_slice2)
3593*da0073e9SAndroid Build Coastguard Worker
3594*da0073e9SAndroid Build Coastguard Worker        cpu_slice3 = cpu_x[1:2, :]
3595*da0073e9SAndroid Build Coastguard Worker        mps_slice3 = mps_x[1:2, :]
3596*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_slice3, mps_slice3.to('cpu'))
3597*da0073e9SAndroid Build Coastguard Worker
3598*da0073e9SAndroid Build Coastguard Worker        cpu_slice4 = cpu_x[1, :]
3599*da0073e9SAndroid Build Coastguard Worker        mps_slice4 = mps_x[1, :].to('cpu')
3600*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_slice4, mps_slice4)
3601*da0073e9SAndroid Build Coastguard Worker
3602*da0073e9SAndroid Build Coastguard Worker    @parametrize("torch_type", arg_values=[torch.float16, torch.float32, torch.bfloat16])
3603*da0073e9SAndroid Build Coastguard Worker    def test_slice_view_api(self, torch_type: torch.dtype):
3604*da0073e9SAndroid Build Coastguard Worker
3605*da0073e9SAndroid Build Coastguard Worker        def helper(x_tensor, y_func, z_func, r_func=None):
3606*da0073e9SAndroid Build Coastguard Worker            x_mps = x_tensor.detach().clone().to("mps")
3607*da0073e9SAndroid Build Coastguard Worker
3608*da0073e9SAndroid Build Coastguard Worker            y = y_func(x_tensor)
3609*da0073e9SAndroid Build Coastguard Worker            y_mps = y_func(x_mps)
3610*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, y_mps)
3611*da0073e9SAndroid Build Coastguard Worker
3612*da0073e9SAndroid Build Coastguard Worker            z = z_func(y)
3613*da0073e9SAndroid Build Coastguard Worker            z_mps = z_func(y_mps)
3614*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z, z_mps)
3615*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z.storage_offset(), z_mps.storage_offset())
3616*da0073e9SAndroid Build Coastguard Worker
3617*da0073e9SAndroid Build Coastguard Worker            if r_func:
3618*da0073e9SAndroid Build Coastguard Worker                r = r_func(z)
3619*da0073e9SAndroid Build Coastguard Worker                r_mps = r_func(z_mps)
3620*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r, r_mps)
3621*da0073e9SAndroid Build Coastguard Worker
3622*da0073e9SAndroid Build Coastguard Worker        # Skip bfloat16 before MacOS15
3623*da0073e9SAndroid Build Coastguard Worker        if not (product_version < 15.0 and torch_type == torch.bfloat16):
3624*da0073e9SAndroid Build Coastguard Worker            # Tests for previously encountered MPS bugs
3625*da0073e9SAndroid Build Coastguard Worker            helper(
3626*da0073e9SAndroid Build Coastguard Worker                torch.randn(4, 4, dtype=torch_type),
3627*da0073e9SAndroid Build Coastguard Worker                lambda x: x[1],
3628*da0073e9SAndroid Build Coastguard Worker                lambda y: y.reshape(2, 2),
3629*da0073e9SAndroid Build Coastguard Worker                lambda z: z + 1
3630*da0073e9SAndroid Build Coastguard Worker            )
3631*da0073e9SAndroid Build Coastguard Worker            helper(
3632*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 4, dtype=torch_type),
3633*da0073e9SAndroid Build Coastguard Worker                lambda x: x[1],
3634*da0073e9SAndroid Build Coastguard Worker                lambda y: y + torch.ones(4, device=y.device)
3635*da0073e9SAndroid Build Coastguard Worker            )
3636*da0073e9SAndroid Build Coastguard Worker            helper(
3637*da0073e9SAndroid Build Coastguard Worker                torch.randn(4, 6, dtype=torch_type),
3638*da0073e9SAndroid Build Coastguard Worker                lambda x: x[1],
3639*da0073e9SAndroid Build Coastguard Worker                lambda y: y.reshape(3, 2).t(),
3640*da0073e9SAndroid Build Coastguard Worker                lambda z: z + 1
3641*da0073e9SAndroid Build Coastguard Worker            )
3642*da0073e9SAndroid Build Coastguard Worker            helper(
3643*da0073e9SAndroid Build Coastguard Worker                torch.arange(4, dtype=torch_type).resize(1, 2, 2),
3644*da0073e9SAndroid Build Coastguard Worker                lambda x: x.permute(2, 0, 1),
3645*da0073e9SAndroid Build Coastguard Worker                lambda y: y + 1
3646*da0073e9SAndroid Build Coastguard Worker            )
3647*da0073e9SAndroid Build Coastguard Worker            helper(
3648*da0073e9SAndroid Build Coastguard Worker                torch.randn(4, 8, dtype=torch_type),
3649*da0073e9SAndroid Build Coastguard Worker                lambda x: x.transpose(0, 1).reshape(-1),
3650*da0073e9SAndroid Build Coastguard Worker                lambda y: y[:2],
3651*da0073e9SAndroid Build Coastguard Worker                lambda z: z + 1
3652*da0073e9SAndroid Build Coastguard Worker            )
3653*da0073e9SAndroid Build Coastguard Worker            helper(
3654*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, dtype=torch_type),
3655*da0073e9SAndroid Build Coastguard Worker                lambda x: x.expand(2, 3),
3656*da0073e9SAndroid Build Coastguard Worker                lambda y: y + torch.ones(2, 3, device=y.device)
3657*da0073e9SAndroid Build Coastguard Worker            )
3658*da0073e9SAndroid Build Coastguard Worker
3659*da0073e9SAndroid Build Coastguard Worker    def test_slice_reshape_contiguous(self):
3660*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 4)
3661*da0073e9SAndroid Build Coastguard Worker        x_mps = x.detach().clone().to("mps")
3662*da0073e9SAndroid Build Coastguard Worker
3663*da0073e9SAndroid Build Coastguard Worker        y = x[1]
3664*da0073e9SAndroid Build Coastguard Worker        y_mps = x_mps[1]
3665*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_mps)
3666*da0073e9SAndroid Build Coastguard Worker
3667*da0073e9SAndroid Build Coastguard Worker        z = y.reshape(2, 2)
3668*da0073e9SAndroid Build Coastguard Worker        z_mps = y_mps.reshape(2, 2)
3669*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z, z_mps)
3670*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.storage_offset(), z_mps.storage_offset())
3671*da0073e9SAndroid Build Coastguard Worker
3672*da0073e9SAndroid Build Coastguard Worker    def test_scalar_from_slice_unary(self):
3673*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/82543
3674*da0073e9SAndroid Build Coastguard Worker        tensor_list = torch.tensor([1.0, 1.2], device="mps")
3675*da0073e9SAndroid Build Coastguard Worker
3676*da0073e9SAndroid Build Coastguard Worker        for scalar in tensor_list:
3677*da0073e9SAndroid Build Coastguard Worker            r_mps = torch.ceil(scalar)
3678*da0073e9SAndroid Build Coastguard Worker            r_cpu = torch.ceil(scalar.to("cpu"))
3679*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r_mps.cpu(), r_cpu)
3680*da0073e9SAndroid Build Coastguard Worker
3681*da0073e9SAndroid Build Coastguard Worker    def test_scalar_from_slice_binary(self):
3682*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/82543
3683*da0073e9SAndroid Build Coastguard Worker        def helper(binary_op):
3684*da0073e9SAndroid Build Coastguard Worker            tensor_list = torch.tensor([1.0, 1.2, 2.5, 1.0], device="mps")
3685*da0073e9SAndroid Build Coastguard Worker
3686*da0073e9SAndroid Build Coastguard Worker            for scalar in tensor_list:
3687*da0073e9SAndroid Build Coastguard Worker                r_mps = binary_op(scalar, 1.0)
3688*da0073e9SAndroid Build Coastguard Worker                r_cpu = binary_op(scalar.cpu(), 1.0)
3689*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(r_mps.cpu(), r_cpu)
3690*da0073e9SAndroid Build Coastguard Worker        helper(torch.sub)
3691*da0073e9SAndroid Build Coastguard Worker        helper(torch.add)
3692*da0073e9SAndroid Build Coastguard Worker        helper(torch.not_equal)
3693*da0073e9SAndroid Build Coastguard Worker        helper(torch.eq)
3694*da0073e9SAndroid Build Coastguard Worker
3695*da0073e9SAndroid Build Coastguard Worker    def test_slice_contiguous_view(self):
3696*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/77750
3697*da0073e9SAndroid Build Coastguard Worker
3698*da0073e9SAndroid Build Coastguard Worker        def helper(operator):
3699*da0073e9SAndroid Build Coastguard Worker            t_mps = torch.tensor([1, 2, 3, 4], device="mps")
3700*da0073e9SAndroid Build Coastguard Worker            t_cpu = torch.tensor([1, 2, 3, 4], device="cpu")
3701*da0073e9SAndroid Build Coastguard Worker
3702*da0073e9SAndroid Build Coastguard Worker            # contiguous view
3703*da0073e9SAndroid Build Coastguard Worker            x_mps = t_mps[2:]  # 3, 4
3704*da0073e9SAndroid Build Coastguard Worker            y_mps = t_mps[:2]  # 1, 2
3705*da0073e9SAndroid Build Coastguard Worker
3706*da0073e9SAndroid Build Coastguard Worker            x_cpu = t_cpu[2:]
3707*da0073e9SAndroid Build Coastguard Worker            y_cpu = t_cpu[:2]
3708*da0073e9SAndroid Build Coastguard Worker
3709*da0073e9SAndroid Build Coastguard Worker            res_mps = res_cpu = None
3710*da0073e9SAndroid Build Coastguard Worker            if operator == "<=":
3711*da0073e9SAndroid Build Coastguard Worker                res_mps = x_mps <= y_mps
3712*da0073e9SAndroid Build Coastguard Worker                res_cpu = x_cpu <= y_cpu
3713*da0073e9SAndroid Build Coastguard Worker            elif operator == "<":
3714*da0073e9SAndroid Build Coastguard Worker                res_mps = x_mps < y_mps
3715*da0073e9SAndroid Build Coastguard Worker                res_cpu = x_cpu < y_cpu
3716*da0073e9SAndroid Build Coastguard Worker            elif operator == ">=":
3717*da0073e9SAndroid Build Coastguard Worker                res_mps = x_mps >= y_mps
3718*da0073e9SAndroid Build Coastguard Worker                res_cpu = x_cpu >= y_cpu
3719*da0073e9SAndroid Build Coastguard Worker            elif operator == ">":
3720*da0073e9SAndroid Build Coastguard Worker                res_mps = x_mps >= y_mps
3721*da0073e9SAndroid Build Coastguard Worker                res_cpu = x_cpu >= y_cpu
3722*da0073e9SAndroid Build Coastguard Worker            elif operator == "==":
3723*da0073e9SAndroid Build Coastguard Worker                res_mps = x_mps == y_mps
3724*da0073e9SAndroid Build Coastguard Worker                res_cpu = x_cpu == y_cpu
3725*da0073e9SAndroid Build Coastguard Worker            elif operator == "!=":
3726*da0073e9SAndroid Build Coastguard Worker                res_mps = x_mps != y_mps
3727*da0073e9SAndroid Build Coastguard Worker                res_cpu = x_cpu != y_cpu
3728*da0073e9SAndroid Build Coastguard Worker            elif operator == "stack":
3729*da0073e9SAndroid Build Coastguard Worker                res_mps = torch.stack((y_mps, x_mps), dim=-1)
3730*da0073e9SAndroid Build Coastguard Worker                res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
3731*da0073e9SAndroid Build Coastguard Worker
3732*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_mps, res_cpu)
3733*da0073e9SAndroid Build Coastguard Worker
3734*da0073e9SAndroid Build Coastguard Worker        for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]:
3735*da0073e9SAndroid Build Coastguard Worker            helper(op)
3736*da0073e9SAndroid Build Coastguard Worker
3737*da0073e9SAndroid Build Coastguard Worker    def test_slice_of_slice(self):
3738*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.5, 0.5], device="cpu")
3739*da0073e9SAndroid Build Coastguard Worker        x_mps = torch.tensor([0.5, 0.5], device="mps")
3740*da0073e9SAndroid Build Coastguard Worker
3741*da0073e9SAndroid Build Coastguard Worker        tensor = x[1][None]
3742*da0073e9SAndroid Build Coastguard Worker        tensor_mps = x_mps[1][None]
3743*da0073e9SAndroid Build Coastguard Worker
3744*da0073e9SAndroid Build Coastguard Worker        res = tensor.ne(0)
3745*da0073e9SAndroid Build Coastguard Worker        res_mps = tensor_mps.ne(0)
3746*da0073e9SAndroid Build Coastguard Worker
3747*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_mps)
3748*da0073e9SAndroid Build Coastguard Worker
3749*da0073e9SAndroid Build Coastguard Worker    def test_index_storage_offset(self):
3750*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/78107
3751*da0073e9SAndroid Build Coastguard Worker
3752*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([8.2670e-01, -1.0293e+00])
3753*da0073e9SAndroid Build Coastguard Worker        b_cpu = a[0]
3754*da0073e9SAndroid Build Coastguard Worker        c_cpu = a[1]
3755*da0073e9SAndroid Build Coastguard Worker
3756*da0073e9SAndroid Build Coastguard Worker        # both 'b' and 'c' are views of 'a'
3757*da0073e9SAndroid Build Coastguard Worker        # 'b' has a storage offset of 0, while 'c' has a storage offset of 1
3758*da0073e9SAndroid Build Coastguard Worker        # when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account,
3759*da0073e9SAndroid Build Coastguard Worker        # otherwise it ends with same value as 'b'
3760*da0073e9SAndroid Build Coastguard Worker        b = b_cpu.to('mps')
3761*da0073e9SAndroid Build Coastguard Worker        c = c_cpu.to('mps')
3762*da0073e9SAndroid Build Coastguard Worker
3763*da0073e9SAndroid Build Coastguard Worker        res_mps = b > c
3764*da0073e9SAndroid Build Coastguard Worker        res_cpu = b_cpu > c_cpu
3765*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_mps, res_cpu)
3766*da0073e9SAndroid Build Coastguard Worker
3767*da0073e9SAndroid Build Coastguard Worker        res_mps = c > b
3768*da0073e9SAndroid Build Coastguard Worker        res_cpu = c_cpu > b_cpu
3769*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_mps, res_cpu)
3770*da0073e9SAndroid Build Coastguard Worker
3771*da0073e9SAndroid Build Coastguard Worker    def test_flatten(self):
3772*da0073e9SAndroid Build Coastguard Worker        values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
3773*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values, device='cpu')
3774*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.tensor(values, device='mps')
3775*da0073e9SAndroid Build Coastguard Worker
3776*da0073e9SAndroid Build Coastguard Worker        cpu_flatten1 = cpu_x.flatten()
3777*da0073e9SAndroid Build Coastguard Worker        mps_flatten1 = mps_x.flatten().to('cpu')
3778*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_flatten1, mps_flatten1)
3779*da0073e9SAndroid Build Coastguard Worker
3780*da0073e9SAndroid Build Coastguard Worker        cpu_flatten2 = cpu_x.flatten(start_dim=1)
3781*da0073e9SAndroid Build Coastguard Worker        mps_flatten2 = mps_x.flatten(start_dim=1).to('cpu')
3782*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_flatten2, mps_flatten2)
3783*da0073e9SAndroid Build Coastguard Worker
3784*da0073e9SAndroid Build Coastguard Worker        cpu_flatten3 = cpu_x.flatten(end_dim=1)
3785*da0073e9SAndroid Build Coastguard Worker        mps_flatten3 = mps_x.flatten(end_dim=1).to('cpu')
3786*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_flatten3, mps_flatten3)
3787*da0073e9SAndroid Build Coastguard Worker
3788*da0073e9SAndroid Build Coastguard Worker    # Test repeat
3789*da0073e9SAndroid Build Coastguard Worker    def test_repeat(self):
3790*da0073e9SAndroid Build Coastguard Worker        def helper(shape, repeats):
3791*da0073e9SAndroid Build Coastguard Worker
3792*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3793*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
3794*da0073e9SAndroid Build Coastguard Worker
3795*da0073e9SAndroid Build Coastguard Worker            y = x.repeat(repeats)
3796*da0073e9SAndroid Build Coastguard Worker            ref_y = cpu_x.repeat(repeats)
3797*da0073e9SAndroid Build Coastguard Worker
3798*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(ref_y.shape)
3799*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
3800*da0073e9SAndroid Build Coastguard Worker
3801*da0073e9SAndroid Build Coastguard Worker            y.backward(gradient=grad)
3802*da0073e9SAndroid Build Coastguard Worker            ref_y.backward(gradient=cpu_grad)
3803*da0073e9SAndroid Build Coastguard Worker
3804*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y)
3805*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
3806*da0073e9SAndroid Build Coastguard Worker
3807*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5), (2, 3, 4, 5))
3808*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4), (4, 3, 2, 5, 7, 2))
3809*da0073e9SAndroid Build Coastguard Worker        helper((3, 4, 5), (2, 3, 4, 5))
3810*da0073e9SAndroid Build Coastguard Worker        helper((3, 4, 5), (2, 2, 2))
3811*da0073e9SAndroid Build Coastguard Worker
3812*da0073e9SAndroid Build Coastguard Worker    def test_torch_repeat_interleave(self, device="mps"):
3813*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([[1, 2], [3, 4]], device=device)
3814*da0073e9SAndroid Build Coastguard Worker        # exercise single argument function signature
3815*da0073e9SAndroid Build Coastguard Worker        temp = y.repeat_interleave(2)
3816*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.Size([8]), temp.size())
3817*da0073e9SAndroid Build Coastguard Worker
3818*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.int, torch.long]:
3819*da0073e9SAndroid Build Coastguard Worker            lengths = torch.tensor([1, 2], dtype=dtype, device="mps")
3820*da0073e9SAndroid Build Coastguard Worker            output_size = torch.sum(lengths)
3821*da0073e9SAndroid Build Coastguard Worker            a = torch.repeat_interleave(
3822*da0073e9SAndroid Build Coastguard Worker                y,
3823*da0073e9SAndroid Build Coastguard Worker                lengths,
3824*da0073e9SAndroid Build Coastguard Worker                dim=0,
3825*da0073e9SAndroid Build Coastguard Worker            )
3826*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.dtype, y.dtype)
3827*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.size(), torch.Size([3, 2]))
3828*da0073e9SAndroid Build Coastguard Worker
3829*da0073e9SAndroid Build Coastguard Worker            a_with_output = torch.repeat_interleave(
3830*da0073e9SAndroid Build Coastguard Worker                y,
3831*da0073e9SAndroid Build Coastguard Worker                lengths,
3832*da0073e9SAndroid Build Coastguard Worker                dim=0,
3833*da0073e9SAndroid Build Coastguard Worker                output_size=output_size,
3834*da0073e9SAndroid Build Coastguard Worker            )
3835*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a_with_output.dtype, y.dtype)
3836*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
3837*da0073e9SAndroid Build Coastguard Worker
3838*da0073e9SAndroid Build Coastguard Worker    def test_repeat_interleave(self, device="mps"):
3839*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0, 1, 2, 3], device=device)
3840*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor([1, 2, 2, 3, 3, 3], device=device)
3841*da0073e9SAndroid Build Coastguard Worker        # Prior to macos 13.3, input of dtype=torch.int64 returns dtype=torch.int32
3842*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.repeat_interleave(x), expected, exact_dtype=product_version >= 13.3)
3843*da0073e9SAndroid Build Coastguard Worker
3844*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3845*da0073e9SAndroid Build Coastguard Worker            torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2))
3846*da0073e9SAndroid Build Coastguard Worker
3847*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3848*da0073e9SAndroid Build Coastguard Worker            torch.repeat_interleave(torch.arange(4.0, device=device))
3849*da0073e9SAndroid Build Coastguard Worker
3850*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3851*da0073e9SAndroid Build Coastguard Worker            torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4], device=device))
3852*da0073e9SAndroid Build Coastguard Worker
3853*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([[1, 2], [3, 4]], device=device)
3854*da0073e9SAndroid Build Coastguard Worker
3855*da0073e9SAndroid Build Coastguard Worker        y1_v1 = torch.repeat_interleave(y, 2)
3856*da0073e9SAndroid Build Coastguard Worker        y1_v2 = torch.repeat_interleave(y, torch.tensor(2, device=device))
3857*da0073e9SAndroid Build Coastguard Worker        y1_v3 = torch.repeat_interleave(y, torch.tensor([2], device=device))
3858*da0073e9SAndroid Build Coastguard Worker        y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4], device=device)
3859*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1_v1, y1_expect)
3860*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1_v2, y1_expect)
3861*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1_v3, y1_expect)
3862*da0073e9SAndroid Build Coastguard Worker
3863*da0073e9SAndroid Build Coastguard Worker        y2 = torch.repeat_interleave(y, 3, dim=1)
3864*da0073e9SAndroid Build Coastguard Worker        y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2],
3865*da0073e9SAndroid Build Coastguard Worker                                  [3, 3, 3, 4, 4, 4]], device=device)
3866*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y2, y2_expect)
3867*da0073e9SAndroid Build Coastguard Worker
3868*da0073e9SAndroid Build Coastguard Worker        y3 = torch.repeat_interleave(y, torch.tensor([1, 2], device=device), dim=0)
3869*da0073e9SAndroid Build Coastguard Worker        y3_expect = torch.tensor([[1, 2],
3870*da0073e9SAndroid Build Coastguard Worker                                  [3, 4],
3871*da0073e9SAndroid Build Coastguard Worker                                  [3, 4]], device=device)
3872*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y3, y3_expect)
3873*da0073e9SAndroid Build Coastguard Worker
3874*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3875*da0073e9SAndroid Build Coastguard Worker            torch.repeat_interleave(y, torch.tensor([1, 2, 3], device=device), dim=0)
3876*da0073e9SAndroid Build Coastguard Worker
3877*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3878*da0073e9SAndroid Build Coastguard Worker            torch.repeat_interleave(y, torch.arange(9, device=device).reshape(3, 3), dim=0)
3879*da0073e9SAndroid Build Coastguard Worker
3880*da0073e9SAndroid Build Coastguard Worker        # test zero sized dimension
3881*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros((5, 0), device=device)
3882*da0073e9SAndroid Build Coastguard Worker        y = torch.repeat_interleave(x, repeats=3, dim=1)
3883*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.new_zeros(5, 0, device=device))
3884*da0073e9SAndroid Build Coastguard Worker
3885*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([], dtype=torch.int64, device=device)
3886*da0073e9SAndroid Build Coastguard Worker        y = torch.repeat_interleave(x, x)
3887*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x)
3888*da0073e9SAndroid Build Coastguard Worker
3889*da0073e9SAndroid Build Coastguard Worker    def test_repeat_interleave_simple(self):
3890*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dtype=torch.float32, num_repeats=torch.Tensor(), dim=None):
3891*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, dtype=dtype, device="mps")
3892*da0073e9SAndroid Build Coastguard Worker            x_cpu = x.detach().clone().cpu()
3893*da0073e9SAndroid Build Coastguard Worker
3894*da0073e9SAndroid Build Coastguard Worker            num_repeats_cpu = num_repeats.detach().clone().cpu()
3895*da0073e9SAndroid Build Coastguard Worker
3896*da0073e9SAndroid Build Coastguard Worker            repeats = torch.repeat_interleave(x, num_repeats, dim)
3897*da0073e9SAndroid Build Coastguard Worker            repeats_cpu = torch.repeat_interleave(x_cpu, num_repeats_cpu, dim)
3898*da0073e9SAndroid Build Coastguard Worker
3899*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(repeats, repeats_cpu)
3900*da0073e9SAndroid Build Coastguard Worker        helper(shape=3, num_repeats=torch.tensor([100], device="mps"))
3901*da0073e9SAndroid Build Coastguard Worker        helper(shape=(2, 2), num_repeats=torch.tensor([3, 3], device="mps"), dim=0)
3902*da0073e9SAndroid Build Coastguard Worker        helper(shape=(10, 15, 8), num_repeats=torch.arange(10, device="mps"), dim=0)
3903*da0073e9SAndroid Build Coastguard Worker        helper(shape=(10, 15, 8), num_repeats=torch.randint(0, 100, (15, ), device="mps"), dim=1)
3904*da0073e9SAndroid Build Coastguard Worker        helper(shape=(10, 15, 30), num_repeats=torch.randint(0, 100, (30, ), device="mps"), dim=2)
3905*da0073e9SAndroid Build Coastguard Worker
3906*da0073e9SAndroid Build Coastguard Worker    def test_count_nonzero(self):
3907*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
3908*da0073e9SAndroid Build Coastguard Worker            n = [
3909*da0073e9SAndroid Build Coastguard Worker                [[1, 0, 2], [3, 0, 2], [7, 9, -4]],
3910*da0073e9SAndroid Build Coastguard Worker                [[0, 2, 3], [3, 2, 1], [2, 0, 0]],
3911*da0073e9SAndroid Build Coastguard Worker            ]
3912*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.tensor(n, dtype=dtype)
3913*da0073e9SAndroid Build Coastguard Worker            mps_x = torch.tensor(n, dtype=dtype).to('mps')
3914*da0073e9SAndroid Build Coastguard Worker
3915*da0073e9SAndroid Build Coastguard Worker            # All non-zeros
3916*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
3917*da0073e9SAndroid Build Coastguard Worker                torch.count_nonzero(cpu_x),
3918*da0073e9SAndroid Build Coastguard Worker                torch.count_nonzero(mps_x)
3919*da0073e9SAndroid Build Coastguard Worker            )
3920*da0073e9SAndroid Build Coastguard Worker
3921*da0073e9SAndroid Build Coastguard Worker            # dim=1
3922*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
3923*da0073e9SAndroid Build Coastguard Worker                torch.count_nonzero(cpu_x, dim=1),
3924*da0073e9SAndroid Build Coastguard Worker                torch.count_nonzero(mps_x, dim=1)
3925*da0073e9SAndroid Build Coastguard Worker            )
3926*da0073e9SAndroid Build Coastguard Worker
3927*da0073e9SAndroid Build Coastguard Worker            # dim=(0, 1)
3928*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
3929*da0073e9SAndroid Build Coastguard Worker                torch.count_nonzero(cpu_x, dim=(0, 1)),
3930*da0073e9SAndroid Build Coastguard Worker                torch.count_nonzero(mps_x, dim=(0, 1))
3931*da0073e9SAndroid Build Coastguard Worker            )
3932*da0073e9SAndroid Build Coastguard Worker        helper(torch.int32)
3933*da0073e9SAndroid Build Coastguard Worker        helper(torch.int64)
3934*da0073e9SAndroid Build Coastguard Worker        helper(torch.float16)
3935*da0073e9SAndroid Build Coastguard Worker        helper(torch.float32)
3936*da0073e9SAndroid Build Coastguard Worker
3937*da0073e9SAndroid Build Coastguard Worker    def _test_module_empty_input(self, module, inp, check_size=True):
3938*da0073e9SAndroid Build Coastguard Worker        inp.requires_grad_(True)
3939*da0073e9SAndroid Build Coastguard Worker        out = module(inp)
3940*da0073e9SAndroid Build Coastguard Worker        gO = torch.rand_like(out)
3941*da0073e9SAndroid Build Coastguard Worker        out.backward(gO)
3942*da0073e9SAndroid Build Coastguard Worker        if check_size:
3943*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.size(), inp.size())
3944*da0073e9SAndroid Build Coastguard Worker        for p in module.parameters():
3945*da0073e9SAndroid Build Coastguard Worker            if p.requires_grad:
3946*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(p.grad, torch.zeros_like(p.grad))
3947*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inp.grad, torch.zeros_like(inp))
3948*da0073e9SAndroid Build Coastguard Worker
3949*da0073e9SAndroid Build Coastguard Worker    # Test dtype casting, with and without simultaneous device change
3950*da0073e9SAndroid Build Coastguard Worker    def test_to(self):
3951*da0073e9SAndroid Build Coastguard Worker        values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
3952*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values, device='cpu')
3953*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.tensor(values, device='mps')
3954*da0073e9SAndroid Build Coastguard Worker
3955*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x.int(), mps_x.int().cpu())
3956*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x.bool(), mps_x.bool().cpu())
3957*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x.float(), mps_x.float().cpu())
3958*da0073e9SAndroid Build Coastguard Worker
3959*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(),
3960*da0073e9SAndroid Build Coastguard Worker                         torch.tensor(1, dtype=torch.int32))
3961*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False))
3962*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True))
3963*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(),
3964*da0073e9SAndroid Build Coastguard Worker                         torch.tensor(1, dtype=torch.int32))
3965*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(),
3966*da0073e9SAndroid Build Coastguard Worker                         torch.tensor(1.0))
3967*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int),
3968*da0073e9SAndroid Build Coastguard Worker                         torch.tensor(4, dtype=torch.int32))
3969*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(),
3970*da0073e9SAndroid Build Coastguard Worker                         torch.tensor(4, dtype=torch.int32))
3971*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
3972*da0073e9SAndroid Build Coastguard Worker                         torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
3973*da0073e9SAndroid Build Coastguard Worker        # Cast int8 and uint8 to float and compare results
3974*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/80009 for more details
3975*da0073e9SAndroid Build Coastguard Worker        cpu_byte = torch.tensor([60, 160, 20, 220], dtype=torch.uint8)
3976*da0073e9SAndroid Build Coastguard Worker        cpu_char = torch.tensor([60, -60, 20, -120], dtype=torch.uint8)
3977*da0073e9SAndroid Build Coastguard Worker        for x_cpu in [cpu_byte, cpu_char]:
3978*da0073e9SAndroid Build Coastguard Worker            x_mps = x_cpu.to('mps')
3979*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_mps.to(torch.float32), x_cpu.to(torch.float32))
3980*da0073e9SAndroid Build Coastguard Worker
3981*da0073e9SAndroid Build Coastguard Worker
3982*da0073e9SAndroid Build Coastguard Worker    def test_setitem_scalar(self) -> None:
3983*da0073e9SAndroid Build Coastguard Worker        device = 'mps'
3984*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.int32, torch.float32, torch.int64]:
3985*da0073e9SAndroid Build Coastguard Worker            for i in range(3, 6):
3986*da0073e9SAndroid Build Coastguard Worker                for j in range(3, 6):
3987*da0073e9SAndroid Build Coastguard Worker                    t = torch.zeros(i, j, dtype=dtype, device=device)
3988*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(t.sum(), 0)
3989*da0073e9SAndroid Build Coastguard Worker                    t[1, 1] = 1
3990*da0073e9SAndroid Build Coastguard Worker                    t[2, 1] = j
3991*da0073e9SAndroid Build Coastguard Worker                    t[1, 2] = i
3992*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(t[1, 1], 1)
3993*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(t[1, 2], i)
3994*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(t[2, 1], j)
3995*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(t.sum(), 1 + i + j)
3996*da0073e9SAndroid Build Coastguard Worker
3997*da0073e9SAndroid Build Coastguard Worker    def test_stride_of_strides(self) -> None:
3998*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(32, 1, device='mps')
3999*da0073e9SAndroid Build Coastguard Worker        y = x.as_strided(size=(32, 2), stride=(1, 0))
4000*da0073e9SAndroid Build Coastguard Worker        # Casting stride of strided tensor to CPU use to crash with "buffer is not large enough." assert
4001*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/79181#issuecomment-1154683435
4002*da0073e9SAndroid Build Coastguard Worker        z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu")
4003*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z)
4004*da0073e9SAndroid Build Coastguard Worker
4005*da0073e9SAndroid Build Coastguard Worker    def test_type_casting(self):
4006*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/81567
4007*da0073e9SAndroid Build Coastguard Worker        def helper(data, to_dtype):
4008*da0073e9SAndroid Build Coastguard Worker            a_cpu = torch.tensor(data)
4009*da0073e9SAndroid Build Coastguard Worker            a_mps = a_cpu.to(torch.device('mps'))
4010*da0073e9SAndroid Build Coastguard Worker
4011*da0073e9SAndroid Build Coastguard Worker            res_cpu = a_cpu.type(to_dtype)
4012*da0073e9SAndroid Build Coastguard Worker            res_mps = a_mps.type(to_dtype)
4013*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_cpu, res_mps)
4014*da0073e9SAndroid Build Coastguard Worker
4015*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor)
4016*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor)
4017*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor)
4018*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor)
4019*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor)
4020*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor)
4021*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor)
4022*da0073e9SAndroid Build Coastguard Worker
4023*da0073e9SAndroid Build Coastguard Worker    def test_to_casting(self):
4024*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/81567
4025*da0073e9SAndroid Build Coastguard Worker        def helper(data, to_dtype):
4026*da0073e9SAndroid Build Coastguard Worker            a_cpu = torch.tensor(data)
4027*da0073e9SAndroid Build Coastguard Worker            a_mps = a_cpu.to(torch.device('mps'))
4028*da0073e9SAndroid Build Coastguard Worker
4029*da0073e9SAndroid Build Coastguard Worker            res_cpu = a_cpu.to(to_dtype)
4030*da0073e9SAndroid Build Coastguard Worker            res_mps = a_mps.to(to_dtype)
4031*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_cpu, res_mps)
4032*da0073e9SAndroid Build Coastguard Worker
4033*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.int64)
4034*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.float)
4035*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.int32)
4036*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.short)
4037*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.half)
4038*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.int8)
4039*da0073e9SAndroid Build Coastguard Worker        helper([9.0, 3.0, 5.0, 4.0], torch.uint8)
4040*da0073e9SAndroid Build Coastguard Worker
4041*da0073e9SAndroid Build Coastguard Worker    def test_storage_offset_greater_than_src_nbytes(self):
4042*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/80844
4043*da0073e9SAndroid Build Coastguard Worker        n_tensors = 100
4044*da0073e9SAndroid Build Coastguard Worker        n_tensor_elems = 784
4045*da0073e9SAndroid Build Coastguard Worker        elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32)
4046*da0073e9SAndroid Build Coastguard Worker
4047*da0073e9SAndroid Build Coastguard Worker        tensor_list = []
4048*da0073e9SAndroid Build Coastguard Worker        for i in range(0, n_tensors - 1):
4049*da0073e9SAndroid Build Coastguard Worker            # create a list of contiguous view tensors (view tensor created by the slice op)
4050*da0073e9SAndroid Build Coastguard Worker            t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)]
4051*da0073e9SAndroid Build Coastguard Worker            tensor_list.append(t)
4052*da0073e9SAndroid Build Coastguard Worker
4053*da0073e9SAndroid Build Coastguard Worker        for i in range(0, n_tensors - 1):
4054*da0073e9SAndroid Build Coastguard Worker            t = tensor_list[i].view(1, n_tensor_elems)
4055*da0073e9SAndroid Build Coastguard Worker            t_mps = t.to("mps")
4056*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t, t_mps.cpu(), f"i={i}")
4057*da0073e9SAndroid Build Coastguard Worker
4058*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/82427
4059*da0073e9SAndroid Build Coastguard Worker    # and https://github.com/pytorch/pytorch/issues/83692
4060*da0073e9SAndroid Build Coastguard Worker    def test_full_bugs(self):
4061*da0073e9SAndroid Build Coastguard Worker        # Test should not crash
4062*da0073e9SAndroid Build Coastguard Worker        x = torch.full((3, 3), True, device='mps')
4063*da0073e9SAndroid Build Coastguard Worker        # torch.full should work for uint8
4064*da0073e9SAndroid Build Coastguard Worker        y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8)
4065*da0073e9SAndroid Build Coastguard Worker        y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
4066*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_mps, y_cpu)
4067*da0073e9SAndroid Build Coastguard Worker
4068*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
4069*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/84995
4070*da0073e9SAndroid Build Coastguard Worker    def test_div_bugs(self):
4071*da0073e9SAndroid Build Coastguard Worker        for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
4072*da0073e9SAndroid Build Coastguard Worker            if dtype != torch.int64:
4073*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
4074*da0073e9SAndroid Build Coastguard Worker                y = torch.div(x, 101, rounding_mode=mode)
4075*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y.sum(), 0)
4076*da0073e9SAndroid Build Coastguard Worker
4077*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/82663
4078*da0073e9SAndroid Build Coastguard Worker    def test_bool_expand(self):
4079*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')
4080*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([0, 1], dtype=torch.bool, device='mps')
4081*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
4082*da0073e9SAndroid Build Coastguard Worker
4083*da0073e9SAndroid Build Coastguard Worker    def test_int_expand(self):
4084*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1], [0]], dtype=torch.int8, device='mps')
4085*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([0, 1], dtype=torch.int8, device='mps')
4086*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
4087*da0073e9SAndroid Build Coastguard Worker
4088*da0073e9SAndroid Build Coastguard Worker    # Empty unary op should return tensor of the same size
4089*da0073e9SAndroid Build Coastguard Worker    def test_empty_neg(self):
4090*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[]], device='mps')
4091*da0073e9SAndroid Build Coastguard Worker        y = -x
4092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, y)
4093*da0073e9SAndroid Build Coastguard Worker
4094*da0073e9SAndroid Build Coastguard Worker    def _test_unique_scalar_empty(self, dtype, device, f):
4095*da0073e9SAndroid Build Coastguard Worker        # test scalar
4096*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(0, dtype=dtype, device=device)
4097*da0073e9SAndroid Build Coastguard Worker        unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
4098*da0073e9SAndroid Build Coastguard Worker        expected_unique = torch.tensor([0], dtype=dtype, device=device)
4099*da0073e9SAndroid Build Coastguard Worker        expected_inverse = torch.tensor(0, device=device)
4100*da0073e9SAndroid Build Coastguard Worker        expected_counts = torch.tensor([1], device=device)
4101*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unique, expected_unique)
4102*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inverse, expected_inverse)
4103*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counts, expected_counts)
4104*da0073e9SAndroid Build Coastguard Worker
4105*da0073e9SAndroid Build Coastguard Worker        # test zero sized tensor
4106*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros((0, 0, 3), dtype=dtype, device=device)
4107*da0073e9SAndroid Build Coastguard Worker        unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
4108*da0073e9SAndroid Build Coastguard Worker        expected_unique = torch.tensor([], dtype=dtype, device=device)
4109*da0073e9SAndroid Build Coastguard Worker        expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device)
4110*da0073e9SAndroid Build Coastguard Worker        expected_counts = torch.tensor([], dtype=torch.long, device=device)
4111*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unique, expected_unique)
4112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inverse, expected_inverse)
4113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counts, expected_counts)
4114*da0073e9SAndroid Build Coastguard Worker
4115*da0073e9SAndroid Build Coastguard Worker    def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape):
4116*da0073e9SAndroid Build Coastguard Worker        def ensure_tuple(x):
4117*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, torch.Tensor):
4118*da0073e9SAndroid Build Coastguard Worker                return (x,)
4119*da0073e9SAndroid Build Coastguard Worker            return x
4120*da0073e9SAndroid Build Coastguard Worker
4121*da0073e9SAndroid Build Coastguard Worker        for return_inverse in [True, False]:
4122*da0073e9SAndroid Build Coastguard Worker            for return_counts in [True, False]:
4123*da0073e9SAndroid Build Coastguard Worker                # test with expected
4124*da0073e9SAndroid Build Coastguard Worker                ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
4125*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
4126*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected_unique, ret[0])
4127*da0073e9SAndroid Build Coastguard Worker                if return_inverse:
4128*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected_inverse, ret[1])
4129*da0073e9SAndroid Build Coastguard Worker                if return_counts:
4130*da0073e9SAndroid Build Coastguard Worker                    count_index = 1 + int(return_inverse)
4131*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected_counts, ret[count_index])
4132*da0073e9SAndroid Build Coastguard Worker
4133*da0073e9SAndroid Build Coastguard Worker                # tests per-element unique on a higher rank tensor.
4134*da0073e9SAndroid Build Coastguard Worker                y = x.view(additional_shape)
4135*da0073e9SAndroid Build Coastguard Worker                y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True)
4136*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected_unique, y_unique)
4137*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected_inverse.view(additional_shape), y_inverse)
4138*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected_counts, y_counts)
4139*da0073e9SAndroid Build Coastguard Worker
4140*da0073e9SAndroid Build Coastguard Worker    def test_unique_all_dtypes(self, device="mps"):
4141*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
4142*da0073e9SAndroid Build Coastguard Worker            def ensure_tuple(x):
4143*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, torch.Tensor):
4144*da0073e9SAndroid Build Coastguard Worker                    return (x,)
4145*da0073e9SAndroid Build Coastguard Worker                return x
4146*da0073e9SAndroid Build Coastguard Worker
4147*da0073e9SAndroid Build Coastguard Worker            if dtype is torch.bool:
4148*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device)
4149*da0073e9SAndroid Build Coastguard Worker                expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device)
4150*da0073e9SAndroid Build Coastguard Worker                expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device)
4151*da0073e9SAndroid Build Coastguard Worker                expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device)
4152*da0073e9SAndroid Build Coastguard Worker            else:
4153*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device)
4154*da0073e9SAndroid Build Coastguard Worker                expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device)
4155*da0073e9SAndroid Build Coastguard Worker                expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
4156*da0073e9SAndroid Build Coastguard Worker                expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
4157*da0073e9SAndroid Build Coastguard Worker
4158*da0073e9SAndroid Build Coastguard Worker            # test sorted unique
4159*da0073e9SAndroid Build Coastguard Worker            fs = (
4160*da0073e9SAndroid Build Coastguard Worker                lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs),
4161*da0073e9SAndroid Build Coastguard Worker                lambda x, **kwargs: x.unique(sorted=True, **kwargs),
4162*da0073e9SAndroid Build Coastguard Worker            )
4163*da0073e9SAndroid Build Coastguard Worker            x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x)
4164*da0073e9SAndroid Build Coastguard Worker            xs = (x, x_sliced)
4165*da0073e9SAndroid Build Coastguard Worker            for f, x in product(fs, xs):
4166*da0073e9SAndroid Build Coastguard Worker                self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2))
4167*da0073e9SAndroid Build Coastguard Worker                self._test_unique_scalar_empty(dtype, device, f)
4168*da0073e9SAndroid Build Coastguard Worker
4169*da0073e9SAndroid Build Coastguard Worker            # test unsorted unique
4170*da0073e9SAndroid Build Coastguard Worker            fs = (
4171*da0073e9SAndroid Build Coastguard Worker                lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs),
4172*da0073e9SAndroid Build Coastguard Worker                lambda x, **kwargs: x.unique(sorted=False, **kwargs)
4173*da0073e9SAndroid Build Coastguard Worker            )
4174*da0073e9SAndroid Build Coastguard Worker            for f, x in product(fs, xs):
4175*da0073e9SAndroid Build Coastguard Worker                self._test_unique_scalar_empty(dtype, device, f)
4176*da0073e9SAndroid Build Coastguard Worker                for return_inverse, return_counts in product((True, False), repeat=2):
4177*da0073e9SAndroid Build Coastguard Worker                    ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
4178*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
4179*da0073e9SAndroid Build Coastguard Worker                    x_list = x.tolist()
4180*da0073e9SAndroid Build Coastguard Worker                    x_unique_list = ret[0].tolist()
4181*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(expected_unique.tolist(), sorted(x_unique_list))
4182*da0073e9SAndroid Build Coastguard Worker                    if return_inverse:
4183*da0073e9SAndroid Build Coastguard Worker                        x_inverse_list = ret[1].tolist()
4184*da0073e9SAndroid Build Coastguard Worker                        for i, j in enumerate(x_inverse_list):
4185*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(x_list[i], x_unique_list[j])
4186*da0073e9SAndroid Build Coastguard Worker                    if return_counts:
4187*da0073e9SAndroid Build Coastguard Worker                        count_index = 1 + int(return_inverse)
4188*da0073e9SAndroid Build Coastguard Worker                        x_counts_list = ret[count_index].tolist()
4189*da0073e9SAndroid Build Coastguard Worker                        for i, j in zip(x_unique_list, x_counts_list):
4190*da0073e9SAndroid Build Coastguard Worker                            count = 0
4191*da0073e9SAndroid Build Coastguard Worker                            for k in x_list:
4192*da0073e9SAndroid Build Coastguard Worker                                if k == i:
4193*da0073e9SAndroid Build Coastguard Worker                                    count += 1
4194*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(j, count)
4195*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in [torch.float32, torch.int64, torch.int32, torch.int16, torch.uint8]]
4196*da0073e9SAndroid Build Coastguard Worker
4197*da0073e9SAndroid Build Coastguard Worker    def test_unique(self):
4198*da0073e9SAndroid Build Coastguard Worker        def helper(x, return_inverse, return_counts):
4199*da0073e9SAndroid Build Coastguard Worker            cpu_x = x
4200*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
4201*da0073e9SAndroid Build Coastguard Worker
4202*da0073e9SAndroid Build Coastguard Worker            result = torch.unique(x, return_inverse=return_inverse, return_counts=return_counts)
4203*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.unique(cpu_x, return_inverse=return_inverse, return_counts=return_counts)
4204*da0073e9SAndroid Build Coastguard Worker
4205*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_cpu)
4206*da0073e9SAndroid Build Coastguard Worker        helper(torch.tensor([1, 2, 4, 2, 1]), False, False)
4207*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (10, )), False, False)
4208*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (10, )), True, False)
4209*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (10, )), False, True)
4210*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (10, )), True, True)
4211*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (1, )), True, True)
4212*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (0, )), True, True)
4213*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/104879
4214*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(2, device="mps")
4215*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.reshape(1, 1, 2).unique(), x)
4216*da0073e9SAndroid Build Coastguard Worker
4217*da0073e9SAndroid Build Coastguard Worker    def test_unique_consecutive(self):
4218*da0073e9SAndroid Build Coastguard Worker        def helper(x, dim, return_inverse, return_counts):
4219*da0073e9SAndroid Build Coastguard Worker            cpu_x = x
4220*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
4221*da0073e9SAndroid Build Coastguard Worker
4222*da0073e9SAndroid Build Coastguard Worker            result = torch.unique_consecutive(x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
4223*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.unique_consecutive(cpu_x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
4224*da0073e9SAndroid Build Coastguard Worker
4225*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_cpu)
4226*da0073e9SAndroid Build Coastguard Worker        helper(torch.tensor([1, 2, 4, 2, 1]), 0, False, False)
4227*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (10, )), 0, False, False)
4228*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (10, )), 0, True, False)
4229*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (10, )), 0, False, True)
4230*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (10, )), 0, True, True)
4231*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (10, )), 0, True, True)
4232*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (1, )), 0, True, True)
4233*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(3, (0, )), 0, True, True)
4234*da0073e9SAndroid Build Coastguard Worker
4235*da0073e9SAndroid Build Coastguard Worker        helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, False, False)
4236*da0073e9SAndroid Build Coastguard Worker        helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, True, True)
4237*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(2, (20, 2)), 0, True, True)
4238*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(2, (1, 2)), 0, True, True)
4239*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(2, (0, 2)), 0, True, True)
4240*da0073e9SAndroid Build Coastguard Worker
4241*da0073e9SAndroid Build Coastguard Worker        helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, False, False)
4242*da0073e9SAndroid Build Coastguard Worker        helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, True, True)
4243*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(2, (2, 20)), 1, True, True)
4244*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(2, (2, 1)), 1, True, True)
4245*da0073e9SAndroid Build Coastguard Worker        helper(torch.randint(2, (2, 0)), 1, True, True)
4246*da0073e9SAndroid Build Coastguard Worker
4247*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/85675
4248*da0073e9SAndroid Build Coastguard Worker    def test_cat_non_contiguous(self):
4249*da0073e9SAndroid Build Coastguard Worker        def rotate_subset(data, dim):
4250*da0073e9SAndroid Build Coastguard Worker            x1 = data[:, :, :2, :]
4251*da0073e9SAndroid Build Coastguard Worker            x2 = data[:, :, 2:, :]
4252*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x1.is_contiguous())
4253*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x2.is_contiguous())
4254*da0073e9SAndroid Build Coastguard Worker            return torch.concat((x1, x2), dim=dim)
4255*da0073e9SAndroid Build Coastguard Worker        for dtype in MPS_DTYPES:
4256*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bool:
4257*da0073e9SAndroid Build Coastguard Worker                continue
4258*da0073e9SAndroid Build Coastguard Worker            data = torch.arange(48, dtype=dtype).reshape(1, 2, 4, 6)
4259*da0073e9SAndroid Build Coastguard Worker            data = data.to(memory_format=torch.channels_last)
4260*da0073e9SAndroid Build Coastguard Worker            mps_data = data.to("mps")
4261*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(data, mps_data)
4262*da0073e9SAndroid Build Coastguard Worker            for dim in range(data.dim()):
4263*da0073e9SAndroid Build Coastguard Worker                cpu_result = rotate_subset(data, dim)
4264*da0073e9SAndroid Build Coastguard Worker                mps_result = rotate_subset(mps_data, dim)
4265*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cpu_result, mps_result.to("cpu"))
4266*da0073e9SAndroid Build Coastguard Worker                # TODO: enable memory format test
4267*da0073e9SAndroid Build Coastguard Worker                # self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
4268*da0073e9SAndroid Build Coastguard Worker
4269*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/85967
4270*da0073e9SAndroid Build Coastguard Worker    def test_from_numpy_non_contiguous(self):
4271*da0073e9SAndroid Build Coastguard Worker        a = np.arange(9).reshape(3, 3)[:, :2]
4272*da0073e9SAndroid Build Coastguard Worker        t_cpu = torch.tensor(a, device="cpu")
4273*da0073e9SAndroid Build Coastguard Worker        t_mps = torch.tensor(a, device="mps")
4274*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t_cpu, t_mps.to("cpu"))
4275*da0073e9SAndroid Build Coastguard Worker
4276*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/86954
4277*da0073e9SAndroid Build Coastguard Worker    def test_copy_non_contiguous(self):
4278*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
4279*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.is_contiguous())
4280*da0073e9SAndroid Build Coastguard Worker        y = x.to('mps')
4281*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(y.is_contiguous())
4282*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, y.to('cpu'))
4283*da0073e9SAndroid Build Coastguard Worker
4284*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
4285*da0073e9SAndroid Build Coastguard Worker        y = x.to('mps')
4286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, y.to('cpu'))
4287*da0073e9SAndroid Build Coastguard Worker
4288*da0073e9SAndroid Build Coastguard Worker        x = torch.full((4, 4, 4, 4), 13, device="cpu")
4289*da0073e9SAndroid Build Coastguard Worker        y = torch.full((4, 4, 4, 4), 13, device="mps")
4290*da0073e9SAndroid Build Coastguard Worker        z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
4291*da0073e9SAndroid Build Coastguard Worker        x.permute(3, 2, 1, 0)[1::, ::2] = z
4292*da0073e9SAndroid Build Coastguard Worker        # As y is on MPS and z on CPU, this dispatches to a copy operator
4293*da0073e9SAndroid Build Coastguard Worker        y.permute(3, 2, 1, 0)[1::, ::2] = z
4294*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, y.to('cpu'))
4295*da0073e9SAndroid Build Coastguard Worker
4296*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/95417
4297*da0073e9SAndroid Build Coastguard Worker    def test_copy_storage_offset(self):
4298*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.zeros(5, device="cpu", dtype=torch.float32)
4299*da0073e9SAndroid Build Coastguard Worker        x_mps = torch.zeros(5, device="mps", dtype=torch.float32)
4300*da0073e9SAndroid Build Coastguard Worker        update_cpu = torch.tensor([1, 1], device="cpu", dtype=torch.int64)
4301*da0073e9SAndroid Build Coastguard Worker        update_mps = torch.tensor([1, 1], device="mps", dtype=torch.int64)
4302*da0073e9SAndroid Build Coastguard Worker        x_cpu[2:4] = update_cpu
4303*da0073e9SAndroid Build Coastguard Worker        x_mps[2:4] = update_mps  # implicit type casting and copy
4304*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_cpu, x_mps)
4305*da0073e9SAndroid Build Coastguard Worker
4306*da0073e9SAndroid Build Coastguard Worker        x_cpu[2:4] = update_mps  # implicit device moving and copy
4307*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_cpu, x_mps)
4308*da0073e9SAndroid Build Coastguard Worker
4309*da0073e9SAndroid Build Coastguard Worker    def test_copy_broadcasting(self):
4310*da0073e9SAndroid Build Coastguard Worker        def helper(src_shape, dst_shape, src_dtype, dst_dtype):
4311*da0073e9SAndroid Build Coastguard Worker            cpu_src = torch.randint(0, 127, src_shape).to(src_dtype)
4312*da0073e9SAndroid Build Coastguard Worker            cpu_dst = torch.randint(0, 127, dst_shape).to(dst_dtype)
4313*da0073e9SAndroid Build Coastguard Worker            cpu_result = cpu_dst.copy_(cpu_src)
4314*da0073e9SAndroid Build Coastguard Worker            mps_src = cpu_src.to("mps")
4315*da0073e9SAndroid Build Coastguard Worker            mps_dst = cpu_dst.to("mps")
4316*da0073e9SAndroid Build Coastguard Worker            mps_result = mps_dst.copy_(mps_src)
4317*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_result, mps_result)
4318*da0073e9SAndroid Build Coastguard Worker
4319*da0073e9SAndroid Build Coastguard Worker        test_dtypes = [torch.float32, torch.int32, torch.int16, torch.int8]
4320*da0073e9SAndroid Build Coastguard Worker
4321*da0073e9SAndroid Build Coastguard Worker        for (src_dtype, dst_dtype) in itertools.product(test_dtypes, test_dtypes):
4322*da0073e9SAndroid Build Coastguard Worker            helper((2, 1), (2, 3), src_dtype, dst_dtype)
4323*da0073e9SAndroid Build Coastguard Worker            helper((2, 1), (2, 2), src_dtype, dst_dtype)
4324*da0073e9SAndroid Build Coastguard Worker            helper((3, 1, 4, 1), (3, 4, 4, 5), src_dtype, dst_dtype)
4325*da0073e9SAndroid Build Coastguard Worker            helper((3,), (2, 3), src_dtype, dst_dtype)
4326*da0073e9SAndroid Build Coastguard Worker            helper((2,), (2, 2), src_dtype, dst_dtype)
4327*da0073e9SAndroid Build Coastguard Worker            helper((4, 1, 5), (3, 4, 4, 5), src_dtype, dst_dtype)
4328*da0073e9SAndroid Build Coastguard Worker            helper((4, 1, 5), (4, 0, 5), src_dtype, dst_dtype)
4329*da0073e9SAndroid Build Coastguard Worker            helper((1, 5), (4, 0, 5), src_dtype, dst_dtype)
4330*da0073e9SAndroid Build Coastguard Worker            helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype)
4331*da0073e9SAndroid Build Coastguard Worker            helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype)
4332*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/107867
4333*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([[1]], device='mps').item(), 1.0)
4334*da0073e9SAndroid Build Coastguard Worker
4335*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/pull/84742
4336*da0073e9SAndroid Build Coastguard Worker    # and https://github.com/pytorch/pytorch/pull/78319
4337*da0073e9SAndroid Build Coastguard Worker    def test_binops_dtype_precedence(self):
4338*da0073e9SAndroid Build Coastguard Worker        # Test dtype precedence (casting order) in binary operations by comparing to CPU result
4339*da0073e9SAndroid Build Coastguard Worker        # Example values for all dtypes supported on the MPS backend
4340*da0073e9SAndroid Build Coastguard Worker        sample_vals = {
4341*da0073e9SAndroid Build Coastguard Worker            torch.bool: [False, True],
4342*da0073e9SAndroid Build Coastguard Worker            torch.int16: [-15, 0, 1, 10],
4343*da0073e9SAndroid Build Coastguard Worker            torch.int32: [-376, 0, 1, 13],
4344*da0073e9SAndroid Build Coastguard Worker            torch.int64: [-8, 0, 1, 77],
4345*da0073e9SAndroid Build Coastguard Worker            torch.float16: [-234.5, 0.0, 1.0, 2.0],
4346*da0073e9SAndroid Build Coastguard Worker            torch.float32: [-1.0, 0.0, 0.1, 111.99],
4347*da0073e9SAndroid Build Coastguard Worker        }
4348*da0073e9SAndroid Build Coastguard Worker        # Test all combinations of dtypes, operations, dimensionality
4349*da0073e9SAndroid Build Coastguard Worker        for dtype1, dtype2, binop in itertools.product(
4350*da0073e9SAndroid Build Coastguard Worker                sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul', 'div']):
4351*da0073e9SAndroid Build Coastguard Worker            # bool minus bool is generally unsupported, so skip
4352*da0073e9SAndroid Build Coastguard Worker            if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool):
4353*da0073e9SAndroid Build Coastguard Worker                continue
4354*da0073e9SAndroid Build Coastguard Worker            full_shape = (10,)
4355*da0073e9SAndroid Build Coastguard Worker            for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]):
4356*da0073e9SAndroid Build Coastguard Worker                # print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})')
4357*da0073e9SAndroid Build Coastguard Worker                # print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4358*da0073e9SAndroid Build Coastguard Worker                #            (torch.tensor(val2, dtype=dtype2, device='mps')))
4359*da0073e9SAndroid Build Coastguard Worker                # print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4360*da0073e9SAndroid Build Coastguard Worker                #            (torch.tensor(val2, dtype=dtype2, device='cpu')))
4361*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
4362*da0073e9SAndroid Build Coastguard Worker                    getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4363*da0073e9SAndroid Build Coastguard Worker                           (torch.tensor(val2, dtype=dtype2, device='mps')),
4364*da0073e9SAndroid Build Coastguard Worker                    getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4365*da0073e9SAndroid Build Coastguard Worker                           (torch.tensor(val2, dtype=dtype2, device='cpu')))
4366*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
4367*da0073e9SAndroid Build Coastguard Worker                    getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
4368*da0073e9SAndroid Build Coastguard Worker                           (torch.tensor([val2], dtype=dtype2, device='mps')),
4369*da0073e9SAndroid Build Coastguard Worker                    getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
4370*da0073e9SAndroid Build Coastguard Worker                           (torch.tensor([val2], dtype=dtype2, device='cpu')))
4371*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
4372*da0073e9SAndroid Build Coastguard Worker                    getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4373*da0073e9SAndroid Build Coastguard Worker                           (torch.tensor([val2], dtype=dtype2, device='mps')),
4374*da0073e9SAndroid Build Coastguard Worker                    getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4375*da0073e9SAndroid Build Coastguard Worker                           (torch.tensor([val2], dtype=dtype2, device='cpu')))
4376*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
4377*da0073e9SAndroid Build Coastguard Worker                    getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
4378*da0073e9SAndroid Build Coastguard Worker                           (torch.tensor(val2, dtype=dtype2, device='mps')),
4379*da0073e9SAndroid Build Coastguard Worker                    getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
4380*da0073e9SAndroid Build Coastguard Worker                           (torch.tensor(val2, dtype=dtype2, device='cpu')))
4381*da0073e9SAndroid Build Coastguard Worker                # Test tensors created with torch.full
4382*da0073e9SAndroid Build Coastguard Worker                x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps')
4383*da0073e9SAndroid Build Coastguard Worker                y1 = torch.tensor(val2, dtype=dtype2, device='mps')
4384*da0073e9SAndroid Build Coastguard Worker                x2 = torch.full(full_shape, val1, dtype=dtype1, device='cpu')
4385*da0073e9SAndroid Build Coastguard Worker                y2 = torch.tensor(val2, dtype=dtype2, device='cpu')
4386*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2))
4387*da0073e9SAndroid Build Coastguard Worker                x3 = torch.tensor(val1, dtype=dtype1, device='mps')
4388*da0073e9SAndroid Build Coastguard Worker                y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps')
4389*da0073e9SAndroid Build Coastguard Worker                x4 = torch.tensor(val1, dtype=dtype1, device='cpu')
4390*da0073e9SAndroid Build Coastguard Worker                y4 = torch.full(full_shape, val2, dtype=dtype2, device='cpu')
4391*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4))
4392*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
4393*da0073e9SAndroid Build Coastguard Worker                    getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4394*da0073e9SAndroid Build Coastguard Worker                           (torch.full(full_shape, val2, dtype=dtype2, device='mps')),
4395*da0073e9SAndroid Build Coastguard Worker                    getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4396*da0073e9SAndroid Build Coastguard Worker                           (torch.full(full_shape, val2, dtype=dtype2, device='cpu')))
4397*da0073e9SAndroid Build Coastguard Worker
4398*da0073e9SAndroid Build Coastguard Worker    def test_nansum(self):
4399*da0073e9SAndroid Build Coastguard Worker        def helper(dtype, noncontiguous, dim):
4400*da0073e9SAndroid Build Coastguard Worker            zero_cpu = torch.zeros((), dtype=dtype)
4401*da0073e9SAndroid Build Coastguard Worker
4402*da0073e9SAndroid Build Coastguard Worker            # Randomly scale the values
4403*da0073e9SAndroid Build Coastguard Worker            scale = random.randint(10, 100)
4404*da0073e9SAndroid Build Coastguard Worker            x_cpu: torch.Tensor = make_tensor(
4405*da0073e9SAndroid Build Coastguard Worker                (5, 5), dtype=dtype, device='cpu',
4406*da0073e9SAndroid Build Coastguard Worker                low=-scale, high=scale, noncontiguous=noncontiguous)
4407*da0073e9SAndroid Build Coastguard Worker
4408*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point:
4409*da0073e9SAndroid Build Coastguard Worker                nan_mask_cpu = x_cpu < (0.2 * scale)
4410*da0073e9SAndroid Build Coastguard Worker                x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu)
4411*da0073e9SAndroid Build Coastguard Worker                x_cpu[nan_mask_cpu] = np.nan
4412*da0073e9SAndroid Build Coastguard Worker            else:
4413*da0073e9SAndroid Build Coastguard Worker                x_no_nan_cpu = x_cpu
4414*da0073e9SAndroid Build Coastguard Worker
4415*da0073e9SAndroid Build Coastguard Worker            x_mps = x_cpu.to('mps')
4416*da0073e9SAndroid Build Coastguard Worker            actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
4417*da0073e9SAndroid Build Coastguard Worker            expect_out_cpu = torch.empty(0, dtype=dtype)
4418*da0073e9SAndroid Build Coastguard Worker            dim_kwargs = {"dim": dim} if dim is not None else {}
4419*da0073e9SAndroid Build Coastguard Worker            expect = torch.sum(x_no_nan_cpu, **dim_kwargs)
4420*da0073e9SAndroid Build Coastguard Worker
4421*da0073e9SAndroid Build Coastguard Worker            actual_cpu = torch.nansum(x_cpu, **dim_kwargs)
4422*da0073e9SAndroid Build Coastguard Worker            # Sanity check on CPU
4423*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual_cpu)
4424*da0073e9SAndroid Build Coastguard Worker
4425*da0073e9SAndroid Build Coastguard Worker            # Test MPS
4426*da0073e9SAndroid Build Coastguard Worker            actual_mps = torch.nansum(x_mps, **dim_kwargs)
4427*da0073e9SAndroid Build Coastguard Worker            # Test out= variant
4428*da0073e9SAndroid Build Coastguard Worker            torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs)
4429*da0073e9SAndroid Build Coastguard Worker            torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs)
4430*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual_mps)
4431*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect_out_cpu, actual_out_mps)
4432*da0073e9SAndroid Build Coastguard Worker
4433*da0073e9SAndroid Build Coastguard Worker        args = itertools.product(
4434*da0073e9SAndroid Build Coastguard Worker            (torch.float16, torch.float32, torch.int32, torch.int64),   # dtype
4435*da0073e9SAndroid Build Coastguard Worker            (True, False),                                              # noncontiguous
4436*da0073e9SAndroid Build Coastguard Worker            (0, 1, None),                                               # dim
4437*da0073e9SAndroid Build Coastguard Worker        )
4438*da0073e9SAndroid Build Coastguard Worker
4439*da0073e9SAndroid Build Coastguard Worker        for dtype, noncontiguous, dim in args:
4440*da0073e9SAndroid Build Coastguard Worker            with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
4441*da0073e9SAndroid Build Coastguard Worker                helper(dtype, noncontiguous, dim)
4442*da0073e9SAndroid Build Coastguard Worker
4443*da0073e9SAndroid Build Coastguard Worker    def test_cumsum_all_dtypes(self):
4444*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
4445*da0073e9SAndroid Build Coastguard Worker            t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
4446*da0073e9SAndroid Build Coastguard Worker            t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
4447*da0073e9SAndroid Build Coastguard Worker
4448*da0073e9SAndroid Build Coastguard Worker            a = t.cumsum(0, dtype=dtype)
4449*da0073e9SAndroid Build Coastguard Worker            a_cpu = t_cpu.cumsum(0, dtype=dtype)
4450*da0073e9SAndroid Build Coastguard Worker
4451*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.cpu(), a_cpu)
4452*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
4453*da0073e9SAndroid Build Coastguard Worker
4454*da0073e9SAndroid Build Coastguard Worker        try:
4455*da0073e9SAndroid Build Coastguard Worker            helper(torch.int64)
4456*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
4457*da0073e9SAndroid Build Coastguard Worker            e_string = str(e)
4458*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." +
4459*da0073e9SAndroid Build Coastguard Worker                             " Support has been added in macOS 13.3")
4460*da0073e9SAndroid Build Coastguard Worker
4461*da0073e9SAndroid Build Coastguard Worker    def test_cumsum_bool(self):
4462*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2**16, dtype=torch.bool)
4463*da0073e9SAndroid Build Coastguard Worker        t_cpu = a.cumsum(0)
4464*da0073e9SAndroid Build Coastguard Worker        t_mps = a.to("mps").cumsum(0)
4465*da0073e9SAndroid Build Coastguard Worker
4466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t_cpu, t_mps)
4467*da0073e9SAndroid Build Coastguard Worker
4468*da0073e9SAndroid Build Coastguard Worker    def test_cumsum_minus_one_axis(self):
4469*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
4470*da0073e9SAndroid Build Coastguard Worker            # Test with axis -1
4471*da0073e9SAndroid Build Coastguard Worker            cpu_x = None
4472*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.float32:
4473*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
4474*da0073e9SAndroid Build Coastguard Worker            else:
4475*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
4476*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
4477*da0073e9SAndroid Build Coastguard Worker
4478*da0073e9SAndroid Build Coastguard Worker            cpu_y = cpu_x.cumsum(-1)
4479*da0073e9SAndroid Build Coastguard Worker            y = x.cumsum(-1)
4480*da0073e9SAndroid Build Coastguard Worker
4481*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, cpu_y)
4482*da0073e9SAndroid Build Coastguard Worker
4483*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
4484*da0073e9SAndroid Build Coastguard Worker
4485*da0073e9SAndroid Build Coastguard Worker    def test_cumprod_all_dtypes(self):
4486*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
4487*da0073e9SAndroid Build Coastguard Worker            t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
4488*da0073e9SAndroid Build Coastguard Worker            t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
4489*da0073e9SAndroid Build Coastguard Worker
4490*da0073e9SAndroid Build Coastguard Worker            a = t.cumprod(0, dtype=dtype)
4491*da0073e9SAndroid Build Coastguard Worker            a_cpu = t_cpu.cumprod(0, dtype=dtype)
4492*da0073e9SAndroid Build Coastguard Worker
4493*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.cpu(), a_cpu)
4494*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
4495*da0073e9SAndroid Build Coastguard Worker
4496*da0073e9SAndroid Build Coastguard Worker        try:
4497*da0073e9SAndroid Build Coastguard Worker            helper(torch.int64)
4498*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
4499*da0073e9SAndroid Build Coastguard Worker            e_string = str(e)
4500*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e_string, "MPS does not support cumprod_out_mps op with int64 input."
4501*da0073e9SAndroid Build Coastguard Worker                             + " Support has been added in macOS 13.3")
4502*da0073e9SAndroid Build Coastguard Worker
4503*da0073e9SAndroid Build Coastguard Worker    def test_cumprod_minus_one_axis(self):
4504*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
4505*da0073e9SAndroid Build Coastguard Worker            # Test with axis -1
4506*da0073e9SAndroid Build Coastguard Worker            cpu_x = None
4507*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.float32:
4508*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
4509*da0073e9SAndroid Build Coastguard Worker            else:
4510*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
4511*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
4512*da0073e9SAndroid Build Coastguard Worker
4513*da0073e9SAndroid Build Coastguard Worker            cpu_y = cpu_x.cumprod(-1)
4514*da0073e9SAndroid Build Coastguard Worker            y = x.cumprod(-1)
4515*da0073e9SAndroid Build Coastguard Worker
4516*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, cpu_y)
4517*da0073e9SAndroid Build Coastguard Worker
4518*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
4519*da0073e9SAndroid Build Coastguard Worker
4520*da0073e9SAndroid Build Coastguard Worker    def test_median_int16(self):
4521*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dtype):
4522*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
4523*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
4524*da0073e9SAndroid Build Coastguard Worker
4525*da0073e9SAndroid Build Coastguard Worker            median_result = torch.median(x)
4526*da0073e9SAndroid Build Coastguard Worker            median_result_cpu = torch.median(cpu_x)
4527*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(median_result, median_result_cpu)
4528*da0073e9SAndroid Build Coastguard Worker
4529*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), torch.int16)
4530*da0073e9SAndroid Build Coastguard Worker
4531*da0073e9SAndroid Build Coastguard Worker    def test_activation_checkpoint_does_not_error(self):
4532*da0073e9SAndroid Build Coastguard Worker        from torch.utils.checkpoint import checkpoint
4533*da0073e9SAndroid Build Coastguard Worker
4534*da0073e9SAndroid Build Coastguard Worker        for use_reentrant in (True, False):
4535*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(1., device="mps", requires_grad=True)
4536*da0073e9SAndroid Build Coastguard Worker
4537*da0073e9SAndroid Build Coastguard Worker            def fn(x):
4538*da0073e9SAndroid Build Coastguard Worker                return x.sin().cos().exp()
4539*da0073e9SAndroid Build Coastguard Worker
4540*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, a, use_reentrant=use_reentrant)
4541*da0073e9SAndroid Build Coastguard Worker            out.backward()
4542*da0073e9SAndroid Build Coastguard Worker
4543*da0073e9SAndroid Build Coastguard Worker    def test_as_strided(self):
4544*da0073e9SAndroid Build Coastguard Worker        values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
4545*da0073e9SAndroid Build Coastguard Worker        values_1 = [[1.0, 1.0], [1.0, 1.0]]
4546*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values, device='cpu')
4547*da0073e9SAndroid Build Coastguard Worker        ones1 = torch.tensor(values_1, device='mps')
4548*da0073e9SAndroid Build Coastguard Worker        x = cpu_x.detach().clone().to('mps').requires_grad_()
4549*da0073e9SAndroid Build Coastguard Worker        strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2))
4550*da0073e9SAndroid Build Coastguard Worker        strided_mps = torch.as_strided(x, (2, 2), (1, 2))
4551*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(strided_mps, strided_cpu)
4552*da0073e9SAndroid Build Coastguard Worker        strided_cpu_out = strided_cpu + ones1.to('cpu')
4553*da0073e9SAndroid Build Coastguard Worker        strided_mps_out = strided_mps + ones1
4554*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(strided_cpu_out, strided_mps_out)
4555*da0073e9SAndroid Build Coastguard Worker
4556*da0073e9SAndroid Build Coastguard Worker        # test with storage offsets
4557*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.rand(3, 3, device='cpu')
4558*da0073e9SAndroid Build Coastguard Worker        mps_x = cpu_x.to('mps')
4559*da0073e9SAndroid Build Coastguard Worker        strided_cpu1 = torch.as_strided(cpu_x, (2, 2), (1, 2), 0)
4560*da0073e9SAndroid Build Coastguard Worker        strided_mps1 = torch.as_strided(mps_x, (2, 2), (1, 2), 0)
4561*da0073e9SAndroid Build Coastguard Worker        strided_cpu2 = torch.as_strided(cpu_x, (2, 2), (1, 2), 1)
4562*da0073e9SAndroid Build Coastguard Worker        strided_mps2 = torch.as_strided(mps_x, (2, 2), (1, 2), 1)
4563*da0073e9SAndroid Build Coastguard Worker        strided_cpu_out = strided_cpu1 - strided_cpu2
4564*da0073e9SAndroid Build Coastguard Worker        strided_mps_out = strided_mps1 - strided_mps2
4565*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(strided_cpu_out, strided_mps_out)
4566*da0073e9SAndroid Build Coastguard Worker
4567*da0073e9SAndroid Build Coastguard Worker    def test_unfold(self):
4568*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(1., 8)
4569*da0073e9SAndroid Build Coastguard Worker        x_mps = torch.arange(1., 8, device="mps")
4570*da0073e9SAndroid Build Coastguard Worker
4571*da0073e9SAndroid Build Coastguard Worker        y = x.unfold(0, 2, 1)
4572*da0073e9SAndroid Build Coastguard Worker        y_mps = x_mps.unfold(0, 2, 1)
4573*da0073e9SAndroid Build Coastguard Worker
4574*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_mps)
4575*da0073e9SAndroid Build Coastguard Worker
4576*da0073e9SAndroid Build Coastguard Worker    def test_unfold_all_devices_and_dtypes(self):
4577*da0073e9SAndroid Build Coastguard Worker        supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
4578*da0073e9SAndroid Build Coastguard Worker        for dt in supported_dtypes:
4579*da0073e9SAndroid Build Coastguard Worker            x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
4580*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
4581*da0073e9SAndroid Build Coastguard Worker
4582*da0073e9SAndroid Build Coastguard Worker    def test_unfold_scalars(self):
4583*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(0.5, device="mps")
4584*da0073e9SAndroid Build Coastguard Worker        # unfold on a 0-dimensional tensor should always return a 1-d dimensional
4585*da0073e9SAndroid Build Coastguard Worker        # tensor of shape [size] (i.e., the second parameter to unfold)
4586*da0073e9SAndroid Build Coastguard Worker
4587*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
4588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
4589*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))
4590*da0073e9SAndroid Build Coastguard Worker
4591*da0073e9SAndroid Build Coastguard Worker    def test_bincount_simple(self):
4592*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(0, 8, (5,), dtype=torch.int32, device="mps")
4593*da0073e9SAndroid Build Coastguard Worker        input_cpu = input.to("cpu")
4594*da0073e9SAndroid Build Coastguard Worker        weights = torch.linspace(0, 1, steps=5, device="mps", dtype=torch.float32)
4595*da0073e9SAndroid Build Coastguard Worker        weights_cpu = weights.to("cpu")
4596*da0073e9SAndroid Build Coastguard Worker
4597*da0073e9SAndroid Build Coastguard Worker        x = torch.bincount(input)
4598*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.bincount(input_cpu)
4599*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x_cpu)
4600*da0073e9SAndroid Build Coastguard Worker
4601*da0073e9SAndroid Build Coastguard Worker        y = input.bincount(weights)
4602*da0073e9SAndroid Build Coastguard Worker        y_cpu = input_cpu.bincount(weights_cpu)
4603*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_cpu)
4604*da0073e9SAndroid Build Coastguard Worker
4605*da0073e9SAndroid Build Coastguard Worker    def test_bincount_reduction(self):
4606*da0073e9SAndroid Build Coastguard Worker        device = "mps"
4607*da0073e9SAndroid Build Coastguard Worker        # negative input throws
4608*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
4609*da0073e9SAndroid Build Coastguard Worker            torch.bincount(torch.tensor([1, -1], device=device, dtype=torch.int32))
4610*da0073e9SAndroid Build Coastguard Worker        # n-d input, with n > 1 throws
4611*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
4612*da0073e9SAndroid Build Coastguard Worker            torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device))
4613*da0073e9SAndroid Build Coastguard Worker        # minlength < 0 throws
4614*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'):
4615*da0073e9SAndroid Build Coastguard Worker            torch.bincount(torch.tensor([1, 3], device=device),
4616*da0073e9SAndroid Build Coastguard Worker                           torch.tensor([.2, .2], device=device),
4617*da0073e9SAndroid Build Coastguard Worker                           minlength=-1)
4618*da0073e9SAndroid Build Coastguard Worker        # n-d weights, with n > 1 throws
4619*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, '1-d'):
4620*da0073e9SAndroid Build Coastguard Worker            torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
4621*da0073e9SAndroid Build Coastguard Worker                           torch.tensor([[1., 0.3], [1., 0.3]], device=device, dtype=torch.float))
4622*da0073e9SAndroid Build Coastguard Worker        # input and weights dim mismatch
4623*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'same length'):
4624*da0073e9SAndroid Build Coastguard Worker            torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
4625*da0073e9SAndroid Build Coastguard Worker                           torch.tensor([1., 0.3, 0.5], device=device, dtype=torch.float))
4626*da0073e9SAndroid Build Coastguard Worker        # 1-d input with no elements and default minlength
4627*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
4628*da0073e9SAndroid Build Coastguard Worker                         torch.zeros(0, dtype=torch.long, device=device))
4629*da0073e9SAndroid Build Coastguard Worker        # 1-d input with no elements and specified minlength
4630*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
4631*da0073e9SAndroid Build Coastguard Worker                         torch.zeros(10, dtype=torch.long, device=device))
4632*da0073e9SAndroid Build Coastguard Worker
4633*da0073e9SAndroid Build Coastguard Worker        # test tensor method without weights
4634*da0073e9SAndroid Build Coastguard Worker        long_counts = torch.tensor(
4635*da0073e9SAndroid Build Coastguard Worker            [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
4636*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4637*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
4638*da0073e9SAndroid Build Coastguard Worker            long_counts)
4639*da0073e9SAndroid Build Coastguard Worker        # test avoiding overflow for uint8 (#76979)
4640*da0073e9SAndroid Build Coastguard Worker        count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount()
4641*da0073e9SAndroid Build Coastguard Worker        count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount()
4642*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count_uint8, count_int16)
4643*da0073e9SAndroid Build Coastguard Worker        # test minlength functionality
4644*da0073e9SAndroid Build Coastguard Worker        int_counts = torch.bincount(
4645*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 1, 1, 1], device=device, dtype=torch.int32), minlength=5)
4646*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4647*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
4648*da0073e9SAndroid Build Coastguard Worker            int_counts)
4649*da0073e9SAndroid Build Coastguard Worker        # test weights
4650*da0073e9SAndroid Build Coastguard Worker        byte_counts = torch.bincount(
4651*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
4652*da0073e9SAndroid Build Coastguard Worker            torch.tensor([.1, .2, .3, .4, .5], device=device))
4653*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4654*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
4655*da0073e9SAndroid Build Coastguard Worker        byte_counts = torch.bincount(
4656*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
4657*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
4658*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4659*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.int32), byte_counts)
4660*da0073e9SAndroid Build Coastguard Worker        # test non-contiguous inputs and weights
4661*da0073e9SAndroid Build Coastguard Worker        inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device, dtype=torch.int32)
4662*da0073e9SAndroid Build Coastguard Worker        weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
4663*da0073e9SAndroid Build Coastguard Worker        for i in [0, 1]:
4664*da0073e9SAndroid Build Coastguard Worker            assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
4665*da0073e9SAndroid Build Coastguard Worker            assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
4666*da0073e9SAndroid Build Coastguard Worker        # inputs are non-contiguous but weights are contiguous
4667*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
4668*da0073e9SAndroid Build Coastguard Worker        # inputs and weights are non-contiguous
4669*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4670*da0073e9SAndroid Build Coastguard Worker            inputs[:, 1].bincount(weights[:, 1]),
4671*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
4672*da0073e9SAndroid Build Coastguard Worker        # weights are non-contiguous but inputs are contiguous
4673*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
4674*da0073e9SAndroid Build Coastguard Worker                         torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
4675*da0073e9SAndroid Build Coastguard Worker
4676*da0073e9SAndroid Build Coastguard Worker        # test bincount on non-contiguous slices
4677*da0073e9SAndroid Build Coastguard Worker        all0s = torch.zeros((32, 2), dtype=torch.int32, device=device)
4678*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
4679*da0073e9SAndroid Build Coastguard Worker
4680*da0073e9SAndroid Build Coastguard Worker        all1s = torch.ones((32, 2), dtype=torch.int32, device=device)
4681*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
4682*da0073e9SAndroid Build Coastguard Worker
4683*da0073e9SAndroid Build Coastguard Worker        # test large number of bins - global memory use
4684*da0073e9SAndroid Build Coastguard Worker        big_exp = torch.zeros(100, device=device)
4685*da0073e9SAndroid Build Coastguard Worker        big_exp[-1] = 50.0
4686*da0073e9SAndroid Build Coastguard Worker        big_w = torch.tensor([.5] * 100, device=device)
4687*da0073e9SAndroid Build Coastguard Worker        big_out = torch.tensor([99] * 100, device=device, dtype=torch.int32).bincount(big_w)
4688*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(big_exp, big_out)
4689*da0073e9SAndroid Build Coastguard Worker        # test large input size
4690*da0073e9SAndroid Build Coastguard Worker        big_exp = torch.zeros(2, device=device, dtype=torch.int64)
4691*da0073e9SAndroid Build Coastguard Worker        big_exp[1] = 10
4692*da0073e9SAndroid Build Coastguard Worker        big_out = torch.ones(10, dtype=torch.int8, device=device).bincount()
4693*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(big_exp, big_out)
4694*da0073e9SAndroid Build Coastguard Worker
4695*da0073e9SAndroid Build Coastguard Worker    def test_bincount(self):
4696*da0073e9SAndroid Build Coastguard Worker        device = "mps"
4697*da0073e9SAndroid Build Coastguard Worker        input_size = (5000,)
4698*da0073e9SAndroid Build Coastguard Worker        w = torch.randn(input_size, dtype=torch.float, device=device)
4699*da0073e9SAndroid Build Coastguard Worker        w_cpu = w.cpu()
4700*da0073e9SAndroid Build Coastguard Worker
4701*da0073e9SAndroid Build Coastguard Worker        t = torch.randint(50, input_size, dtype=torch.int8, device=device)
4702*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.cpu().bincount(), t.bincount())
4703*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4704*da0073e9SAndroid Build Coastguard Worker
4705*da0073e9SAndroid Build Coastguard Worker        t = torch.randint(500, input_size, dtype=torch.int32, device=device)
4706*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.cpu().bincount(), t.bincount())
4707*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4708*da0073e9SAndroid Build Coastguard Worker
4709*da0073e9SAndroid Build Coastguard Worker        t = torch.randint(2000, input_size, dtype=torch.int32, device=device)
4710*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.cpu().bincount(), t.bincount())
4711*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4712*da0073e9SAndroid Build Coastguard Worker
4713*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros([10], dtype=torch.int32, device=device)
4714*da0073e9SAndroid Build Coastguard Worker        t[0] = 35488
4715*da0073e9SAndroid Build Coastguard Worker        counted = t.bincount(minlength=65536)
4716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sum(counted), 10)
4717*da0073e9SAndroid Build Coastguard Worker
4718*da0073e9SAndroid Build Coastguard Worker    def test_sum_backward(self):
4719*da0073e9SAndroid Build Coastguard Worker        def helper(n, c):
4720*da0073e9SAndroid Build Coastguard Worker            values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
4721*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
4722*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
4723*da0073e9SAndroid Build Coastguard Worker
4724*da0073e9SAndroid Build Coastguard Worker            all_sum = torch.sum(x)
4725*da0073e9SAndroid Build Coastguard Worker            all_sum_cpu = torch.sum(cpu_x)
4726*da0073e9SAndroid Build Coastguard Worker
4727*da0073e9SAndroid Build Coastguard Worker            all_sum.backward()
4728*da0073e9SAndroid Build Coastguard Worker            all_sum_cpu.backward()
4729*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(all_sum, all_sum_cpu)
4730*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
4731*da0073e9SAndroid Build Coastguard Worker
4732*da0073e9SAndroid Build Coastguard Worker        helper(3, 3)
4733*da0073e9SAndroid Build Coastguard Worker
4734*da0073e9SAndroid Build Coastguard Worker    # L1 loss
4735*da0073e9SAndroid Build Coastguard Worker    def test_l1_loss(self):
4736*da0073e9SAndroid Build Coastguard Worker        def helper(shape, reduction):
4737*da0073e9SAndroid Build Coastguard Worker            # create the criterion
4738*da0073e9SAndroid Build Coastguard Worker            loss = torch.nn.L1Loss(reduction=reduction)
4739*da0073e9SAndroid Build Coastguard Worker
4740*da0073e9SAndroid Build Coastguard Worker            inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
4741*da0073e9SAndroid Build Coastguard Worker            targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
4742*da0073e9SAndroid Build Coastguard Worker            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4743*da0073e9SAndroid Build Coastguard Worker            targetMPS = targetCPU.detach().clone().to('mps')
4744*da0073e9SAndroid Build Coastguard Worker
4745*da0073e9SAndroid Build Coastguard Worker            # forward pass
4746*da0073e9SAndroid Build Coastguard Worker            outputCPU = loss(inputCPU, targetCPU)
4747*da0073e9SAndroid Build Coastguard Worker            outputMPS = loss(inputMPS, targetMPS)
4748*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(outputCPU, outputMPS)
4749*da0073e9SAndroid Build Coastguard Worker
4750*da0073e9SAndroid Build Coastguard Worker            # backward pass
4751*da0073e9SAndroid Build Coastguard Worker            if reduction != 'none':
4752*da0073e9SAndroid Build Coastguard Worker                # chose 2 just to make the grad_output > 1 in backward pass
4753*da0073e9SAndroid Build Coastguard Worker                outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
4754*da0073e9SAndroid Build Coastguard Worker                outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
4755*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(inputCPU.grad, inputMPS.grad)
4756*da0073e9SAndroid Build Coastguard Worker
4757*da0073e9SAndroid Build Coastguard Worker        helper([8, 5, 4], 'none')
4758*da0073e9SAndroid Build Coastguard Worker        helper([7, 5, 2, 4], 'sum')
4759*da0073e9SAndroid Build Coastguard Worker        # verify if changes in shape would cause cached graph lookup problems
4760*da0073e9SAndroid Build Coastguard Worker        helper([7, 5, 2, 4, 6], 'sum')
4761*da0073e9SAndroid Build Coastguard Worker        helper([8, 4, 5, 7, 6], 'mean')
4762*da0073e9SAndroid Build Coastguard Worker
4763*da0073e9SAndroid Build Coastguard Worker    # Mean Squared Error
4764*da0073e9SAndroid Build Coastguard Worker    def test_mse_loss(self):
4765*da0073e9SAndroid Build Coastguard Worker        def helper(shape, reduction):
4766*da0073e9SAndroid Build Coastguard Worker            # create the criterion
4767*da0073e9SAndroid Build Coastguard Worker            loss = torch.nn.MSELoss(reduction=reduction)
4768*da0073e9SAndroid Build Coastguard Worker
4769*da0073e9SAndroid Build Coastguard Worker            inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
4770*da0073e9SAndroid Build Coastguard Worker            targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
4771*da0073e9SAndroid Build Coastguard Worker            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4772*da0073e9SAndroid Build Coastguard Worker            targetMPS = targetCPU.detach().clone().to('mps')
4773*da0073e9SAndroid Build Coastguard Worker
4774*da0073e9SAndroid Build Coastguard Worker            # forward pass
4775*da0073e9SAndroid Build Coastguard Worker            outputCPU = loss(inputCPU, targetCPU)
4776*da0073e9SAndroid Build Coastguard Worker            outputMPS = loss(inputMPS, targetMPS)
4777*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(outputCPU, outputMPS)
4778*da0073e9SAndroid Build Coastguard Worker
4779*da0073e9SAndroid Build Coastguard Worker            # backward pass
4780*da0073e9SAndroid Build Coastguard Worker            if reduction != 'none':
4781*da0073e9SAndroid Build Coastguard Worker                # chose 2 just to make the grad_output > 1 in backward pass
4782*da0073e9SAndroid Build Coastguard Worker                outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
4783*da0073e9SAndroid Build Coastguard Worker                outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
4784*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(inputCPU.grad, inputMPS.grad)
4785*da0073e9SAndroid Build Coastguard Worker
4786*da0073e9SAndroid Build Coastguard Worker        helper([8, 5, 4], 'none')
4787*da0073e9SAndroid Build Coastguard Worker        helper([7, 5, 2, 4], 'sum')
4788*da0073e9SAndroid Build Coastguard Worker        # verify if changes in shape would cause cached graph lookup problems
4789*da0073e9SAndroid Build Coastguard Worker        helper([7, 5, 2, 4, 6], 'sum')
4790*da0073e9SAndroid Build Coastguard Worker        helper([8, 4, 5, 7, 6], 'mean')
4791*da0073e9SAndroid Build Coastguard Worker
4792*da0073e9SAndroid Build Coastguard Worker    def test_mse_loss_strided_output(self):
4793*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/124621
4794*da0073e9SAndroid Build Coastguard Worker        lf = nn.MSELoss(reduction='none')
4795*da0073e9SAndroid Build Coastguard Worker        model_cpu = nn.Sequential(
4796*da0073e9SAndroid Build Coastguard Worker            nn.Conv1d(3, 3, 1),
4797*da0073e9SAndroid Build Coastguard Worker        )
4798*da0073e9SAndroid Build Coastguard Worker        model_mps = copy.deepcopy(model_cpu).to("mps")
4799*da0073e9SAndroid Build Coastguard Worker
4800*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(128, 10, 3)
4801*da0073e9SAndroid Build Coastguard Worker        x = x.permute(0, 2, 1)
4802*da0073e9SAndroid Build Coastguard Worker
4803*da0073e9SAndroid Build Coastguard Worker        x_mps = x.detach().clone().to("mps").permute(0, 2, 1)
4804*da0073e9SAndroid Build Coastguard Worker        x_mps = x_mps.permute(0, 2, 1)
4805*da0073e9SAndroid Build Coastguard Worker
4806*da0073e9SAndroid Build Coastguard Worker        y = model_cpu(x)
4807*da0073e9SAndroid Build Coastguard Worker        y_mps = model_mps(x_mps)
4808*da0073e9SAndroid Build Coastguard Worker
4809*da0073e9SAndroid Build Coastguard Worker        y = y.permute(0, 2, 1)[:, :5, :]
4810*da0073e9SAndroid Build Coastguard Worker        y_mps = y_mps.permute(0, 2, 1)[:, :5, :]
4811*da0073e9SAndroid Build Coastguard Worker
4812*da0073e9SAndroid Build Coastguard Worker        y_hat = torch.randn(128, 5, 3)
4813*da0073e9SAndroid Build Coastguard Worker        y_hat_mps = y_hat.detach().clone().to("mps")
4814*da0073e9SAndroid Build Coastguard Worker
4815*da0073e9SAndroid Build Coastguard Worker        loss = lf(y, y_hat)
4816*da0073e9SAndroid Build Coastguard Worker        loss_mps = lf(y_mps, y_hat_mps)
4817*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(loss, loss_mps)
4818*da0073e9SAndroid Build Coastguard Worker
4819*da0073e9SAndroid Build Coastguard Worker    # Binary Cross Enropy
4820*da0073e9SAndroid Build Coastguard Worker    def test_bce_loss_simple(self):
4821*da0073e9SAndroid Build Coastguard Worker        def helper(shape, reduction):
4822*da0073e9SAndroid Build Coastguard Worker            # create the criterion
4823*da0073e9SAndroid Build Coastguard Worker            loss = torch.nn.BCELoss(reduction=reduction)
4824*da0073e9SAndroid Build Coastguard Worker
4825*da0073e9SAndroid Build Coastguard Worker            # input and target must be within [0..1]
4826*da0073e9SAndroid Build Coastguard Worker            input_t = np.random.random_sample(size=shape).astype(np.float32)
4827*da0073e9SAndroid Build Coastguard Worker            target_t = np.random.random_sample(size=shape).astype(np.float32)
4828*da0073e9SAndroid Build Coastguard Worker            inputCPU = torch.tensor(input_t, device='cpu', dtype=torch.float, requires_grad=True)
4829*da0073e9SAndroid Build Coastguard Worker            targetCPU = torch.tensor(target_t, device='cpu', dtype=torch.float, requires_grad=False)
4830*da0073e9SAndroid Build Coastguard Worker            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4831*da0073e9SAndroid Build Coastguard Worker            targetMPS = targetCPU.detach().clone().to('mps')
4832*da0073e9SAndroid Build Coastguard Worker
4833*da0073e9SAndroid Build Coastguard Worker            # forward pass
4834*da0073e9SAndroid Build Coastguard Worker            outputCPU = loss(inputCPU, targetCPU)
4835*da0073e9SAndroid Build Coastguard Worker            outputMPS = loss(inputMPS, targetMPS)
4836*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(outputCPU, outputMPS)
4837*da0073e9SAndroid Build Coastguard Worker
4838*da0073e9SAndroid Build Coastguard Worker            # backward pass
4839*da0073e9SAndroid Build Coastguard Worker            if reduction != 'none':
4840*da0073e9SAndroid Build Coastguard Worker                # chose 0.6 just to have the grad_output != 1
4841*da0073e9SAndroid Build Coastguard Worker                outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
4842*da0073e9SAndroid Build Coastguard Worker                outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
4843*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(inputCPU.grad, inputMPS.grad)
4844*da0073e9SAndroid Build Coastguard Worker
4845*da0073e9SAndroid Build Coastguard Worker        helper([8, 5, 4], 'none')
4846*da0073e9SAndroid Build Coastguard Worker        helper([7, 5, 2, 4], 'sum')
4847*da0073e9SAndroid Build Coastguard Worker        # verify if changes in shape would cause cached graph lookup problems
4848*da0073e9SAndroid Build Coastguard Worker        helper([7, 5, 2, 4, 6], 'sum')
4849*da0073e9SAndroid Build Coastguard Worker        helper([8, 4, 5, 7, 6], 'mean')
4850*da0073e9SAndroid Build Coastguard Worker        helper([1, 1, 32, 32], 'mean')
4851*da0073e9SAndroid Build Coastguard Worker
4852*da0073e9SAndroid Build Coastguard Worker    def test_bce_loss_always_nonnegative(self):
4853*da0073e9SAndroid Build Coastguard Worker        target = torch.ones(5, device='mps')
4854*da0073e9SAndroid Build Coastguard Worker        input = torch.ones(5, device='mps')
4855*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4856*da0073e9SAndroid Build Coastguard Worker
4857*da0073e9SAndroid Build Coastguard Worker        target = torch.zeros(5, device='mps')
4858*da0073e9SAndroid Build Coastguard Worker        input = torch.zeros(5, device='mps')
4859*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4860*da0073e9SAndroid Build Coastguard Worker
4861*da0073e9SAndroid Build Coastguard Worker    def test_bce_loss_size_mismatch(self):
4862*da0073e9SAndroid Build Coastguard Worker        bceloss = nn.BCELoss()
4863*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(25, device='mps')
4864*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(25, 1, device='mps')
4865*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, r'Using a target size \('):
4866*da0073e9SAndroid Build Coastguard Worker            bceloss(a, b)
4867*da0073e9SAndroid Build Coastguard Worker
4868*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
4869*da0073e9SAndroid Build Coastguard Worker        x_size = 1024
4870*da0073e9SAndroid Build Coastguard Worker        y_size = 256
4871*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(x_size, y_size, device='mps')
4872*da0073e9SAndroid Build Coastguard Worker
4873*da0073e9SAndroid Build Coastguard Worker        for reduction in ['none', 'mean', 'sum']:
4874*da0073e9SAndroid Build Coastguard Worker            output_sig = torch.rand(x_size, y_size, device='mps') - 0.5
4875*da0073e9SAndroid Build Coastguard Worker            output_logits = output_sig.clone().detach()
4876*da0073e9SAndroid Build Coastguard Worker
4877*da0073e9SAndroid Build Coastguard Worker            output_sig.requires_grad = True
4878*da0073e9SAndroid Build Coastguard Worker            output_logits.requires_grad = True
4879*da0073e9SAndroid Build Coastguard Worker            weight = torch.rand(y_size, device='mps')
4880*da0073e9SAndroid Build Coastguard Worker
4881*da0073e9SAndroid Build Coastguard Worker            loss_sig = nn.BCELoss(weight, reduction=reduction)(
4882*da0073e9SAndroid Build Coastguard Worker                torch.sigmoid(output_sig), target
4883*da0073e9SAndroid Build Coastguard Worker            )
4884*da0073e9SAndroid Build Coastguard Worker            loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
4885*da0073e9SAndroid Build Coastguard Worker                output_logits, target
4886*da0073e9SAndroid Build Coastguard Worker            )
4887*da0073e9SAndroid Build Coastguard Worker
4888*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(loss_logits, loss_sig)
4889*da0073e9SAndroid Build Coastguard Worker
4890*da0073e9SAndroid Build Coastguard Worker            if reduction == 'none':
4891*da0073e9SAndroid Build Coastguard Worker                grad = torch.rand(x_size, y_size, device='mps')
4892*da0073e9SAndroid Build Coastguard Worker                loss_sig.backward(grad)
4893*da0073e9SAndroid Build Coastguard Worker                loss_logits.backward(grad)
4894*da0073e9SAndroid Build Coastguard Worker            else:
4895*da0073e9SAndroid Build Coastguard Worker                loss_sig.backward()
4896*da0073e9SAndroid Build Coastguard Worker                loss_logits.backward()
4897*da0073e9SAndroid Build Coastguard Worker
4898*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_sig.grad, output_logits.grad)
4899*da0073e9SAndroid Build Coastguard Worker
4900*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_has_correct_grad_at_zero(self):
4901*da0073e9SAndroid Build Coastguard Worker        output = torch.zeros(3, 1, requires_grad=True, device='mps')
4902*da0073e9SAndroid Build Coastguard Worker        target = torch.zeros(3, 1, device='mps')
4903*da0073e9SAndroid Build Coastguard Worker        nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
4904*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4905*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.grad, expected_grad)
4906*da0073e9SAndroid Build Coastguard Worker
4907*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_broadcasts_weights(self):
4908*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(16, 4, device='mps')
4909*da0073e9SAndroid Build Coastguard Worker        output = torch.rand(16, 4, device='mps') - 0.5
4910*da0073e9SAndroid Build Coastguard Worker
4911*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(4, device='mps')
4912*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4913*da0073e9SAndroid Build Coastguard Worker
4914*da0073e9SAndroid Build Coastguard Worker        weight = weight.expand(16, 4).contiguous()
4915*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4916*da0073e9SAndroid Build Coastguard Worker
4917*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
4918*da0073e9SAndroid Build Coastguard Worker
4919*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(16, 1, device='mps')
4920*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4921*da0073e9SAndroid Build Coastguard Worker
4922*da0073e9SAndroid Build Coastguard Worker        weight = weight.expand(16, 4).contiguous()
4923*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4924*da0073e9SAndroid Build Coastguard Worker
4925*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
4926*da0073e9SAndroid Build Coastguard Worker
4927*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
4928*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(64, 4, device='mps')
4929*da0073e9SAndroid Build Coastguard Worker        output = torch.rand(64, 4, device='mps') - 0.5
4930*da0073e9SAndroid Build Coastguard Worker        pos_weight = torch.ones(64, 4, device='mps')
4931*da0073e9SAndroid Build Coastguard Worker
4932*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
4933*da0073e9SAndroid Build Coastguard Worker                         nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))
4934*da0073e9SAndroid Build Coastguard Worker
4935*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_broadcasts_pos_weights(self):
4936*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(64, 4, device='mps')
4937*da0073e9SAndroid Build Coastguard Worker        output = torch.rand(64, 4, device='mps') - 0.5
4938*da0073e9SAndroid Build Coastguard Worker        pos_weight = torch.rand(4, device='mps')
4939*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4940*da0073e9SAndroid Build Coastguard Worker
4941*da0073e9SAndroid Build Coastguard Worker        pos_weight1 = pos_weight.expand(1, 4)
4942*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)
4943*da0073e9SAndroid Build Coastguard Worker
4944*da0073e9SAndroid Build Coastguard Worker        pos_weight2 = pos_weight.expand(64, 4)
4945*da0073e9SAndroid Build Coastguard Worker        out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)
4946*da0073e9SAndroid Build Coastguard Worker
4947*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
4948*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out3)
4949*da0073e9SAndroid Build Coastguard Worker
4950*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
4951*da0073e9SAndroid Build Coastguard Worker        output = torch.zeros(3, 1, requires_grad=True, device='mps')
4952*da0073e9SAndroid Build Coastguard Worker        target = torch.zeros(3, 1, device='mps')
4953*da0073e9SAndroid Build Coastguard Worker        pos_weight = torch.ones(3, 1, device='mps')
4954*da0073e9SAndroid Build Coastguard Worker        nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
4955*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4956*da0073e9SAndroid Build Coastguard Worker        grad = output.grad
4957*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad, expected_grad)
4958*da0073e9SAndroid Build Coastguard Worker
4959*da0073e9SAndroid Build Coastguard Worker    def test_bce_with_logits_stability(self):
4960*da0073e9SAndroid Build Coastguard Worker        output = torch.tensor([0., -120.], device='mps')
4961*da0073e9SAndroid Build Coastguard Worker        target = torch.tensor([0., 1.], device='mps')
4962*da0073e9SAndroid Build Coastguard Worker        pos_weight = torch.tensor([1., 1.], device='mps')
4963*da0073e9SAndroid Build Coastguard Worker
4964*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCEWithLogitsLoss()(output, target)
4965*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isfinite(out1).all().item())
4966*da0073e9SAndroid Build Coastguard Worker
4967*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4968*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.isfinite(out2).all().item())
4969*da0073e9SAndroid Build Coastguard Worker
4970*da0073e9SAndroid Build Coastguard Worker    def test_bce_loss_broadcasts_weights(self):
4971*da0073e9SAndroid Build Coastguard Worker        sigmoid = nn.Sigmoid()
4972*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(16, 4, device='mps')
4973*da0073e9SAndroid Build Coastguard Worker        output = torch.rand(16, 4, device='mps') - 0.5
4974*da0073e9SAndroid Build Coastguard Worker
4975*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(4, device='mps')
4976*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCELoss(weight)(sigmoid(output), target)
4977*da0073e9SAndroid Build Coastguard Worker
4978*da0073e9SAndroid Build Coastguard Worker        weight = weight.expand(16, 4).contiguous()
4979*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCELoss(weight)(sigmoid(output), target)
4980*da0073e9SAndroid Build Coastguard Worker
4981*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
4982*da0073e9SAndroid Build Coastguard Worker
4983*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(16, 1, device='mps')
4984*da0073e9SAndroid Build Coastguard Worker        out1 = nn.BCELoss(weight)(sigmoid(output), target)
4985*da0073e9SAndroid Build Coastguard Worker
4986*da0073e9SAndroid Build Coastguard Worker        weight = weight.expand(16, 4).contiguous()
4987*da0073e9SAndroid Build Coastguard Worker        out2 = nn.BCELoss(weight)(sigmoid(output), target)
4988*da0073e9SAndroid Build Coastguard Worker
4989*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1, out2)
4990*da0073e9SAndroid Build Coastguard Worker
4991*da0073e9SAndroid Build Coastguard Worker    def test_cross_entropy_loss(self):
4992*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/116095
4993*da0073e9SAndroid Build Coastguard Worker        loss = nn.CrossEntropyLoss()
4994*da0073e9SAndroid Build Coastguard Worker        pred = torch.randn(3, 5, requires_grad=True, dtype=torch.float16, device='mps')
4995*da0073e9SAndroid Build Coastguard Worker        target = torch.ones(3, dtype=torch.long, device='mps')
4996*da0073e9SAndroid Build Coastguard Worker        output = loss(pred, target)
4997*da0073e9SAndroid Build Coastguard Worker        output.backward()
4998*da0073e9SAndroid Build Coastguard Worker
4999*da0073e9SAndroid Build Coastguard Worker    def test_log_softmax(self):
5000*da0073e9SAndroid Build Coastguard Worker        values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
5001*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
5002*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.tensor(values, device='mps', requires_grad=True)
5003*da0073e9SAndroid Build Coastguard Worker
5004*da0073e9SAndroid Build Coastguard Worker        cpu_log_softmax = F.log_softmax(cpu_x, dim=0)
5005*da0073e9SAndroid Build Coastguard Worker        mps_log_softmax = F.log_softmax(mps_x, dim=0)
5006*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
5007*da0073e9SAndroid Build Coastguard Worker
5008*da0073e9SAndroid Build Coastguard Worker        cpu_grad = torch.ones_like(cpu_log_softmax)
5009*da0073e9SAndroid Build Coastguard Worker        mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
5010*da0073e9SAndroid Build Coastguard Worker
5011*da0073e9SAndroid Build Coastguard Worker        cpu_log_softmax.backward(gradient=cpu_grad)
5012*da0073e9SAndroid Build Coastguard Worker        mps_log_softmax.backward(gradient=mps_grad)
5013*da0073e9SAndroid Build Coastguard Worker
5014*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
5015*da0073e9SAndroid Build Coastguard Worker
5016*da0073e9SAndroid Build Coastguard Worker    def test_log_softmax_large_numbers(self):
5017*da0073e9SAndroid Build Coastguard Worker        values = [
5018*da0073e9SAndroid Build Coastguard Worker            [10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0],
5019*da0073e9SAndroid Build Coastguard Worker            [-10.0, -100.0, -1000.0, -10000.0, -100000.0, -1000000.0]
5020*da0073e9SAndroid Build Coastguard Worker        ]
5021*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
5022*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.tensor(values, device='mps', requires_grad=True)
5023*da0073e9SAndroid Build Coastguard Worker
5024*da0073e9SAndroid Build Coastguard Worker        cpu_log_softmax = F.log_softmax(cpu_x, dim=-1)
5025*da0073e9SAndroid Build Coastguard Worker        mps_log_softmax = F.log_softmax(mps_x, dim=-1)
5026*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
5027*da0073e9SAndroid Build Coastguard Worker
5028*da0073e9SAndroid Build Coastguard Worker        cpu_grad = torch.ones_like(cpu_log_softmax)
5029*da0073e9SAndroid Build Coastguard Worker        mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
5030*da0073e9SAndroid Build Coastguard Worker
5031*da0073e9SAndroid Build Coastguard Worker        cpu_log_softmax.backward(gradient=cpu_grad)
5032*da0073e9SAndroid Build Coastguard Worker        mps_log_softmax.backward(gradient=mps_grad)
5033*da0073e9SAndroid Build Coastguard Worker
5034*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
5035*da0073e9SAndroid Build Coastguard Worker
5036*da0073e9SAndroid Build Coastguard Worker    def test_eq(self):
5037*da0073e9SAndroid Build Coastguard Worker        values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
5038*da0073e9SAndroid Build Coastguard Worker        values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]]
5039*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.tensor(values1, device='mps')
5040*da0073e9SAndroid Build Coastguard Worker        mps_y = torch.tensor(values2, device='mps')
5041*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values1, device='cpu')
5042*da0073e9SAndroid Build Coastguard Worker        cpu_y = torch.tensor(values2, device='cpu')
5043*da0073e9SAndroid Build Coastguard Worker        result_mps = torch.eq(mps_x, mps_y)
5044*da0073e9SAndroid Build Coastguard Worker        result_cpu = torch.eq(cpu_x, cpu_y)
5045*da0073e9SAndroid Build Coastguard Worker
5046*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result_cpu, result_mps.to('cpu'))
5047*da0073e9SAndroid Build Coastguard Worker
5048*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
5049*da0073e9SAndroid Build Coastguard Worker    def test_signed_vs_unsigned_comparison(self):
5050*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8)
5051*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8)
5052*da0073e9SAndroid Build Coastguard Worker        # in the comparison of signed vs. unsigned we should always cast to unsigned
5053*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x == -1, mps_x == -1)
5054*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x > -1, mps_x > -1)
5055*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x < -1, mps_x < -1)
5056*da0073e9SAndroid Build Coastguard Worker
5057*da0073e9SAndroid Build Coastguard Worker    def test_eq_int64(self):
5058*da0073e9SAndroid Build Coastguard Worker        values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
5059*da0073e9SAndroid Build Coastguard Worker        values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]]
5060*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.tensor(values1, device='mps')
5061*da0073e9SAndroid Build Coastguard Worker        mps_y = torch.tensor(values2, device='mps')
5062*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values1, device='cpu')
5063*da0073e9SAndroid Build Coastguard Worker        cpu_y = torch.tensor(values2, device='cpu')
5064*da0073e9SAndroid Build Coastguard Worker        result_mps = torch.eq(mps_x, mps_y)
5065*da0073e9SAndroid Build Coastguard Worker        result_cpu = torch.eq(cpu_x, cpu_y)
5066*da0073e9SAndroid Build Coastguard Worker
5067*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result_cpu, result_mps.to('cpu'))
5068*da0073e9SAndroid Build Coastguard Worker
5069*da0073e9SAndroid Build Coastguard Worker    def test_ne(self):
5070*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5071*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5072*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5073*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5074*da0073e9SAndroid Build Coastguard Worker            mps_y = cpu_y.detach().clone().to('mps')
5075*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.ne(mps_x, mps_y)
5076*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.ne(cpu_x, cpu_y)
5077*da0073e9SAndroid Build Coastguard Worker
5078*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps.to('cpu'))
5079*da0073e9SAndroid Build Coastguard Worker
5080*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
5081*da0073e9SAndroid Build Coastguard Worker
5082*da0073e9SAndroid Build Coastguard Worker    def test_ne_scalar(self):
5083*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5084*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5085*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5086*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.ne(mps_x, 0.0)
5087*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.ne(cpu_x, 0.0)
5088*da0073e9SAndroid Build Coastguard Worker
5089*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps.to('cpu'))
5090*da0073e9SAndroid Build Coastguard Worker
5091*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
5092*da0073e9SAndroid Build Coastguard Worker
5093*da0073e9SAndroid Build Coastguard Worker    def test_lt(self):
5094*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5095*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5096*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5097*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5098*da0073e9SAndroid Build Coastguard Worker            mps_y = cpu_y.detach().clone().to('mps')
5099*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.lt(mps_x, mps_y)
5100*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.lt(cpu_x, cpu_y)
5101*da0073e9SAndroid Build Coastguard Worker
5102*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps.to('cpu'))
5103*da0073e9SAndroid Build Coastguard Worker
5104*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
5105*da0073e9SAndroid Build Coastguard Worker
5106*da0073e9SAndroid Build Coastguard Worker    def test_lt_scalar(self):
5107*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5108*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5109*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5110*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.lt(mps_x, 0.0)
5111*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.lt(cpu_x, 0.0)
5112*da0073e9SAndroid Build Coastguard Worker
5113*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps.to('cpu'))
5114*da0073e9SAndroid Build Coastguard Worker
5115*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
5116*da0073e9SAndroid Build Coastguard Worker
5117*da0073e9SAndroid Build Coastguard Worker    def test_le(self):
5118*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5119*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5120*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5121*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5122*da0073e9SAndroid Build Coastguard Worker            mps_y = cpu_y.detach().clone().to('mps')
5123*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.le(mps_x, mps_y)
5124*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.le(cpu_x, cpu_y)
5125*da0073e9SAndroid Build Coastguard Worker
5126*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps.to('cpu'))
5127*da0073e9SAndroid Build Coastguard Worker
5128*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
5129*da0073e9SAndroid Build Coastguard Worker
5130*da0073e9SAndroid Build Coastguard Worker    def test_le_scalar(self):
5131*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5132*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5133*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5134*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.le(mps_x, 0.0)
5135*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.le(cpu_x, 0.0)
5136*da0073e9SAndroid Build Coastguard Worker
5137*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps.to('cpu'))
5138*da0073e9SAndroid Build Coastguard Worker
5139*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
5140*da0073e9SAndroid Build Coastguard Worker
5141*da0073e9SAndroid Build Coastguard Worker    def test_ge(self):
5142*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5143*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5144*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5145*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5146*da0073e9SAndroid Build Coastguard Worker            mps_y = cpu_y.detach().clone().to('mps')
5147*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.ge(mps_x, mps_y)
5148*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.ge(cpu_x, cpu_y)
5149*da0073e9SAndroid Build Coastguard Worker
5150*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps.to('cpu'))
5151*da0073e9SAndroid Build Coastguard Worker
5152*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
5153*da0073e9SAndroid Build Coastguard Worker
5154*da0073e9SAndroid Build Coastguard Worker    def test_ge_scalar(self):
5155*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5156*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5157*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5158*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.ge(mps_x, 0.0)
5159*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.ge(cpu_x, 0.0)
5160*da0073e9SAndroid Build Coastguard Worker
5161*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps.to('cpu'))
5162*da0073e9SAndroid Build Coastguard Worker
5163*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
5164*da0073e9SAndroid Build Coastguard Worker
5165*da0073e9SAndroid Build Coastguard Worker    def test_gt(self):
5166*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5167*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5168*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5169*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5170*da0073e9SAndroid Build Coastguard Worker            mps_y = cpu_y.detach().clone().to('mps')
5171*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.gt(mps_x, mps_y)
5172*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.gt(cpu_x, cpu_y)
5173*da0073e9SAndroid Build Coastguard Worker
5174*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps.to('cpu'))
5175*da0073e9SAndroid Build Coastguard Worker
5176*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
5177*da0073e9SAndroid Build Coastguard Worker
5178*da0073e9SAndroid Build Coastguard Worker    def test_gt_scalar(self):
5179*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5180*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5181*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5182*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.gt(mps_x, 0.0)
5183*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.gt(cpu_x, 0.0)
5184*da0073e9SAndroid Build Coastguard Worker
5185*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps.to('cpu'))
5186*da0073e9SAndroid Build Coastguard Worker
5187*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5))
5188*da0073e9SAndroid Build Coastguard Worker
5189*da0073e9SAndroid Build Coastguard Worker    def test_argmax(self):
5190*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/98191
5191*da0073e9SAndroid Build Coastguard Worker        cpu_tensor = torch.tensor([[0, 1], [2, 1], [1, 0]])
5192*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.argmax(cpu_tensor, dim=1)
5193*da0073e9SAndroid Build Coastguard Worker
5194*da0073e9SAndroid Build Coastguard Worker        mps_tensor = cpu_tensor.to(torch.device('mps'))
5195*da0073e9SAndroid Build Coastguard Worker        res_mps = torch.argmax(mps_tensor, dim=1)
5196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_cpu, res_mps)
5197*da0073e9SAndroid Build Coastguard Worker
5198*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/92311
5199*da0073e9SAndroid Build Coastguard Worker        mps_tensor = torch.randn(10, 2, device='mps', dtype=torch.float32)
5200*da0073e9SAndroid Build Coastguard Worker        cpu_tensor = mps_tensor.detach().clone().cpu()
5201*da0073e9SAndroid Build Coastguard Worker
5202*da0073e9SAndroid Build Coastguard Worker        res_mps = torch.argmax(mps_tensor, dim=1)
5203*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.argmax(cpu_tensor, dim=1)
5204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_cpu, res_mps)
5205*da0073e9SAndroid Build Coastguard Worker
5206*da0073e9SAndroid Build Coastguard Worker    # Test forward argmin argmax
5207*da0073e9SAndroid Build Coastguard Worker    def test_argmin_argmax(self):
5208*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w, reduction_type, dtype=torch.float32):
5209*da0073e9SAndroid Build Coastguard Worker            if reduction_type == "max":
5210*da0073e9SAndroid Build Coastguard Worker                arg_reduction_fn = torch.argmax
5211*da0073e9SAndroid Build Coastguard Worker            else:
5212*da0073e9SAndroid Build Coastguard Worker                arg_reduction_fn = torch.argmin
5213*da0073e9SAndroid Build Coastguard Worker
5214*da0073e9SAndroid Build Coastguard Worker            cpu_x = None
5215*da0073e9SAndroid Build Coastguard Worker            x = None
5216*da0073e9SAndroid Build Coastguard Worker            if (dtype not in [torch.float32, torch.bool]):
5217*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5218*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
5219*da0073e9SAndroid Build Coastguard Worker            elif (dtype == torch.bool):
5220*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5221*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
5222*da0073e9SAndroid Build Coastguard Worker            else:
5223*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5224*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps').requires_grad_()
5225*da0073e9SAndroid Build Coastguard Worker
5226*da0073e9SAndroid Build Coastguard Worker            y = arg_reduction_fn(x)
5227*da0073e9SAndroid Build Coastguard Worker            ref_y = arg_reduction_fn(cpu_x)
5228*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y)
5229*da0073e9SAndroid Build Coastguard Worker
5230*da0073e9SAndroid Build Coastguard Worker            y_0 = arg_reduction_fn(x, dim=0)
5231*da0073e9SAndroid Build Coastguard Worker            refy_0 = arg_reduction_fn(cpu_x, dim=0)
5232*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_0, refy_0)
5233*da0073e9SAndroid Build Coastguard Worker
5234*da0073e9SAndroid Build Coastguard Worker            y_0dim = arg_reduction_fn(x, dim=0, keepdim=True)
5235*da0073e9SAndroid Build Coastguard Worker            refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True)
5236*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_0dim, refy_0dim)
5237*da0073e9SAndroid Build Coastguard Worker
5238*da0073e9SAndroid Build Coastguard Worker            y_1 = arg_reduction_fn(x, dim=1)
5239*da0073e9SAndroid Build Coastguard Worker            refy_1 = arg_reduction_fn(cpu_x, dim=1)
5240*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_1, refy_1)
5241*da0073e9SAndroid Build Coastguard Worker
5242*da0073e9SAndroid Build Coastguard Worker            y_1dim = arg_reduction_fn(x, dim=1, keepdim=True)
5243*da0073e9SAndroid Build Coastguard Worker            refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True)
5244*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_1dim, refy_1dim)
5245*da0073e9SAndroid Build Coastguard Worker
5246*da0073e9SAndroid Build Coastguard Worker            y_2 = arg_reduction_fn(x, dim=2)
5247*da0073e9SAndroid Build Coastguard Worker            refy_2 = arg_reduction_fn(cpu_x, dim=2)
5248*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_2, refy_2)
5249*da0073e9SAndroid Build Coastguard Worker
5250*da0073e9SAndroid Build Coastguard Worker            y_2dim = arg_reduction_fn(x, dim=2, keepdim=True)
5251*da0073e9SAndroid Build Coastguard Worker            refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True)
5252*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_2dim, refy_2dim)
5253*da0073e9SAndroid Build Coastguard Worker
5254*da0073e9SAndroid Build Coastguard Worker            y_3 = arg_reduction_fn(x, dim=3)
5255*da0073e9SAndroid Build Coastguard Worker            refy_3 = arg_reduction_fn(cpu_x, dim=3)
5256*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_3, refy_3)
5257*da0073e9SAndroid Build Coastguard Worker
5258*da0073e9SAndroid Build Coastguard Worker            y_3dim = arg_reduction_fn(x, dim=3, keepdim=True)
5259*da0073e9SAndroid Build Coastguard Worker            refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True)
5260*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_3dim, refy_3dim)
5261*da0073e9SAndroid Build Coastguard Worker
5262*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, "max", torch.float32)
5263*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, "max", torch.int32)
5264*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, "max", torch.float16)
5265*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, "max", torch.int64)
5266*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, "min", torch.float32)
5267*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, "min", torch.int32)
5268*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, "min", torch.float16)
5269*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, "min", torch.int64)
5270*da0073e9SAndroid Build Coastguard Worker
5271*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(product_version < 13.3, "Long data type supported from macOS 13.3 and above")
5272*da0073e9SAndroid Build Coastguard Worker    def test_reduction_sum_max_long_val(self):
5273*da0073e9SAndroid Build Coastguard Worker        x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps")
5274*da0073e9SAndroid Build Coastguard Worker        x_cpu = x_mps.detach().clone().cpu()
5275*da0073e9SAndroid Build Coastguard Worker
5276*da0073e9SAndroid Build Coastguard Worker        res_mps = torch.sum(x_mps)
5277*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.sum(x_cpu)
5278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_mps, res_cpu)
5279*da0073e9SAndroid Build Coastguard Worker
5280*da0073e9SAndroid Build Coastguard Worker    # Test forward max
5281*da0073e9SAndroid Build Coastguard Worker    # Note - don't test grad now
5282*da0073e9SAndroid Build Coastguard Worker    def test_max_el(self):
5283*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w, dtype=torch.float32):
5284*da0073e9SAndroid Build Coastguard Worker
5285*da0073e9SAndroid Build Coastguard Worker            if (dtype not in [torch.float32, torch.bool]):
5286*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5287*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
5288*da0073e9SAndroid Build Coastguard Worker            elif (dtype == torch.bool):
5289*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5290*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
5291*da0073e9SAndroid Build Coastguard Worker            else:
5292*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5293*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
5294*da0073e9SAndroid Build Coastguard Worker
5295*da0073e9SAndroid Build Coastguard Worker            ref_y = torch.max(cpu_x)
5296*da0073e9SAndroid Build Coastguard Worker            y = torch.max(x)
5297*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y)
5298*da0073e9SAndroid Build Coastguard Worker
5299*da0073e9SAndroid Build Coastguard Worker            for dim in [0, 1, 2, 3]:
5300*da0073e9SAndroid Build Coastguard Worker                for keepdim in [True, False]:
5301*da0073e9SAndroid Build Coastguard Worker                    y, idx = torch.max(x, dim=dim, keepdim=keepdim)
5302*da0073e9SAndroid Build Coastguard Worker                    refy, refidx = torch.max(cpu_x, dim=dim, keepdim=keepdim)
5303*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y, refy)
5304*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(idx, refidx)
5305*da0073e9SAndroid Build Coastguard Worker
5306*da0073e9SAndroid Build Coastguard Worker            y_0 = torch.ones(c, h, w, device='mps', dtype=dtype)
5307*da0073e9SAndroid Build Coastguard Worker            idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
5308*da0073e9SAndroid Build Coastguard Worker            torch.max(x, dim=0, out=(y_0, idx_0))
5309*da0073e9SAndroid Build Coastguard Worker            refy_0, refidx_0 = torch.max(cpu_x, dim=0)
5310*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_0, refy_0)
5311*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_0, refidx_0)
5312*da0073e9SAndroid Build Coastguard Worker
5313*da0073e9SAndroid Build Coastguard Worker            y_0dim = torch.ones(1, c, h, w, device='mps', dtype=dtype)
5314*da0073e9SAndroid Build Coastguard Worker            idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
5315*da0073e9SAndroid Build Coastguard Worker            torch.max(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
5316*da0073e9SAndroid Build Coastguard Worker            refy_0dim, refidx_0dim = torch.max(cpu_x, dim=0, keepdim=True)
5317*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_0dim, refy_0dim)
5318*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_0dim, refidx_0dim)
5319*da0073e9SAndroid Build Coastguard Worker
5320*da0073e9SAndroid Build Coastguard Worker            y_1 = torch.ones(n, h, w, device='mps', dtype=dtype)
5321*da0073e9SAndroid Build Coastguard Worker            idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
5322*da0073e9SAndroid Build Coastguard Worker            torch.max(x, dim=1, out=(y_1, idx_1))
5323*da0073e9SAndroid Build Coastguard Worker            refy_1, refidx_1 = torch.max(cpu_x, dim=1)
5324*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_1, refy_1)
5325*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_1, refidx_1)
5326*da0073e9SAndroid Build Coastguard Worker
5327*da0073e9SAndroid Build Coastguard Worker            y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=dtype)
5328*da0073e9SAndroid Build Coastguard Worker            idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
5329*da0073e9SAndroid Build Coastguard Worker            torch.max(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
5330*da0073e9SAndroid Build Coastguard Worker            refy_1dim, refidx_1dim = torch.max(cpu_x, keepdim=True, dim=1)
5331*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_1dim, refy_1dim)
5332*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_1dim, refidx_1dim)
5333*da0073e9SAndroid Build Coastguard Worker
5334*da0073e9SAndroid Build Coastguard Worker            y_2 = torch.ones(n, c, w, device='mps', dtype=dtype)
5335*da0073e9SAndroid Build Coastguard Worker            idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
5336*da0073e9SAndroid Build Coastguard Worker            torch.max(x, dim=2, out=(y_2, idx_2))
5337*da0073e9SAndroid Build Coastguard Worker            refy_2, refidx_2 = torch.max(cpu_x, dim=2)
5338*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_2, refy_2)
5339*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_2, refidx_2)
5340*da0073e9SAndroid Build Coastguard Worker
5341*da0073e9SAndroid Build Coastguard Worker            y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=dtype)
5342*da0073e9SAndroid Build Coastguard Worker            idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
5343*da0073e9SAndroid Build Coastguard Worker            torch.max(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
5344*da0073e9SAndroid Build Coastguard Worker            refy_2dim, refidx_2dim = torch.max(cpu_x, dim=2, keepdim=True,)
5345*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_2dim, refy_2dim)
5346*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_2dim, refidx_2dim)
5347*da0073e9SAndroid Build Coastguard Worker
5348*da0073e9SAndroid Build Coastguard Worker            y_3 = torch.ones(n, c, h, device='mps', dtype=dtype)
5349*da0073e9SAndroid Build Coastguard Worker            idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
5350*da0073e9SAndroid Build Coastguard Worker            torch.max(x, dim=3, out=(y_3, idx_3))
5351*da0073e9SAndroid Build Coastguard Worker            refy_3, refidx_3 = torch.max(cpu_x, dim=3)
5352*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_3, refy_3)
5353*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_3, refidx_3)
5354*da0073e9SAndroid Build Coastguard Worker
5355*da0073e9SAndroid Build Coastguard Worker            y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=dtype)
5356*da0073e9SAndroid Build Coastguard Worker            idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
5357*da0073e9SAndroid Build Coastguard Worker            torch.max(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
5358*da0073e9SAndroid Build Coastguard Worker            refy_3dim, refidx_3dim = torch.max(cpu_x, dim=3, keepdim=True,)
5359*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_3dim, refy_3dim)
5360*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_3dim, refidx_3dim)
5361*da0073e9SAndroid Build Coastguard Worker
5362*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 5, torch.float32)
5363*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 5, torch.int32)
5364*da0073e9SAndroid Build Coastguard Worker        # helper(2, 8, 4, 5, torch.int64)
5365*da0073e9SAndroid Build Coastguard Worker
5366*da0073e9SAndroid Build Coastguard Worker    def test_median(self):
5367*da0073e9SAndroid Build Coastguard Worker        def helper_dtype_int32(n1, n2, n3):
5368*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32)
5369*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5370*da0073e9SAndroid Build Coastguard Worker
5371*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.median(cpu_x)
5372*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.median(mps_x)
5373*da0073e9SAndroid Build Coastguard Worker
5374*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps)
5375*da0073e9SAndroid Build Coastguard Worker
5376*da0073e9SAndroid Build Coastguard Worker            for dim in [0, 1, 2]:
5377*da0073e9SAndroid Build Coastguard Worker                for keepdim in [True, False]:
5378*da0073e9SAndroid Build Coastguard Worker                    y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
5379*da0073e9SAndroid Build Coastguard Worker                    refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
5380*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y, refy)
5381*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(idx, refidx)
5382*da0073e9SAndroid Build Coastguard Worker
5383*da0073e9SAndroid Build Coastguard Worker        def helper_dtype_float32(n1, n2, n3):
5384*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32)
5385*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
5386*da0073e9SAndroid Build Coastguard Worker
5387*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.median(cpu_x)
5388*da0073e9SAndroid Build Coastguard Worker            result_mps = torch.median(mps_x)
5389*da0073e9SAndroid Build Coastguard Worker
5390*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_cpu, result_mps)
5391*da0073e9SAndroid Build Coastguard Worker
5392*da0073e9SAndroid Build Coastguard Worker            for dim in [0, 1, 2]:
5393*da0073e9SAndroid Build Coastguard Worker                for keepdim in [True, False]:
5394*da0073e9SAndroid Build Coastguard Worker                    y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
5395*da0073e9SAndroid Build Coastguard Worker                    refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
5396*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y, refy)
5397*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(idx, refidx)
5398*da0073e9SAndroid Build Coastguard Worker
5399*da0073e9SAndroid Build Coastguard Worker        helper_dtype_int32(10, 10, 10)  # median at even place
5400*da0073e9SAndroid Build Coastguard Worker        helper_dtype_int32(3, 3, 3)  # median at odd place
5401*da0073e9SAndroid Build Coastguard Worker        helper_dtype_int32(1, 1, 1)
5402*da0073e9SAndroid Build Coastguard Worker        helper_dtype_int32(1, 2, 3)
5403*da0073e9SAndroid Build Coastguard Worker        helper_dtype_float32(10, 10, 10)
5404*da0073e9SAndroid Build Coastguard Worker        helper_dtype_float32(3, 3, 3)
5405*da0073e9SAndroid Build Coastguard Worker        helper_dtype_float32(1, 1, 1)
5406*da0073e9SAndroid Build Coastguard Worker
5407*da0073e9SAndroid Build Coastguard Worker    def test_any(self):
5408*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5409*da0073e9SAndroid Build Coastguard Worker            input_xs = []
5410*da0073e9SAndroid Build Coastguard Worker            prod = 1
5411*da0073e9SAndroid Build Coastguard Worker
5412*da0073e9SAndroid Build Coastguard Worker            for i in range(len(shape)):
5413*da0073e9SAndroid Build Coastguard Worker                prod *= shape[i]
5414*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
5415*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
5416*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
5417*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
5418*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
5419*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
5420*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
5421*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
5422*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
5423*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
5424*da0073e9SAndroid Build Coastguard Worker
5425*da0073e9SAndroid Build Coastguard Worker            for i, cpu_x in enumerate(input_xs):
5426*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
5427*da0073e9SAndroid Build Coastguard Worker                y = torch.any(x)
5428*da0073e9SAndroid Build Coastguard Worker                ref_y = torch.any(cpu_x)
5429*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y, ref_y)
5430*da0073e9SAndroid Build Coastguard Worker
5431*da0073e9SAndroid Build Coastguard Worker                y_0 = torch.any(x, dim=0)
5432*da0073e9SAndroid Build Coastguard Worker                refy_0 = torch.any(cpu_x, dim=0)
5433*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_0, refy_0)
5434*da0073e9SAndroid Build Coastguard Worker
5435*da0073e9SAndroid Build Coastguard Worker                y_0dim = torch.any(x, dim=0, keepdim=True)
5436*da0073e9SAndroid Build Coastguard Worker                refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
5437*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_0dim, refy_0dim)
5438*da0073e9SAndroid Build Coastguard Worker
5439*da0073e9SAndroid Build Coastguard Worker                y_0dim = torch.any(x, dim=0, keepdim=True)
5440*da0073e9SAndroid Build Coastguard Worker                refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
5441*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_0dim, refy_0dim)
5442*da0073e9SAndroid Build Coastguard Worker
5443*da0073e9SAndroid Build Coastguard Worker                y_1 = torch.any(x, dim=1)
5444*da0073e9SAndroid Build Coastguard Worker                refy_1 = torch.any(cpu_x, dim=1)
5445*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_1, refy_1)
5446*da0073e9SAndroid Build Coastguard Worker
5447*da0073e9SAndroid Build Coastguard Worker                y_1dim = torch.any(x, dim=1, keepdim=True)
5448*da0073e9SAndroid Build Coastguard Worker                refy_1dim = torch.any(cpu_x, dim=1, keepdim=True)
5449*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_1dim, refy_1dim)
5450*da0073e9SAndroid Build Coastguard Worker
5451*da0073e9SAndroid Build Coastguard Worker                if (len(shape) > 2):
5452*da0073e9SAndroid Build Coastguard Worker                    y_2 = torch.any(x, dim=2)
5453*da0073e9SAndroid Build Coastguard Worker                    refy_2 = torch.any(cpu_x, dim=2)
5454*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_2, refy_2)
5455*da0073e9SAndroid Build Coastguard Worker
5456*da0073e9SAndroid Build Coastguard Worker                    y_2dim = torch.any(x, dim=2, keepdim=True)
5457*da0073e9SAndroid Build Coastguard Worker                    refy_2dim = torch.any(cpu_x, dim=2, keepdim=True)
5458*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_2dim, refy_2dim)
5459*da0073e9SAndroid Build Coastguard Worker
5460*da0073e9SAndroid Build Coastguard Worker                    y_3 = torch.any(x, dim=3)
5461*da0073e9SAndroid Build Coastguard Worker                    refy_3 = torch.any(cpu_x, dim=3)
5462*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_3, refy_3)
5463*da0073e9SAndroid Build Coastguard Worker
5464*da0073e9SAndroid Build Coastguard Worker                    y_3dim = torch.any(x, dim=3, keepdim=True)
5465*da0073e9SAndroid Build Coastguard Worker                    refy_3dim = torch.any(cpu_x, dim=3, keepdim=True)
5466*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_3dim, refy_3dim)
5467*da0073e9SAndroid Build Coastguard Worker        helper((1, 1, 1, 1))
5468*da0073e9SAndroid Build Coastguard Worker        helper((1, 1, 3, 3))
5469*da0073e9SAndroid Build Coastguard Worker        helper((7, 13))
5470*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
5471*da0073e9SAndroid Build Coastguard Worker
5472*da0073e9SAndroid Build Coastguard Worker    def test_reduction_ops_5D(self):
5473*da0073e9SAndroid Build Coastguard Worker        def helper(fn, dim):
5474*da0073e9SAndroid Build Coastguard Worker            shape = (1, 1, 2, 1, 1)
5475*da0073e9SAndroid Build Coastguard Worker            x_cpu = fn(torch.zeros(shape), dim=dim)
5476*da0073e9SAndroid Build Coastguard Worker            x_mps = fn(torch.zeros(shape, device="mps"), dim=dim)
5477*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu, x_mps.cpu())
5478*da0073e9SAndroid Build Coastguard Worker        for fn in [torch.any, torch.all]:
5479*da0073e9SAndroid Build Coastguard Worker            for dim in range(0, 4):
5480*da0073e9SAndroid Build Coastguard Worker                helper(fn, dim)
5481*da0073e9SAndroid Build Coastguard Worker
5482*da0073e9SAndroid Build Coastguard Worker        # 6D tensor reductions
5483*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/95538
5484*da0073e9SAndroid Build Coastguard Worker        x = (torch.rand(2, 3, 4, 3, 4, 2, device="mps") - .5).relu()
5485*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.all(), x.cpu().all())
5486*da0073e9SAndroid Build Coastguard Worker        for i in range(-5, 6):
5487*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.all(dim=i), x.cpu().all(dim=i))
5488*da0073e9SAndroid Build Coastguard Worker
5489*da0073e9SAndroid Build Coastguard Worker    def test_all(self):
5490*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5491*da0073e9SAndroid Build Coastguard Worker            input_xs = []
5492*da0073e9SAndroid Build Coastguard Worker            prod = 1
5493*da0073e9SAndroid Build Coastguard Worker
5494*da0073e9SAndroid Build Coastguard Worker            for i in range(len(shape)):
5495*da0073e9SAndroid Build Coastguard Worker                prod *= shape[i]
5496*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
5497*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
5498*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
5499*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
5500*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
5501*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
5502*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
5503*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
5504*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
5505*da0073e9SAndroid Build Coastguard Worker            input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
5506*da0073e9SAndroid Build Coastguard Worker
5507*da0073e9SAndroid Build Coastguard Worker            for i, cpu_x in enumerate(input_xs):
5508*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
5509*da0073e9SAndroid Build Coastguard Worker                y = torch.all(x)
5510*da0073e9SAndroid Build Coastguard Worker                ref_y = torch.all(cpu_x)
5511*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y, ref_y)
5512*da0073e9SAndroid Build Coastguard Worker
5513*da0073e9SAndroid Build Coastguard Worker                y_0 = torch.all(x, dim=0)
5514*da0073e9SAndroid Build Coastguard Worker                refy_0 = torch.all(cpu_x, dim=0)
5515*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_0, refy_0)
5516*da0073e9SAndroid Build Coastguard Worker
5517*da0073e9SAndroid Build Coastguard Worker                y_0dim = torch.all(x, dim=0, keepdim=True)
5518*da0073e9SAndroid Build Coastguard Worker                refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
5519*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_0dim, refy_0dim)
5520*da0073e9SAndroid Build Coastguard Worker
5521*da0073e9SAndroid Build Coastguard Worker                y_0dim = torch.all(x, dim=0, keepdim=True)
5522*da0073e9SAndroid Build Coastguard Worker                refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
5523*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_0dim, refy_0dim)
5524*da0073e9SAndroid Build Coastguard Worker
5525*da0073e9SAndroid Build Coastguard Worker                y_1 = torch.all(x, dim=1)
5526*da0073e9SAndroid Build Coastguard Worker                refy_1 = torch.all(cpu_x, dim=1)
5527*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_1, refy_1)
5528*da0073e9SAndroid Build Coastguard Worker
5529*da0073e9SAndroid Build Coastguard Worker                y_1dim = torch.all(x, dim=1, keepdim=True)
5530*da0073e9SAndroid Build Coastguard Worker                refy_1dim = torch.all(cpu_x, dim=1, keepdim=True)
5531*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y_1dim, refy_1dim)
5532*da0073e9SAndroid Build Coastguard Worker                if (len(shape) > 2):
5533*da0073e9SAndroid Build Coastguard Worker                    y_2 = torch.all(x, dim=2)
5534*da0073e9SAndroid Build Coastguard Worker                    refy_2 = torch.all(cpu_x, dim=2)
5535*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_2, refy_2)
5536*da0073e9SAndroid Build Coastguard Worker
5537*da0073e9SAndroid Build Coastguard Worker                    y_2dim = torch.all(x, dim=2, keepdim=True)
5538*da0073e9SAndroid Build Coastguard Worker                    refy_2dim = torch.all(cpu_x, dim=2, keepdim=True)
5539*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_2dim, refy_2dim)
5540*da0073e9SAndroid Build Coastguard Worker
5541*da0073e9SAndroid Build Coastguard Worker                    y_3 = torch.all(x, dim=3)
5542*da0073e9SAndroid Build Coastguard Worker                    refy_3 = torch.all(cpu_x, dim=3)
5543*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_3, refy_3)
5544*da0073e9SAndroid Build Coastguard Worker
5545*da0073e9SAndroid Build Coastguard Worker                    y_3dim = torch.all(x, dim=3, keepdim=True)
5546*da0073e9SAndroid Build Coastguard Worker                    refy_3dim = torch.all(cpu_x, dim=3, keepdim=True)
5547*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(y_3dim, refy_3dim)
5548*da0073e9SAndroid Build Coastguard Worker
5549*da0073e9SAndroid Build Coastguard Worker        helper((1, 1, 1, 1))
5550*da0073e9SAndroid Build Coastguard Worker        helper((1, 1, 3, 3))
5551*da0073e9SAndroid Build Coastguard Worker        helper((7, 13))
5552*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
5553*da0073e9SAndroid Build Coastguard Worker        # Empty tensor
5554*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.tensor([], dtype=torch.bool)
5555*da0073e9SAndroid Build Coastguard Worker        x_mps = x_cpu.to("mps")
5556*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_cpu.all(), x_mps.all().cpu())
5557*da0073e9SAndroid Build Coastguard Worker
5558*da0073e9SAndroid Build Coastguard Worker    # Test forward min
5559*da0073e9SAndroid Build Coastguard Worker    def test_min_el(self):
5560*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w):
5561*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5562*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
5563*da0073e9SAndroid Build Coastguard Worker
5564*da0073e9SAndroid Build Coastguard Worker            y = torch.min(x)
5565*da0073e9SAndroid Build Coastguard Worker            ref_y = torch.min(cpu_x)
5566*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y)
5567*da0073e9SAndroid Build Coastguard Worker
5568*da0073e9SAndroid Build Coastguard Worker            y_0, idx_0 = torch.min(x, dim=0)
5569*da0073e9SAndroid Build Coastguard Worker            refy_0, refidx_0 = torch.min(cpu_x, dim=0)
5570*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_0, refy_0)
5571*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_0, refidx_0)
5572*da0073e9SAndroid Build Coastguard Worker
5573*da0073e9SAndroid Build Coastguard Worker            y_0 = torch.ones(c, h, w, device='mps', dtype=torch.float)
5574*da0073e9SAndroid Build Coastguard Worker            idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
5575*da0073e9SAndroid Build Coastguard Worker            torch.min(x, dim=0, out=(y_0, idx_0))
5576*da0073e9SAndroid Build Coastguard Worker            refy_0, refidx_0 = torch.min(cpu_x, dim=0)
5577*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_0, refy_0)
5578*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_0, refidx_0)
5579*da0073e9SAndroid Build Coastguard Worker
5580*da0073e9SAndroid Build Coastguard Worker            y_0dim, idx_0dim = torch.min(x, dim=0, keepdim=True)
5581*da0073e9SAndroid Build Coastguard Worker            refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
5582*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_0dim, refy_0dim)
5583*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_0dim, refidx_0dim)
5584*da0073e9SAndroid Build Coastguard Worker
5585*da0073e9SAndroid Build Coastguard Worker            y_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.float)
5586*da0073e9SAndroid Build Coastguard Worker            idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
5587*da0073e9SAndroid Build Coastguard Worker            torch.min(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
5588*da0073e9SAndroid Build Coastguard Worker            refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
5589*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_0dim, refy_0dim)
5590*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_0dim, refidx_0dim)
5591*da0073e9SAndroid Build Coastguard Worker
5592*da0073e9SAndroid Build Coastguard Worker            y_1, idx_1 = torch.min(x, dim=1)
5593*da0073e9SAndroid Build Coastguard Worker            refy_1, refidx_1 = torch.min(cpu_x, dim=1)
5594*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_1, refy_1)
5595*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_1, refidx_1)
5596*da0073e9SAndroid Build Coastguard Worker
5597*da0073e9SAndroid Build Coastguard Worker            y_1 = torch.ones(n, h, w, device='mps', dtype=torch.float)
5598*da0073e9SAndroid Build Coastguard Worker            idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
5599*da0073e9SAndroid Build Coastguard Worker            torch.min(x, dim=1, out=(y_1, idx_1))
5600*da0073e9SAndroid Build Coastguard Worker            refy_1, refidx_1 = torch.min(cpu_x, dim=1)
5601*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_1, refy_1)
5602*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_1, refidx_1)
5603*da0073e9SAndroid Build Coastguard Worker
5604*da0073e9SAndroid Build Coastguard Worker            y_1dim, idx_1dim = torch.min(x, dim=1, keepdim=True)
5605*da0073e9SAndroid Build Coastguard Worker            refy_1dim, refidx_1dim = torch.min(cpu_x, dim=1, keepdim=True)
5606*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_1dim, refy_1dim)
5607*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_1dim, refidx_1dim)
5608*da0073e9SAndroid Build Coastguard Worker
5609*da0073e9SAndroid Build Coastguard Worker            y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.float)
5610*da0073e9SAndroid Build Coastguard Worker            idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
5611*da0073e9SAndroid Build Coastguard Worker            torch.min(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
5612*da0073e9SAndroid Build Coastguard Worker            refy_1dim, refidx_1dim = torch.min(cpu_x, keepdim=True, dim=1)
5613*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_1dim, refy_1dim)
5614*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_1dim, refidx_1dim)
5615*da0073e9SAndroid Build Coastguard Worker
5616*da0073e9SAndroid Build Coastguard Worker            y_2, idx_2 = torch.min(x, dim=2)
5617*da0073e9SAndroid Build Coastguard Worker            refy_2, refidx_2 = torch.min(cpu_x, dim=2)
5618*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_2, refy_2)
5619*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_2, refidx_2)
5620*da0073e9SAndroid Build Coastguard Worker
5621*da0073e9SAndroid Build Coastguard Worker            y_2 = torch.ones(n, c, w, device='mps', dtype=torch.float)
5622*da0073e9SAndroid Build Coastguard Worker            idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
5623*da0073e9SAndroid Build Coastguard Worker            torch.min(x, dim=2, out=(y_2, idx_2))
5624*da0073e9SAndroid Build Coastguard Worker            refy_2, refidx_2 = torch.min(cpu_x, dim=2)
5625*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_2, refy_2)
5626*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_2, refidx_2)
5627*da0073e9SAndroid Build Coastguard Worker
5628*da0073e9SAndroid Build Coastguard Worker            y_2dim, idx_2dim = torch.min(x, dim=2, keepdim=True)
5629*da0073e9SAndroid Build Coastguard Worker            refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True)
5630*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_2dim, refy_2dim)
5631*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_2dim, refidx_2dim)
5632*da0073e9SAndroid Build Coastguard Worker
5633*da0073e9SAndroid Build Coastguard Worker            y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.float)
5634*da0073e9SAndroid Build Coastguard Worker            idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
5635*da0073e9SAndroid Build Coastguard Worker            torch.min(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
5636*da0073e9SAndroid Build Coastguard Worker            refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True,)
5637*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_2dim, refy_2dim)
5638*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_2dim, refidx_2dim)
5639*da0073e9SAndroid Build Coastguard Worker
5640*da0073e9SAndroid Build Coastguard Worker            y_3, idx_3 = torch.min(x, dim=3)
5641*da0073e9SAndroid Build Coastguard Worker            refy_3, refidx_3 = torch.min(cpu_x, dim=3)
5642*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_3, refy_3)
5643*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_3, refidx_3)
5644*da0073e9SAndroid Build Coastguard Worker
5645*da0073e9SAndroid Build Coastguard Worker            y_3 = torch.ones(n, c, h, device='mps', dtype=torch.float)
5646*da0073e9SAndroid Build Coastguard Worker            idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
5647*da0073e9SAndroid Build Coastguard Worker            torch.min(x, dim=3, out=(y_3, idx_3))
5648*da0073e9SAndroid Build Coastguard Worker            refy_3, refidx_3 = torch.min(cpu_x, dim=3)
5649*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_3, refy_3)
5650*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_3, refidx_3)
5651*da0073e9SAndroid Build Coastguard Worker
5652*da0073e9SAndroid Build Coastguard Worker            y_3dim, idx_3dim = torch.min(x, dim=3, keepdim=True)
5653*da0073e9SAndroid Build Coastguard Worker            refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True)
5654*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_3dim, refy_3dim)
5655*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_3dim, refidx_3dim)
5656*da0073e9SAndroid Build Coastguard Worker
5657*da0073e9SAndroid Build Coastguard Worker            y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.float)
5658*da0073e9SAndroid Build Coastguard Worker            idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
5659*da0073e9SAndroid Build Coastguard Worker            torch.min(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
5660*da0073e9SAndroid Build Coastguard Worker            refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True,)
5661*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_3dim, refy_3dim)
5662*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_3dim, refidx_3dim)
5663*da0073e9SAndroid Build Coastguard Worker
5664*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 5)
5665*da0073e9SAndroid Build Coastguard Worker
5666*da0073e9SAndroid Build Coastguard Worker    # Test forward sum
5667*da0073e9SAndroid Build Coastguard Worker    def test_sum(self):
5668*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w, dtype=torch.float32):
5669*da0073e9SAndroid Build Coastguard Worker            cpu_x = None
5670*da0073e9SAndroid Build Coastguard Worker            x = None
5671*da0073e9SAndroid Build Coastguard Worker            if (dtype not in [torch.float32, torch.bool]):
5672*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5673*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
5674*da0073e9SAndroid Build Coastguard Worker            elif (dtype == torch.bool):
5675*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5676*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
5677*da0073e9SAndroid Build Coastguard Worker            else:
5678*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5679*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps').requires_grad_()
5680*da0073e9SAndroid Build Coastguard Worker
5681*da0073e9SAndroid Build Coastguard Worker            all_sum = torch.sum(x)
5682*da0073e9SAndroid Build Coastguard Worker            all_sum_cpu = torch.sum(cpu_x)
5683*da0073e9SAndroid Build Coastguard Worker
5684*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(all_sum, all_sum_cpu)
5685*da0073e9SAndroid Build Coastguard Worker
5686*da0073e9SAndroid Build Coastguard Worker            nil_dim_sum = torch.sum(x, dim=[])
5687*da0073e9SAndroid Build Coastguard Worker            nil_dim_sum_cpu = torch.sum(cpu_x, dim=[])
5688*da0073e9SAndroid Build Coastguard Worker
5689*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nil_dim_sum, nil_dim_sum_cpu)
5690*da0073e9SAndroid Build Coastguard Worker
5691*da0073e9SAndroid Build Coastguard Worker            nil_dim_sum_keepdim = torch.sum(x, dim=[], keepdim=True)
5692*da0073e9SAndroid Build Coastguard Worker            nil_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[], keepdim=True)
5693*da0073e9SAndroid Build Coastguard Worker
5694*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nil_dim_sum_keepdim, nil_dim_sum_cpu_keepdim)
5695*da0073e9SAndroid Build Coastguard Worker
5696*da0073e9SAndroid Build Coastguard Worker            zero_dim_sum = torch.sum(x, dim=[0])
5697*da0073e9SAndroid Build Coastguard Worker            zero_dim_sum_cpu = torch.sum(cpu_x, dim=[0])
5698*da0073e9SAndroid Build Coastguard Worker
5699*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_dim_sum, zero_dim_sum_cpu)
5700*da0073e9SAndroid Build Coastguard Worker
5701*da0073e9SAndroid Build Coastguard Worker            zero_dim_sum_keepdim = torch.sum(x, dim=[0], keepdim=True)
5702*da0073e9SAndroid Build Coastguard Worker            zero_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0], keepdim=True)
5703*da0073e9SAndroid Build Coastguard Worker
5704*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_dim_sum_keepdim, zero_dim_sum_cpu_keepdim)
5705*da0073e9SAndroid Build Coastguard Worker
5706*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_sum = torch.sum(x, dim=[0, 1])
5707*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_sum_cpu = torch.sum(cpu_x, dim=[0, 1])
5708*da0073e9SAndroid Build Coastguard Worker
5709*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_one_dim_sum, zero_one_dim_sum_cpu)
5710*da0073e9SAndroid Build Coastguard Worker
5711*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_sum_keepdim = torch.sum(x, dim=[0, 1], keepdim=True)
5712*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0, 1], keepdim=True)
5713*da0073e9SAndroid Build Coastguard Worker
5714*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_one_dim_sum_keepdim, zero_one_dim_sum_cpu_keepdim)
5715*da0073e9SAndroid Build Coastguard Worker
5716*da0073e9SAndroid Build Coastguard Worker            two_three_dim_sum = torch.sum(x, dim=[2, 3])
5717*da0073e9SAndroid Build Coastguard Worker            two_three_dim_sum_cpu = torch.sum(cpu_x, dim=[2, 3])
5718*da0073e9SAndroid Build Coastguard Worker
5719*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(two_three_dim_sum, two_three_dim_sum_cpu)
5720*da0073e9SAndroid Build Coastguard Worker
5721*da0073e9SAndroid Build Coastguard Worker            two_three_keepdim_sum = torch.sum(x, dim=[2, 3], keepdim=True)
5722*da0073e9SAndroid Build Coastguard Worker            two_three_dim_keepsum_cpu = torch.sum(cpu_x, dim=[2, 3], keepdim=True)
5723*da0073e9SAndroid Build Coastguard Worker
5724*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(two_three_keepdim_sum, two_three_dim_keepsum_cpu)
5725*da0073e9SAndroid Build Coastguard Worker
5726*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 5)
5727*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 5, dtype=torch.int32)
5728*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 5, dtype=torch.int64)
5729*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 5, dtype=torch.bool)
5730*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/136132
5731*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 4, 1, 30, 1, device='mps').sum(dim=-2)
5732*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.numel(), 8)
5733*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.max().item(), 30.0)
5734*da0073e9SAndroid Build Coastguard Worker
5735*da0073e9SAndroid Build Coastguard Worker    # Test forward prod
5736*da0073e9SAndroid Build Coastguard Worker    def test_prod(self):
5737*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dtype=torch.float32):
5738*da0073e9SAndroid Build Coastguard Worker            cpu_x = None
5739*da0073e9SAndroid Build Coastguard Worker            x = None
5740*da0073e9SAndroid Build Coastguard Worker            if (dtype not in [torch.float32, torch.bool]):
5741*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(1, 6, shape, device='cpu', dtype=dtype, requires_grad=False)
5742*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
5743*da0073e9SAndroid Build Coastguard Worker            elif (dtype == torch.bool):
5744*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
5745*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
5746*da0073e9SAndroid Build Coastguard Worker            else:
5747*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
5748*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps').requires_grad_()
5749*da0073e9SAndroid Build Coastguard Worker
5750*da0073e9SAndroid Build Coastguard Worker            all_prod = torch.prod(x)
5751*da0073e9SAndroid Build Coastguard Worker            all_prod_cpu = torch.prod(cpu_x)
5752*da0073e9SAndroid Build Coastguard Worker
5753*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(all_prod, all_prod_cpu)
5754*da0073e9SAndroid Build Coastguard Worker
5755*da0073e9SAndroid Build Coastguard Worker            for dim in range(len(shape)):
5756*da0073e9SAndroid Build Coastguard Worker                dim_prod = torch.prod(x, dim=dim)
5757*da0073e9SAndroid Build Coastguard Worker                dim_prod_cpu = torch.prod(cpu_x, dim=dim)
5758*da0073e9SAndroid Build Coastguard Worker
5759*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dim_prod, dim_prod_cpu)
5760*da0073e9SAndroid Build Coastguard Worker
5761*da0073e9SAndroid Build Coastguard Worker                dim_prod_keepdim = torch.prod(x, dim=dim, keepdim=True)
5762*da0073e9SAndroid Build Coastguard Worker                dim_prod_cpu_keepdim = torch.prod(cpu_x, dim=dim, keepdim=True)
5763*da0073e9SAndroid Build Coastguard Worker
5764*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dim_prod_keepdim, dim_prod_cpu_keepdim)
5765*da0073e9SAndroid Build Coastguard Worker
5766*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float32, torch.int32, torch.int64, torch.bool]:
5767*da0073e9SAndroid Build Coastguard Worker            helper((2, 3), dtype)
5768*da0073e9SAndroid Build Coastguard Worker
5769*da0073e9SAndroid Build Coastguard Worker    # Test forward mean
5770*da0073e9SAndroid Build Coastguard Worker    def test_mean(self):
5771*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w):
5772*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=True)
5773*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
5774*da0073e9SAndroid Build Coastguard Worker
5775*da0073e9SAndroid Build Coastguard Worker            all_mean = torch.mean(x)
5776*da0073e9SAndroid Build Coastguard Worker            all_mean_cpu = torch.mean(cpu_x)
5777*da0073e9SAndroid Build Coastguard Worker
5778*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(all_mean, all_mean_cpu)
5779*da0073e9SAndroid Build Coastguard Worker
5780*da0073e9SAndroid Build Coastguard Worker            nil_dim_mean = torch.mean(x, dim=[])
5781*da0073e9SAndroid Build Coastguard Worker            nil_dim_mean_cpu = torch.mean(cpu_x, dim=[])
5782*da0073e9SAndroid Build Coastguard Worker
5783*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nil_dim_mean, nil_dim_mean_cpu)
5784*da0073e9SAndroid Build Coastguard Worker
5785*da0073e9SAndroid Build Coastguard Worker            nil_dim_mean_keepdim = torch.mean(x, dim=[], keepdim=True)
5786*da0073e9SAndroid Build Coastguard Worker            nil_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[], keepdim=True)
5787*da0073e9SAndroid Build Coastguard Worker
5788*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nil_dim_mean_keepdim, nil_dim_mean_cpu_keepdim)
5789*da0073e9SAndroid Build Coastguard Worker
5790*da0073e9SAndroid Build Coastguard Worker            zero_dim_mean = torch.mean(x, dim=[0])
5791*da0073e9SAndroid Build Coastguard Worker            zero_dim_mean_cpu = torch.mean(cpu_x, dim=[0])
5792*da0073e9SAndroid Build Coastguard Worker
5793*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_dim_mean, zero_dim_mean_cpu)
5794*da0073e9SAndroid Build Coastguard Worker
5795*da0073e9SAndroid Build Coastguard Worker            zero_dim_mean_keepdim = torch.mean(x, dim=[0], keepdim=True)
5796*da0073e9SAndroid Build Coastguard Worker            zero_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0], keepdim=True)
5797*da0073e9SAndroid Build Coastguard Worker
5798*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_dim_mean_keepdim, zero_dim_mean_cpu_keepdim)
5799*da0073e9SAndroid Build Coastguard Worker
5800*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_mean = torch.mean(x, dim=[0, 1])
5801*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_mean_cpu = torch.mean(cpu_x, dim=[0, 1])
5802*da0073e9SAndroid Build Coastguard Worker
5803*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_one_dim_mean, zero_one_dim_mean_cpu)
5804*da0073e9SAndroid Build Coastguard Worker
5805*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_mean_keepdim = torch.mean(x, dim=[0, 1], keepdim=True)
5806*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0, 1], keepdim=True)
5807*da0073e9SAndroid Build Coastguard Worker
5808*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_one_dim_mean_keepdim, zero_one_dim_mean_cpu_keepdim)
5809*da0073e9SAndroid Build Coastguard Worker
5810*da0073e9SAndroid Build Coastguard Worker            two_three_dim_mean = torch.mean(x, dim=[2, 3])
5811*da0073e9SAndroid Build Coastguard Worker            two_three_dim_mean_cpu = torch.mean(cpu_x, dim=[2, 3])
5812*da0073e9SAndroid Build Coastguard Worker
5813*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(two_three_dim_mean, two_three_dim_mean_cpu)
5814*da0073e9SAndroid Build Coastguard Worker
5815*da0073e9SAndroid Build Coastguard Worker            two_three_keepdim_mean = torch.mean(x, dim=[2, 3], keepdim=True)
5816*da0073e9SAndroid Build Coastguard Worker            two_three_dim_keepmean_cpu = torch.mean(cpu_x, dim=[2, 3], keepdim=True)
5817*da0073e9SAndroid Build Coastguard Worker
5818*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(two_three_keepdim_mean, two_three_dim_keepmean_cpu)
5819*da0073e9SAndroid Build Coastguard Worker
5820*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 5)
5821*da0073e9SAndroid Build Coastguard Worker
5822*da0073e9SAndroid Build Coastguard Worker    # Test std
5823*da0073e9SAndroid Build Coastguard Worker    def test_std(self):
5824*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
5825*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5826*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
5827*da0073e9SAndroid Build Coastguard Worker
5828*da0073e9SAndroid Build Coastguard Worker            all_std = torch.std(x, unbiased=False)
5829*da0073e9SAndroid Build Coastguard Worker            all_std_cpu = torch.std(cpu_x, unbiased=False)
5830*da0073e9SAndroid Build Coastguard Worker
5831*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(all_std, all_std_cpu)
5832*da0073e9SAndroid Build Coastguard Worker
5833*da0073e9SAndroid Build Coastguard Worker            nil_dim_std = torch.std(x, dim=[], unbiased=False)
5834*da0073e9SAndroid Build Coastguard Worker            nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=False)
5835*da0073e9SAndroid Build Coastguard Worker
5836*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nil_dim_std, nil_dim_std_cpu)
5837*da0073e9SAndroid Build Coastguard Worker
5838*da0073e9SAndroid Build Coastguard Worker            nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=False)
5839*da0073e9SAndroid Build Coastguard Worker            nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=False)
5840*da0073e9SAndroid Build Coastguard Worker
5841*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
5842*da0073e9SAndroid Build Coastguard Worker
5843*da0073e9SAndroid Build Coastguard Worker            zero_dim_std = torch.std(x, dim=[0], unbiased=False)
5844*da0073e9SAndroid Build Coastguard Worker            zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=False)
5845*da0073e9SAndroid Build Coastguard Worker
5846*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_dim_std, zero_dim_std_cpu)
5847*da0073e9SAndroid Build Coastguard Worker
5848*da0073e9SAndroid Build Coastguard Worker            zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=False)
5849*da0073e9SAndroid Build Coastguard Worker            zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=False)
5850*da0073e9SAndroid Build Coastguard Worker
5851*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
5852*da0073e9SAndroid Build Coastguard Worker
5853*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=False)
5854*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=False)
5855*da0073e9SAndroid Build Coastguard Worker
5856*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
5857*da0073e9SAndroid Build Coastguard Worker
5858*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=False)
5859*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
5860*da0073e9SAndroid Build Coastguard Worker
5861*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
5862*da0073e9SAndroid Build Coastguard Worker
5863*da0073e9SAndroid Build Coastguard Worker            two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=False)
5864*da0073e9SAndroid Build Coastguard Worker            two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=False)
5865*da0073e9SAndroid Build Coastguard Worker
5866*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
5867*da0073e9SAndroid Build Coastguard Worker
5868*da0073e9SAndroid Build Coastguard Worker            two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=False)
5869*da0073e9SAndroid Build Coastguard Worker            two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
5870*da0073e9SAndroid Build Coastguard Worker
5871*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
5872*da0073e9SAndroid Build Coastguard Worker
5873*da0073e9SAndroid Build Coastguard Worker            all_std = torch.std(x, unbiased=True)
5874*da0073e9SAndroid Build Coastguard Worker            all_std_cpu = torch.std(cpu_x, unbiased=True)
5875*da0073e9SAndroid Build Coastguard Worker
5876*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(all_std, all_std_cpu)
5877*da0073e9SAndroid Build Coastguard Worker
5878*da0073e9SAndroid Build Coastguard Worker            nil_dim_std = torch.std(x, dim=[], unbiased=True)
5879*da0073e9SAndroid Build Coastguard Worker            nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=True)
5880*da0073e9SAndroid Build Coastguard Worker
5881*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nil_dim_std, nil_dim_std_cpu)
5882*da0073e9SAndroid Build Coastguard Worker
5883*da0073e9SAndroid Build Coastguard Worker            nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=True)
5884*da0073e9SAndroid Build Coastguard Worker            nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=True)
5885*da0073e9SAndroid Build Coastguard Worker
5886*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
5887*da0073e9SAndroid Build Coastguard Worker
5888*da0073e9SAndroid Build Coastguard Worker            zero_dim_std = torch.std(x, dim=[0], unbiased=True)
5889*da0073e9SAndroid Build Coastguard Worker            zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=True)
5890*da0073e9SAndroid Build Coastguard Worker
5891*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_dim_std, zero_dim_std_cpu)
5892*da0073e9SAndroid Build Coastguard Worker
5893*da0073e9SAndroid Build Coastguard Worker            zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=True)
5894*da0073e9SAndroid Build Coastguard Worker            zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=True)
5895*da0073e9SAndroid Build Coastguard Worker
5896*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
5897*da0073e9SAndroid Build Coastguard Worker
5898*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=True)
5899*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=True)
5900*da0073e9SAndroid Build Coastguard Worker
5901*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
5902*da0073e9SAndroid Build Coastguard Worker
5903*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=True)
5904*da0073e9SAndroid Build Coastguard Worker            zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
5905*da0073e9SAndroid Build Coastguard Worker
5906*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
5907*da0073e9SAndroid Build Coastguard Worker
5908*da0073e9SAndroid Build Coastguard Worker            two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=True)
5909*da0073e9SAndroid Build Coastguard Worker            two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=True)
5910*da0073e9SAndroid Build Coastguard Worker
5911*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
5912*da0073e9SAndroid Build Coastguard Worker
5913*da0073e9SAndroid Build Coastguard Worker            two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=True)
5914*da0073e9SAndroid Build Coastguard Worker            two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
5915*da0073e9SAndroid Build Coastguard Worker
5916*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
5917*da0073e9SAndroid Build Coastguard Worker
5918*da0073e9SAndroid Build Coastguard Worker        helper((4, 5, 6, 7))
5919*da0073e9SAndroid Build Coastguard Worker        # verify if a change in shape of input would cause problems with graph caching
5920*da0073e9SAndroid Build Coastguard Worker        helper((9, 5, 6, 7))
5921*da0073e9SAndroid Build Coastguard Worker
5922*da0073e9SAndroid Build Coastguard Worker    # Test var
5923*da0073e9SAndroid Build Coastguard Worker    def test_var_simple(self):
5924*da0073e9SAndroid Build Coastguard Worker        def helper():
5925*da0073e9SAndroid Build Coastguard Worker
5926*da0073e9SAndroid Build Coastguard Worker            shape = [2, 3, 4, 5]
5927*da0073e9SAndroid Build Coastguard Worker
5928*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5929*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
5930*da0073e9SAndroid Build Coastguard Worker
5931*da0073e9SAndroid Build Coastguard Worker            for unbiased in [False, True]:
5932*da0073e9SAndroid Build Coastguard Worker                for keepdim in [False, True]:
5933*da0073e9SAndroid Build Coastguard Worker
5934*da0073e9SAndroid Build Coastguard Worker                    zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased)
5935*da0073e9SAndroid Build Coastguard Worker                    zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased)
5936*da0073e9SAndroid Build Coastguard Worker
5937*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(zero_dim_var, zero_dim_var_cpu)
5938*da0073e9SAndroid Build Coastguard Worker
5939*da0073e9SAndroid Build Coastguard Worker                    all_var = torch.var(x, unbiased=unbiased)
5940*da0073e9SAndroid Build Coastguard Worker                    all_var_cpu = torch.var(cpu_x, unbiased=unbiased)
5941*da0073e9SAndroid Build Coastguard Worker
5942*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(all_var, all_var_cpu)
5943*da0073e9SAndroid Build Coastguard Worker
5944*da0073e9SAndroid Build Coastguard Worker                    nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased)
5945*da0073e9SAndroid Build Coastguard Worker                    nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased)
5946*da0073e9SAndroid Build Coastguard Worker
5947*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(nil_dim_var, nil_dim_var_cpu)
5948*da0073e9SAndroid Build Coastguard Worker
5949*da0073e9SAndroid Build Coastguard Worker                    zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased)
5950*da0073e9SAndroid Build Coastguard Worker                    zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased)
5951*da0073e9SAndroid Build Coastguard Worker
5952*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(zero_dim_var, zero_dim_var_cpu)
5953*da0073e9SAndroid Build Coastguard Worker
5954*da0073e9SAndroid Build Coastguard Worker                    zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
5955*da0073e9SAndroid Build Coastguard Worker                    zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
5956*da0073e9SAndroid Build Coastguard Worker
5957*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
5958*da0073e9SAndroid Build Coastguard Worker
5959*da0073e9SAndroid Build Coastguard Worker                    two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
5960*da0073e9SAndroid Build Coastguard Worker                    two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
5961*da0073e9SAndroid Build Coastguard Worker
5962*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
5963*da0073e9SAndroid Build Coastguard Worker
5964*da0073e9SAndroid Build Coastguard Worker        helper()
5965*da0073e9SAndroid Build Coastguard Worker
5966*da0073e9SAndroid Build Coastguard Worker    # Test forward amax
5967*da0073e9SAndroid Build Coastguard Worker    def test_amax(self):
5968*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dim, keepdim):
5969*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5970*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
5971*da0073e9SAndroid Build Coastguard Worker
5972*da0073e9SAndroid Build Coastguard Worker            result = torch.amax(x, dim=dim, keepdim=keepdim)
5973*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)
5974*da0073e9SAndroid Build Coastguard Worker
5975*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(result_cpu.shape)
5976*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
5977*da0073e9SAndroid Build Coastguard Worker
5978*da0073e9SAndroid Build Coastguard Worker            result_cpu.backward(gradient=cpu_grad)
5979*da0073e9SAndroid Build Coastguard Worker            result.backward(gradient=grad)
5980*da0073e9SAndroid Build Coastguard Worker
5981*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_cpu)
5982*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
5983*da0073e9SAndroid Build Coastguard Worker
5984*da0073e9SAndroid Build Coastguard Worker        for dim in ([], [0], [0, 1], [2, 3]):
5985*da0073e9SAndroid Build Coastguard Worker            for keepdim in [False, True]:
5986*da0073e9SAndroid Build Coastguard Worker                helper((2, 8, 4, 5), dim, keepdim)
5987*da0073e9SAndroid Build Coastguard Worker
5988*da0073e9SAndroid Build Coastguard Worker    # Test forward amin
5989*da0073e9SAndroid Build Coastguard Worker    def test_amin(self):
5990*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dim, keepdim):
5991*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5992*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
5993*da0073e9SAndroid Build Coastguard Worker
5994*da0073e9SAndroid Build Coastguard Worker            result = torch.amin(x, dim=dim, keepdim=keepdim)
5995*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)
5996*da0073e9SAndroid Build Coastguard Worker
5997*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(result_cpu.shape)
5998*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
5999*da0073e9SAndroid Build Coastguard Worker
6000*da0073e9SAndroid Build Coastguard Worker            result_cpu.backward(gradient=cpu_grad)
6001*da0073e9SAndroid Build Coastguard Worker            result.backward(gradient=grad)
6002*da0073e9SAndroid Build Coastguard Worker
6003*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_cpu)
6004*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
6005*da0073e9SAndroid Build Coastguard Worker
6006*da0073e9SAndroid Build Coastguard Worker        for dim in ([], [0], [0, 1], [2, 3]):
6007*da0073e9SAndroid Build Coastguard Worker            for keepdim in [False, True]:
6008*da0073e9SAndroid Build Coastguard Worker                helper((2, 8, 4, 5), dim, keepdim)
6009*da0073e9SAndroid Build Coastguard Worker
6010*da0073e9SAndroid Build Coastguard Worker    # Test minimum and maximum
6011*da0073e9SAndroid Build Coastguard Worker    def test_minimum_maximum(self):
6012*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w):
6013*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6014*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6015*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
6016*da0073e9SAndroid Build Coastguard Worker            mps_y = cpu_y.detach().clone().to('mps')
6017*da0073e9SAndroid Build Coastguard Worker
6018*da0073e9SAndroid Build Coastguard Worker            minimum_result_cpu = torch.minimum(cpu_x, cpu_y)
6019*da0073e9SAndroid Build Coastguard Worker            minimum_result_mps = torch.minimum(mps_x, mps_y)
6020*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(minimum_result_cpu, minimum_result_mps)
6021*da0073e9SAndroid Build Coastguard Worker
6022*da0073e9SAndroid Build Coastguard Worker            maximum_result_cpu = torch.maximum(cpu_x, cpu_y)
6023*da0073e9SAndroid Build Coastguard Worker            maximum_result_mps = torch.maximum(mps_x, mps_y)
6024*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(maximum_result_cpu, maximum_result_mps)
6025*da0073e9SAndroid Build Coastguard Worker
6026*da0073e9SAndroid Build Coastguard Worker        helper(1, 1, 4, 5)
6027*da0073e9SAndroid Build Coastguard Worker
6028*da0073e9SAndroid Build Coastguard Worker    def test_clamp_fp16_fp32(self):
6029*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.randn(10, device='cpu', dtype=torch.float, requires_grad=False)
6030*da0073e9SAndroid Build Coastguard Worker        x = cpu_x.detach().clone().to('mps')
6031*da0073e9SAndroid Build Coastguard Worker
6032*da0073e9SAndroid Build Coastguard Worker        dtype = torch.float16
6033*da0073e9SAndroid Build Coastguard Worker
6034*da0073e9SAndroid Build Coastguard Worker        clamp_min_vals_mps = torch.ones(10, device="mps").to(torch.float16)
6035*da0073e9SAndroid Build Coastguard Worker        clamp_max_vals_mps = torch.ones(10, device="mps").to(torch.float16) * 10
6036*da0073e9SAndroid Build Coastguard Worker        clamp_result_mps = torch.clamp(x, clamp_min_vals_mps, clamp_max_vals_mps)
6037*da0073e9SAndroid Build Coastguard Worker
6038*da0073e9SAndroid Build Coastguard Worker        clamp_min_vals_cpu = torch.ones(10, device="cpu").to(torch.float16)
6039*da0073e9SAndroid Build Coastguard Worker        clamp_max_vals_cpu = torch.ones(10, device="cpu").to(torch.float16) * 10
6040*da0073e9SAndroid Build Coastguard Worker        clamp_result_cpu = torch.clamp(cpu_x, clamp_min_vals_cpu, clamp_max_vals_cpu)
6041*da0073e9SAndroid Build Coastguard Worker
6042*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(clamp_result_mps, clamp_result_cpu)
6043*da0073e9SAndroid Build Coastguard Worker
6044*da0073e9SAndroid Build Coastguard Worker    def test_clamp_nan(self):
6045*da0073e9SAndroid Build Coastguard Worker        t_mps = torch.tensor([torch.nan, 1, 2], device="mps")
6046*da0073e9SAndroid Build Coastguard Worker        t_cpu = torch.tensor([torch.nan, 1, 2], device="cpu")
6047*da0073e9SAndroid Build Coastguard Worker
6048*da0073e9SAndroid Build Coastguard Worker        clamp_min_max_mps = torch.clamp(t_mps, min=-100, max=100)
6049*da0073e9SAndroid Build Coastguard Worker        clamp_min_max_cpu = torch.clamp(t_cpu, min=-100, max=100)
6050*da0073e9SAndroid Build Coastguard Worker
6051*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(clamp_min_max_mps, clamp_min_max_cpu)
6052*da0073e9SAndroid Build Coastguard Worker
6053*da0073e9SAndroid Build Coastguard Worker        clamp_min_mps = torch.clamp(t_mps, min=-100)
6054*da0073e9SAndroid Build Coastguard Worker        clamp_min_cpu = torch.clamp(t_cpu, min=-100)
6055*da0073e9SAndroid Build Coastguard Worker
6056*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(clamp_min_mps, clamp_min_cpu)
6057*da0073e9SAndroid Build Coastguard Worker
6058*da0073e9SAndroid Build Coastguard Worker        clamp_max_mps = torch.clamp(t_mps, max=100)
6059*da0073e9SAndroid Build Coastguard Worker        clamp_max_cpu = torch.clamp(t_cpu, max=100)
6060*da0073e9SAndroid Build Coastguard Worker
6061*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(clamp_max_mps, clamp_max_cpu)
6062*da0073e9SAndroid Build Coastguard Worker
6063*da0073e9SAndroid Build Coastguard Worker    # Test clamp_min
6064*da0073e9SAndroid Build Coastguard Worker    def test_clamp_min(self):
6065*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w):
6066*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6067*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6068*da0073e9SAndroid Build Coastguard Worker
6069*da0073e9SAndroid Build Coastguard Worker            cpu_min_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6070*da0073e9SAndroid Build Coastguard Worker            min_t = cpu_min_t.detach().clone().to('mps')
6071*da0073e9SAndroid Build Coastguard Worker
6072*da0073e9SAndroid Build Coastguard Worker            clamp_min_result = torch.clamp_min(x, min=5.0)
6073*da0073e9SAndroid Build Coastguard Worker            clamp_min_result_cpu = torch.clamp_min(cpu_x, min=5.0)
6074*da0073e9SAndroid Build Coastguard Worker
6075*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(clamp_min_result, clamp_min_result_cpu)
6076*da0073e9SAndroid Build Coastguard Worker
6077*da0073e9SAndroid Build Coastguard Worker            clamp_min_t_result = torch.clamp_min(x, min=min_t)
6078*da0073e9SAndroid Build Coastguard Worker            clamp_min_t_result_cpu = torch.clamp_min(cpu_x, min=cpu_min_t)
6079*da0073e9SAndroid Build Coastguard Worker
6080*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(clamp_min_t_result, clamp_min_t_result_cpu)
6081*da0073e9SAndroid Build Coastguard Worker
6082*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 5)
6083*da0073e9SAndroid Build Coastguard Worker
6084*da0073e9SAndroid Build Coastguard Worker    # Test clamp_max
6085*da0073e9SAndroid Build Coastguard Worker
6086*da0073e9SAndroid Build Coastguard Worker    def test_clamp_max(self):
6087*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w):
6088*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6089*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6090*da0073e9SAndroid Build Coastguard Worker
6091*da0073e9SAndroid Build Coastguard Worker            cpu_max_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6092*da0073e9SAndroid Build Coastguard Worker            max_t = cpu_max_t.detach().clone().to('mps')
6093*da0073e9SAndroid Build Coastguard Worker
6094*da0073e9SAndroid Build Coastguard Worker            clamp_max_result = torch.clamp_max(x, max=100.0)
6095*da0073e9SAndroid Build Coastguard Worker            clamp_max_result_cpu = torch.clamp_max(cpu_x, max=100.0)
6096*da0073e9SAndroid Build Coastguard Worker
6097*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(clamp_max_result, clamp_max_result_cpu)
6098*da0073e9SAndroid Build Coastguard Worker
6099*da0073e9SAndroid Build Coastguard Worker            clamp_max_t_result = torch.clamp_max(x, max=max_t)
6100*da0073e9SAndroid Build Coastguard Worker            clamp_max_t_result_cpu = torch.clamp_max(cpu_x, max=cpu_max_t)
6101*da0073e9SAndroid Build Coastguard Worker
6102*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(clamp_max_t_result, clamp_max_t_result_cpu)
6103*da0073e9SAndroid Build Coastguard Worker
6104*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 5)
6105*da0073e9SAndroid Build Coastguard Worker
6106*da0073e9SAndroid Build Coastguard Worker    # Test clamp
6107*da0073e9SAndroid Build Coastguard Worker    def test_clamp(self):
6108*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w):
6109*da0073e9SAndroid Build Coastguard Worker            import numpy as np
6110*da0073e9SAndroid Build Coastguard Worker            upper_bound = 1000
6111*da0073e9SAndroid Build Coastguard Worker            half_upper_bound = upper_bound / 2
6112*da0073e9SAndroid Build Coastguard Worker
6113*da0073e9SAndroid Build Coastguard Worker            # x=[0..1000)
6114*da0073e9SAndroid Build Coastguard Worker            x_arr = upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
6115*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.tensor(x_arr, device='cpu', dtype=torch.float, requires_grad=False)
6116*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6117*da0073e9SAndroid Build Coastguard Worker
6118*da0073e9SAndroid Build Coastguard Worker            # x=[0..500)
6119*da0073e9SAndroid Build Coastguard Worker            min_arr = half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
6120*da0073e9SAndroid Build Coastguard Worker            cpu_min_t = torch.tensor(min_arr, device='cpu', dtype=torch.float, requires_grad=False)
6121*da0073e9SAndroid Build Coastguard Worker            min_t = cpu_min_t.detach().clone().to('mps')
6122*da0073e9SAndroid Build Coastguard Worker
6123*da0073e9SAndroid Build Coastguard Worker            # x=[500..1000), to ensure max's are greater than mins
6124*da0073e9SAndroid Build Coastguard Worker            max_arr = (half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)) + half_upper_bound
6125*da0073e9SAndroid Build Coastguard Worker            cpu_max_t = torch.tensor(max_arr, device='cpu', dtype=torch.float, requires_grad=False)
6126*da0073e9SAndroid Build Coastguard Worker            max_t = cpu_max_t.detach().clone().to('mps')
6127*da0073e9SAndroid Build Coastguard Worker
6128*da0073e9SAndroid Build Coastguard Worker            # [200..600]: just an arbitrary range between [0..1000]
6129*da0073e9SAndroid Build Coastguard Worker            clamp_result = torch.clamp(x, min=200.0, max=600.0)
6130*da0073e9SAndroid Build Coastguard Worker            clamp_result_cpu = torch.clamp(cpu_x, min=200.0, max=600.0)
6131*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(clamp_result, clamp_result_cpu)
6132*da0073e9SAndroid Build Coastguard Worker
6133*da0073e9SAndroid Build Coastguard Worker            # test optional scalar refs and cached graph keys by passing only max
6134*da0073e9SAndroid Build Coastguard Worker            clamp_opt_result = torch.clamp(x, max=600.0)
6135*da0073e9SAndroid Build Coastguard Worker            clamp_opt_result_cpu = torch.clamp(cpu_x, max=600.0)
6136*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(clamp_opt_result, clamp_opt_result_cpu)
6137*da0073e9SAndroid Build Coastguard Worker
6138*da0073e9SAndroid Build Coastguard Worker            clamp_t_result = torch.clamp(x, min=min_t, max=max_t)
6139*da0073e9SAndroid Build Coastguard Worker            clamp_t_result_cpu = torch.clamp(cpu_x, min=cpu_min_t, max=cpu_max_t)
6140*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(clamp_t_result, clamp_t_result_cpu)
6141*da0073e9SAndroid Build Coastguard Worker
6142*da0073e9SAndroid Build Coastguard Worker            # test optional tensor refs and cached graph keys by passing only max
6143*da0073e9SAndroid Build Coastguard Worker            clamp_topt_result = torch.clamp(x, max=max_t)
6144*da0073e9SAndroid Build Coastguard Worker            clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t)
6145*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(clamp_topt_result, clamp_topt_result_cpu)
6146*da0073e9SAndroid Build Coastguard Worker
6147*da0073e9SAndroid Build Coastguard Worker            # test strided x
6148*da0073e9SAndroid Build Coastguard Worker            clamp_result = torch.clamp(x.movedim(0, -1), min=200.0, max=600.0)
6149*da0073e9SAndroid Build Coastguard Worker            clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=200.0, max=600.0)
6150*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(clamp_result, clamp_result_cpu)
6151*da0073e9SAndroid Build Coastguard Worker
6152*da0073e9SAndroid Build Coastguard Worker            # test strided x, min_t, max_t
6153*da0073e9SAndroid Build Coastguard Worker            clamp_result = torch.clamp(x.movedim(0, -1), min=min_t.movedim(0, -1), max=max_t.movedim(0, -1))
6154*da0073e9SAndroid Build Coastguard Worker            clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=cpu_min_t.movedim(0, -1), max=cpu_max_t.movedim(0, -1))
6155*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(clamp_result, clamp_result_cpu)
6156*da0073e9SAndroid Build Coastguard Worker
6157*da0073e9SAndroid Build Coastguard Worker            # test strided min_t, max_t
6158*da0073e9SAndroid Build Coastguard Worker            clamp_result = torch.clamp(
6159*da0073e9SAndroid Build Coastguard Worker                x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
6160*da0073e9SAndroid Build Coastguard Worker                min=min_t.movedim(0, -1),
6161*da0073e9SAndroid Build Coastguard Worker                max=max_t.movedim(0, -1)
6162*da0073e9SAndroid Build Coastguard Worker            )
6163*da0073e9SAndroid Build Coastguard Worker            clamp_result_cpu = torch.clamp(
6164*da0073e9SAndroid Build Coastguard Worker                cpu_x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
6165*da0073e9SAndroid Build Coastguard Worker                min=cpu_min_t.movedim(0, -1),
6166*da0073e9SAndroid Build Coastguard Worker                max=cpu_max_t.movedim(0, -1)
6167*da0073e9SAndroid Build Coastguard Worker            )
6168*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(clamp_result, clamp_result_cpu)
6169*da0073e9SAndroid Build Coastguard Worker
6170*da0073e9SAndroid Build Coastguard Worker            # test inplace clamping
6171*da0073e9SAndroid Build Coastguard Worker            x.clamp_(min=200.0, max=600.0)
6172*da0073e9SAndroid Build Coastguard Worker            cpu_x.clamp_(min=200.0, max=600.0)
6173*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_x, x)
6174*da0073e9SAndroid Build Coastguard Worker
6175*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 5)
6176*da0073e9SAndroid Build Coastguard Worker
6177*da0073e9SAndroid Build Coastguard Worker    def test_divmode(self):
6178*da0073e9SAndroid Build Coastguard Worker        def helper(shape, rounding_mode):
6179*da0073e9SAndroid Build Coastguard Worker            for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
6180*da0073e9SAndroid Build Coastguard Worker                if ((rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) or
6181*da0073e9SAndroid Build Coastguard Worker                        (rounding_mode is not None and "trunc" in rounding_mode and dtype == torch.float16)) is False:
6182*da0073e9SAndroid Build Coastguard Worker                    cpu_x = None
6183*da0073e9SAndroid Build Coastguard Worker                    cpu_y = None
6184*da0073e9SAndroid Build Coastguard Worker                    if (dtype in [torch.float32, torch.float16]):
6185*da0073e9SAndroid Build Coastguard Worker                        cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
6186*da0073e9SAndroid Build Coastguard Worker                        cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
6187*da0073e9SAndroid Build Coastguard Worker                    else:
6188*da0073e9SAndroid Build Coastguard Worker                        cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
6189*da0073e9SAndroid Build Coastguard Worker                        cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
6190*da0073e9SAndroid Build Coastguard Worker
6191*da0073e9SAndroid Build Coastguard Worker                    mps_x = cpu_x.detach().clone().to('mps')
6192*da0073e9SAndroid Build Coastguard Worker                    # clamp to avoid division by 0
6193*da0073e9SAndroid Build Coastguard Worker                    mps_y = cpu_y.detach().clone().to('mps')
6194*da0073e9SAndroid Build Coastguard Worker
6195*da0073e9SAndroid Build Coastguard Worker                    if (rounding_mode == "floor_divide"):
6196*da0073e9SAndroid Build Coastguard Worker                        result_div_cpu = torch.floor_divide(cpu_x, cpu_y)
6197*da0073e9SAndroid Build Coastguard Worker                        result_div_mps = torch.floor_divide(mps_x, mps_y)
6198*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(result_div_mps, result_div_cpu)
6199*da0073e9SAndroid Build Coastguard Worker                    else:
6200*da0073e9SAndroid Build Coastguard Worker                        result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
6201*da0073e9SAndroid Build Coastguard Worker                        result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
6202*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(result_div_mps, result_div_cpu)
6203*da0073e9SAndroid Build Coastguard Worker
6204*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), None)
6205*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), "floor")
6206*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), "trunc")
6207*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), "floor_divide")
6208*da0073e9SAndroid Build Coastguard Worker
6209*da0073e9SAndroid Build Coastguard Worker    def test_rounding(self):
6210*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6211*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6212*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
6213*da0073e9SAndroid Build Coastguard Worker
6214*da0073e9SAndroid Build Coastguard Worker            result_floor_cpu = torch.floor(cpu_x)
6215*da0073e9SAndroid Build Coastguard Worker            result_floor_mps = torch.floor(mps_x)
6216*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_floor_mps, result_floor_cpu)
6217*da0073e9SAndroid Build Coastguard Worker
6218*da0073e9SAndroid Build Coastguard Worker            result_ceil_cpu = torch.ceil(cpu_x)
6219*da0073e9SAndroid Build Coastguard Worker            result_ceil_mps = torch.ceil(mps_x)
6220*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_ceil_mps, result_ceil_cpu)
6221*da0073e9SAndroid Build Coastguard Worker
6222*da0073e9SAndroid Build Coastguard Worker            result_trunc_cpu = torch.trunc(cpu_x)
6223*da0073e9SAndroid Build Coastguard Worker            result_trunc_mps = torch.trunc(mps_x)
6224*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_trunc_mps, result_trunc_cpu)
6225*da0073e9SAndroid Build Coastguard Worker
6226*da0073e9SAndroid Build Coastguard Worker            result_round_cpu = torch.round(cpu_x)
6227*da0073e9SAndroid Build Coastguard Worker            result_round_mps = torch.round(mps_x)
6228*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_round_mps, result_round_cpu)
6229*da0073e9SAndroid Build Coastguard Worker
6230*da0073e9SAndroid Build Coastguard Worker        helper((2, 6, 3, 5))
6231*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
6232*da0073e9SAndroid Build Coastguard Worker
6233*da0073e9SAndroid Build Coastguard Worker    def test_remainder(self):
6234*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.remainder(
6235*da0073e9SAndroid Build Coastguard Worker            torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="cpu"), torch.tensor(2, device="cpu", dtype=torch.int32))
6236*da0073e9SAndroid Build Coastguard Worker        res_mps = torch.remainder(
6237*da0073e9SAndroid Build Coastguard Worker            torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="mps"), torch.tensor(2, device="mps", dtype=torch.int32))
6238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_cpu, res_mps)
6239*da0073e9SAndroid Build Coastguard Worker
6240*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.remainder(
6241*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="cpu"), -1.5)
6242*da0073e9SAndroid Build Coastguard Worker        res_mps = torch.remainder(
6243*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5)
6244*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_cpu, res_mps)
6245*da0073e9SAndroid Build Coastguard Worker
6246*da0073e9SAndroid Build Coastguard Worker    def test_expand(self):
6247*da0073e9SAndroid Build Coastguard Worker        def helper(n, c):
6248*da0073e9SAndroid Build Coastguard Worker            values = [[1.0], [4.0], [7.0]]
6249*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.tensor(values, device='cpu')
6250*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6251*da0073e9SAndroid Build Coastguard Worker
6252*da0073e9SAndroid Build Coastguard Worker            strided_cpu = torch.as_strided(cpu_x, (3, 4), (1, 0))
6253*da0073e9SAndroid Build Coastguard Worker            strided_mps = torch.as_strided(x, (3, 4), (1, 0))
6254*da0073e9SAndroid Build Coastguard Worker
6255*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(strided_mps, strided_cpu)
6256*da0073e9SAndroid Build Coastguard Worker
6257*da0073e9SAndroid Build Coastguard Worker        helper(3, 1)
6258*da0073e9SAndroid Build Coastguard Worker
6259*da0073e9SAndroid Build Coastguard Worker    def test_im2col(self):
6260*da0073e9SAndroid Build Coastguard Worker        def helper(x):
6261*da0073e9SAndroid Build Coastguard Worker            return torch.nn.functional.unfold(x, kernel_size=(10, 15), dilation=2, padding=5, stride=3)
6262*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.rand(1, 1, 200, 100)
6263*da0073e9SAndroid Build Coastguard Worker        x = x_cpu.detach().clone().to('mps')
6264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(helper(x_cpu), helper(x))
6265*da0073e9SAndroid Build Coastguard Worker
6266*da0073e9SAndroid Build Coastguard Worker    def test_select(self):
6267*da0073e9SAndroid Build Coastguard Worker        def helper(n, c):
6268*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(n, c, device='cpu', dtype=torch.float, requires_grad=True)
6269*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
6270*da0073e9SAndroid Build Coastguard Worker
6271*da0073e9SAndroid Build Coastguard Worker            strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1))
6272*da0073e9SAndroid Build Coastguard Worker            strided_mps = torch.as_strided(x, (3, 1), (3, 1))
6273*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(strided_mps, strided_cpu)
6274*da0073e9SAndroid Build Coastguard Worker
6275*da0073e9SAndroid Build Coastguard Worker            strided_cpu = torch.as_strided(cpu_x, (1, 3), (3, 1))
6276*da0073e9SAndroid Build Coastguard Worker            strided_mps = torch.as_strided(x, (1, 3), (3, 1))
6277*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(strided_mps, strided_cpu)
6278*da0073e9SAndroid Build Coastguard Worker
6279*da0073e9SAndroid Build Coastguard Worker            strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1), storage_offset=1)
6280*da0073e9SAndroid Build Coastguard Worker            strided_mps = torch.as_strided(x, (3, 1), (3, 1), storage_offset=1)
6281*da0073e9SAndroid Build Coastguard Worker
6282*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(strided_mps, strided_cpu)
6283*da0073e9SAndroid Build Coastguard Worker
6284*da0073e9SAndroid Build Coastguard Worker        helper(3, 3)
6285*da0073e9SAndroid Build Coastguard Worker
6286*da0073e9SAndroid Build Coastguard Worker    def test_sort(self):
6287*da0073e9SAndroid Build Coastguard Worker        for SIZE in (4, 2049):
6288*da0073e9SAndroid Build Coastguard Worker            device = 'mps'
6289*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(4, SIZE, device=device)
6290*da0073e9SAndroid Build Coastguard Worker            res1val, res1ind = torch.sort(x)
6291*da0073e9SAndroid Build Coastguard Worker
6292*da0073e9SAndroid Build Coastguard Worker            res2val = torch.tensor((), device=device)
6293*da0073e9SAndroid Build Coastguard Worker            res2ind = torch.tensor((), device=device, dtype=torch.long)
6294*da0073e9SAndroid Build Coastguard Worker            torch.sort(x, out=(res2val, res2ind))
6295*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1val, res2val, atol=0, rtol=0)
6296*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
6297*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.argsort(x), res1ind)
6298*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.argsort(), res1ind)
6299*da0073e9SAndroid Build Coastguard Worker
6300*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
6301*da0073e9SAndroid Build Coastguard Worker                torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0],
6302*da0073e9SAndroid Build Coastguard Worker                torch.tensor((10, 20, 30, 40, 50), device=device),
6303*da0073e9SAndroid Build Coastguard Worker                atol=0, rtol=0
6304*da0073e9SAndroid Build Coastguard Worker            )
6305*da0073e9SAndroid Build Coastguard Worker
6306*da0073e9SAndroid Build Coastguard Worker    def test_upsample_nearest2d(self):
6307*da0073e9SAndroid Build Coastguard Worker        def helper(N, C, H, W, memory_format):
6308*da0073e9SAndroid Build Coastguard Worker            inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
6309*da0073e9SAndroid Build Coastguard Worker                                    requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
6310*da0073e9SAndroid Build Coastguard Worker            inputCPU.retain_grad()
6311*da0073e9SAndroid Build Coastguard Worker            inputMPS = inputCPU.detach().to('mps').requires_grad_()
6312*da0073e9SAndroid Build Coastguard Worker
6313*da0073e9SAndroid Build Coastguard Worker            values = [1, 2, 5, 10, 40]
6314*da0073e9SAndroid Build Coastguard Worker
6315*da0073e9SAndroid Build Coastguard Worker            for i in values:
6316*da0073e9SAndroid Build Coastguard Worker                for j in values:
6317*da0073e9SAndroid Build Coastguard Worker                    upsample_nearest2d = nn.UpsamplingNearest2d(scale_factor=(i, j))
6318*da0073e9SAndroid Build Coastguard Worker
6319*da0073e9SAndroid Build Coastguard Worker                    outputCPU = upsample_nearest2d(inputCPU)
6320*da0073e9SAndroid Build Coastguard Worker                    outputMPS = upsample_nearest2d(inputMPS)
6321*da0073e9SAndroid Build Coastguard Worker
6322*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(outputCPU, outputMPS)
6323*da0073e9SAndroid Build Coastguard Worker                    upsample_nearest2d = nn.UpsamplingNearest2d((i * H, j * W))
6324*da0073e9SAndroid Build Coastguard Worker
6325*da0073e9SAndroid Build Coastguard Worker                    outputCPU = upsample_nearest2d(inputCPU)
6326*da0073e9SAndroid Build Coastguard Worker                    outputMPS = upsample_nearest2d(inputMPS)
6327*da0073e9SAndroid Build Coastguard Worker
6328*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(outputCPU, outputMPS)
6329*da0073e9SAndroid Build Coastguard Worker
6330*da0073e9SAndroid Build Coastguard Worker                    outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
6331*da0073e9SAndroid Build Coastguard Worker                    outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
6332*da0073e9SAndroid Build Coastguard Worker
6333*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(inputCPU.grad, inputMPS.grad)
6334*da0073e9SAndroid Build Coastguard Worker
6335*da0073e9SAndroid Build Coastguard Worker        for memory_format in [torch.channels_last, torch.contiguous_format]:
6336*da0073e9SAndroid Build Coastguard Worker            helper(1, 1, 4, 4, memory_format=memory_format)
6337*da0073e9SAndroid Build Coastguard Worker            helper(7, 5, 3, 2, memory_format=memory_format)
6338*da0073e9SAndroid Build Coastguard Worker
6339*da0073e9SAndroid Build Coastguard Worker    def test_upsample_bilinear2d(self):
6340*da0073e9SAndroid Build Coastguard Worker        def helper(N, C, H, W):
6341*da0073e9SAndroid Build Coastguard Worker            inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
6342*da0073e9SAndroid Build Coastguard Worker                                    requires_grad=True).reshape(N, C, H, W)
6343*da0073e9SAndroid Build Coastguard Worker            inputCPU.retain_grad()
6344*da0073e9SAndroid Build Coastguard Worker            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
6345*da0073e9SAndroid Build Coastguard Worker
6346*da0073e9SAndroid Build Coastguard Worker            values = [1, 2, 5, 10, 40]
6347*da0073e9SAndroid Build Coastguard Worker
6348*da0073e9SAndroid Build Coastguard Worker            for i in values:
6349*da0073e9SAndroid Build Coastguard Worker                for j in values:
6350*da0073e9SAndroid Build Coastguard Worker                    upsample_bilinear2d = nn.UpsamplingBilinear2d(scale_factor=(i, j))
6351*da0073e9SAndroid Build Coastguard Worker
6352*da0073e9SAndroid Build Coastguard Worker                    outputCPU = upsample_bilinear2d(inputCPU)
6353*da0073e9SAndroid Build Coastguard Worker                    outputMPS = upsample_bilinear2d(inputMPS)
6354*da0073e9SAndroid Build Coastguard Worker
6355*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(outputCPU, outputMPS)
6356*da0073e9SAndroid Build Coastguard Worker
6357*da0073e9SAndroid Build Coastguard Worker                    upsample_bilinear2d = nn.UpsamplingBilinear2d((i * H, j * W))
6358*da0073e9SAndroid Build Coastguard Worker
6359*da0073e9SAndroid Build Coastguard Worker                    outputCPU = upsample_bilinear2d(inputCPU)
6360*da0073e9SAndroid Build Coastguard Worker                    outputMPS = upsample_bilinear2d(inputMPS)
6361*da0073e9SAndroid Build Coastguard Worker
6362*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(outputCPU, outputMPS)
6363*da0073e9SAndroid Build Coastguard Worker
6364*da0073e9SAndroid Build Coastguard Worker                    outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
6365*da0073e9SAndroid Build Coastguard Worker                    outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
6366*da0073e9SAndroid Build Coastguard Worker
6367*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(inputCPU.grad, inputMPS.grad)
6368*da0073e9SAndroid Build Coastguard Worker
6369*da0073e9SAndroid Build Coastguard Worker        helper(1, 1, 4, 4)
6370*da0073e9SAndroid Build Coastguard Worker        helper(7, 5, 3, 2)
6371*da0073e9SAndroid Build Coastguard Worker
6372*da0073e9SAndroid Build Coastguard Worker    def test_interpolate(self):
6373*da0073e9SAndroid Build Coastguard Worker        def helper(shape, output_size, scales, mode, align_corners=False):
6374*da0073e9SAndroid Build Coastguard Worker            inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6375*da0073e9SAndroid Build Coastguard Worker            inputCPU.retain_grad()
6376*da0073e9SAndroid Build Coastguard Worker            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
6377*da0073e9SAndroid Build Coastguard Worker
6378*da0073e9SAndroid Build Coastguard Worker            # align_corners is used for 2D interpolation only
6379*da0073e9SAndroid Build Coastguard Worker            if (align_corners is True and len(shape) > 3 and mode == 'bilinear'):
6380*da0073e9SAndroid Build Coastguard Worker                if scales is not None:
6381*da0073e9SAndroid Build Coastguard Worker                    outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode, align_corners=align_corners)
6382*da0073e9SAndroid Build Coastguard Worker                    outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode, align_corners=align_corners)
6383*da0073e9SAndroid Build Coastguard Worker                else:
6384*da0073e9SAndroid Build Coastguard Worker                    outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode, align_corners=align_corners)
6385*da0073e9SAndroid Build Coastguard Worker                    outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode, align_corners=align_corners)
6386*da0073e9SAndroid Build Coastguard Worker            elif scales is not None:
6387*da0073e9SAndroid Build Coastguard Worker                outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode)
6388*da0073e9SAndroid Build Coastguard Worker                outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode)
6389*da0073e9SAndroid Build Coastguard Worker            else:
6390*da0073e9SAndroid Build Coastguard Worker                outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode)
6391*da0073e9SAndroid Build Coastguard Worker                outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode)
6392*da0073e9SAndroid Build Coastguard Worker
6393*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(outputCPU, outputMPS)
6394*da0073e9SAndroid Build Coastguard Worker
6395*da0073e9SAndroid Build Coastguard Worker            # backward pass (chose 0.6 just to have the grad_output != 1)
6396*da0073e9SAndroid Build Coastguard Worker            outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
6397*da0073e9SAndroid Build Coastguard Worker            outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
6398*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(inputCPU.grad, inputMPS.grad)
6399*da0073e9SAndroid Build Coastguard Worker
6400*da0073e9SAndroid Build Coastguard Worker        # 1D interpolation
6401*da0073e9SAndroid Build Coastguard Worker        for mode in ['nearest', 'nearest-exact']:
6402*da0073e9SAndroid Build Coastguard Worker            helper([2, 3, 4], [3], None, mode)  # downsample with size
6403*da0073e9SAndroid Build Coastguard Worker            helper([2, 3, 4], [6], None, mode)  # upsample with size
6404*da0073e9SAndroid Build Coastguard Worker            helper([2, 3, 4], None, [0.6], mode)  # downsample with scale factor
6405*da0073e9SAndroid Build Coastguard Worker            helper([2, 3, 4], None, [1.7], mode)  # upsample with scale factor
6406*da0073e9SAndroid Build Coastguard Worker        # 2D interpolation
6407*da0073e9SAndroid Build Coastguard Worker        for mode in ['nearest', 'nearest-exact', 'bilinear']:
6408*da0073e9SAndroid Build Coastguard Worker            helper([2, 3, 4, 5], [3, 4], None, mode)  # downsample_nearest with size
6409*da0073e9SAndroid Build Coastguard Worker            helper([2, 3, 4, 5], [6, 7], None, mode)  # upsample_nearest with size
6410*da0073e9SAndroid Build Coastguard Worker            helper([2, 3, 4, 5], None, [0.6, 0.7], mode)  # downsample_nearest with scale factor
6411*da0073e9SAndroid Build Coastguard Worker            helper([2, 3, 4, 5], None, [1.4, 1.7], mode)  # upsample_nearest with scale factor
6412*da0073e9SAndroid Build Coastguard Worker        # align_corners=True
6413*da0073e9SAndroid Build Coastguard Worker        helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True)
6414*da0073e9SAndroid Build Coastguard Worker        helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True)
6415*da0073e9SAndroid Build Coastguard Worker
6416*da0073e9SAndroid Build Coastguard Worker    # Test concat forward
6417*da0073e9SAndroid Build Coastguard Worker    def test_cat1(self):
6418*da0073e9SAndroid Build Coastguard Worker        def helper(shape_x, shape_y, shape_z):
6419*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6420*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6421*da0073e9SAndroid Build Coastguard Worker
6422*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6423*da0073e9SAndroid Build Coastguard Worker            y = cpu_y.detach().clone().to('mps')
6424*da0073e9SAndroid Build Coastguard Worker
6425*da0073e9SAndroid Build Coastguard Worker            cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6426*da0073e9SAndroid Build Coastguard Worker            z = cpu_z.detach().clone().to('mps')
6427*da0073e9SAndroid Build Coastguard Worker
6428*da0073e9SAndroid Build Coastguard Worker            cat = torch.cat([x, y, z], dim=1)
6429*da0073e9SAndroid Build Coastguard Worker            cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
6430*da0073e9SAndroid Build Coastguard Worker
6431*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cat, cat_cpu)
6432*da0073e9SAndroid Build Coastguard Worker
6433*da0073e9SAndroid Build Coastguard Worker        helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
6434*da0073e9SAndroid Build Coastguard Worker        helper([2, 2, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
6435*da0073e9SAndroid Build Coastguard Worker        helper([0, 2, 4, 5], [0, 3, 4, 5], [0, 5, 4, 5])
6436*da0073e9SAndroid Build Coastguard Worker        helper([2, 2, 6, 5], [0], [2, 5, 6, 5])
6437*da0073e9SAndroid Build Coastguard Worker        helper([0], [2, 3, 6, 5], [2, 5, 6, 5])
6438*da0073e9SAndroid Build Coastguard Worker        helper([2, 3, 4, 5], [2, 5, 4, 5], [0])
6439*da0073e9SAndroid Build Coastguard Worker        helper([2, 2, 6, 5], [2, 0, 6, 5], [2, 5, 6, 5])
6440*da0073e9SAndroid Build Coastguard Worker        helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
6441*da0073e9SAndroid Build Coastguard Worker        helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 0, 6, 5])
6442*da0073e9SAndroid Build Coastguard Worker
6443*da0073e9SAndroid Build Coastguard Worker    # Test stack forward
6444*da0073e9SAndroid Build Coastguard Worker    def test_stack(self):
6445*da0073e9SAndroid Build Coastguard Worker        # All shapes must be same
6446*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dtype=torch.float32):
6447*da0073e9SAndroid Build Coastguard Worker
6448*da0073e9SAndroid Build Coastguard Worker            x, cpu_x = None, None
6449*da0073e9SAndroid Build Coastguard Worker            y, cpu_y = None, None
6450*da0073e9SAndroid Build Coastguard Worker            z, cpu_z = None, None
6451*da0073e9SAndroid Build Coastguard Worker
6452*da0073e9SAndroid Build Coastguard Worker            if (dtype not in [torch.float32, torch.bool]):
6453*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6454*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
6455*da0073e9SAndroid Build Coastguard Worker                cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6456*da0073e9SAndroid Build Coastguard Worker                y = cpu_y.detach().clone().to('mps')
6457*da0073e9SAndroid Build Coastguard Worker                cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6458*da0073e9SAndroid Build Coastguard Worker                z = cpu_z.detach().clone().to('mps')
6459*da0073e9SAndroid Build Coastguard Worker            elif (dtype == torch.bool):
6460*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6461*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
6462*da0073e9SAndroid Build Coastguard Worker                cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6463*da0073e9SAndroid Build Coastguard Worker                y = cpu_y.detach().clone().to('mps')
6464*da0073e9SAndroid Build Coastguard Worker                cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6465*da0073e9SAndroid Build Coastguard Worker                z = cpu_z.detach().clone().to('mps')
6466*da0073e9SAndroid Build Coastguard Worker            else:
6467*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6468*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps').requires_grad_()
6469*da0073e9SAndroid Build Coastguard Worker                cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6470*da0073e9SAndroid Build Coastguard Worker                y = cpu_y.detach().clone().to('mps').requires_grad_()
6471*da0073e9SAndroid Build Coastguard Worker                cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6472*da0073e9SAndroid Build Coastguard Worker                z = cpu_z.detach().clone().to('mps').requires_grad_()
6473*da0073e9SAndroid Build Coastguard Worker
6474*da0073e9SAndroid Build Coastguard Worker            stack = torch.stack([x, y, z], dim=1)
6475*da0073e9SAndroid Build Coastguard Worker            stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1)
6476*da0073e9SAndroid Build Coastguard Worker
6477*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(stack, stack_cpu)
6478*da0073e9SAndroid Build Coastguard Worker
6479*da0073e9SAndroid Build Coastguard Worker        helper([2, 8, 4, 5])
6480*da0073e9SAndroid Build Coastguard Worker        helper([2, 8, 4, 5], dtype=torch.float16)
6481*da0073e9SAndroid Build Coastguard Worker        helper([2, 8, 4, 5], dtype=torch.int32)
6482*da0073e9SAndroid Build Coastguard Worker        helper([2, 8, 4, 5], dtype=torch.int64)
6483*da0073e9SAndroid Build Coastguard Worker        helper([2, 8, 4, 5], dtype=torch.bool)
6484*da0073e9SAndroid Build Coastguard Worker        # Empty test - Currently failing! Empty tensor not handled!
6485*da0073e9SAndroid Build Coastguard Worker        # helper([0, 2, 4, 5])
6486*da0073e9SAndroid Build Coastguard Worker
6487*da0073e9SAndroid Build Coastguard Worker    # Test abs
6488*da0073e9SAndroid Build Coastguard Worker    def test_abs(self):
6489*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6490*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6491*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6492*da0073e9SAndroid Build Coastguard Worker
6493*da0073e9SAndroid Build Coastguard Worker            abs_result = torch.abs(x)
6494*da0073e9SAndroid Build Coastguard Worker            abs_result_cpu = torch.abs(cpu_x)
6495*da0073e9SAndroid Build Coastguard Worker
6496*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(abs_result, abs_result_cpu)
6497*da0073e9SAndroid Build Coastguard Worker
6498*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
6499*da0073e9SAndroid Build Coastguard Worker
6500*da0073e9SAndroid Build Coastguard Worker    def test_log(self):
6501*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6502*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6503*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6504*da0073e9SAndroid Build Coastguard Worker
6505*da0073e9SAndroid Build Coastguard Worker            log_result = torch.log(x)
6506*da0073e9SAndroid Build Coastguard Worker            log_result_cpu = torch.log(cpu_x)
6507*da0073e9SAndroid Build Coastguard Worker
6508*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_result, log_result_cpu)
6509*da0073e9SAndroid Build Coastguard Worker
6510*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
6511*da0073e9SAndroid Build Coastguard Worker
6512*da0073e9SAndroid Build Coastguard Worker    def test_log_ten(self):
6513*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6514*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6515*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6516*da0073e9SAndroid Build Coastguard Worker
6517*da0073e9SAndroid Build Coastguard Worker            log_ten_result = torch.log10(x)
6518*da0073e9SAndroid Build Coastguard Worker            log_ten_result_cpu = torch.log10(cpu_x)
6519*da0073e9SAndroid Build Coastguard Worker
6520*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_ten_result, log_ten_result_cpu)
6521*da0073e9SAndroid Build Coastguard Worker
6522*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
6523*da0073e9SAndroid Build Coastguard Worker
6524*da0073e9SAndroid Build Coastguard Worker    def test_log_two(self):
6525*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6526*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6527*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6528*da0073e9SAndroid Build Coastguard Worker
6529*da0073e9SAndroid Build Coastguard Worker            log_two_result = torch.log2(x)
6530*da0073e9SAndroid Build Coastguard Worker            log_two_result_cpu = torch.log2(cpu_x)
6531*da0073e9SAndroid Build Coastguard Worker
6532*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_two_result, log_two_result_cpu)
6533*da0073e9SAndroid Build Coastguard Worker
6534*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
6535*da0073e9SAndroid Build Coastguard Worker
6536*da0073e9SAndroid Build Coastguard Worker    def test_log1p(self):
6537*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6538*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6539*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6540*da0073e9SAndroid Build Coastguard Worker
6541*da0073e9SAndroid Build Coastguard Worker            log_result = torch.log1p(x)
6542*da0073e9SAndroid Build Coastguard Worker            log_result_cpu = torch.log1p(cpu_x)
6543*da0073e9SAndroid Build Coastguard Worker
6544*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_result, log_result_cpu)
6545*da0073e9SAndroid Build Coastguard Worker
6546*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
6547*da0073e9SAndroid Build Coastguard Worker
6548*da0073e9SAndroid Build Coastguard Worker    def test_logaddexp(self):
6549*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6550*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6551*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6552*da0073e9SAndroid Build Coastguard Worker
6553*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6554*da0073e9SAndroid Build Coastguard Worker            y = cpu_y.detach().clone().to('mps')
6555*da0073e9SAndroid Build Coastguard Worker
6556*da0073e9SAndroid Build Coastguard Worker            log_result = torch.logaddexp(x, y)
6557*da0073e9SAndroid Build Coastguard Worker            log_result_cpu = torch.logaddexp(cpu_x, cpu_y)
6558*da0073e9SAndroid Build Coastguard Worker
6559*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_result, log_result_cpu)
6560*da0073e9SAndroid Build Coastguard Worker
6561*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
6562*da0073e9SAndroid Build Coastguard Worker
6563*da0073e9SAndroid Build Coastguard Worker    def test_logaddexp2(self):
6564*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6565*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6566*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6567*da0073e9SAndroid Build Coastguard Worker
6568*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6569*da0073e9SAndroid Build Coastguard Worker            y = cpu_y.detach().clone().to('mps')
6570*da0073e9SAndroid Build Coastguard Worker
6571*da0073e9SAndroid Build Coastguard Worker            log_result = torch.logaddexp2(x, y)
6572*da0073e9SAndroid Build Coastguard Worker            log_result_cpu = torch.logaddexp2(cpu_x, cpu_y)
6573*da0073e9SAndroid Build Coastguard Worker
6574*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_result, log_result_cpu)
6575*da0073e9SAndroid Build Coastguard Worker
6576*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
6577*da0073e9SAndroid Build Coastguard Worker
6578*da0073e9SAndroid Build Coastguard Worker    def test_logsumexp(self):
6579*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6580*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6581*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6582*da0073e9SAndroid Build Coastguard Worker
6583*da0073e9SAndroid Build Coastguard Worker            log_result = torch.logsumexp(x, -1)
6584*da0073e9SAndroid Build Coastguard Worker            log_result_cpu = torch.logsumexp(cpu_x, -1)
6585*da0073e9SAndroid Build Coastguard Worker
6586*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(log_result, log_result_cpu)
6587*da0073e9SAndroid Build Coastguard Worker
6588*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
6589*da0073e9SAndroid Build Coastguard Worker
6590*da0073e9SAndroid Build Coastguard Worker    # Test concat forward
6591*da0073e9SAndroid Build Coastguard Worker    def test_cat2(self):
6592*da0073e9SAndroid Build Coastguard Worker
6593*da0073e9SAndroid Build Coastguard Worker        def helper1(shape_x, shape_y, shape_z, shape_w):
6594*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6595*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6596*da0073e9SAndroid Build Coastguard Worker
6597*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6598*da0073e9SAndroid Build Coastguard Worker            y = cpu_y.detach().clone().to('mps')
6599*da0073e9SAndroid Build Coastguard Worker
6600*da0073e9SAndroid Build Coastguard Worker            cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6601*da0073e9SAndroid Build Coastguard Worker            z = cpu_z.detach().clone().to('mps')
6602*da0073e9SAndroid Build Coastguard Worker
6603*da0073e9SAndroid Build Coastguard Worker            cpu_w = torch.randn(shape_w, device='cpu', dtype=torch.float, requires_grad=False)
6604*da0073e9SAndroid Build Coastguard Worker            w = cpu_w.detach().clone().to('mps')
6605*da0073e9SAndroid Build Coastguard Worker
6606*da0073e9SAndroid Build Coastguard Worker            cat = torch.cat([x, y, z, w], dim=1)
6607*da0073e9SAndroid Build Coastguard Worker            cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z, cpu_w], dim=1)
6608*da0073e9SAndroid Build Coastguard Worker
6609*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cat, cat_cpu)
6610*da0073e9SAndroid Build Coastguard Worker
6611*da0073e9SAndroid Build Coastguard Worker        def helper(shape_x, shape_y, shape_z):
6612*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6613*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6614*da0073e9SAndroid Build Coastguard Worker
6615*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6616*da0073e9SAndroid Build Coastguard Worker            y = cpu_y.detach().clone().to('mps')
6617*da0073e9SAndroid Build Coastguard Worker
6618*da0073e9SAndroid Build Coastguard Worker            cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6619*da0073e9SAndroid Build Coastguard Worker            z = cpu_z.detach().clone().to('mps')
6620*da0073e9SAndroid Build Coastguard Worker
6621*da0073e9SAndroid Build Coastguard Worker            cat = torch.cat([x, y, z], dim=1)
6622*da0073e9SAndroid Build Coastguard Worker            cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
6623*da0073e9SAndroid Build Coastguard Worker
6624*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cat, cat_cpu)
6625*da0073e9SAndroid Build Coastguard Worker
6626*da0073e9SAndroid Build Coastguard Worker        helper([2, 8, 4, 5], [2, 10, 4, 5], [2, 6, 4, 5])
6627*da0073e9SAndroid Build Coastguard Worker        helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
6628*da0073e9SAndroid Build Coastguard Worker        # Empty test - Currently failing! Empty tensor not handled!
6629*da0073e9SAndroid Build Coastguard Worker        # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5])
6630*da0073e9SAndroid Build Coastguard Worker
6631*da0073e9SAndroid Build Coastguard Worker    # Test isnan
6632*da0073e9SAndroid Build Coastguard Worker    def test_isnan(self):
6633*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6634*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6635*da0073e9SAndroid Build Coastguard Worker            nan_index = [random.randrange(0, shape[0])]
6636*da0073e9SAndroid Build Coastguard Worker            # make a selected row inf
6637*da0073e9SAndroid Build Coastguard Worker            cpu_x.index_put_(indices=[torch.tensor(nan_index)], values=torch.tensor(float('nan')))
6638*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6639*da0073e9SAndroid Build Coastguard Worker
6640*da0073e9SAndroid Build Coastguard Worker            isnan_result = torch.isnan(x)
6641*da0073e9SAndroid Build Coastguard Worker            isnan_result_cpu = torch.isnan(cpu_x)
6642*da0073e9SAndroid Build Coastguard Worker
6643*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(isnan_result, isnan_result_cpu)
6644*da0073e9SAndroid Build Coastguard Worker
6645*da0073e9SAndroid Build Coastguard Worker        helper((8, 2, 4, 5))
6646*da0073e9SAndroid Build Coastguard Worker
6647*da0073e9SAndroid Build Coastguard Worker    # Test reciprocal
6648*da0073e9SAndroid Build Coastguard Worker    def test_reciprocal(self):
6649*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6650*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6651*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
6652*da0073e9SAndroid Build Coastguard Worker
6653*da0073e9SAndroid Build Coastguard Worker            reciprocal_result = torch.reciprocal(x)
6654*da0073e9SAndroid Build Coastguard Worker            reciprocal_result_cpu = torch.reciprocal(cpu_x)
6655*da0073e9SAndroid Build Coastguard Worker
6656*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.ones_like(reciprocal_result_cpu)
6657*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
6658*da0073e9SAndroid Build Coastguard Worker
6659*da0073e9SAndroid Build Coastguard Worker            reciprocal_result.backward(gradient=grad)
6660*da0073e9SAndroid Build Coastguard Worker            reciprocal_result_cpu.backward(gradient=cpu_grad)
6661*da0073e9SAndroid Build Coastguard Worker
6662*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(reciprocal_result, reciprocal_result_cpu)
6663*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
6664*da0073e9SAndroid Build Coastguard Worker
6665*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
6666*da0073e9SAndroid Build Coastguard Worker
6667*da0073e9SAndroid Build Coastguard Worker    # Test sqrt
6668*da0073e9SAndroid Build Coastguard Worker    def test_sqrt(self):
6669*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6670*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6671*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
6672*da0073e9SAndroid Build Coastguard Worker
6673*da0073e9SAndroid Build Coastguard Worker            sqrt_result = torch.sqrt(x)
6674*da0073e9SAndroid Build Coastguard Worker            sqrt_result_cpu = torch.sqrt(cpu_x)
6675*da0073e9SAndroid Build Coastguard Worker
6676*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.ones_like(sqrt_result_cpu)
6677*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
6678*da0073e9SAndroid Build Coastguard Worker
6679*da0073e9SAndroid Build Coastguard Worker            sqrt_result.backward(gradient=grad)
6680*da0073e9SAndroid Build Coastguard Worker            sqrt_result_cpu.backward(gradient=cpu_grad)
6681*da0073e9SAndroid Build Coastguard Worker
6682*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sqrt_result, sqrt_result_cpu)
6683*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
6684*da0073e9SAndroid Build Coastguard Worker
6685*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
6686*da0073e9SAndroid Build Coastguard Worker
6687*da0073e9SAndroid Build Coastguard Worker    # Test selu, elu, celu
6688*da0073e9SAndroid Build Coastguard Worker    def test_elu(self):
6689*da0073e9SAndroid Build Coastguard Worker        def helper(shape, alpha=1.0, memory_format=torch.contiguous_format):
6690*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
6691*da0073e9SAndroid Build Coastguard Worker            cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_()
6692*da0073e9SAndroid Build Coastguard Worker
6693*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_(True)
6694*da0073e9SAndroid Build Coastguard Worker            for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]:
6695*da0073e9SAndroid Build Coastguard Worker                elu_result = activation_func(x)
6696*da0073e9SAndroid Build Coastguard Worker                elu_result_cpu = activation_func(cpu_x)
6697*da0073e9SAndroid Build Coastguard Worker
6698*da0073e9SAndroid Build Coastguard Worker                cpu_grad = torch.randn(elu_result_cpu.shape)
6699*da0073e9SAndroid Build Coastguard Worker                grad = cpu_grad.to('mps')
6700*da0073e9SAndroid Build Coastguard Worker
6701*da0073e9SAndroid Build Coastguard Worker                elu_result.backward(gradient=grad)
6702*da0073e9SAndroid Build Coastguard Worker                elu_result_cpu.backward(gradient=cpu_grad)
6703*da0073e9SAndroid Build Coastguard Worker
6704*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(elu_result, elu_result_cpu)
6705*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.grad, cpu_x.grad)
6706*da0073e9SAndroid Build Coastguard Worker
6707*da0073e9SAndroid Build Coastguard Worker        # Test empty shape too
6708*da0073e9SAndroid Build Coastguard Worker        for memory_fromat in [torch.channels_last, torch.contiguous_format]:
6709*da0073e9SAndroid Build Coastguard Worker            for shape in [(2, 8, 4, 5)]:
6710*da0073e9SAndroid Build Coastguard Worker                for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
6711*da0073e9SAndroid Build Coastguard Worker                    helper(shape, alpha, memory_fromat)
6712*da0073e9SAndroid Build Coastguard Worker
6713*da0073e9SAndroid Build Coastguard Worker    def test_elu_strided_output(self):
6714*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/124834
6715*da0073e9SAndroid Build Coastguard Worker        elu_input = torch.randn(1, 1024, 500)
6716*da0073e9SAndroid Build Coastguard Worker        alpha = float(1)
6717*da0073e9SAndroid Build Coastguard Worker        inplace = False
6718*da0073e9SAndroid Build Coastguard Worker
6719*da0073e9SAndroid Build Coastguard Worker        elu_input_noncontiguous = elu_input.transpose(1, 2)
6720*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
6721*da0073e9SAndroid Build Coastguard Worker            F.elu(elu_input_noncontiguous.to('cpu'), alpha, inplace),
6722*da0073e9SAndroid Build Coastguard Worker            F.elu(elu_input_noncontiguous.to('mps'), alpha, inplace)
6723*da0073e9SAndroid Build Coastguard Worker        )
6724*da0073e9SAndroid Build Coastguard Worker
6725*da0073e9SAndroid Build Coastguard Worker    # Test glu
6726*da0073e9SAndroid Build Coastguard Worker    def test_glu(self):
6727*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dim=0):
6728*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6729*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
6730*da0073e9SAndroid Build Coastguard Worker
6731*da0073e9SAndroid Build Coastguard Worker            for activation_func in [torch.nn.GLU(dim=dim)]:
6732*da0073e9SAndroid Build Coastguard Worker                glu_result = activation_func(x)
6733*da0073e9SAndroid Build Coastguard Worker                glu_result_cpu = activation_func(cpu_x)
6734*da0073e9SAndroid Build Coastguard Worker
6735*da0073e9SAndroid Build Coastguard Worker                cpu_grad = torch.randn(glu_result_cpu.shape)
6736*da0073e9SAndroid Build Coastguard Worker                grad = cpu_grad.to('mps')
6737*da0073e9SAndroid Build Coastguard Worker
6738*da0073e9SAndroid Build Coastguard Worker                glu_result.backward(gradient=grad)
6739*da0073e9SAndroid Build Coastguard Worker                glu_result_cpu.backward(gradient=cpu_grad)
6740*da0073e9SAndroid Build Coastguard Worker
6741*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(glu_result, glu_result_cpu)
6742*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.grad, cpu_x.grad)
6743*da0073e9SAndroid Build Coastguard Worker
6744*da0073e9SAndroid Build Coastguard Worker        for shape in [[4], (2, 4), (2, 8, 4, 6)]:
6745*da0073e9SAndroid Build Coastguard Worker            for dim in range(len(shape)):
6746*da0073e9SAndroid Build Coastguard Worker                helper(shape, dim)
6747*da0073e9SAndroid Build Coastguard Worker
6748*da0073e9SAndroid Build Coastguard Worker    # Test softplus
6749*da0073e9SAndroid Build Coastguard Worker    def test_softplus(self):
6750*da0073e9SAndroid Build Coastguard Worker        def helper(shape, beta, threshold, dtype):
6751*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6752*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
6753*da0073e9SAndroid Build Coastguard Worker
6754*da0073e9SAndroid Build Coastguard Worker            softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
6755*da0073e9SAndroid Build Coastguard Worker            softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x)
6756*da0073e9SAndroid Build Coastguard Worker
6757*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(softplus_result.shape)
6758*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
6759*da0073e9SAndroid Build Coastguard Worker
6760*da0073e9SAndroid Build Coastguard Worker            softplus_result.backward(gradient=grad)
6761*da0073e9SAndroid Build Coastguard Worker            softplus_result_cpu.backward(gradient=cpu_grad)
6762*da0073e9SAndroid Build Coastguard Worker
6763*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(softplus_result, softplus_result_cpu)
6764*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
6765*da0073e9SAndroid Build Coastguard Worker
6766*da0073e9SAndroid Build Coastguard Worker        # Test empty shape too
6767*da0073e9SAndroid Build Coastguard Worker        for shape, beta, threshold, dtype in product(
6768*da0073e9SAndroid Build Coastguard Worker            [(), (2, 3), (10, 10), (2, 3, 4, 5)],
6769*da0073e9SAndroid Build Coastguard Worker            [0.5, 1, 2, 3, 4],
6770*da0073e9SAndroid Build Coastguard Worker            [0.5, 20, 30, 40, 50],
6771*da0073e9SAndroid Build Coastguard Worker            [torch.float16, torch.float32]
6772*da0073e9SAndroid Build Coastguard Worker        ):
6773*da0073e9SAndroid Build Coastguard Worker            helper(shape, beta, threshold, dtype)
6774*da0073e9SAndroid Build Coastguard Worker
6775*da0073e9SAndroid Build Coastguard Worker    # Test silu
6776*da0073e9SAndroid Build Coastguard Worker
6777*da0073e9SAndroid Build Coastguard Worker    def test_silu(self):
6778*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
6779*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6780*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
6781*da0073e9SAndroid Build Coastguard Worker
6782*da0073e9SAndroid Build Coastguard Worker            silu_result = torch.nn.SiLU()(x)
6783*da0073e9SAndroid Build Coastguard Worker            silu_result_cpu = torch.nn.SiLU()(cpu_x)
6784*da0073e9SAndroid Build Coastguard Worker
6785*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(silu_result_cpu.shape)
6786*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
6787*da0073e9SAndroid Build Coastguard Worker
6788*da0073e9SAndroid Build Coastguard Worker            silu_result.backward(gradient=grad)
6789*da0073e9SAndroid Build Coastguard Worker            silu_result_cpu.backward(gradient=cpu_grad)
6790*da0073e9SAndroid Build Coastguard Worker
6791*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(silu_result, silu_result_cpu)
6792*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
6793*da0073e9SAndroid Build Coastguard Worker
6794*da0073e9SAndroid Build Coastguard Worker        # Test empty shape too
6795*da0073e9SAndroid Build Coastguard Worker        for shape in [[], (2, 3), (2, 8, 4, 5)]:
6796*da0073e9SAndroid Build Coastguard Worker            helper(shape)
6797*da0073e9SAndroid Build Coastguard Worker
6798*da0073e9SAndroid Build Coastguard Worker    def test_cast_mps_to_cpu(self):
6799*da0073e9SAndroid Build Coastguard Worker        def helper(src_dtype, dst_dtype):
6800*da0073e9SAndroid Build Coastguard Worker            input = torch.rand((1, 3, 128, 128), dtype=src_dtype)
6801*da0073e9SAndroid Build Coastguard Worker            input_cast_mps = input.to('mps')
6802*da0073e9SAndroid Build Coastguard Worker            input_cast_cpu = input_cast_mps.to('cpu', dtype=dst_dtype)
6803*da0073e9SAndroid Build Coastguard Worker
6804*da0073e9SAndroid Build Coastguard Worker            # needs to match the initial Tensor
6805*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input_cast_cpu, input.to(dtype=dst_dtype))
6806*da0073e9SAndroid Build Coastguard Worker        helper(torch.half, torch.float)
6807*da0073e9SAndroid Build Coastguard Worker        helper(torch.float, torch.half)
6808*da0073e9SAndroid Build Coastguard Worker
6809*da0073e9SAndroid Build Coastguard Worker    def test_cast_mps_to_mps(self):
6810*da0073e9SAndroid Build Coastguard Worker        def helper(src_dtype, dst_dtype):
6811*da0073e9SAndroid Build Coastguard Worker            input_cpu = torch.rand((1, 3, 128, 128), dtype=src_dtype)
6812*da0073e9SAndroid Build Coastguard Worker            input_mps = input_cpu.to('mps')
6813*da0073e9SAndroid Build Coastguard Worker            output_mps = input_mps.to(dtype=dst_dtype)
6814*da0073e9SAndroid Build Coastguard Worker            output_cpu = input_cpu.to(dtype=dst_dtype)
6815*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_mps.cpu(), output_cpu)
6816*da0073e9SAndroid Build Coastguard Worker        helper(torch.half, torch.float)
6817*da0073e9SAndroid Build Coastguard Worker        helper(torch.float, torch.half)
6818*da0073e9SAndroid Build Coastguard Worker        helper(torch.half, torch.long)
6819*da0073e9SAndroid Build Coastguard Worker        helper(torch.float, torch.int)
6820*da0073e9SAndroid Build Coastguard Worker
6821*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool2d_count_include_pad(self):
6822*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.randn((1, 3, 9, 9), device='cpu', dtype=torch.float, requires_grad=True)
6823*da0073e9SAndroid Build Coastguard Worker        x = cpu_x.detach().clone().to('mps').requires_grad_()
6824*da0073e9SAndroid Build Coastguard Worker        pool = torch.nn.AvgPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), ceil_mode=True, count_include_pad=True)
6825*da0073e9SAndroid Build Coastguard Worker        ref_y = pool(cpu_x)
6826*da0073e9SAndroid Build Coastguard Worker        y = pool(x)
6827*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, ref_y)
6828*da0073e9SAndroid Build Coastguard Worker        cpu_grad = torch.randn(ref_y.shape)
6829*da0073e9SAndroid Build Coastguard Worker        grad = cpu_grad.to('mps')
6830*da0073e9SAndroid Build Coastguard Worker        ref_y.backward(gradient=cpu_grad)
6831*da0073e9SAndroid Build Coastguard Worker        y.backward(gradient=grad)
6832*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, cpu_x.grad)
6833*da0073e9SAndroid Build Coastguard Worker
6834*da0073e9SAndroid Build Coastguard Worker    # Test adaptive avg pool2d - when the input size is a multiple of output size
6835*da0073e9SAndroid Build Coastguard Worker    # Not testing for channels last right now
6836*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_avg_pool2d_simple(self):
6837*da0073e9SAndroid Build Coastguard Worker        def helper(input_shape, out_shape, channels_last):
6838*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
6839*da0073e9SAndroid Build Coastguard Worker            if (channels_last):
6840*da0073e9SAndroid Build Coastguard Worker                cpu_x = cpu_x.to(memory_format=torch.channels_last)
6841*da0073e9SAndroid Build Coastguard Worker                cpu_x.retain_grad()
6842*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
6843*da0073e9SAndroid Build Coastguard Worker
6844*da0073e9SAndroid Build Coastguard Worker            avg_result = torch.nn.AdaptiveAvgPool2d(out_shape)(x)
6845*da0073e9SAndroid Build Coastguard Worker            avg_result_cpu = torch.nn.AdaptiveAvgPool2d(out_shape)(cpu_x)
6846*da0073e9SAndroid Build Coastguard Worker
6847*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(avg_result_cpu.shape)
6848*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
6849*da0073e9SAndroid Build Coastguard Worker
6850*da0073e9SAndroid Build Coastguard Worker            avg_result.backward(gradient=grad)
6851*da0073e9SAndroid Build Coastguard Worker            avg_result_cpu.backward(gradient=cpu_grad)
6852*da0073e9SAndroid Build Coastguard Worker
6853*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(avg_result, avg_result_cpu)
6854*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
6855*da0073e9SAndroid Build Coastguard Worker
6856*da0073e9SAndroid Build Coastguard Worker        helper((2, 2, 4, 4), (2, 2), False)
6857*da0073e9SAndroid Build Coastguard Worker        helper((2, 2, 9, 9), (3, 3), False)
6858*da0073e9SAndroid Build Coastguard Worker        helper((2, 2, 9, 9), (9, 9), False)
6859*da0073e9SAndroid Build Coastguard Worker        helper((2, 2, 16, 16), (2, 2), False)
6860*da0073e9SAndroid Build Coastguard Worker        helper((2, 2, 16, 16), (2, 16), False)
6861*da0073e9SAndroid Build Coastguard Worker
6862*da0073e9SAndroid Build Coastguard Worker        helper((2, 16, 16), (4, 4), False)
6863*da0073e9SAndroid Build Coastguard Worker
6864*da0073e9SAndroid Build Coastguard Worker        # Output shape larger than input shape
6865*da0073e9SAndroid Build Coastguard Worker
6866*da0073e9SAndroid Build Coastguard Worker        helper((2, 2, 4, 4), (8, 8), False)
6867*da0073e9SAndroid Build Coastguard Worker        helper((2, 2, 2, 2), (4, 4), False)
6868*da0073e9SAndroid Build Coastguard Worker        helper((2, 2, 3, 3), (9, 9), False)
6869*da0073e9SAndroid Build Coastguard Worker        helper((2, 2, 2, 2), (16, 16), False)
6870*da0073e9SAndroid Build Coastguard Worker        helper((2, 2, 2, 16), (16, 16), False)
6871*da0073e9SAndroid Build Coastguard Worker
6872*da0073e9SAndroid Build Coastguard Worker        helper((2, 4, 4), (16, 16), False)
6873*da0073e9SAndroid Build Coastguard Worker
6874*da0073e9SAndroid Build Coastguard Worker        try:
6875*da0073e9SAndroid Build Coastguard Worker            helper((2, 2, 3, 3), (7, 7), False)
6876*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
6877*da0073e9SAndroid Build Coastguard Worker            pass
6878*da0073e9SAndroid Build Coastguard Worker
6879*da0073e9SAndroid Build Coastguard Worker    # Test max avg pool2d - when the input size is a multiple of output size
6880*da0073e9SAndroid Build Coastguard Worker    # Not testing for channels last right now
6881*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_max_pool2d_simple(self):
6882*da0073e9SAndroid Build Coastguard Worker        def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
6883*da0073e9SAndroid Build Coastguard Worker            cpu_x = None
6884*da0073e9SAndroid Build Coastguard Worker            if (dtype in [torch.float16, torch.float32]):
6885*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
6886*da0073e9SAndroid Build Coastguard Worker            else:
6887*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
6888*da0073e9SAndroid Build Coastguard Worker            if (channels_last):
6889*da0073e9SAndroid Build Coastguard Worker                cpu_x = cpu_x.to(memory_format=torch.channels_last)
6890*da0073e9SAndroid Build Coastguard Worker                cpu_x.retain_grad()
6891*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
6892*da0073e9SAndroid Build Coastguard Worker
6893*da0073e9SAndroid Build Coastguard Worker            max_result, max_indices = None, None
6894*da0073e9SAndroid Build Coastguard Worker            max_result_cpu, max_indices_cpu = None, None
6895*da0073e9SAndroid Build Coastguard Worker
6896*da0073e9SAndroid Build Coastguard Worker            if (return_indices):
6897*da0073e9SAndroid Build Coastguard Worker                max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
6898*da0073e9SAndroid Build Coastguard Worker                max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
6899*da0073e9SAndroid Build Coastguard Worker            else:
6900*da0073e9SAndroid Build Coastguard Worker                max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
6901*da0073e9SAndroid Build Coastguard Worker                max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
6902*da0073e9SAndroid Build Coastguard Worker
6903*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(max_result_cpu.shape)
6904*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
6905*da0073e9SAndroid Build Coastguard Worker
6906*da0073e9SAndroid Build Coastguard Worker            max_result.backward(gradient=grad)
6907*da0073e9SAndroid Build Coastguard Worker            max_result_cpu.backward(gradient=cpu_grad)
6908*da0073e9SAndroid Build Coastguard Worker
6909*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(max_result, max_result_cpu)
6910*da0073e9SAndroid Build Coastguard Worker            if (return_indices):
6911*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(max_indices, max_indices_cpu)
6912*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
6913*da0073e9SAndroid Build Coastguard Worker
6914*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float32]:
6915*da0073e9SAndroid Build Coastguard Worker            for return_indices in [False, True]:
6916*da0073e9SAndroid Build Coastguard Worker                helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
6917*da0073e9SAndroid Build Coastguard Worker                helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
6918*da0073e9SAndroid Build Coastguard Worker                helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
6919*da0073e9SAndroid Build Coastguard Worker                helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
6920*da0073e9SAndroid Build Coastguard Worker                helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
6921*da0073e9SAndroid Build Coastguard Worker                helper((2, 16, 16), (4, 4), return_indices, dtype)
6922*da0073e9SAndroid Build Coastguard Worker
6923*da0073e9SAndroid Build Coastguard Worker    def test_gelu_simple(self):
6924*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dtype=torch.float, contiguous=True):
6925*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
6926*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6927*da0073e9SAndroid Build Coastguard Worker
6928*da0073e9SAndroid Build Coastguard Worker            if not contiguous and (0 not in shape and len(shape) >= 2):
6929*da0073e9SAndroid Build Coastguard Worker                # Tranposing will make the tensor non-contiguous
6930*da0073e9SAndroid Build Coastguard Worker                cpu_x = cpu_x.transpose(0, 1)
6931*da0073e9SAndroid Build Coastguard Worker                x = x.transpose(0, 1)
6932*da0073e9SAndroid Build Coastguard Worker                assert not x.is_contiguous()
6933*da0073e9SAndroid Build Coastguard Worker
6934*da0073e9SAndroid Build Coastguard Worker            cpu_x.requires_grad_()
6935*da0073e9SAndroid Build Coastguard Worker            x.requires_grad_()
6936*da0073e9SAndroid Build Coastguard Worker
6937*da0073e9SAndroid Build Coastguard Worker            gelu_result = torch.nn.GELU()(x)
6938*da0073e9SAndroid Build Coastguard Worker            # GELU is not supported on CPU, so cast it to float
6939*da0073e9SAndroid Build Coastguard Worker            gelu_result_cpu = torch.nn.GELU()(cpu_x.to(torch.float))
6940*da0073e9SAndroid Build Coastguard Worker
6941*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.ones_like(gelu_result_cpu)
6942*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
6943*da0073e9SAndroid Build Coastguard Worker
6944*da0073e9SAndroid Build Coastguard Worker            gelu_result.backward(gradient=grad)
6945*da0073e9SAndroid Build Coastguard Worker            gelu_result_cpu.backward(gradient=cpu_grad)
6946*da0073e9SAndroid Build Coastguard Worker
6947*da0073e9SAndroid Build Coastguard Worker            atol = 1e-5 if dtype == torch.float else 1e-2
6948*da0073e9SAndroid Build Coastguard Worker            rtol = 1e-3 if dtype == torch.float else 1e-2
6949*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol)
6950*da0073e9SAndroid Build Coastguard Worker
6951*da0073e9SAndroid Build Coastguard Worker            assert x.grad is not None  # Check that the grad is well-populated
6952*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
6953*da0073e9SAndroid Build Coastguard Worker
6954*da0073e9SAndroid Build Coastguard Worker        # Test empty shape too
6955*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.half]:
6956*da0073e9SAndroid Build Coastguard Worker            for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
6957*da0073e9SAndroid Build Coastguard Worker                for contiguous in [True, False]:
6958*da0073e9SAndroid Build Coastguard Worker                    helper(shape, dtype, contiguous)
6959*da0073e9SAndroid Build Coastguard Worker        # Test that gelu would raise an assert for integral types
6960*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
6961*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))
6962*da0073e9SAndroid Build Coastguard Worker
6963*da0073e9SAndroid Build Coastguard Worker    def test_mish_simple(self):
6964*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dtype=torch.float, contiguous=True):
6965*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
6966*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
6967*da0073e9SAndroid Build Coastguard Worker
6968*da0073e9SAndroid Build Coastguard Worker            if not contiguous and (0 not in shape and len(shape) >= 2):
6969*da0073e9SAndroid Build Coastguard Worker                # Tranposing will make the tensor non-contiguous
6970*da0073e9SAndroid Build Coastguard Worker                cpu_x = cpu_x.transpose(0, 1)
6971*da0073e9SAndroid Build Coastguard Worker                x = x.transpose(0, 1)
6972*da0073e9SAndroid Build Coastguard Worker                assert not x.is_contiguous()
6973*da0073e9SAndroid Build Coastguard Worker
6974*da0073e9SAndroid Build Coastguard Worker            cpu_x.requires_grad_()
6975*da0073e9SAndroid Build Coastguard Worker            x.requires_grad_()
6976*da0073e9SAndroid Build Coastguard Worker
6977*da0073e9SAndroid Build Coastguard Worker            mish_result = torch.nn.Mish()(x)
6978*da0073e9SAndroid Build Coastguard Worker            mish_result_cpu = torch.nn.Mish()(cpu_x)
6979*da0073e9SAndroid Build Coastguard Worker
6980*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.ones_like(mish_result_cpu)
6981*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
6982*da0073e9SAndroid Build Coastguard Worker
6983*da0073e9SAndroid Build Coastguard Worker            mish_result.backward(gradient=grad)
6984*da0073e9SAndroid Build Coastguard Worker            mish_result_cpu.backward(gradient=cpu_grad)
6985*da0073e9SAndroid Build Coastguard Worker
6986*da0073e9SAndroid Build Coastguard Worker            atol = 1e-5 if dtype == torch.float else 1e-2
6987*da0073e9SAndroid Build Coastguard Worker            rtol = 1e-3 if dtype == torch.float else 1e-2
6988*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol)
6989*da0073e9SAndroid Build Coastguard Worker
6990*da0073e9SAndroid Build Coastguard Worker            assert x.grad is not None  # Check that the grad is well-populated
6991*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
6992*da0073e9SAndroid Build Coastguard Worker
6993*da0073e9SAndroid Build Coastguard Worker        # Test empty shape too
6994*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.half]:
6995*da0073e9SAndroid Build Coastguard Worker            for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
6996*da0073e9SAndroid Build Coastguard Worker                for contiguous in [True, False]:
6997*da0073e9SAndroid Build Coastguard Worker                    helper(shape, dtype, contiguous)
6998*da0073e9SAndroid Build Coastguard Worker
6999*da0073e9SAndroid Build Coastguard Worker    def test_gelu(self):
7000*da0073e9SAndroid Build Coastguard Worker        def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
7001*da0073e9SAndroid Build Coastguard Worker            numpy_dtype = {
7002*da0073e9SAndroid Build Coastguard Worker                torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double
7003*da0073e9SAndroid Build Coastguard Worker            }[dtype]
7004*da0073e9SAndroid Build Coastguard Worker            devices = ['cpu']
7005*da0073e9SAndroid Build Coastguard Worker            devices += ['mps']
7006*da0073e9SAndroid Build Coastguard Worker
7007*da0073e9SAndroid Build Coastguard Worker            def _gelu_ref(X):
7008*da0073e9SAndroid Build Coastguard Worker                return X * stats.norm.cdf(X)  # noqa: F821
7009*da0073e9SAndroid Build Coastguard Worker
7010*da0073e9SAndroid Build Coastguard Worker            for d in devices:
7011*da0073e9SAndroid Build Coastguard Worker                X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
7012*da0073e9SAndroid Build Coastguard Worker                res = X
7013*da0073e9SAndroid Build Coastguard Worker                ref = (X.to(numpy_dtype).cpu().detach().numpy())
7014*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
7015*da0073e9SAndroid Build Coastguard Worker
7016*da0073e9SAndroid Build Coastguard Worker        for n in [1, 5, 10]:
7017*da0073e9SAndroid Build Coastguard Worker            for m in [1, 5, 10]:
7018*da0073e9SAndroid Build Coastguard Worker                _test_gelu(n, m, torch.float32, True)
7019*da0073e9SAndroid Build Coastguard Worker                _test_gelu(n, m, torch.float32, False)
7020*da0073e9SAndroid Build Coastguard Worker
7021*da0073e9SAndroid Build Coastguard Worker        # Test multi threaded
7022*da0073e9SAndroid Build Coastguard Worker        num_threads = torch.get_num_threads()
7023*da0073e9SAndroid Build Coastguard Worker        torch.set_num_threads(4)
7024*da0073e9SAndroid Build Coastguard Worker        try:
7025*da0073e9SAndroid Build Coastguard Worker            _test_gelu(32, 32, torch.float32, False)
7026*da0073e9SAndroid Build Coastguard Worker        finally:
7027*da0073e9SAndroid Build Coastguard Worker            torch.set_num_threads(num_threads)
7028*da0073e9SAndroid Build Coastguard Worker
7029*da0073e9SAndroid Build Coastguard Worker    def test_gelu_tanh(self):
7030*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
7031*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
7032*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
7033*da0073e9SAndroid Build Coastguard Worker
7034*da0073e9SAndroid Build Coastguard Worker            gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh')
7035*da0073e9SAndroid Build Coastguard Worker            gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh')
7036*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu)
7037*da0073e9SAndroid Build Coastguard Worker
7038*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
7039*da0073e9SAndroid Build Coastguard Worker
7040*da0073e9SAndroid Build Coastguard Worker    # Test hardtanh
7041*da0073e9SAndroid Build Coastguard Worker    def test_hardtanh(self):
7042*da0073e9SAndroid Build Coastguard Worker        def helper(shape, min_val, max_val, inplace=False):
7043*da0073e9SAndroid Build Coastguard Worker            cpu_x = None
7044*da0073e9SAndroid Build Coastguard Worker            x = None
7045*da0073e9SAndroid Build Coastguard Worker
7046*da0073e9SAndroid Build Coastguard Worker            if (not inplace):
7047*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7048*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps').requires_grad_()
7049*da0073e9SAndroid Build Coastguard Worker            else:
7050*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7051*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
7052*da0073e9SAndroid Build Coastguard Worker
7053*da0073e9SAndroid Build Coastguard Worker            hardtanh_result = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(x)
7054*da0073e9SAndroid Build Coastguard Worker            hardtanh_result_cpu = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(cpu_x)
7055*da0073e9SAndroid Build Coastguard Worker
7056*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(hardtanh_result, hardtanh_result_cpu)
7057*da0073e9SAndroid Build Coastguard Worker
7058*da0073e9SAndroid Build Coastguard Worker            if (not inplace):
7059*da0073e9SAndroid Build Coastguard Worker                cpu_grad = torch.randn(hardtanh_result_cpu.shape)
7060*da0073e9SAndroid Build Coastguard Worker                grad = cpu_grad.to('mps')
7061*da0073e9SAndroid Build Coastguard Worker                hardtanh_result.backward(gradient=grad)
7062*da0073e9SAndroid Build Coastguard Worker                hardtanh_result_cpu.backward(gradient=cpu_grad)
7063*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.grad, cpu_x.grad)
7064*da0073e9SAndroid Build Coastguard Worker
7065*da0073e9SAndroid Build Coastguard Worker        # Test empty shape too
7066*da0073e9SAndroid Build Coastguard Worker        for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
7067*da0073e9SAndroid Build Coastguard Worker            for min_val, max_val in zip([-1, -2, 3], [1, -1, 4]):
7068*da0073e9SAndroid Build Coastguard Worker                helper(shape, min_val, max_val)
7069*da0073e9SAndroid Build Coastguard Worker                helper(shape, min_val, max_val, inplace=True)
7070*da0073e9SAndroid Build Coastguard Worker
7071*da0073e9SAndroid Build Coastguard Worker    def test_hardswish(self):
7072*da0073e9SAndroid Build Coastguard Worker        def helper(shape, inplace=False, requires_grad=True):
7073*da0073e9SAndroid Build Coastguard Worker            m = nn.Hardswish(inplace=inplace)
7074*da0073e9SAndroid Build Coastguard Worker
7075*da0073e9SAndroid Build Coastguard Worker            input_cpu = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=requires_grad)
7076*da0073e9SAndroid Build Coastguard Worker            input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
7077*da0073e9SAndroid Build Coastguard Worker
7078*da0073e9SAndroid Build Coastguard Worker            if inplace and requires_grad:  # check that both raise runtime error
7079*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: m(input_cpu))
7080*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: m(input_mps))
7081*da0073e9SAndroid Build Coastguard Worker                return
7082*da0073e9SAndroid Build Coastguard Worker
7083*da0073e9SAndroid Build Coastguard Worker            output_cpu = m(input_cpu)
7084*da0073e9SAndroid Build Coastguard Worker            output_mps = m(input_mps)
7085*da0073e9SAndroid Build Coastguard Worker
7086*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.ones_like(output_cpu)
7087*da0073e9SAndroid Build Coastguard Worker            mps_grad = cpu_grad.to('mps')
7088*da0073e9SAndroid Build Coastguard Worker
7089*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_cpu, output_mps)
7090*da0073e9SAndroid Build Coastguard Worker
7091*da0073e9SAndroid Build Coastguard Worker            if requires_grad:
7092*da0073e9SAndroid Build Coastguard Worker                output_cpu.backward(gradient=cpu_grad)
7093*da0073e9SAndroid Build Coastguard Worker                output_mps.backward(gradient=mps_grad)
7094*da0073e9SAndroid Build Coastguard Worker
7095*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input_cpu.grad, input_mps.grad)
7096*da0073e9SAndroid Build Coastguard Worker
7097*da0073e9SAndroid Build Coastguard Worker        for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
7098*da0073e9SAndroid Build Coastguard Worker            helper(shape, inplace=False, requires_grad=False)
7099*da0073e9SAndroid Build Coastguard Worker            helper(shape, inplace=True, requires_grad=False)
7100*da0073e9SAndroid Build Coastguard Worker            helper(shape, inplace=False, requires_grad=True)
7101*da0073e9SAndroid Build Coastguard Worker            helper(shape, inplace=True, requires_grad=True)
7102*da0073e9SAndroid Build Coastguard Worker
7103*da0073e9SAndroid Build Coastguard Worker    def test_transpose_2D(self):
7104*da0073e9SAndroid Build Coastguard Worker        values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
7105*da0073e9SAndroid Build Coastguard Worker        values1 = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
7106*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values, device='cpu')
7107*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.tensor(values, device='mps')
7108*da0073e9SAndroid Build Coastguard Worker        mps_x1 = torch.tensor(values1, device='mps')
7109*da0073e9SAndroid Build Coastguard Worker
7110*da0073e9SAndroid Build Coastguard Worker        cpu_transpose = torch.transpose(cpu_x, 0, 1)
7111*da0073e9SAndroid Build Coastguard Worker        mps_transpose = torch.transpose(mps_x, 0, 1)
7112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_transpose, mps_transpose.to('cpu'))
7113*da0073e9SAndroid Build Coastguard Worker
7114*da0073e9SAndroid Build Coastguard Worker    def test_transpose_3D(self):
7115*da0073e9SAndroid Build Coastguard Worker        values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
7116*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values, device='cpu')
7117*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.tensor(values, device='mps')
7118*da0073e9SAndroid Build Coastguard Worker
7119*da0073e9SAndroid Build Coastguard Worker        cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
7120*da0073e9SAndroid Build Coastguard Worker        mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
7121*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_transpose1, mps_transpose1)
7122*da0073e9SAndroid Build Coastguard Worker
7123*da0073e9SAndroid Build Coastguard Worker        cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
7124*da0073e9SAndroid Build Coastguard Worker        mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
7125*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_transpose2, mps_transpose2)
7126*da0073e9SAndroid Build Coastguard Worker
7127*da0073e9SAndroid Build Coastguard Worker        cpu_transpose3 = torch.transpose(cpu_x, 1, 2)
7128*da0073e9SAndroid Build Coastguard Worker        mps_transpose3 = torch.transpose(mps_x, 1, 2).to('cpu')
7129*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_transpose3, mps_transpose3)
7130*da0073e9SAndroid Build Coastguard Worker
7131*da0073e9SAndroid Build Coastguard Worker
7132*da0073e9SAndroid Build Coastguard Worker    def test_transpose_4D(self):
7133*da0073e9SAndroid Build Coastguard Worker        values = [[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]],
7134*da0073e9SAndroid Build Coastguard Worker                  [[[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]]]
7135*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(values, device='cpu')
7136*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.tensor(values, device='mps')
7137*da0073e9SAndroid Build Coastguard Worker
7138*da0073e9SAndroid Build Coastguard Worker        cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
7139*da0073e9SAndroid Build Coastguard Worker        mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
7140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_transpose1, mps_transpose1)
7141*da0073e9SAndroid Build Coastguard Worker
7142*da0073e9SAndroid Build Coastguard Worker        cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
7143*da0073e9SAndroid Build Coastguard Worker        mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
7144*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_transpose2, mps_transpose2)
7145*da0073e9SAndroid Build Coastguard Worker
7146*da0073e9SAndroid Build Coastguard Worker        cpu_transpose3 = torch.transpose(cpu_x, 0, 3)
7147*da0073e9SAndroid Build Coastguard Worker        mps_transpose3 = torch.transpose(mps_x, 0, 3).to('cpu')
7148*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_transpose3, mps_transpose3)
7149*da0073e9SAndroid Build Coastguard Worker
7150*da0073e9SAndroid Build Coastguard Worker        cpu_transpose4 = torch.transpose(cpu_x, 3, 1)
7151*da0073e9SAndroid Build Coastguard Worker        mps_transpose4 = torch.transpose(mps_x, 3, 1).to('cpu')
7152*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_transpose4, mps_transpose4)
7153*da0073e9SAndroid Build Coastguard Worker
7154*da0073e9SAndroid Build Coastguard Worker        cpu_transpose5 = torch.transpose(cpu_x, 3, 2)
7155*da0073e9SAndroid Build Coastguard Worker        mps_transpose5 = torch.transpose(mps_x, 3, 2).to('cpu')
7156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_transpose5, mps_transpose5)
7157*da0073e9SAndroid Build Coastguard Worker
7158*da0073e9SAndroid Build Coastguard Worker        cpu_transpose6 = torch.transpose(cpu_x, 1, 2)
7159*da0073e9SAndroid Build Coastguard Worker        mps_transpose6 = torch.transpose(mps_x, 1, 2).to('cpu')
7160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_transpose6, mps_transpose6)
7161*da0073e9SAndroid Build Coastguard Worker
7162*da0073e9SAndroid Build Coastguard Worker    # Test sign
7163*da0073e9SAndroid Build Coastguard Worker    def test_sign(self):
7164*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
7165*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7166*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7167*da0073e9SAndroid Build Coastguard Worker
7168*da0073e9SAndroid Build Coastguard Worker            sign_result = torch.sign(x)
7169*da0073e9SAndroid Build Coastguard Worker            sign_result_cpu = torch.sign(cpu_x)
7170*da0073e9SAndroid Build Coastguard Worker
7171*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.ones_like(sign_result_cpu)
7172*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
7173*da0073e9SAndroid Build Coastguard Worker
7174*da0073e9SAndroid Build Coastguard Worker            sign_result.backward(gradient=grad)
7175*da0073e9SAndroid Build Coastguard Worker            sign_result_cpu.backward(gradient=cpu_grad)
7176*da0073e9SAndroid Build Coastguard Worker
7177*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sign_result, sign_result_cpu)
7178*da0073e9SAndroid Build Coastguard Worker
7179*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
7180*da0073e9SAndroid Build Coastguard Worker
7181*da0073e9SAndroid Build Coastguard Worker    def test_signbit(self):
7182*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dtype):
7183*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu').to(dtype)
7184*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.clone().to('mps')
7185*da0073e9SAndroid Build Coastguard Worker
7186*da0073e9SAndroid Build Coastguard Worker            signbit_result = torch.signbit(x)
7187*da0073e9SAndroid Build Coastguard Worker            signbit_result_cpu = torch.signbit(cpu_x)
7188*da0073e9SAndroid Build Coastguard Worker
7189*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(signbit_result, signbit_result_cpu)
7190*da0073e9SAndroid Build Coastguard Worker
7191*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), torch.int)
7192*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), torch.float)
7193*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), torch.int64)
7194*da0073e9SAndroid Build Coastguard Worker
7195*da0073e9SAndroid Build Coastguard Worker    # Test neg
7196*da0073e9SAndroid Build Coastguard Worker    def test_neg(self):
7197*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
7198*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7199*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7200*da0073e9SAndroid Build Coastguard Worker
7201*da0073e9SAndroid Build Coastguard Worker            neg_result = torch.neg(x)
7202*da0073e9SAndroid Build Coastguard Worker            neg_result_cpu = torch.neg(cpu_x)
7203*da0073e9SAndroid Build Coastguard Worker
7204*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.ones_like(neg_result_cpu)
7205*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
7206*da0073e9SAndroid Build Coastguard Worker
7207*da0073e9SAndroid Build Coastguard Worker            neg_result.backward(gradient=grad)
7208*da0073e9SAndroid Build Coastguard Worker            neg_result_cpu.backward(gradient=cpu_grad)
7209*da0073e9SAndroid Build Coastguard Worker
7210*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(neg_result, neg_result_cpu)
7211*da0073e9SAndroid Build Coastguard Worker
7212*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
7213*da0073e9SAndroid Build Coastguard Worker
7214*da0073e9SAndroid Build Coastguard Worker    def test_neg_strided_input(self):
7215*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/98074#issuecomment-1496088337
7216*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(18.0, device='mps').reshape(2, 3, 3)
7217*da0073e9SAndroid Build Coastguard Worker        y = x.permute(1, 0, 2)[..., 1]
7218*da0073e9SAndroid Build Coastguard Worker        z = y + y.neg()
7219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.abs().max().item(), 0.0)
7220*da0073e9SAndroid Build Coastguard Worker
7221*da0073e9SAndroid Build Coastguard Worker    # Test index add
7222*da0073e9SAndroid Build Coastguard Worker    def test_index_add(self):
7223*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dim, index, source_shape, alpha, x_dtype=torch.float32, idx_dtype=torch.int32):
7224*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=x_dtype, requires_grad=False)
7225*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
7226*da0073e9SAndroid Build Coastguard Worker
7227*da0073e9SAndroid Build Coastguard Worker            cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7228*da0073e9SAndroid Build Coastguard Worker            idx = cpu_idx.detach().clone().to('mps')
7229*da0073e9SAndroid Build Coastguard Worker
7230*da0073e9SAndroid Build Coastguard Worker            cpu_source = torch.randn(source_shape, device='cpu', dtype=x_dtype, requires_grad=False)
7231*da0073e9SAndroid Build Coastguard Worker            source = cpu_source.detach().clone().to('mps')
7232*da0073e9SAndroid Build Coastguard Worker
7233*da0073e9SAndroid Build Coastguard Worker            idx_result = torch.index_add(x, dim=dim, index=idx, source=source, alpha=alpha)
7234*da0073e9SAndroid Build Coastguard Worker            idx_result_cpu = torch.index_add(cpu_x, dim=dim, index=cpu_idx, source=cpu_source, alpha=alpha)
7235*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_result, idx_result_cpu)
7236*da0073e9SAndroid Build Coastguard Worker
7237*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 0, [0, 1, 0], (3, 8, 4, 5), 5)
7238*da0073e9SAndroid Build Coastguard Worker        helper((8, 8, 4, 5), 0, [7], (1, 8, 4, 5), 6.0)
7239*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 1, [0, 3, 7], (2, 3, 4, 5), 5)
7240*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 2, [3, 0], (2, 8, 2, 5), 3.0)
7241*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 3, [2, 3, 0], (2, 8, 4, 3), 4)
7242*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 3), -1, [1, 2], (2, 3, 2), 6.0)
7243*da0073e9SAndroid Build Coastguard Worker        # test result dim=1
7244*da0073e9SAndroid Build Coastguard Worker        helper((2,), 0, [1], (1,), 6.0)
7245*da0073e9SAndroid Build Coastguard Worker        helper(2, 0, 1, 1, 6)
7246*da0073e9SAndroid Build Coastguard Worker        # test float16
7247*da0073e9SAndroid Build Coastguard Worker        helper((2,), 0, [1], (1,), 6.0, x_dtype=torch.float16)
7248*da0073e9SAndroid Build Coastguard Worker
7249*da0073e9SAndroid Build Coastguard Worker    def test_index_64bit(self):
7250*da0073e9SAndroid Build Coastguard Worker        """ Test that index operations work for 4Gb+ tensors """
7251*da0073e9SAndroid Build Coastguard Worker        if product_version < 14.0:
7252*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("Sonoma is needed for large tensors, see https://github.com/pytorch/pytorch/issues/84039")
7253*da0073e9SAndroid Build Coastguard Worker        # Cleanup memory
7254*da0073e9SAndroid Build Coastguard Worker        gc.collect()
7255*da0073e9SAndroid Build Coastguard Worker        torch.mps.empty_cache()
7256*da0073e9SAndroid Build Coastguard Worker        # Check that index operations work for 4+GB tensors
7257*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(16000, 67120, device="mps")
7258*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(x.element_size() * x.numel(), 2**32)
7259*da0073e9SAndroid Build Coastguard Worker        idx = torch.arange(0, 2, device="mps")
7260*da0073e9SAndroid Build Coastguard Worker        x_sampled = x[:, idx]
7261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:, 0], x_sampled[:, 0])
7262*da0073e9SAndroid Build Coastguard Worker        # Reclaim memory after running the tests
7263*da0073e9SAndroid Build Coastguard Worker        del x
7264*da0073e9SAndroid Build Coastguard Worker        gc.collect()
7265*da0073e9SAndroid Build Coastguard Worker        torch.mps.empty_cache()
7266*da0073e9SAndroid Build Coastguard Worker
7267*da0073e9SAndroid Build Coastguard Worker    def test_mm_large(self):
7268*da0073e9SAndroid Build Coastguard Worker        """ Test that MM works for matrices with index larger than 32K """
7269*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10, 1, device="mps")
7270*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 32769, device="mps")
7271*da0073e9SAndroid Build Coastguard Worker        # This used to crash with:
7272*da0073e9SAndroid Build Coastguard Worker        # error: subRange.start (24576) is not less than length of dimension[0] (16384)
7273*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
7274*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
7275*da0073e9SAndroid Build Coastguard Worker
7276*da0073e9SAndroid Build Coastguard Worker        def compare_mm(m, n, k, dtype=torch.float):
7277*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(m, n, device="mps", dtype=dtype)
7278*da0073e9SAndroid Build Coastguard Worker            y = torch.rand(n, k, device="mps", dtype=dtype)
7279*da0073e9SAndroid Build Coastguard Worker            z = torch.mm(x, y).cpu()
7280*da0073e9SAndroid Build Coastguard Worker            z_cpu = torch.mm(x.cpu(), y.cpu())
7281*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z, z_cpu)
7282*da0073e9SAndroid Build Coastguard Worker
7283*da0073e9SAndroid Build Coastguard Worker        # Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal
7284*da0073e9SAndroid Build Coastguard Worker        compare_mm(1024, 1, 32769)
7285*da0073e9SAndroid Build Coastguard Worker        # one more time, but with dimensions inverted
7286*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984
7287*da0073e9SAndroid Build Coastguard Worker        compare_mm(32769, 1, 1025)
7288*da0073e9SAndroid Build Coastguard Worker
7289*da0073e9SAndroid Build Coastguard Worker        if product_version >= 14.0:
7290*da0073e9SAndroid Build Coastguard Worker            # Test bfloat16 mm
7291*da0073e9SAndroid Build Coastguard Worker            compare_mm(1024, 1, 32769, torch.bfloat16)
7292*da0073e9SAndroid Build Coastguard Worker
7293*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test")
7294*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(product_version < 14.0, "Can't allocate 4Gb tensor on MacOS 13")
7295*da0073e9SAndroid Build Coastguard Worker    def test_copy_large(self):
7296*da0073e9SAndroid Build Coastguard Worker        """ Test that copy of 4Gb+ tensors works """
7297*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((2**30 + 11,), dtype=torch.float32)
7298*da0073e9SAndroid Build Coastguard Worker        y = x.to(device="mps")
7299*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(y == torch.tensor(1.0, device="mps")))
7300*da0073e9SAndroid Build Coastguard Worker        del y
7301*da0073e9SAndroid Build Coastguard Worker        del x
7302*da0073e9SAndroid Build Coastguard Worker
7303*da0073e9SAndroid Build Coastguard Worker    # Test flip
7304*da0073e9SAndroid Build Coastguard Worker    def test_flip(self):
7305*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dims):
7306*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7307*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
7308*da0073e9SAndroid Build Coastguard Worker
7309*da0073e9SAndroid Build Coastguard Worker            flip_result = torch.flip(x, dims=dims)
7310*da0073e9SAndroid Build Coastguard Worker            flip_result_cpu = torch.flip(cpu_x, dims=dims)
7311*da0073e9SAndroid Build Coastguard Worker
7312*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(flip_result, flip_result_cpu)
7313*da0073e9SAndroid Build Coastguard Worker
7314*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), [0])
7315*da0073e9SAndroid Build Coastguard Worker        helper((8, 8, 4, 5), [0, 1])
7316*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), (0, 1, 2, 3))
7317*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 3), (-1,))
7318*da0073e9SAndroid Build Coastguard Worker        # empty dims
7319*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), [])
7320*da0073e9SAndroid Build Coastguard Worker        # input.numel() == 1
7321*da0073e9SAndroid Build Coastguard Worker        helper((1,), (0,))
7322*da0073e9SAndroid Build Coastguard Worker        # input.numel() == 0
7323*da0073e9SAndroid Build Coastguard Worker        helper((0,), (0,))
7324*da0073e9SAndroid Build Coastguard Worker        # none of dims that needs to be flipped
7325*da0073e9SAndroid Build Coastguard Worker        helper((1, 3), [0])
7326*da0073e9SAndroid Build Coastguard Worker
7327*da0073e9SAndroid Build Coastguard Worker    # Test index select
7328*da0073e9SAndroid Build Coastguard Worker    def test_index_select(self):
7329*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dim, index, idx_dtype=torch.int32):
7330*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7331*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
7332*da0073e9SAndroid Build Coastguard Worker
7333*da0073e9SAndroid Build Coastguard Worker            cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7334*da0073e9SAndroid Build Coastguard Worker            idx = cpu_idx.detach().clone().to('mps')
7335*da0073e9SAndroid Build Coastguard Worker
7336*da0073e9SAndroid Build Coastguard Worker            idx_result = torch.index_select(x, dim=dim, index=idx)
7337*da0073e9SAndroid Build Coastguard Worker            idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
7338*da0073e9SAndroid Build Coastguard Worker
7339*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_result, idx_result_cpu)
7340*da0073e9SAndroid Build Coastguard Worker
7341*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 0, [1])
7342*da0073e9SAndroid Build Coastguard Worker        helper((8, 8, 4, 5), 0, [0, 3, 2, 7, 6])
7343*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 1, [0, 3, 2, 7, 6])
7344*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 2, [3, 0, 1])
7345*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 3, [2, 3, 0])
7346*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 3), -1, [1, 2])
7347*da0073e9SAndroid Build Coastguard Worker        helper((), 0, [0])
7348*da0073e9SAndroid Build Coastguard Worker        helper((5), 0, [])
7349*da0073e9SAndroid Build Coastguard Worker
7350*da0073e9SAndroid Build Coastguard Worker    def test_index_select_scalar(self):
7351*da0073e9SAndroid Build Coastguard Worker        def helper(value, dim, index, idx_dtype=torch.int32):
7352*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.tensor(value, device='cpu', dtype=torch.float, requires_grad=False)
7353*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
7354*da0073e9SAndroid Build Coastguard Worker
7355*da0073e9SAndroid Build Coastguard Worker            cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7356*da0073e9SAndroid Build Coastguard Worker            idx = cpu_idx.detach().clone().to('mps')
7357*da0073e9SAndroid Build Coastguard Worker
7358*da0073e9SAndroid Build Coastguard Worker            idx_result = torch.index_select(x, dim=dim, index=idx)
7359*da0073e9SAndroid Build Coastguard Worker            idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
7360*da0073e9SAndroid Build Coastguard Worker
7361*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(idx_result, idx_result_cpu)
7362*da0073e9SAndroid Build Coastguard Worker
7363*da0073e9SAndroid Build Coastguard Worker        helper(22, 0, [0])
7364*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
7365*da0073e9SAndroid Build Coastguard Worker            helper(22, 0, [])
7366*da0073e9SAndroid Build Coastguard Worker
7367*da0073e9SAndroid Build Coastguard Worker    def test_embedding_dense_backward(self):
7368*da0073e9SAndroid Build Coastguard Worker        def helper(n, d, m, idx):
7369*da0073e9SAndroid Build Coastguard Worker            embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
7370*da0073e9SAndroid Build Coastguard Worker            emedding_weight = embeddingMPS.weight.detach().cpu()
7371*da0073e9SAndroid Build Coastguard Worker            W_MPS = torch.randn((m, d), requires_grad=True, device='mps')
7372*da0073e9SAndroid Build Coastguard Worker            idx_MPS = torch.tensor(idx, device='mps')
7373*da0073e9SAndroid Build Coastguard Worker            a_MPS = embeddingMPS.weight.clone() @ W_MPS.t()  # weight must be cloned for this to be differentiable
7374*da0073e9SAndroid Build Coastguard Worker            a_MPS.retain_grad()
7375*da0073e9SAndroid Build Coastguard Worker            b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t()  # modifies weight in-place
7376*da0073e9SAndroid Build Coastguard Worker            b_MPS.retain_grad()
7377*da0073e9SAndroid Build Coastguard Worker            out_MPS = (a_MPS.unsqueeze(0) + b_MPS)
7378*da0073e9SAndroid Build Coastguard Worker            loss_MPS = out_MPS.sigmoid().prod()
7379*da0073e9SAndroid Build Coastguard Worker            loss_MPS.backward()
7380*da0073e9SAndroid Build Coastguard Worker
7381*da0073e9SAndroid Build Coastguard Worker            embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=emedding_weight)
7382*da0073e9SAndroid Build Coastguard Worker            W_CPU = W_MPS.to('cpu')
7383*da0073e9SAndroid Build Coastguard Worker            idx_CPU = torch.tensor(idx)
7384*da0073e9SAndroid Build Coastguard Worker            a_CPU = embeddingCPU.weight.clone() @ W_CPU.t()  # weight must be cloned for this to be differentiable
7385*da0073e9SAndroid Build Coastguard Worker            a_CPU.retain_grad()
7386*da0073e9SAndroid Build Coastguard Worker            b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t()  # modifies weight in-place
7387*da0073e9SAndroid Build Coastguard Worker            b_CPU.retain_grad()
7388*da0073e9SAndroid Build Coastguard Worker            out_CPU = (a_CPU.unsqueeze(0) + b_CPU)
7389*da0073e9SAndroid Build Coastguard Worker            loss_CPU = out_CPU.sigmoid().prod()
7390*da0073e9SAndroid Build Coastguard Worker            loss_CPU.backward()
7391*da0073e9SAndroid Build Coastguard Worker
7392*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b_CPU.grad, b_MPS.grad)
7393*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a_CPU.grad, a_MPS.grad)
7394*da0073e9SAndroid Build Coastguard Worker
7395*da0073e9SAndroid Build Coastguard Worker        helper(3, 5, 7, [0, 1, 2])
7396*da0073e9SAndroid Build Coastguard Worker        helper(3, 6, 7, [0, 1, 2])  # verify if changes in shape would cause cached graph lookup problems
7397*da0073e9SAndroid Build Coastguard Worker        helper(3, 5, 7, 2)  # test scalar index
7398*da0073e9SAndroid Build Coastguard Worker
7399*da0073e9SAndroid Build Coastguard Worker    # Test pytorch gather
7400*da0073e9SAndroid Build Coastguard Worker    def test_gather(self):
7401*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dim, idx_shape, idx_dtype=torch.int64):
7402*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7403*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7404*da0073e9SAndroid Build Coastguard Worker
7405*da0073e9SAndroid Build Coastguard Worker            # Indices should be taken from range of axis along which gathering is done
7406*da0073e9SAndroid Build Coastguard Worker            idx_np = np.random.randint(0, shape[dim], idx_shape)
7407*da0073e9SAndroid Build Coastguard Worker
7408*da0073e9SAndroid Build Coastguard Worker            cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7409*da0073e9SAndroid Build Coastguard Worker            idx = cpu_idx.detach().clone().to('mps')
7410*da0073e9SAndroid Build Coastguard Worker
7411*da0073e9SAndroid Build Coastguard Worker            gather_result = torch.gather(x, dim=dim, index=idx)
7412*da0073e9SAndroid Build Coastguard Worker            gather_result_cpu = torch.gather(cpu_x, dim=dim, index=cpu_idx)
7413*da0073e9SAndroid Build Coastguard Worker
7414*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(idx_shape, device='cpu', dtype=torch.float)
7415*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
7416*da0073e9SAndroid Build Coastguard Worker            gather_result.backward(gradient=grad)
7417*da0073e9SAndroid Build Coastguard Worker            gather_result_cpu.backward(gradient=cpu_grad)
7418*da0073e9SAndroid Build Coastguard Worker
7419*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(gather_result, gather_result_cpu)
7420*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_x.grad, x.grad)
7421*da0073e9SAndroid Build Coastguard Worker
7422*da0073e9SAndroid Build Coastguard Worker        helper((6, 3, 3), 0, (3, 3, 3))
7423*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 3, 3), 0, (10, 3, 3, 3))
7424*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 0, (10, 8, 4, 5))
7425*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 0, (10, 6, 3, 2))
7426*da0073e9SAndroid Build Coastguard Worker        helper((8, 8, 4, 5), 0, (6, 8, 4, 5))
7427*da0073e9SAndroid Build Coastguard Worker        helper((8, 8, 4, 5), 0, (6, 7, 2, 3))
7428*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 1, (2, 5, 3, 4))
7429*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 2, (1, 8, 10, 3))
7430*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 3, (2, 5, 3, 12))
7431*da0073e9SAndroid Build Coastguard Worker
7432*da0073e9SAndroid Build Coastguard Worker    # Test pytorch gather
7433*da0073e9SAndroid Build Coastguard Worker    def test_gather_scalar(self):
7434*da0073e9SAndroid Build Coastguard Worker        idx_dtype = torch.int64
7435*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
7436*da0073e9SAndroid Build Coastguard Worker        x = cpu_x.detach().clone().to('mps').requires_grad_()
7437*da0073e9SAndroid Build Coastguard Worker
7438*da0073e9SAndroid Build Coastguard Worker        idx_np = [0]
7439*da0073e9SAndroid Build Coastguard Worker
7440*da0073e9SAndroid Build Coastguard Worker        cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7441*da0073e9SAndroid Build Coastguard Worker        idx = cpu_idx.detach().clone().to('mps')
7442*da0073e9SAndroid Build Coastguard Worker
7443*da0073e9SAndroid Build Coastguard Worker        gather_result = torch.gather(x, dim=0, index=idx)
7444*da0073e9SAndroid Build Coastguard Worker        gather_result_cpu = torch.gather(cpu_x, dim=0, index=cpu_idx)
7445*da0073e9SAndroid Build Coastguard Worker
7446*da0073e9SAndroid Build Coastguard Worker        cpu_grad = torch.randn([1], device='cpu', dtype=torch.float)
7447*da0073e9SAndroid Build Coastguard Worker        grad = cpu_grad.to('mps')
7448*da0073e9SAndroid Build Coastguard Worker        gather_result.backward(gradient=grad)
7449*da0073e9SAndroid Build Coastguard Worker        gather_result_cpu.backward(gradient=cpu_grad)
7450*da0073e9SAndroid Build Coastguard Worker
7451*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gather_result, gather_result_cpu)
7452*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x.grad, x.grad)
7453*da0073e9SAndroid Build Coastguard Worker
7454*da0073e9SAndroid Build Coastguard Worker    # Test pytorch scatter_add and scatter
7455*da0073e9SAndroid Build Coastguard Worker    def test_scatter_add(self):
7456*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True):
7457*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7458*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7459*da0073e9SAndroid Build Coastguard Worker
7460*da0073e9SAndroid Build Coastguard Worker            cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
7461*da0073e9SAndroid Build Coastguard Worker            src = cpu_src.detach().clone().to('mps').requires_grad_()
7462*da0073e9SAndroid Build Coastguard Worker
7463*da0073e9SAndroid Build Coastguard Worker            # Indices should be taken from range of axis along which gathering is done
7464*da0073e9SAndroid Build Coastguard Worker            idx_np = None
7465*da0073e9SAndroid Build Coastguard Worker            if (do_add):
7466*da0073e9SAndroid Build Coastguard Worker                idx_np = np.random.randint(0, shape[dim], idx_shape)
7467*da0073e9SAndroid Build Coastguard Worker            else:
7468*da0073e9SAndroid Build Coastguard Worker                idx_np = np.array([[0, 1, 2],
7469*da0073e9SAndroid Build Coastguard Worker                                   [1, 2, 3],
7470*da0073e9SAndroid Build Coastguard Worker                                   [2, 3, 4],
7471*da0073e9SAndroid Build Coastguard Worker                                   [3, 4, 5],
7472*da0073e9SAndroid Build Coastguard Worker                                   [4, 5, 6]])
7473*da0073e9SAndroid Build Coastguard Worker
7474*da0073e9SAndroid Build Coastguard Worker            cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7475*da0073e9SAndroid Build Coastguard Worker            idx = cpu_idx.detach().clone().to('mps')
7476*da0073e9SAndroid Build Coastguard Worker
7477*da0073e9SAndroid Build Coastguard Worker            scatter_result = None
7478*da0073e9SAndroid Build Coastguard Worker            scatter_result_cpu = None
7479*da0073e9SAndroid Build Coastguard Worker
7480*da0073e9SAndroid Build Coastguard Worker            if (do_add):
7481*da0073e9SAndroid Build Coastguard Worker                scatter_result = torch.scatter_add(x, dim=dim, index=idx, src=src)
7482*da0073e9SAndroid Build Coastguard Worker                scatter_result_cpu = torch.scatter_add(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
7483*da0073e9SAndroid Build Coastguard Worker            else:
7484*da0073e9SAndroid Build Coastguard Worker                scatter_result = torch.scatter(x, dim=dim, index=idx, src=src)
7485*da0073e9SAndroid Build Coastguard Worker                scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
7486*da0073e9SAndroid Build Coastguard Worker
7487*da0073e9SAndroid Build Coastguard Worker            cpu_grad = None
7488*da0073e9SAndroid Build Coastguard Worker            grad = None
7489*da0073e9SAndroid Build Coastguard Worker
7490*da0073e9SAndroid Build Coastguard Worker            if (idx_shape == src_shape):
7491*da0073e9SAndroid Build Coastguard Worker                cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
7492*da0073e9SAndroid Build Coastguard Worker                grad = cpu_grad.to('mps')
7493*da0073e9SAndroid Build Coastguard Worker                scatter_result.backward(gradient=grad)
7494*da0073e9SAndroid Build Coastguard Worker                scatter_result_cpu.backward(gradient=cpu_grad)
7495*da0073e9SAndroid Build Coastguard Worker
7496*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scatter_result, scatter_result_cpu)
7497*da0073e9SAndroid Build Coastguard Worker            if (idx_shape == src_shape):
7498*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cpu_x.grad, x.grad)
7499*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cpu_src.grad, src.grad)
7500*da0073e9SAndroid Build Coastguard Worker
7501*da0073e9SAndroid Build Coastguard Worker        helper((2, 3), 0, (5, 3), (5, 3))
7502*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
7503*da0073e9SAndroid Build Coastguard Worker        helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
7504*da0073e9SAndroid Build Coastguard Worker        helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2))
7505*da0073e9SAndroid Build Coastguard Worker        helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2))
7506*da0073e9SAndroid Build Coastguard Worker        helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5))
7507*da0073e9SAndroid Build Coastguard Worker
7508*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5))
7509*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2))
7510*da0073e9SAndroid Build Coastguard Worker        helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3))
7511*da0073e9SAndroid Build Coastguard Worker        helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3))
7512*da0073e9SAndroid Build Coastguard Worker
7513*da0073e9SAndroid Build Coastguard Worker        helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8))
7514*da0073e9SAndroid Build Coastguard Worker        helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6))
7515*da0073e9SAndroid Build Coastguard Worker        helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6))
7516*da0073e9SAndroid Build Coastguard Worker
7517*da0073e9SAndroid Build Coastguard Worker        # Test scatter src
7518*da0073e9SAndroid Build Coastguard Worker        helper((8, 3), 0, (5, 3), (5, 3), do_add=False)
7519*da0073e9SAndroid Build Coastguard Worker        helper((10, 3), 0, (5, 3), (5, 8), do_add=False)
7520*da0073e9SAndroid Build Coastguard Worker
7521*da0073e9SAndroid Build Coastguard Worker    # Test pytorch scatter_add and scatter for scalar input
7522*da0073e9SAndroid Build Coastguard Worker    def test_scatter_add_scalar(self):
7523*da0073e9SAndroid Build Coastguard Worker        def helper(idx_dtype=torch.int64, do_add=True):
7524*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True)
7525*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7526*da0073e9SAndroid Build Coastguard Worker
7527*da0073e9SAndroid Build Coastguard Worker            cpu_src = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
7528*da0073e9SAndroid Build Coastguard Worker            src = cpu_src.detach().clone().to('mps').requires_grad_()
7529*da0073e9SAndroid Build Coastguard Worker
7530*da0073e9SAndroid Build Coastguard Worker            # Indices should be taken from range of axis along which gathering is done
7531*da0073e9SAndroid Build Coastguard Worker            idx_np = [0]
7532*da0073e9SAndroid Build Coastguard Worker
7533*da0073e9SAndroid Build Coastguard Worker            cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7534*da0073e9SAndroid Build Coastguard Worker            idx = cpu_idx.detach().clone().to('mps')
7535*da0073e9SAndroid Build Coastguard Worker
7536*da0073e9SAndroid Build Coastguard Worker            scatter_result = None
7537*da0073e9SAndroid Build Coastguard Worker            scatter_result_cpu = None
7538*da0073e9SAndroid Build Coastguard Worker
7539*da0073e9SAndroid Build Coastguard Worker            if (do_add):
7540*da0073e9SAndroid Build Coastguard Worker                scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src)
7541*da0073e9SAndroid Build Coastguard Worker                scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
7542*da0073e9SAndroid Build Coastguard Worker            else:
7543*da0073e9SAndroid Build Coastguard Worker                scatter_result = torch.scatter(x, dim=0, index=idx, src=src)
7544*da0073e9SAndroid Build Coastguard Worker                scatter_result_cpu = torch.scatter(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
7545*da0073e9SAndroid Build Coastguard Worker
7546*da0073e9SAndroid Build Coastguard Worker            cpu_grad = None
7547*da0073e9SAndroid Build Coastguard Worker            grad = None
7548*da0073e9SAndroid Build Coastguard Worker
7549*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.tensor(1.2, device='cpu', dtype=torch.float)
7550*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
7551*da0073e9SAndroid Build Coastguard Worker            scatter_result.backward(gradient=grad)
7552*da0073e9SAndroid Build Coastguard Worker            scatter_result_cpu.backward(gradient=cpu_grad)
7553*da0073e9SAndroid Build Coastguard Worker
7554*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scatter_result, scatter_result_cpu)
7555*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_x.grad, x.grad)
7556*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_src.grad, src.grad)
7557*da0073e9SAndroid Build Coastguard Worker
7558*da0073e9SAndroid Build Coastguard Worker        helper()
7559*da0073e9SAndroid Build Coastguard Worker        helper(do_add=False)
7560*da0073e9SAndroid Build Coastguard Worker
7561*da0073e9SAndroid Build Coastguard Worker    # Test pytorch scatter_reduce
7562*da0073e9SAndroid Build Coastguard Worker    def test_scatter_reduce(self):
7563*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"):
7564*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7565*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7566*da0073e9SAndroid Build Coastguard Worker
7567*da0073e9SAndroid Build Coastguard Worker            cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
7568*da0073e9SAndroid Build Coastguard Worker            src = cpu_src.detach().clone().to('mps').requires_grad_()
7569*da0073e9SAndroid Build Coastguard Worker
7570*da0073e9SAndroid Build Coastguard Worker            # Indices should be taken from range of axis along which gathering is done
7571*da0073e9SAndroid Build Coastguard Worker            idx_np = np.random.randint(0, shape[dim], idx_shape)
7572*da0073e9SAndroid Build Coastguard Worker
7573*da0073e9SAndroid Build Coastguard Worker            cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7574*da0073e9SAndroid Build Coastguard Worker            idx = cpu_idx.detach().clone().to('mps')
7575*da0073e9SAndroid Build Coastguard Worker
7576*da0073e9SAndroid Build Coastguard Worker            scatter_result = torch.scatter(x, dim=dim, index=idx, src=src, reduce=reduce_str)
7577*da0073e9SAndroid Build Coastguard Worker            scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src, reduce=reduce_str)
7578*da0073e9SAndroid Build Coastguard Worker
7579*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(scatter_result, scatter_result_cpu)
7580*da0073e9SAndroid Build Coastguard Worker
7581*da0073e9SAndroid Build Coastguard Worker        # for reduce in ["sum", "prod", "amax", "amin"]:
7582*da0073e9SAndroid Build Coastguard Worker        for reduce_type in ["add", "multiply"]:
7583*da0073e9SAndroid Build Coastguard Worker            helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type)
7584*da0073e9SAndroid Build Coastguard Worker            helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
7585*da0073e9SAndroid Build Coastguard Worker            helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
7586*da0073e9SAndroid Build Coastguard Worker            helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
7587*da0073e9SAndroid Build Coastguard Worker            helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
7588*da0073e9SAndroid Build Coastguard Worker            helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type)
7589*da0073e9SAndroid Build Coastguard Worker
7590*da0073e9SAndroid Build Coastguard Worker            helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type)
7591*da0073e9SAndroid Build Coastguard Worker            helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type)
7592*da0073e9SAndroid Build Coastguard Worker            helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type)
7593*da0073e9SAndroid Build Coastguard Worker            helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type)
7594*da0073e9SAndroid Build Coastguard Worker
7595*da0073e9SAndroid Build Coastguard Worker            helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type)
7596*da0073e9SAndroid Build Coastguard Worker            helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type)
7597*da0073e9SAndroid Build Coastguard Worker            helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type)
7598*da0073e9SAndroid Build Coastguard Worker
7599*da0073e9SAndroid Build Coastguard Worker    def test_is_nonzero(self):
7600*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps')))
7601*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.is_nonzero(torch.tensor([1.5]).to('mps')))
7602*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.is_nonzero(torch.tensor([False]).to('mps')))
7603*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.is_nonzero(torch.tensor([3]).to('mps')))
7604*da0073e9SAndroid Build Coastguard Worker
7605*da0073e9SAndroid Build Coastguard Worker    # Test triu
7606*da0073e9SAndroid Build Coastguard Worker    def test_triu(self):
7607*da0073e9SAndroid Build Coastguard Worker        def helper(shape, diag=0):
7608*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7609*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7610*da0073e9SAndroid Build Coastguard Worker
7611*da0073e9SAndroid Build Coastguard Worker            triu_result = torch.triu(x, diag)
7612*da0073e9SAndroid Build Coastguard Worker            triu_result_cpu = torch.triu(cpu_x, diag)
7613*da0073e9SAndroid Build Coastguard Worker
7614*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(triu_result_cpu.shape)
7615*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
7616*da0073e9SAndroid Build Coastguard Worker
7617*da0073e9SAndroid Build Coastguard Worker            triu_result.backward(gradient=grad)
7618*da0073e9SAndroid Build Coastguard Worker            triu_result_cpu.backward(gradient=cpu_grad)
7619*da0073e9SAndroid Build Coastguard Worker
7620*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(triu_result, triu_result_cpu)
7621*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
7622*da0073e9SAndroid Build Coastguard Worker
7623*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
7624*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=1)
7625*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=2)
7626*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=3)
7627*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=-1)
7628*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=-2)
7629*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=-3)
7630*da0073e9SAndroid Build Coastguard Worker
7631*da0073e9SAndroid Build Coastguard Worker    # Test inverse
7632*da0073e9SAndroid Build Coastguard Worker    def test_inverse(self):
7633*da0073e9SAndroid Build Coastguard Worker        def helper(n):
7634*da0073e9SAndroid Build Coastguard Worker            cpu_input = torch.randn(n, n, device='cpu')
7635*da0073e9SAndroid Build Coastguard Worker            mps_input = cpu_input.to('mps')
7636*da0073e9SAndroid Build Coastguard Worker
7637*da0073e9SAndroid Build Coastguard Worker            cpu_result = torch.linalg.inv(cpu_input)
7638*da0073e9SAndroid Build Coastguard Worker            mps_result = torch.linalg.inv(mps_input)
7639*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_result, mps_result)
7640*da0073e9SAndroid Build Coastguard Worker
7641*da0073e9SAndroid Build Coastguard Worker        helper(2)
7642*da0073e9SAndroid Build Coastguard Worker        helper(6)
7643*da0073e9SAndroid Build Coastguard Worker        helper(3)
7644*da0073e9SAndroid Build Coastguard Worker        helper(8)
7645*da0073e9SAndroid Build Coastguard Worker
7646*da0073e9SAndroid Build Coastguard Worker    # Test tril
7647*da0073e9SAndroid Build Coastguard Worker    def test_tril(self):
7648*da0073e9SAndroid Build Coastguard Worker        def helper(shape, diag=0):
7649*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7650*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7651*da0073e9SAndroid Build Coastguard Worker
7652*da0073e9SAndroid Build Coastguard Worker            tril_result = torch.tril(x, diag)
7653*da0073e9SAndroid Build Coastguard Worker            tril_result_cpu = torch.tril(cpu_x, diag)
7654*da0073e9SAndroid Build Coastguard Worker
7655*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(tril_result_cpu.shape)
7656*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
7657*da0073e9SAndroid Build Coastguard Worker
7658*da0073e9SAndroid Build Coastguard Worker            tril_result.backward(gradient=grad)
7659*da0073e9SAndroid Build Coastguard Worker            tril_result_cpu.backward(gradient=cpu_grad)
7660*da0073e9SAndroid Build Coastguard Worker
7661*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tril_result, tril_result_cpu)
7662*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
7663*da0073e9SAndroid Build Coastguard Worker
7664*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5))
7665*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=1)
7666*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=2)
7667*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=3)
7668*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=-1)
7669*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=-2)
7670*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), diag=-3)
7671*da0073e9SAndroid Build Coastguard Worker
7672*da0073e9SAndroid Build Coastguard Worker    # test eye
7673*da0073e9SAndroid Build Coastguard Worker    def test_eye(self):
7674*da0073e9SAndroid Build Coastguard Worker        def helper(n, m, dtype):
7675*da0073e9SAndroid Build Coastguard Worker            cpu_result = None
7676*da0073e9SAndroid Build Coastguard Worker            result = None
7677*da0073e9SAndroid Build Coastguard Worker
7678*da0073e9SAndroid Build Coastguard Worker            if (n == m):
7679*da0073e9SAndroid Build Coastguard Worker                cpu_result = torch.eye(n, dtype=dtype, device='cpu')
7680*da0073e9SAndroid Build Coastguard Worker                result = torch.eye(n, dtype=dtype, device='mps')
7681*da0073e9SAndroid Build Coastguard Worker            else:
7682*da0073e9SAndroid Build Coastguard Worker                cpu_result = torch.eye(n, m, device='cpu')
7683*da0073e9SAndroid Build Coastguard Worker                result = torch.eye(n, m, device='mps')
7684*da0073e9SAndroid Build Coastguard Worker
7685*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, cpu_result)
7686*da0073e9SAndroid Build Coastguard Worker
7687*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.bool, torch.float16, torch.float32, torch.uint8, torch.int16, torch.int32, torch.int64]:
7688*da0073e9SAndroid Build Coastguard Worker            helper(2, 2, dtype)
7689*da0073e9SAndroid Build Coastguard Worker            helper(2, 3, dtype)
7690*da0073e9SAndroid Build Coastguard Worker            helper(0, 2, dtype)
7691*da0073e9SAndroid Build Coastguard Worker            helper(0, 0, dtype)
7692*da0073e9SAndroid Build Coastguard Worker            helper(3, 8, dtype)
7693*da0073e9SAndroid Build Coastguard Worker            helper(8, 3, dtype)
7694*da0073e9SAndroid Build Coastguard Worker
7695*da0073e9SAndroid Build Coastguard Worker    # Test diag
7696*da0073e9SAndroid Build Coastguard Worker    def test_diag(self):
7697*da0073e9SAndroid Build Coastguard Worker        def helper(shape, diag=0):
7698*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7699*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7700*da0073e9SAndroid Build Coastguard Worker
7701*da0073e9SAndroid Build Coastguard Worker            diag_result = torch.diag(x, diag)
7702*da0073e9SAndroid Build Coastguard Worker            diag_result_cpu = torch.diag(cpu_x, diag)
7703*da0073e9SAndroid Build Coastguard Worker
7704*da0073e9SAndroid Build Coastguard Worker            # cpu_grad = torch.randn(diag_result_cpu.shape)
7705*da0073e9SAndroid Build Coastguard Worker            # grad = cpu_grad.to('mps')
7706*da0073e9SAndroid Build Coastguard Worker
7707*da0073e9SAndroid Build Coastguard Worker            # diag_result.backward(gradient=grad)
7708*da0073e9SAndroid Build Coastguard Worker            # diag_result_cpu.backward(gradient=cpu_grad)
7709*da0073e9SAndroid Build Coastguard Worker
7710*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(diag_result, diag_result_cpu)
7711*da0073e9SAndroid Build Coastguard Worker            # self.assertEqual(x.grad, cpu_x.grad)
7712*da0073e9SAndroid Build Coastguard Worker
7713*da0073e9SAndroid Build Coastguard Worker        for shape in [(5, 5), (5, 6), (6, 5), (5,), (6,)]:
7714*da0073e9SAndroid Build Coastguard Worker            for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]:
7715*da0073e9SAndroid Build Coastguard Worker                helper(shape, diag=diag)
7716*da0073e9SAndroid Build Coastguard Worker
7717*da0073e9SAndroid Build Coastguard Worker    # Test linspace
7718*da0073e9SAndroid Build Coastguard Worker    def test_linspace(self):
7719*da0073e9SAndroid Build Coastguard Worker        def helper(start, end, steps, dtype=torch.float32):
7720*da0073e9SAndroid Build Coastguard Worker            cpu_result = torch.tensor(np.linspace(start, end, steps), dtype=dtype)
7721*da0073e9SAndroid Build Coastguard Worker            result = torch.linspace(start, end, steps, dtype=dtype, device='mps')
7722*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_result, result)
7723*da0073e9SAndroid Build Coastguard Worker
7724*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float32, torch.int32, torch.uint8, torch.int64]:
7725*da0073e9SAndroid Build Coastguard Worker            helper(2, 5, 10, dtype)
7726*da0073e9SAndroid Build Coastguard Worker            helper(2, 2, 10, dtype)
7727*da0073e9SAndroid Build Coastguard Worker            helper(5, 2, 10, dtype)
7728*da0073e9SAndroid Build Coastguard Worker            helper(2, 2, 0, dtype)
7729*da0073e9SAndroid Build Coastguard Worker
7730*da0073e9SAndroid Build Coastguard Worker    # Test argange
7731*da0073e9SAndroid Build Coastguard Worker    def test_arange(self):
7732*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np.arange(10), torch.arange(10, device='mps'))
7733*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np.arange(7, 1, -1), torch.arange(7, 1, -1, device='mps'))
7734*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps'))
7735*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps'))
7736*da0073e9SAndroid Build Coastguard Worker
7737*da0073e9SAndroid Build Coastguard Worker    def test_arange_empty(self):
7738*da0073e9SAndroid Build Coastguard Worker        out_mps = torch.tensor([], device="mps")
7739*da0073e9SAndroid Build Coastguard Worker        out_cpu = torch.tensor([], device="cpu")
7740*da0073e9SAndroid Build Coastguard Worker
7741*da0073e9SAndroid Build Coastguard Worker        y_mps = torch.arange(0, 0, 1, out=out_mps)
7742*da0073e9SAndroid Build Coastguard Worker        y_cpu = torch.arange(0, 0, 1, out=out_cpu)
7743*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_mps, y_cpu)
7744*da0073e9SAndroid Build Coastguard Worker
7745*da0073e9SAndroid Build Coastguard Worker    # Test rgange
7746*da0073e9SAndroid Build Coastguard Worker    def test_range(self):
7747*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps'))
7748*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps'))
7749*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps'))
7750*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps'))
7751*da0073e9SAndroid Build Coastguard Worker
7752*da0073e9SAndroid Build Coastguard Worker    # Test softmax
7753*da0073e9SAndroid Build Coastguard Worker    def test_softmax(self):
7754*da0073e9SAndroid Build Coastguard Worker        def helper(shape, dim, channels_last=False):
7755*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7756*da0073e9SAndroid Build Coastguard Worker            if (channels_last):
7757*da0073e9SAndroid Build Coastguard Worker                cpu_x = cpu_x.to(memory_format=torch.channels_last)
7758*da0073e9SAndroid Build Coastguard Worker                cpu_x.retain_grad()
7759*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7760*da0073e9SAndroid Build Coastguard Worker
7761*da0073e9SAndroid Build Coastguard Worker            softmax_result = torch.nn.functional.softmax(x, dim=dim)
7762*da0073e9SAndroid Build Coastguard Worker            softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
7763*da0073e9SAndroid Build Coastguard Worker
7764*da0073e9SAndroid Build Coastguard Worker            # Currently NOT testing backward for channels last backward
7765*da0073e9SAndroid Build Coastguard Worker            cpu_grad = None
7766*da0073e9SAndroid Build Coastguard Worker            grad = None
7767*da0073e9SAndroid Build Coastguard Worker
7768*da0073e9SAndroid Build Coastguard Worker            if (not channels_last):
7769*da0073e9SAndroid Build Coastguard Worker                cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
7770*da0073e9SAndroid Build Coastguard Worker                grad = cpu_grad.to('mps')
7771*da0073e9SAndroid Build Coastguard Worker
7772*da0073e9SAndroid Build Coastguard Worker                softmax_result.backward(gradient=grad)
7773*da0073e9SAndroid Build Coastguard Worker                softmax_result_cpu.backward(gradient=cpu_grad)
7774*da0073e9SAndroid Build Coastguard Worker
7775*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(softmax_result, softmax_result_cpu)
7776*da0073e9SAndroid Build Coastguard Worker            if (not channels_last):
7777*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.grad, cpu_x.grad)
7778*da0073e9SAndroid Build Coastguard Worker
7779*da0073e9SAndroid Build Coastguard Worker        def helper2(dim):
7780*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.tensor(1.23, device='cpu', dtype=torch.float, requires_grad=True)
7781*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7782*da0073e9SAndroid Build Coastguard Worker
7783*da0073e9SAndroid Build Coastguard Worker            softmax_result = torch.nn.functional.softmax(x, dim=dim)
7784*da0073e9SAndroid Build Coastguard Worker            softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
7785*da0073e9SAndroid Build Coastguard Worker
7786*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.tensor(2.34, device='cpu', dtype=torch.float)
7787*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
7788*da0073e9SAndroid Build Coastguard Worker
7789*da0073e9SAndroid Build Coastguard Worker            softmax_result.backward(gradient=grad)
7790*da0073e9SAndroid Build Coastguard Worker            softmax_result_cpu.backward(gradient=cpu_grad)
7791*da0073e9SAndroid Build Coastguard Worker
7792*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(softmax_result, softmax_result_cpu)
7793*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
7794*da0073e9SAndroid Build Coastguard Worker
7795*da0073e9SAndroid Build Coastguard Worker        helper2(0)
7796*da0073e9SAndroid Build Coastguard Worker
7797*da0073e9SAndroid Build Coastguard Worker        for channels_last in [False]:
7798*da0073e9SAndroid Build Coastguard Worker            for shape in [(2, 4, 8, 5), (3, 4, 6, 7, 2)]:
7799*da0073e9SAndroid Build Coastguard Worker                if (len(shape) != 4 and channels_last):
7800*da0073e9SAndroid Build Coastguard Worker                    continue
7801*da0073e9SAndroid Build Coastguard Worker                for dim in [0, 1, 2, 3, -1, -2, -3]:
7802*da0073e9SAndroid Build Coastguard Worker                    helper(shape, dim, channels_last)
7803*da0073e9SAndroid Build Coastguard Worker
7804*da0073e9SAndroid Build Coastguard Worker    def test_nan_to_num(self):
7805*da0073e9SAndroid Build Coastguard Worker        inputCPU = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
7806*da0073e9SAndroid Build Coastguard Worker        inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
7807*da0073e9SAndroid Build Coastguard Worker        outputCPU = torch.nan_to_num(inputCPU, nan=2.0, posinf=1.0, neginf=-1.0)
7808*da0073e9SAndroid Build Coastguard Worker        outputMPS = torch.nan_to_num(inputMPS, nan=2.0, posinf=1.0, neginf=-1.0)
7809*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outputMPS, outputCPU)
7810*da0073e9SAndroid Build Coastguard Worker
7811*da0073e9SAndroid Build Coastguard Worker    # Test where
7812*da0073e9SAndroid Build Coastguard Worker    def test_where(self):
7813*da0073e9SAndroid Build Coastguard Worker        def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float):
7814*da0073e9SAndroid Build Coastguard Worker
7815*da0073e9SAndroid Build Coastguard Worker            cpu_cond = torch.randint(2, shape, device='cpu', dtype=cond_dtype, requires_grad=False)
7816*da0073e9SAndroid Build Coastguard Worker            cond = cpu_cond.detach().clone().to('mps')
7817*da0073e9SAndroid Build Coastguard Worker
7818*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(x_shape, device='cpu', dtype=x_dtype, requires_grad=True)
7819*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps').requires_grad_()
7820*da0073e9SAndroid Build Coastguard Worker
7821*da0073e9SAndroid Build Coastguard Worker            cpu_y = torch.randn(y_shape, device='cpu', dtype=x_dtype, requires_grad=True)
7822*da0073e9SAndroid Build Coastguard Worker            y = cpu_y.detach().clone().to('mps').requires_grad_()
7823*da0073e9SAndroid Build Coastguard Worker
7824*da0073e9SAndroid Build Coastguard Worker            cpu_out = torch.where(cpu_cond, cpu_x, cpu_y)
7825*da0073e9SAndroid Build Coastguard Worker            out = torch.where(cond, x, y)
7826*da0073e9SAndroid Build Coastguard Worker
7827*da0073e9SAndroid Build Coastguard Worker            cpu_grad = torch.randn(cpu_out.shape)
7828*da0073e9SAndroid Build Coastguard Worker            grad = cpu_grad.to('mps')
7829*da0073e9SAndroid Build Coastguard Worker
7830*da0073e9SAndroid Build Coastguard Worker            cpu_out.backward(gradient=cpu_grad)
7831*da0073e9SAndroid Build Coastguard Worker            out.backward(gradient=grad)
7832*da0073e9SAndroid Build Coastguard Worker
7833*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, cpu_out)
7834*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, cpu_x.grad)
7835*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.grad, cpu_y.grad)
7836*da0073e9SAndroid Build Coastguard Worker
7837*da0073e9SAndroid Build Coastguard Worker        for shape in ([(0, 3), [], (2, 3), (9,)]):
7838*da0073e9SAndroid Build Coastguard Worker            helper(shape, shape, shape)
7839*da0073e9SAndroid Build Coastguard Worker
7840*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 1), (2, 3, 4), (2, 1, 4))
7841*da0073e9SAndroid Build Coastguard Worker        helper((2, 1, 1), (2, 3, 4), (1, 3, 4))
7842*da0073e9SAndroid Build Coastguard Worker        helper((1, 1, 1), (1, 1, 4), (2, 3, 1))
7843*da0073e9SAndroid Build Coastguard Worker        helper([], (1, 1, 4), (2, 3, 1))
7844*da0073e9SAndroid Build Coastguard Worker        helper([], (2, 3, 4), [])
7845*da0073e9SAndroid Build Coastguard Worker        helper((5, 2, 3), (2, 3), (2, 3))
7846*da0073e9SAndroid Build Coastguard Worker        helper((2, 3), (5, 2, 3), (2, 3))
7847*da0073e9SAndroid Build Coastguard Worker        helper((2, 3), (2, 3), (5, 2, 3))
7848*da0073e9SAndroid Build Coastguard Worker        helper((2, 3), (5, 2, 3), (6, 5, 2, 3))
7849*da0073e9SAndroid Build Coastguard Worker        # Test that output is correctly resizes
7850*da0073e9SAndroid Build Coastguard Worker        # TODO: Remove me when out OpInfo testing is enabled on MPS
7851*da0073e9SAndroid Build Coastguard Worker        output = torch.tensor(0.0, device="mps")
7852*da0073e9SAndroid Build Coastguard Worker        cond = torch.randint(2, (3, 3), dtype=torch.bool, device="mps")
7853*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(3, 3, device="mps")
7854*da0073e9SAndroid Build Coastguard Worker        other = torch.rand(3, 3, device="mps")
7855*da0073e9SAndroid Build Coastguard Worker        out = torch.where(cond, inp, other, out=output)
7856*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id(out), id(output))
7857*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.shape, (3, 3))
7858*da0073e9SAndroid Build Coastguard Worker
7859*da0073e9SAndroid Build Coastguard Worker    # Test normal
7860*da0073e9SAndroid Build Coastguard Worker    def test_normal(self):
7861*da0073e9SAndroid Build Coastguard Worker        def helper(shape, mean=0.0, std=1.0):
7862*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.normal(mean, std, shape, device='mps')
7863*da0073e9SAndroid Build Coastguard Worker
7864*da0073e9SAndroid Build Coastguard Worker            mean_array = np.ones(shape)
7865*da0073e9SAndroid Build Coastguard Worker            mean_array *= mean
7866*da0073e9SAndroid Build Coastguard Worker            cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False)
7867*da0073e9SAndroid Build Coastguard Worker            mean_tensor = cpu_mean_tensor.detach().clone().to('mps')
7868*da0073e9SAndroid Build Coastguard Worker
7869*da0073e9SAndroid Build Coastguard Worker            std_array = np.ones(shape)
7870*da0073e9SAndroid Build Coastguard Worker            std_array *= std
7871*da0073e9SAndroid Build Coastguard Worker            cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False)
7872*da0073e9SAndroid Build Coastguard Worker            std_tensor = cpu_std_tensor.detach().clone().to('mps')
7873*da0073e9SAndroid Build Coastguard Worker
7874*da0073e9SAndroid Build Coastguard Worker            # test out
7875*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.zeros(shape, device='mps')
7876*da0073e9SAndroid Build Coastguard Worker            torch.normal(mean_tensor, std, out=mps_out)
7877*da0073e9SAndroid Build Coastguard Worker
7878*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.zeros(shape, device='mps')
7879*da0073e9SAndroid Build Coastguard Worker            torch.normal(mean, std_tensor, out=mps_out)
7880*da0073e9SAndroid Build Coastguard Worker
7881*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.zeros(shape, device='mps')
7882*da0073e9SAndroid Build Coastguard Worker            torch.normal(mean_tensor, std_tensor, out=mps_out)
7883*da0073e9SAndroid Build Coastguard Worker
7884*da0073e9SAndroid Build Coastguard Worker            # test without out
7885*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.normal(mean_tensor, std)
7886*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mps_out.size(), mean_tensor.size())
7887*da0073e9SAndroid Build Coastguard Worker
7888*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.normal(mean, std_tensor)
7889*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mps_out.size(), std_tensor.size())
7890*da0073e9SAndroid Build Coastguard Worker
7891*da0073e9SAndroid Build Coastguard Worker            inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size())
7892*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.normal(mean_tensor, std_tensor)
7893*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mps_out.size(), inferred_shape)
7894*da0073e9SAndroid Build Coastguard Worker
7895*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4, 5, 6))
7896*da0073e9SAndroid Build Coastguard Worker        helper((100, 100), 2.5, 1.2)
7897*da0073e9SAndroid Build Coastguard Worker
7898*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli(self):
7899*da0073e9SAndroid Build Coastguard Worker        shape = (10, 10)
7900*da0073e9SAndroid Build Coastguard Worker        all_ones = torch.ones(shape, device='mps')
7901*da0073e9SAndroid Build Coastguard Worker        all_zeros = torch.zeros(shape, device='mps')
7902*da0073e9SAndroid Build Coastguard Worker
7903*da0073e9SAndroid Build Coastguard Worker        prob_tensor = all_ones * 0.5
7904*da0073e9SAndroid Build Coastguard Worker        # probability of drawing "1" is 0.5
7905*da0073e9SAndroid Build Coastguard Worker        mps_out = torch.bernoulli(prob_tensor)
7906*da0073e9SAndroid Build Coastguard Worker        # We can't check reliably the mean and std.
7907*da0073e9SAndroid Build Coastguard Worker        # Just make sure we don't return constant values
7908*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
7909*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
7910*da0073e9SAndroid Build Coastguard Worker
7911*da0073e9SAndroid Build Coastguard Worker        # probability of drawing "1" is 0
7912*da0073e9SAndroid Build Coastguard Worker        mps_out = torch.bernoulli(all_zeros)
7913*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mps_out, all_zeros)
7914*da0073e9SAndroid Build Coastguard Worker
7915*da0073e9SAndroid Build Coastguard Worker        # probability of drawing "1" is 1
7916*da0073e9SAndroid Build Coastguard Worker        mps_out = torch.bernoulli(all_ones)
7917*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mps_out, all_ones)
7918*da0073e9SAndroid Build Coastguard Worker
7919*da0073e9SAndroid Build Coastguard Worker        # Check it works for different dtypes
7920*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float16, torch.int8, torch.int16, torch.int32, torch.int64]:
7921*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.zeros(shape, device='mps', dtype=dtype).bernoulli(0.5)
7922*da0073e9SAndroid Build Coastguard Worker            # Check that output is not all zeros or ones
7923*da0073e9SAndroid Build Coastguard Worker            if product_version > 13.0:
7924*da0073e9SAndroid Build Coastguard Worker                uniq = mps_out.unique()
7925*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype))
7926*da0073e9SAndroid Build Coastguard Worker            else:
7927*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mps_out.min().item(), 0.)
7928*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mps_out.max().item(), 1.)
7929*da0073e9SAndroid Build Coastguard Worker
7930*da0073e9SAndroid Build Coastguard Worker    def test_mps_generator(self):
7931*da0073e9SAndroid Build Coastguard Worker        # explicit manual seeding by creating an MPS Generator
7932*da0073e9SAndroid Build Coastguard Worker        g_mps = torch.Generator(device='mps')
7933*da0073e9SAndroid Build Coastguard Worker        g_mps.manual_seed(999)
7934*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.randn(5, device='mps', generator=g_mps)
7935*da0073e9SAndroid Build Coastguard Worker        g_mps.manual_seed(999)
7936*da0073e9SAndroid Build Coastguard Worker        # generate random numbers with offset `0`
7937*da0073e9SAndroid Build Coastguard Worker        mps_y = torch.randn(5, device='mps', generator=g_mps)
7938*da0073e9SAndroid Build Coastguard Worker        # seed values were the same, so the random tensor contents should match
7939*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mps_x, mps_y)
7940*da0073e9SAndroid Build Coastguard Worker        # save generator's state (offset = 1) to restore it later
7941*da0073e9SAndroid Build Coastguard Worker        g_state = g_mps.get_state()
7942*da0073e9SAndroid Build Coastguard Worker
7943*da0073e9SAndroid Build Coastguard Worker        # generate random numbers with offset `1`
7944*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.randn(5, device='mps', generator=g_mps)
7945*da0073e9SAndroid Build Coastguard Worker        # in this case, the random results must differ from the last generated random results
7946*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(mps_x, mps_y)
7947*da0073e9SAndroid Build Coastguard Worker
7948*da0073e9SAndroid Build Coastguard Worker        # mps_x was produced by g_state, we use it as our reference mps_y.
7949*da0073e9SAndroid Build Coastguard Worker        mps_y = mps_x
7950*da0073e9SAndroid Build Coastguard Worker
7951*da0073e9SAndroid Build Coastguard Worker        # restore the previously saved state, and the results should match again
7952*da0073e9SAndroid Build Coastguard Worker        g_mps.set_state(g_state)
7953*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.randn(5, device='mps', generator=g_mps)
7954*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mps_x, mps_y)
7955*da0073e9SAndroid Build Coastguard Worker
7956*da0073e9SAndroid Build Coastguard Worker    @serialTest()
7957*da0073e9SAndroid Build Coastguard Worker    def test_default_mps_generator(self):
7958*da0073e9SAndroid Build Coastguard Worker        # manual seeding on the "default" MPS generator using
7959*da0073e9SAndroid Build Coastguard Worker        # the global torch.manual_seed()
7960*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(230)
7961*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.randn(5, device='mps')
7962*da0073e9SAndroid Build Coastguard Worker        # manual seeding using torch.mps.manual_seed()
7963*da0073e9SAndroid Build Coastguard Worker        # which should set the "default" MPS generator
7964*da0073e9SAndroid Build Coastguard Worker        # like the global torch.manual_seed()
7965*da0073e9SAndroid Build Coastguard Worker        torch.mps.manual_seed(230)
7966*da0073e9SAndroid Build Coastguard Worker        # generate random numbers with offset `0`
7967*da0073e9SAndroid Build Coastguard Worker        mps_y = torch.randn(5, device='mps')
7968*da0073e9SAndroid Build Coastguard Worker        # seed values were the same, so the random tensor contents should match
7969*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mps_x, mps_y)
7970*da0073e9SAndroid Build Coastguard Worker
7971*da0073e9SAndroid Build Coastguard Worker        # save the default generator's state (offset = 1) to restore it later
7972*da0073e9SAndroid Build Coastguard Worker        g_state = torch.mps.get_rng_state()
7973*da0073e9SAndroid Build Coastguard Worker
7974*da0073e9SAndroid Build Coastguard Worker        # generate random numbers with offset `1`
7975*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.randn(5, device='mps')
7976*da0073e9SAndroid Build Coastguard Worker        # in this case, the random results must differ from the last generated random results
7977*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(mps_x, mps_y)
7978*da0073e9SAndroid Build Coastguard Worker        # since we called randn twice after seeding, the offset should be 2
7979*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mps._get_default_mps_generator().get_offset(), 2)
7980*da0073e9SAndroid Build Coastguard Worker
7981*da0073e9SAndroid Build Coastguard Worker        # mps_x was produced by g_state, we use it as our reference mps_y.
7982*da0073e9SAndroid Build Coastguard Worker        mps_y = mps_x
7983*da0073e9SAndroid Build Coastguard Worker
7984*da0073e9SAndroid Build Coastguard Worker        # restore the previously saved state to the "default" MPS generator, and the results should match again
7985*da0073e9SAndroid Build Coastguard Worker        torch.mps.set_rng_state(g_state)
7986*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.randn(5, device='mps')
7987*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mps_x, mps_y)
7988*da0073e9SAndroid Build Coastguard Worker
7989*da0073e9SAndroid Build Coastguard Worker    def test_device_synchronize(self):
7990*da0073e9SAndroid Build Coastguard Worker        # just running some ops each followed by a synchronize to wait for
7991*da0073e9SAndroid Build Coastguard Worker        # MPS stream to finish running each of them
7992*da0073e9SAndroid Build Coastguard Worker        net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
7993*da0073e9SAndroid Build Coastguard Worker            .to(device='mps', dtype=torch.float)
7994*da0073e9SAndroid Build Coastguard Worker
7995*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
7996*da0073e9SAndroid Build Coastguard Worker        torch.mps.synchronize()
7997*da0073e9SAndroid Build Coastguard Worker        x = net1(x)
7998*da0073e9SAndroid Build Coastguard Worker        torch.mps.synchronize()
7999*da0073e9SAndroid Build Coastguard Worker        x.backward(torch.randn_like(x))
8000*da0073e9SAndroid Build Coastguard Worker        torch.mps.synchronize()
8001*da0073e9SAndroid Build Coastguard Worker
8002*da0073e9SAndroid Build Coastguard Worker    @serialTest()
8003*da0073e9SAndroid Build Coastguard Worker    def test_mps_allocator_module(self):
8004*da0073e9SAndroid Build Coastguard Worker        # first garbage collect and empty the cached blocks
8005*da0073e9SAndroid Build Coastguard Worker        gc.collect()
8006*da0073e9SAndroid Build Coastguard Worker        torch.mps.empty_cache()
8007*da0073e9SAndroid Build Coastguard Worker        # measure memory allocations from MPSAllocator
8008*da0073e9SAndroid Build Coastguard Worker        current_alloc_before = torch.mps.current_allocated_memory()
8009*da0073e9SAndroid Build Coastguard Worker        # after garbage collection and emptying the cache the
8010*da0073e9SAndroid Build Coastguard Worker        # current_allocated_memory must be zero
8011*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(current_alloc_before, 0)
8012*da0073e9SAndroid Build Coastguard Worker        # measure total memory allocations from Metal driver
8013*da0073e9SAndroid Build Coastguard Worker        driver_alloc_before = torch.mps.driver_allocated_memory()
8014*da0073e9SAndroid Build Coastguard Worker        # allocate a new 8 MB tensor to force allocation of a new Metal Heap
8015*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1024 * 1024 * 8, device="mps")
8016*da0073e9SAndroid Build Coastguard Worker        # get memory allocations after allocating tensor x
8017*da0073e9SAndroid Build Coastguard Worker        current_alloc_after = torch.mps.current_allocated_memory()
8018*da0073e9SAndroid Build Coastguard Worker        driver_alloc_after = torch.mps.driver_allocated_memory()
8019*da0073e9SAndroid Build Coastguard Worker        # current and driver memory allocations must have
8020*da0073e9SAndroid Build Coastguard Worker        # grown at this point
8021*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(current_alloc_after, current_alloc_before)
8022*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(driver_alloc_after, driver_alloc_before)
8023*da0073e9SAndroid Build Coastguard Worker
8024*da0073e9SAndroid Build Coastguard Worker    def test_mps_allocator_stats(self):
8025*da0073e9SAndroid Build Coastguard Worker        max_memory = torch.mps.recommended_max_memory()
8026*da0073e9SAndroid Build Coastguard Worker        print(f"Recommended Max Memory : {max_memory/ 1024 ** 3} GB")
8027*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(max_memory, 0)
8028*da0073e9SAndroid Build Coastguard Worker
8029*da0073e9SAndroid Build Coastguard Worker    # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool,
8030*da0073e9SAndroid Build Coastguard Worker    # press record, then run this python test, and press stop. Next expand
8031*da0073e9SAndroid Build Coastguard Worker    # the os_signposts->PyTorchMPS and check if events or intervals are logged
8032*da0073e9SAndroid Build Coastguard Worker    # like this example:
8033*da0073e9SAndroid Build Coastguard Worker    # "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)"
8034*da0073e9SAndroid Build Coastguard Worker    def test_mps_profiler_module(self):
8035*da0073e9SAndroid Build Coastguard Worker        with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p:
8036*da0073e9SAndroid Build Coastguard Worker            # just running some ops to capture the OS Signposts traces for profiling
8037*da0073e9SAndroid Build Coastguard Worker            net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
8038*da0073e9SAndroid Build Coastguard Worker                .to(device='mps', dtype=torch.float)
8039*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8040*da0073e9SAndroid Build Coastguard Worker            x = net1(x)
8041*da0073e9SAndroid Build Coastguard Worker
8042*da0073e9SAndroid Build Coastguard Worker        torch.mps.profiler.start(mode="interval", wait_until_completed=True)
8043*da0073e9SAndroid Build Coastguard Worker        # just running some ops to capture the OS Signposts traces for profiling
8044*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8045*da0073e9SAndroid Build Coastguard Worker        x = net1(x)
8046*da0073e9SAndroid Build Coastguard Worker        torch.mps.profiler.stop()
8047*da0073e9SAndroid Build Coastguard Worker
8048*da0073e9SAndroid Build Coastguard Worker    def test_mps_event_module(self):
8049*da0073e9SAndroid Build Coastguard Worker        startEvent = torch.mps.Event(enable_timing=True)
8050*da0073e9SAndroid Build Coastguard Worker        startEvent.record()
8051*da0073e9SAndroid Build Coastguard Worker        net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
8052*da0073e9SAndroid Build Coastguard Worker            .to(device='mps', dtype=torch.float)
8053*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8054*da0073e9SAndroid Build Coastguard Worker        x = net1(x)
8055*da0073e9SAndroid Build Coastguard Worker        endEvent = torch.mps.Event(enable_timing=True)
8056*da0073e9SAndroid Build Coastguard Worker        endEvent.record()
8057*da0073e9SAndroid Build Coastguard Worker        elapsedTime = startEvent.elapsed_time(endEvent)
8058*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(elapsedTime, 0.0)
8059*da0073e9SAndroid Build Coastguard Worker
8060*da0073e9SAndroid Build Coastguard Worker    def test_jit_save_load(self):
8061*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.Module()
8062*da0073e9SAndroid Build Coastguard Worker        m.x = torch.rand(3, 3, device='mps')
8063*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
8064*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(torch.jit.script(m), buffer)
8065*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
8066*da0073e9SAndroid Build Coastguard Worker        n = torch.jit.load(buffer)
8067*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n.x, m.x)
8068*da0073e9SAndroid Build Coastguard Worker
8069*da0073e9SAndroid Build Coastguard Worker    # Test random_, random_.to and random_.from
8070*da0073e9SAndroid Build Coastguard Worker    def test_random(self):
8071*da0073e9SAndroid Build Coastguard Worker        def helper(shape, low, high, dtype=torch.int32):
8072*da0073e9SAndroid Build Coastguard Worker
8073*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps')
8074*da0073e9SAndroid Build Coastguard Worker
8075*da0073e9SAndroid Build Coastguard Worker            # We can't check reliably the mean and std.
8076*da0073e9SAndroid Build Coastguard Worker            # Just make sure we don't return constant values
8077*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(mps_out.float().mean().item(), 0.)
8078*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(mps_out.float().std().item(), 0.)
8079*da0073e9SAndroid Build Coastguard Worker
8080*da0073e9SAndroid Build Coastguard Worker        helper([100, 100], 0, 10)
8081*da0073e9SAndroid Build Coastguard Worker        helper([100, 100], 23, 89)
8082*da0073e9SAndroid Build Coastguard Worker        helper([100, 100], 23, 89, dtype=torch.float32)
8083*da0073e9SAndroid Build Coastguard Worker        helper([100, 100], 23, 89, dtype=torch.int64)
8084*da0073e9SAndroid Build Coastguard Worker        helper([100, 100], 0, 2, dtype=torch.bool)
8085*da0073e9SAndroid Build Coastguard Worker
8086*da0073e9SAndroid Build Coastguard Worker        # Test random_
8087*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.bool, torch.int8, torch.uint8, torch.int32, torch.float16, torch.float32]:
8088*da0073e9SAndroid Build Coastguard Worker            x = torch.empty(10, 10, dtype=dtype, device='mps')
8089*da0073e9SAndroid Build Coastguard Worker            x.random_()
8090*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(x.max().item(), 0)
8091*da0073e9SAndroid Build Coastguard Worker
8092*da0073e9SAndroid Build Coastguard Worker    # Test exponential
8093*da0073e9SAndroid Build Coastguard Worker    def test_exponential(self):
8094*da0073e9SAndroid Build Coastguard Worker        def helper(shape, lamda, dtype=torch.float32):
8095*da0073e9SAndroid Build Coastguard Worker
8096*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.zeros(shape, device='mps', dtype=dtype)
8097*da0073e9SAndroid Build Coastguard Worker            mps_out.exponential_(lamda)
8098*da0073e9SAndroid Build Coastguard Worker
8099*da0073e9SAndroid Build Coastguard Worker            print(mps_out.to('cpu').float().mean(), 1 / lamda)
8100*da0073e9SAndroid Build Coastguard Worker            print(mps_out.to('cpu').float().std() ** 2, 1 / (lamda**2))
8101*da0073e9SAndroid Build Coastguard Worker
8102*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float32, torch.float16]:
8103*da0073e9SAndroid Build Coastguard Worker            helper([100, 100], 2, dtype)
8104*da0073e9SAndroid Build Coastguard Worker            helper([100, 100], 1, dtype)
8105*da0073e9SAndroid Build Coastguard Worker            helper([100, 100], 3, dtype)
8106*da0073e9SAndroid Build Coastguard Worker            helper([100, 100], 0.5, dtype)
8107*da0073e9SAndroid Build Coastguard Worker
8108*da0073e9SAndroid Build Coastguard Worker    def test_exponential_1(self):
8109*da0073e9SAndroid Build Coastguard Worker        rate = torch.randn(5, 5).abs().requires_grad_()
8110*da0073e9SAndroid Build Coastguard Worker        rate_1d = torch.randn(1).abs().requires_grad_()
8111*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(rate).sample().size(), (5, 5))
8112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
8113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
8114*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
8115*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
8116*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
8117*da0073e9SAndroid Build Coastguard Worker
8118*da0073e9SAndroid Build Coastguard Worker    # Test add
8119*da0073e9SAndroid Build Coastguard Worker    def test_add_sub(self):
8120*da0073e9SAndroid Build Coastguard Worker        def helper(shape, alpha, op_name, inplace):
8121*da0073e9SAndroid Build Coastguard Worker            if op_name == "add":
8122*da0073e9SAndroid Build Coastguard Worker                op = torch.Tensor.add_ if inplace else torch.add
8123*da0073e9SAndroid Build Coastguard Worker            elif op_name == "sub":
8124*da0073e9SAndroid Build Coastguard Worker                op = torch.Tensor.sub_ if inplace else torch.sub
8125*da0073e9SAndroid Build Coastguard Worker
8126*da0073e9SAndroid Build Coastguard Worker            for dtype in [torch.float16, torch.float32]:
8127*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
8128*da0073e9SAndroid Build Coastguard Worker                mps_x = cpu_x.detach().clone().to('mps')
8129*da0073e9SAndroid Build Coastguard Worker
8130*da0073e9SAndroid Build Coastguard Worker                cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
8131*da0073e9SAndroid Build Coastguard Worker                mps_y = cpu_y.detach().clone().to('mps')
8132*da0073e9SAndroid Build Coastguard Worker
8133*da0073e9SAndroid Build Coastguard Worker                cpu_out = op(cpu_x, cpu_y, alpha=alpha)
8134*da0073e9SAndroid Build Coastguard Worker                mps_out = op(mps_x, mps_y, alpha=alpha)
8135*da0073e9SAndroid Build Coastguard Worker                # fp16 isn't accurate when alpha is passed
8136*da0073e9SAndroid Build Coastguard Worker                # TODO: remove or fix 'tol' when we fix problems with fp16
8137*da0073e9SAndroid Build Coastguard Worker                tol = 2e-3 if dtype is torch.float16 else None
8138*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mps_out, cpu_out, rtol=tol, atol=tol)
8139*da0073e9SAndroid Build Coastguard Worker                if not (cpu_y.shape != () and inplace):  # in-place output cannot be broadcasted.
8140*da0073e9SAndroid Build Coastguard Worker                    # create a scalar tensor
8141*da0073e9SAndroid Build Coastguard Worker                    cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
8142*da0073e9SAndroid Build Coastguard Worker                    mps_s = cpu_s.detach().clone().to('mps')
8143*da0073e9SAndroid Build Coastguard Worker                    # primary tensor is scalar
8144*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(op(cpu_s, cpu_y), op(mps_s, mps_y))
8145*da0073e9SAndroid Build Coastguard Worker                # create a scalar tensor
8146*da0073e9SAndroid Build Coastguard Worker                cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
8147*da0073e9SAndroid Build Coastguard Worker                mps_s = cpu_s.detach().clone().to('mps')
8148*da0073e9SAndroid Build Coastguard Worker                # secondary tensor is scalar
8149*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(cpu_x, cpu_s), op(mps_x, mps_s), rtol=tol, atol=tol)
8150*da0073e9SAndroid Build Coastguard Worker
8151*da0073e9SAndroid Build Coastguard Worker
8152*da0073e9SAndroid Build Coastguard Worker        for op_name, inplace in product(["add", "sub"], [True, False]):
8153*da0073e9SAndroid Build Coastguard Worker            helper((), 0.0, op_name, inplace)
8154*da0073e9SAndroid Build Coastguard Worker            helper((2, 8, 4, 5), 0.0, op_name, inplace)
8155*da0073e9SAndroid Build Coastguard Worker            helper((2, 8, 4, 5), 0.1, op_name, inplace)
8156*da0073e9SAndroid Build Coastguard Worker            helper((2, 8, 4, 5), 1.0, op_name, inplace)
8157*da0073e9SAndroid Build Coastguard Worker            helper((2, 8, 3, 5), 0.1, op_name, inplace)
8158*da0073e9SAndroid Build Coastguard Worker            helper((2, 8, 3, 5), 0.2, op_name, inplace)
8159*da0073e9SAndroid Build Coastguard Worker
8160*da0073e9SAndroid Build Coastguard Worker    # Test add
8161*da0073e9SAndroid Build Coastguard Worker    def test_add_scalars(self):
8162*da0073e9SAndroid Build Coastguard Worker        def helper(alpha):
8163*da0073e9SAndroid Build Coastguard Worker            for dtype in [torch.float16, torch.float32]:
8164*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
8165*da0073e9SAndroid Build Coastguard Worker                x = cpu_x.detach().clone().to('mps')
8166*da0073e9SAndroid Build Coastguard Worker
8167*da0073e9SAndroid Build Coastguard Worker                cpu_y = torch.tensor(3.4, device='cpu', dtype=dtype, requires_grad=False)
8168*da0073e9SAndroid Build Coastguard Worker                y = cpu_y.detach().clone().to('mps')
8169*da0073e9SAndroid Build Coastguard Worker
8170*da0073e9SAndroid Build Coastguard Worker                cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha)
8171*da0073e9SAndroid Build Coastguard Worker                out = torch.add(x, y, alpha=alpha)
8172*da0073e9SAndroid Build Coastguard Worker                # fp16 isn't accurate when alpha is passed
8173*da0073e9SAndroid Build Coastguard Worker                tol = 1e-3 if dtype is torch.float16 else None
8174*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, cpu_out, rtol=tol, atol=tol)
8175*da0073e9SAndroid Build Coastguard Worker
8176*da0073e9SAndroid Build Coastguard Worker        helper(1.0)
8177*da0073e9SAndroid Build Coastguard Worker        helper(0.0)
8178*da0073e9SAndroid Build Coastguard Worker        helper(0.1)
8179*da0073e9SAndroid Build Coastguard Worker        helper(0.2)
8180*da0073e9SAndroid Build Coastguard Worker
8181*da0073e9SAndroid Build Coastguard Worker        # Test int32 tensor + int64 scalar add
8182*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/issues/79835#issuecomment-1164984534
8183*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(4, dtype=torch.int32, device='mps')
8184*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps'))
8185*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps')))
8186*da0073e9SAndroid Build Coastguard Worker
8187*da0073e9SAndroid Build Coastguard Worker    def test_types_binary_op(self):
8188*da0073e9SAndroid Build Coastguard Worker        # Float * Bool
8189*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([True, False, True, False, True], device="cpu")
8190*da0073e9SAndroid Build Coastguard Worker        mps_x = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([True, False, True, False, True], device="mps")
8191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x, mps_x)
8192*da0073e9SAndroid Build Coastguard Worker        # Float * Int64
8193*da0073e9SAndroid Build Coastguard Worker        cpu_y = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([1, 0, 1, 0, 1], device="cpu")
8194*da0073e9SAndroid Build Coastguard Worker        mps_y = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([1, 0, 1, 0, 1], device="mps")
8195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_y, mps_y)
8196*da0073e9SAndroid Build Coastguard Worker
8197*da0073e9SAndroid Build Coastguard Worker    def test_unary_ops(self):
8198*da0073e9SAndroid Build Coastguard Worker        def helper(shape, op):
8199*da0073e9SAndroid Build Coastguard Worker            for dtypef in [torch.float32]:
8200*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8201*da0073e9SAndroid Build Coastguard Worker                mps_x = cpu_x.detach().clone().to('mps')
8202*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(cpu_x), op(mps_x))
8203*da0073e9SAndroid Build Coastguard Worker
8204*da0073e9SAndroid Build Coastguard Worker            for dtypei in [torch.int32, torch.int16]:
8205*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(0, 1000, shape, device='cpu', dtype=dtypei, requires_grad=False)
8206*da0073e9SAndroid Build Coastguard Worker                mps_x = cpu_x.to('mps')
8207*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(cpu_x), op(mps_x), rtol=1e-4, atol=1e-4)
8208*da0073e9SAndroid Build Coastguard Worker            # test slice
8209*da0073e9SAndroid Build Coastguard Worker            for dtypef in [torch.float32]:
8210*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8211*da0073e9SAndroid Build Coastguard Worker                mps_x = cpu_x.detach().clone().to('mps')
8212*da0073e9SAndroid Build Coastguard Worker                cpu_slice = cpu_x[:, ::2, :, :]
8213*da0073e9SAndroid Build Coastguard Worker                mps_slice = mps_x[:, ::2, :, :]
8214*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(cpu_slice), op(mps_slice))
8215*da0073e9SAndroid Build Coastguard Worker            # test view
8216*da0073e9SAndroid Build Coastguard Worker            for dtypef in [torch.float32]:
8217*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8218*da0073e9SAndroid Build Coastguard Worker                mps_x = cpu_x.detach().clone().to('mps')
8219*da0073e9SAndroid Build Coastguard Worker                # create view of tensor by reducing the 3rd and 4th dimension
8220*da0073e9SAndroid Build Coastguard Worker                combined_dim = shape[-1] * shape[-2]
8221*da0073e9SAndroid Build Coastguard Worker                reshaped_dims = list(shape[:-2]) + [combined_dim]
8222*da0073e9SAndroid Build Coastguard Worker                cpu_view = cpu_x.view(*reshaped_dims)
8223*da0073e9SAndroid Build Coastguard Worker                mps_view = mps_x.view(*reshaped_dims)
8224*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(cpu_view), op(mps_view))
8225*da0073e9SAndroid Build Coastguard Worker
8226*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 4, 5), torch.exp)
8227*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 3, 5), torch.exp2)
8228*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 3, 5), torch.expm1)
8229*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 3, 5), torch.log)
8230*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 3, 5), torch.cos)
8231*da0073e9SAndroid Build Coastguard Worker        helper((2, 8, 3, 5), torch.erfinv)
8232*da0073e9SAndroid Build Coastguard Worker
8233*da0073e9SAndroid Build Coastguard Worker
8234*da0073e9SAndroid Build Coastguard Worker    def test_non_dense_in_storage_unary_ops(self):
8235*da0073e9SAndroid Build Coastguard Worker        def helper(op):
8236*da0073e9SAndroid Build Coastguard Worker            for dtypef in [torch.float32]:
8237*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(100, device='cpu', dtype=dtypef, requires_grad=False)
8238*da0073e9SAndroid Build Coastguard Worker                mps_x = cpu_x.detach().clone().to('mps')
8239*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]))
8240*da0073e9SAndroid Build Coastguard Worker
8241*da0073e9SAndroid Build Coastguard Worker            for dtypei in [torch.int32, torch.int16, torch.int8]:
8242*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randint(127, device='cpu', size=(100,), dtype=dtypei, requires_grad=False)
8243*da0073e9SAndroid Build Coastguard Worker                mps_x = cpu_x.to('mps')
8244*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]), rtol=1e-4, atol=1e-4)
8245*da0073e9SAndroid Build Coastguard Worker
8246*da0073e9SAndroid Build Coastguard Worker        helper(torch.exp)
8247*da0073e9SAndroid Build Coastguard Worker        helper(torch.exp2)
8248*da0073e9SAndroid Build Coastguard Worker        helper(torch.expm1)
8249*da0073e9SAndroid Build Coastguard Worker        helper(torch.log)
8250*da0073e9SAndroid Build Coastguard Worker        helper(torch.cos)
8251*da0073e9SAndroid Build Coastguard Worker
8252*da0073e9SAndroid Build Coastguard Worker    def test_unary_ops_storage_offset_strided(self):
8253*da0073e9SAndroid Build Coastguard Worker        def helper(shape, op, inplace, dtype=torch.float32):
8254*da0073e9SAndroid Build Coastguard Worker            # test in-place with storage_offset
8255*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
8256*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
8257*da0073e9SAndroid Build Coastguard Worker            y = op(mps_x[1])
8258*da0073e9SAndroid Build Coastguard Worker            cpu_y = op(cpu_x[1])
8259*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, cpu_y)
8260*da0073e9SAndroid Build Coastguard Worker
8261*da0073e9SAndroid Build Coastguard Worker
8262*da0073e9SAndroid Build Coastguard Worker            # See https://github.com/pytorch/pytorch/issues/100764
8263*da0073e9SAndroid Build Coastguard Worker            if not inplace:
8264*da0073e9SAndroid Build Coastguard Worker                cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
8265*da0073e9SAndroid Build Coastguard Worker                mps_x = cpu_x.detach().clone().to('mps')
8266*da0073e9SAndroid Build Coastguard Worker                cpu_y = torch.empty(shape, device='cpu', dtype=dtype).t()
8267*da0073e9SAndroid Build Coastguard Worker                mps_y = cpu_y.detach().clone().to('mps')
8268*da0073e9SAndroid Build Coastguard Worker                op(cpu_x, out=cpu_y)
8269*da0073e9SAndroid Build Coastguard Worker                op(mps_x, out=mps_y)
8270*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(mps_y, cpu_y)
8271*da0073e9SAndroid Build Coastguard Worker
8272*da0073e9SAndroid Build Coastguard Worker
8273*da0073e9SAndroid Build Coastguard Worker        helper((5, 5), torch.exp, False)
8274*da0073e9SAndroid Build Coastguard Worker        helper((5, 5), torch.cos, False)
8275*da0073e9SAndroid Build Coastguard Worker        helper((5, 5), torch.neg, False)
8276*da0073e9SAndroid Build Coastguard Worker        helper((5, 5), torch.tanh, False)
8277*da0073e9SAndroid Build Coastguard Worker        helper((5, 5), torch.tanh_, True)
8278*da0073e9SAndroid Build Coastguard Worker
8279*da0073e9SAndroid Build Coastguard Worker    def test_atan2(self):
8280*da0073e9SAndroid Build Coastguard Worker        def helper(shape):
8281*da0073e9SAndroid Build Coastguard Worker            input_cpu = torch.randn(shape)
8282*da0073e9SAndroid Build Coastguard Worker            input_mps = input_cpu.detach().clone().to("mps")
8283*da0073e9SAndroid Build Coastguard Worker
8284*da0073e9SAndroid Build Coastguard Worker            other_cpu = torch.randn(shape)
8285*da0073e9SAndroid Build Coastguard Worker            other_mps = other_cpu.detach().clone().to("mps")
8286*da0073e9SAndroid Build Coastguard Worker
8287*da0073e9SAndroid Build Coastguard Worker            atan2_cpu = torch.atan2(input_cpu, other_cpu)
8288*da0073e9SAndroid Build Coastguard Worker            atan2_mps = torch.atan2(input_mps, other_mps)
8289*da0073e9SAndroid Build Coastguard Worker
8290*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(atan2_cpu, atan2_mps.to("cpu"))
8291*da0073e9SAndroid Build Coastguard Worker
8292*da0073e9SAndroid Build Coastguard Worker        helper(4)
8293*da0073e9SAndroid Build Coastguard Worker        helper(10000)
8294*da0073e9SAndroid Build Coastguard Worker        helper((10000, 40))
8295*da0073e9SAndroid Build Coastguard Worker
8296*da0073e9SAndroid Build Coastguard Worker    def test_multinomial(self):
8297*da0073e9SAndroid Build Coastguard Worker        # Test with num_dist = 1
8298*da0073e9SAndroid Build Coastguard Worker        def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True):
8299*da0073e9SAndroid Build Coastguard Worker            cpu_prob_tensor = torch.tensor(probs, device='cpu', dtype=torch.float, requires_grad=False)
8300*da0073e9SAndroid Build Coastguard Worker            prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
8301*da0073e9SAndroid Build Coastguard Worker
8302*da0073e9SAndroid Build Coastguard Worker            mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement)
8303*da0073e9SAndroid Build Coastguard Worker            if (not replacement):
8304*da0073e9SAndroid Build Coastguard Worker                print(mps_out.to('cpu'))
8305*da0073e9SAndroid Build Coastguard Worker            else:
8306*da0073e9SAndroid Build Coastguard Worker                # Compare "real" with theoretical values
8307*da0073e9SAndroid Build Coastguard Worker                print(mps_out.to('cpu').float().mean(), compare_mean)
8308*da0073e9SAndroid Build Coastguard Worker                print(mps_out.to('cpu').float().std() ** 2, compare_var)
8309*da0073e9SAndroid Build Coastguard Worker
8310*da0073e9SAndroid Build Coastguard Worker        # TODO: Add tests for data types
8311*da0073e9SAndroid Build Coastguard Worker        helper(np.array([[0., 0., 0., 0.5, 0.5]]), (3 + 4) / 2, (12.5 - 3.5 ** 2), 100000)
8312*da0073e9SAndroid Build Coastguard Worker        helper(np.array([[.2, .2, .2, .2, .2]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8313*da0073e9SAndroid Build Coastguard Worker        helper(np.array([[1, 1, 1, 1, 1]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8314*da0073e9SAndroid Build Coastguard Worker        helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8315*da0073e9SAndroid Build Coastguard Worker        helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False)
8316*da0073e9SAndroid Build Coastguard Worker
8317*da0073e9SAndroid Build Coastguard Worker    def test_cumsum_dim_check(self):
8318*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3), device="mps")
8319*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.cumsum(1), x.cumsum(-1))
8320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.cumsum(0), x.cumsum(-2))
8321*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: x.cumsum(2))
8322*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: x.cumsum(-3))
8323*da0073e9SAndroid Build Coastguard Worker
8324*da0073e9SAndroid Build Coastguard Worker    def test_cumprod_dim_check(self):
8325*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((3, 3), device="mps")
8326*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.cumprod(1), x.cumprod(-1))
8327*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.cumprod(0), x.cumprod(-2))
8328*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: x.cumprod(2))
8329*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: x.cumprod(-3))
8330*da0073e9SAndroid Build Coastguard Worker
8331*da0073e9SAndroid Build Coastguard Workerclass TestLogical(TestCaseMPS):
8332*da0073e9SAndroid Build Coastguard Worker    def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
8333*da0073e9SAndroid Build Coastguard Worker        return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
8334*da0073e9SAndroid Build Coastguard Worker
8335*da0073e9SAndroid Build Coastguard Worker    def test_logical_not(self):
8336*da0073e9SAndroid Build Coastguard Worker        def helper(x):
8337*da0073e9SAndroid Build Coastguard Worker            cpu_x = x
8338*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
8339*da0073e9SAndroid Build Coastguard Worker
8340*da0073e9SAndroid Build Coastguard Worker            result = torch.logical_not(x)
8341*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.logical_not(cpu_x)
8342*da0073e9SAndroid Build Coastguard Worker
8343*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_cpu)
8344*da0073e9SAndroid Build Coastguard Worker
8345*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor([1, 1, 0, 0]))
8346*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True))
8347*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor([True, True, False, False]))
8348*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor(1))
8349*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor(0))
8350*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor(True))
8351*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor(False))
8352*da0073e9SAndroid Build Coastguard Worker
8353*da0073e9SAndroid Build Coastguard Worker    def test_logical_and(self):
8354*da0073e9SAndroid Build Coastguard Worker        def helper(x, other):
8355*da0073e9SAndroid Build Coastguard Worker            cpu_x = x
8356*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
8357*da0073e9SAndroid Build Coastguard Worker
8358*da0073e9SAndroid Build Coastguard Worker            cpu_other = other
8359*da0073e9SAndroid Build Coastguard Worker            other = cpu_other.detach().clone().to('mps')
8360*da0073e9SAndroid Build Coastguard Worker
8361*da0073e9SAndroid Build Coastguard Worker            result = torch.logical_and(x, other)
8362*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.logical_and(cpu_x, cpu_other)
8363*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_cpu)
8364*da0073e9SAndroid Build Coastguard Worker
8365*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8366*da0073e9SAndroid Build Coastguard Worker        helper(
8367*da0073e9SAndroid Build Coastguard Worker            self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8368*da0073e9SAndroid Build Coastguard Worker            self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8369*da0073e9SAndroid Build Coastguard Worker        )
8370*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8371*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8372*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8373*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8374*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8375*da0073e9SAndroid Build Coastguard Worker
8376*da0073e9SAndroid Build Coastguard Worker    def test_logical_or(self):
8377*da0073e9SAndroid Build Coastguard Worker        def helper(x, other):
8378*da0073e9SAndroid Build Coastguard Worker            cpu_x = x
8379*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
8380*da0073e9SAndroid Build Coastguard Worker
8381*da0073e9SAndroid Build Coastguard Worker            cpu_other = other
8382*da0073e9SAndroid Build Coastguard Worker            other = cpu_other.detach().clone().to('mps')
8383*da0073e9SAndroid Build Coastguard Worker
8384*da0073e9SAndroid Build Coastguard Worker            result = torch.logical_or(x, other)
8385*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.logical_or(cpu_x, cpu_other)
8386*da0073e9SAndroid Build Coastguard Worker
8387*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_cpu)
8388*da0073e9SAndroid Build Coastguard Worker
8389*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8390*da0073e9SAndroid Build Coastguard Worker        helper(
8391*da0073e9SAndroid Build Coastguard Worker            self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8392*da0073e9SAndroid Build Coastguard Worker            self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8393*da0073e9SAndroid Build Coastguard Worker        )
8394*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8395*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8396*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8397*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8398*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8399*da0073e9SAndroid Build Coastguard Worker
8400*da0073e9SAndroid Build Coastguard Worker    def test_logical_xor(self):
8401*da0073e9SAndroid Build Coastguard Worker        def helper(x, other):
8402*da0073e9SAndroid Build Coastguard Worker            cpu_x = x
8403*da0073e9SAndroid Build Coastguard Worker            x = cpu_x.detach().clone().to('mps')
8404*da0073e9SAndroid Build Coastguard Worker
8405*da0073e9SAndroid Build Coastguard Worker            cpu_other = other
8406*da0073e9SAndroid Build Coastguard Worker            other = cpu_other.detach().clone().to('mps')
8407*da0073e9SAndroid Build Coastguard Worker
8408*da0073e9SAndroid Build Coastguard Worker            result = torch.logical_xor(x, other)
8409*da0073e9SAndroid Build Coastguard Worker            result_cpu = torch.logical_xor(cpu_x, cpu_other)
8410*da0073e9SAndroid Build Coastguard Worker
8411*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, result_cpu)
8412*da0073e9SAndroid Build Coastguard Worker
8413*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8414*da0073e9SAndroid Build Coastguard Worker        helper(
8415*da0073e9SAndroid Build Coastguard Worker            self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8416*da0073e9SAndroid Build Coastguard Worker            self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8417*da0073e9SAndroid Build Coastguard Worker        )
8418*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8419*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8420*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8421*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8422*da0073e9SAndroid Build Coastguard Worker        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8423*da0073e9SAndroid Build Coastguard Worker
8424*da0073e9SAndroid Build Coastguard Worker    def test_min_max(self):
8425*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
8426*da0073e9SAndroid Build Coastguard Worker            for _ in range(10):
8427*da0073e9SAndroid Build Coastguard Worker                if dtype == torch.float32 or dtype == torch.float16:
8428*da0073e9SAndroid Build Coastguard Worker                    x = torch.randn((30, 15), device='mps', dtype=dtype)
8429*da0073e9SAndroid Build Coastguard Worker                else:
8430*da0073e9SAndroid Build Coastguard Worker                    x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype)
8431*da0073e9SAndroid Build Coastguard Worker                x_cpu = x.to("cpu")
8432*da0073e9SAndroid Build Coastguard Worker
8433*da0073e9SAndroid Build Coastguard Worker                y = x.max()
8434*da0073e9SAndroid Build Coastguard Worker                y_cpu = x_cpu.max()
8435*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y, y_cpu)
8436*da0073e9SAndroid Build Coastguard Worker
8437*da0073e9SAndroid Build Coastguard Worker                z = x.min()
8438*da0073e9SAndroid Build Coastguard Worker                z_cpu = x_cpu.min()
8439*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(z, z_cpu)
8440*da0073e9SAndroid Build Coastguard Worker
8441*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
8442*da0073e9SAndroid Build Coastguard Worker
8443*da0073e9SAndroid Build Coastguard Worker    def test_min_max_nan_propagation(self):
8444*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
8445*da0073e9SAndroid Build Coastguard Worker            cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu")
8446*da0073e9SAndroid Build Coastguard Worker            mps_x = cpu_x.detach().clone().to('mps')
8447*da0073e9SAndroid Build Coastguard Worker
8448*da0073e9SAndroid Build Coastguard Worker            cpu_max = torch.max(cpu_x)
8449*da0073e9SAndroid Build Coastguard Worker            mps_max = torch.max(mps_x).to('cpu')
8450*da0073e9SAndroid Build Coastguard Worker
8451*da0073e9SAndroid Build Coastguard Worker            cpu_amax = torch.amax(cpu_x)
8452*da0073e9SAndroid Build Coastguard Worker            mps_amax = torch.amax(mps_x).to('cpu')
8453*da0073e9SAndroid Build Coastguard Worker
8454*da0073e9SAndroid Build Coastguard Worker            cpu_min = torch.min(cpu_x)
8455*da0073e9SAndroid Build Coastguard Worker            mps_min = torch.min(mps_x).to('cpu')
8456*da0073e9SAndroid Build Coastguard Worker
8457*da0073e9SAndroid Build Coastguard Worker            cpu_amin = torch.amin(cpu_x)
8458*da0073e9SAndroid Build Coastguard Worker            mps_amin = torch.amin(mps_x).to('cpu')
8459*da0073e9SAndroid Build Coastguard Worker
8460*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_max, mps_max)
8461*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_amax, mps_amax)
8462*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_min, mps_min)
8463*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_amin, mps_amin)
8464*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in [torch.float32, torch.float16, torch.bfloat16]]
8465*da0073e9SAndroid Build Coastguard Worker
8466*da0073e9SAndroid Build Coastguard Worker    def test_isin(self):
8467*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
8468*da0073e9SAndroid Build Coastguard Worker            shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]),
8469*da0073e9SAndroid Build Coastguard Worker                      ([5], [10]), ([0], [5]), ([5], [0])]
8470*da0073e9SAndroid Build Coastguard Worker            for shape_tuple in shapes:
8471*da0073e9SAndroid Build Coastguard Worker                for inverted in [True, False]:
8472*da0073e9SAndroid Build Coastguard Worker                    if dtype.is_floating_point:
8473*da0073e9SAndroid Build Coastguard Worker                        # Half is not supported for CPU isin. Compute reference in FP32
8474*da0073e9SAndroid Build Coastguard Worker                        A = torch.randn(size=shape_tuple[0], device='cpu', dtype=torch.float32)
8475*da0073e9SAndroid Build Coastguard Worker                        B = torch.randn(size=shape_tuple[1], device='cpu', dtype=torch.float32)
8476*da0073e9SAndroid Build Coastguard Worker                    else:
8477*da0073e9SAndroid Build Coastguard Worker                        A = torch.randint(0, 100, size=shape_tuple[0], device='cpu', dtype=dtype)
8478*da0073e9SAndroid Build Coastguard Worker                        B = torch.randint(0, 100, size=shape_tuple[1], device='cpu', dtype=dtype)
8479*da0073e9SAndroid Build Coastguard Worker
8480*da0073e9SAndroid Build Coastguard Worker                    A_mps = A.clone().detach().to('mps')
8481*da0073e9SAndroid Build Coastguard Worker                    B_mps = B.clone().detach().to('mps')
8482*da0073e9SAndroid Build Coastguard Worker
8483*da0073e9SAndroid Build Coastguard Worker                    cpu_ref = torch.isin(A, B, invert=inverted)
8484*da0073e9SAndroid Build Coastguard Worker                    if dtype in [torch.float16, torch.bfloat16]:
8485*da0073e9SAndroid Build Coastguard Worker                        cpu_ref.type(dtype)
8486*da0073e9SAndroid Build Coastguard Worker
8487*da0073e9SAndroid Build Coastguard Worker                    mps_out = torch.isin(A_mps, B_mps, invert=inverted)
8488*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mps_out, cpu_ref)
8489*da0073e9SAndroid Build Coastguard Worker
8490*da0073e9SAndroid Build Coastguard Worker        dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8]
8491*da0073e9SAndroid Build Coastguard Worker        if product_version < 14.0:
8492*da0073e9SAndroid Build Coastguard Worker            # Int types expected to fail on MacOS < 14.0
8493*da0073e9SAndroid Build Coastguard Worker            dtypes = [torch.float32, torch.float16, torch.bfloat16]
8494*da0073e9SAndroid Build Coastguard Worker
8495*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in dtypes]
8496*da0073e9SAndroid Build Coastguard Worker
8497*da0073e9SAndroid Build Coastguard Worker    def test_isin_asserts(self):
8498*da0073e9SAndroid Build Coastguard Worker        A = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
8499*da0073e9SAndroid Build Coastguard Worker        B = torch.randn(size=[1, 4], device='mps', dtype=torch.float16)
8500*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Expected elements.dtype()*'):
8501*da0073e9SAndroid Build Coastguard Worker            out = torch.isin(A, B)
8502*da0073e9SAndroid Build Coastguard Worker
8503*da0073e9SAndroid Build Coastguard Worker
8504*da0073e9SAndroid Build Coastguard Worker        C = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
8505*da0073e9SAndroid Build Coastguard Worker        D = torch.randn(size=[1, 4], device='cpu', dtype=torch.float32)
8506*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Expected elements.is_mps()*'):
8507*da0073e9SAndroid Build Coastguard Worker            out = torch.isin(C, D)
8508*da0073e9SAndroid Build Coastguard Worker
8509*da0073e9SAndroid Build Coastguard Workerclass TestSmoothL1Loss(TestCaseMPS):
8510*da0073e9SAndroid Build Coastguard Worker
8511*da0073e9SAndroid Build Coastguard Worker    def _smooth_l1_loss_helper(self, reduction="mean", requires_grad=False):
8512*da0073e9SAndroid Build Coastguard Worker        # CPU
8513*da0073e9SAndroid Build Coastguard Worker        input_cpu = torch.randn(4, 7, requires_grad=requires_grad)
8514*da0073e9SAndroid Build Coastguard Worker        target_cpu = torch.randn(4, 7)
8515*da0073e9SAndroid Build Coastguard Worker
8516*da0073e9SAndroid Build Coastguard Worker        # MPS
8517*da0073e9SAndroid Build Coastguard Worker        input_mps = input_cpu.detach().clone().to('mps').requires_grad_()
8518*da0073e9SAndroid Build Coastguard Worker        target_mps = target_cpu.detach().clone().to('mps')
8519*da0073e9SAndroid Build Coastguard Worker
8520*da0073e9SAndroid Build Coastguard Worker        smooth_l1_loss_cpu = F.smooth_l1_loss(input_cpu, target_cpu, beta=1.0, reduction=reduction)
8521*da0073e9SAndroid Build Coastguard Worker        smooth_l1_loss_mps = F.smooth_l1_loss(input_mps, target_mps, beta=1.0, reduction=reduction)
8522*da0073e9SAndroid Build Coastguard Worker
8523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(smooth_l1_loss_cpu, smooth_l1_loss_mps)
8524*da0073e9SAndroid Build Coastguard Worker
8525*da0073e9SAndroid Build Coastguard Worker        if requires_grad:
8526*da0073e9SAndroid Build Coastguard Worker            smooth_l1_loss_cpu.backward()
8527*da0073e9SAndroid Build Coastguard Worker            smooth_l1_loss_mps.backward()
8528*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input_cpu.grad, input_mps.grad.to("cpu"))
8529*da0073e9SAndroid Build Coastguard Worker
8530*da0073e9SAndroid Build Coastguard Worker        return smooth_l1_loss_cpu, smooth_l1_loss_mps
8531*da0073e9SAndroid Build Coastguard Worker
8532*da0073e9SAndroid Build Coastguard Worker    def test_smooth_l1_loss_reduction_none(self):
8533*da0073e9SAndroid Build Coastguard Worker        self._smooth_l1_loss_helper(reduction="none")
8534*da0073e9SAndroid Build Coastguard Worker
8535*da0073e9SAndroid Build Coastguard Worker    def test_smooth_l1_loss_reduction_mean(self):
8536*da0073e9SAndroid Build Coastguard Worker        self._smooth_l1_loss_helper(reduction="mean")
8537*da0073e9SAndroid Build Coastguard Worker
8538*da0073e9SAndroid Build Coastguard Worker    def test_smooth_l1_loss_reduction_sum(self):
8539*da0073e9SAndroid Build Coastguard Worker        self._smooth_l1_loss_helper(reduction="sum")
8540*da0073e9SAndroid Build Coastguard Worker
8541*da0073e9SAndroid Build Coastguard Worker    def test_smooth_l1_loss_reduction_mean_backward(self):
8542*da0073e9SAndroid Build Coastguard Worker        self._smooth_l1_loss_helper(reduction="mean", requires_grad=True)
8543*da0073e9SAndroid Build Coastguard Worker
8544*da0073e9SAndroid Build Coastguard Worker    def test_smooth_l1_loss_reduction_mean_sum_backward(self):
8545*da0073e9SAndroid Build Coastguard Worker        self._smooth_l1_loss_helper(reduction="sum", requires_grad=True)
8546*da0073e9SAndroid Build Coastguard Worker
8547*da0073e9SAndroid Build Coastguard Workerclass TestNLLLoss(TestCaseMPS):
8548*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_mismatched_batch(self, device='mps'):
8549*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((10, 3), requires_grad=True, device=device)
8550*da0073e9SAndroid Build Coastguard Worker        # t should have size (10,)
8551*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros((3,), dtype=torch.int64, device=device)
8552*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
8553*da0073e9SAndroid Build Coastguard Worker            F.nll_loss(x, t)
8554*da0073e9SAndroid Build Coastguard Worker
8555*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_out_of_bounds_ignore_index(self):
8556*da0073e9SAndroid Build Coastguard Worker
8557*da0073e9SAndroid Build Coastguard Worker        def test_nll_loss_out_of_bounds_ignore_index_helper(device):
8558*da0073e9SAndroid Build Coastguard Worker            output = []
8559*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8560*da0073e9SAndroid Build Coastguard Worker                             0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8561*da0073e9SAndroid Build Coastguard Worker            t1 = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device)
8562*da0073e9SAndroid Build Coastguard Worker            t2 = torch.tensor([0, 1, 1, 0, -100, 2], dtype=torch.int64, device=device)
8563*da0073e9SAndroid Build Coastguard Worker            for reduction in ['mean', 'none']:
8564*da0073e9SAndroid Build Coastguard Worker                # out of bound ignore_index
8565*da0073e9SAndroid Build Coastguard Worker                output.append(F.nll_loss(x, t1, ignore_index=255, reduction=reduction))
8566*da0073e9SAndroid Build Coastguard Worker                # default ignore_index
8567*da0073e9SAndroid Build Coastguard Worker                output.append(F.nll_loss(x, t2, reduction=reduction))
8568*da0073e9SAndroid Build Coastguard Worker            return output
8569*da0073e9SAndroid Build Coastguard Worker
8570*da0073e9SAndroid Build Coastguard Worker        output_cpu = test_nll_loss_out_of_bounds_ignore_index_helper(device='cpu')
8571*da0073e9SAndroid Build Coastguard Worker        output_mps = test_nll_loss_out_of_bounds_ignore_index_helper(device='mps')
8572*da0073e9SAndroid Build Coastguard Worker
8573*da0073e9SAndroid Build Coastguard Worker        for cpu, mps in zip(output_cpu, output_mps):
8574*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu, mps)
8575*da0073e9SAndroid Build Coastguard Worker
8576*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_invalid_target_dim(self):
8577*da0073e9SAndroid Build Coastguard Worker
8578*da0073e9SAndroid Build Coastguard Worker        def _test_nll_loss_invalid_target_dim(device):
8579*da0073e9SAndroid Build Coastguard Worker            output = []
8580*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8581*da0073e9SAndroid Build Coastguard Worker                             0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8582*da0073e9SAndroid Build Coastguard Worker            t = torch.zeros((6, 2), dtype=torch.int64, device=device)
8583*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
8584*da0073e9SAndroid Build Coastguard Worker                F.nll_loss(x, t)
8585*da0073e9SAndroid Build Coastguard Worker
8586*da0073e9SAndroid Build Coastguard Worker        _test_nll_loss_invalid_target_dim(device='cpu')
8587*da0073e9SAndroid Build Coastguard Worker        _test_nll_loss_invalid_target_dim(device='mps')
8588*da0073e9SAndroid Build Coastguard Worker
8589*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_invalid_weights(self):
8590*da0073e9SAndroid Build Coastguard Worker
8591*da0073e9SAndroid Build Coastguard Worker        def _test_nll_loss_invalid_weights(device):
8592*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8593*da0073e9SAndroid Build Coastguard Worker                             0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8594*da0073e9SAndroid Build Coastguard Worker            t = torch.tensor([0, 1, 2, 1, 1, 2], dtype=torch.int64, device=device)
8595*da0073e9SAndroid Build Coastguard Worker            invalid_weights = [
8596*da0073e9SAndroid Build Coastguard Worker                torch.zeros(4, device=device),
8597*da0073e9SAndroid Build Coastguard Worker                torch.zeros((1, 3), device=device),
8598*da0073e9SAndroid Build Coastguard Worker            ]
8599*da0073e9SAndroid Build Coastguard Worker            msg = "weight tensor should be defined either for all 3 classes or no classes"
8600*da0073e9SAndroid Build Coastguard Worker            for weight in invalid_weights:
8601*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, msg):
8602*da0073e9SAndroid Build Coastguard Worker                    F.nll_loss(x, t, weight=weight)
8603*da0073e9SAndroid Build Coastguard Worker
8604*da0073e9SAndroid Build Coastguard Worker        _test_nll_loss_invalid_weights(device='cpu')
8605*da0073e9SAndroid Build Coastguard Worker        _test_nll_loss_invalid_weights(device='mps')
8606*da0073e9SAndroid Build Coastguard Worker
8607*da0073e9SAndroid Build Coastguard Worker    def _nll_loss_helper(self, input_size, reduction, expected):
8608*da0073e9SAndroid Build Coastguard Worker
8609*da0073e9SAndroid Build Coastguard Worker        # CPU
8610*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(input_size, requires_grad=True, device='cpu')
8611*da0073e9SAndroid Build Coastguard Worker        num_channels = input_size[1]
8612*da0073e9SAndroid Build Coastguard Worker        target_size = (input_size[0], ) + tuple(input_size[2:])
8613*da0073e9SAndroid Build Coastguard Worker        target = torch.randint(num_channels, target_size, device='cpu')
8614*da0073e9SAndroid Build Coastguard Worker        weights = torch.randn(num_channels)
8615*da0073e9SAndroid Build Coastguard Worker
8616*da0073e9SAndroid Build Coastguard Worker        # MPS
8617*da0073e9SAndroid Build Coastguard Worker        input_mps = input.detach().clone().to('mps').requires_grad_()
8618*da0073e9SAndroid Build Coastguard Worker        target_mps = target.detach().clone().to('mps')
8619*da0073e9SAndroid Build Coastguard Worker        weights_mps = weights.to("mps")
8620*da0073e9SAndroid Build Coastguard Worker
8621*da0073e9SAndroid Build Coastguard Worker        output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction)
8622*da0073e9SAndroid Build Coastguard Worker        output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction)
8623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu, output_mps.to('cpu'))
8624*da0073e9SAndroid Build Coastguard Worker
8625*da0073e9SAndroid Build Coastguard Worker        output_cpu.sum().backward()
8626*da0073e9SAndroid Build Coastguard Worker        output_mps.sum().backward()
8627*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input.grad, input_mps.grad.to('cpu'))
8628*da0073e9SAndroid Build Coastguard Worker
8629*da0073e9SAndroid Build Coastguard Worker    def _nll_loss_1d_helper(self, input_size, reduction):
8630*da0073e9SAndroid Build Coastguard Worker
8631*da0073e9SAndroid Build Coastguard Worker        # CPU
8632*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(input_size, requires_grad=True, device='cpu')
8633*da0073e9SAndroid Build Coastguard Worker        num_channels = input_size[0]
8634*da0073e9SAndroid Build Coastguard Worker        target = torch.randint(num_channels, [], device='cpu')
8635*da0073e9SAndroid Build Coastguard Worker
8636*da0073e9SAndroid Build Coastguard Worker        # MPS
8637*da0073e9SAndroid Build Coastguard Worker        input_mps = input.detach().clone().to('mps').requires_grad_()
8638*da0073e9SAndroid Build Coastguard Worker        target_mps = target.detach().clone().to('mps')
8639*da0073e9SAndroid Build Coastguard Worker
8640*da0073e9SAndroid Build Coastguard Worker        output_cpu = F.nll_loss(input, target, reduction=reduction)
8641*da0073e9SAndroid Build Coastguard Worker        output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
8642*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu, output_mps.to('cpu'))
8643*da0073e9SAndroid Build Coastguard Worker
8644*da0073e9SAndroid Build Coastguard Worker        output_cpu.sum().backward()
8645*da0073e9SAndroid Build Coastguard Worker        output_mps.sum().backward()
8646*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input.grad, input_mps.grad.to('cpu'))
8647*da0073e9SAndroid Build Coastguard Worker
8648*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_1d(self, device='cpu'):
8649*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_1d_helper([10], "none")
8650*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_1d_helper([10], "mean")
8651*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_1d_helper([10], "sum")
8652*da0073e9SAndroid Build Coastguard Worker
8653*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
8654*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
8655*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))
8656*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 1, 7], "none", torch.empty([2, 1, 7], device=device))
8657*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 1], "none", torch.empty([2, 5, 1], device=device))
8658*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 7, 1], "none", torch.empty([2, 5, 7, 1], device=device))
8659*da0073e9SAndroid Build Coastguard Worker
8660*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_empty_tensor_reduction_mean(self, device='cpu'):
8661*da0073e9SAndroid Build Coastguard Worker        nan = torch.tensor(float('nan'), device=device)
8662*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([1, 3], "mean", nan)
8663*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([1, 3, 5, 7], "mean", nan)
8664*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 1, 7], "mean", nan)
8665*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 1], "mean", nan)
8666*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 7, 1], "mean", nan)
8667*da0073e9SAndroid Build Coastguard Worker
8668*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_empty_tensor_reduction_sum(self, device='cpu'):
8669*da0073e9SAndroid Build Coastguard Worker        zero = torch.tensor(0, device=device)
8670*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([1, 3], "sum", zero)
8671*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([1, 3, 5, 7], "sum", zero)
8672*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 1, 7], "sum", zero)
8673*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 1], "sum", zero)
8674*da0073e9SAndroid Build Coastguard Worker        self._nll_loss_helper([2, 3, 5, 7, 1], "sum", zero)
8675*da0073e9SAndroid Build Coastguard Worker
8676*da0073e9SAndroid Build Coastguard Worker    def test_nll_loss_byte_target_matches_long(self, device='cpu'):
8677*da0073e9SAndroid Build Coastguard Worker        N, C = 10, 4
8678*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(N, C, device=device, requires_grad=True)
8679*da0073e9SAndroid Build Coastguard Worker        target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
8680*da0073e9SAndroid Build Coastguard Worker
8681*da0073e9SAndroid Build Coastguard Worker        def compute_result_and_gradient(reduction, target_dtype):
8682*da0073e9SAndroid Build Coastguard Worker            result, grad = {}, {}
8683*da0073e9SAndroid Build Coastguard Worker            for dev in ['cpu', 'mps']:
8684*da0073e9SAndroid Build Coastguard Worker                input_dev = input.to(dev)
8685*da0073e9SAndroid Build Coastguard Worker                input_ = input_dev.detach()
8686*da0073e9SAndroid Build Coastguard Worker                input_.requires_grad_()
8687*da0073e9SAndroid Build Coastguard Worker
8688*da0073e9SAndroid Build Coastguard Worker                target_dev = target.to(dev)
8689*da0073e9SAndroid Build Coastguard Worker
8690*da0073e9SAndroid Build Coastguard Worker                prob = F.log_softmax(input_, dim=-1)
8691*da0073e9SAndroid Build Coastguard Worker                loss = nn.NLLLoss(reduction=reduction)
8692*da0073e9SAndroid Build Coastguard Worker                result[dev] = loss(prob, target_dev.to(target_dtype))
8693*da0073e9SAndroid Build Coastguard Worker                result[dev].sum().backward()
8694*da0073e9SAndroid Build Coastguard Worker                grad[dev] = input_.grad
8695*da0073e9SAndroid Build Coastguard Worker
8696*da0073e9SAndroid Build Coastguard Worker            return result, grad
8697*da0073e9SAndroid Build Coastguard Worker
8698*da0073e9SAndroid Build Coastguard Worker        for reduction in ["none", "mean", "sum"]:
8699*da0073e9SAndroid Build Coastguard Worker            result_long, grad_long = compute_result_and_gradient(reduction, torch.long)
8700*da0073e9SAndroid Build Coastguard Worker            result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8)
8701*da0073e9SAndroid Build Coastguard Worker
8702*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu'])
8703*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu'])
8704*da0073e9SAndroid Build Coastguard Worker
8705*da0073e9SAndroid Build Coastguard Workerclass TestTopK(TestCase):
8706*da0073e9SAndroid Build Coastguard Worker    def _test_topk(self, shape, largest):
8707*da0073e9SAndroid Build Coastguard Worker        cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
8708*da0073e9SAndroid Build Coastguard Worker        x = cpu_x.detach().clone().to('mps')
8709*da0073e9SAndroid Build Coastguard Worker        if isinstance(shape, tuple):
8710*da0073e9SAndroid Build Coastguard Worker            for curr_dim, dim_size in enumerate(shape):
8711*da0073e9SAndroid Build Coastguard Worker                for k in range(1, dim_size + 1):
8712*da0073e9SAndroid Build Coastguard Worker                    topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest)
8713*da0073e9SAndroid Build Coastguard Worker                    topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest)
8714*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(topk_values, topk_values_cpu)
8715*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(topk_indices, topk_indices_cpu)
8716*da0073e9SAndroid Build Coastguard Worker        else:
8717*da0073e9SAndroid Build Coastguard Worker            for k in range(1, shape):
8718*da0073e9SAndroid Build Coastguard Worker                topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest)
8719*da0073e9SAndroid Build Coastguard Worker                topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest)
8720*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(topk_values, topk_values_cpu)
8721*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(topk_indices, topk_indices_cpu)
8722*da0073e9SAndroid Build Coastguard Worker
8723*da0073e9SAndroid Build Coastguard Worker    def test_topk(self):
8724*da0073e9SAndroid Build Coastguard Worker        largest_vals = [True, False]
8725*da0073e9SAndroid Build Coastguard Worker        shapes = [
8726*da0073e9SAndroid Build Coastguard Worker            # Zero Element Tensors
8727*da0073e9SAndroid Build Coastguard Worker            0,
8728*da0073e9SAndroid Build Coastguard Worker            (1, 0),
8729*da0073e9SAndroid Build Coastguard Worker            (0, 1),
8730*da0073e9SAndroid Build Coastguard Worker            (1, 0, 1),
8731*da0073e9SAndroid Build Coastguard Worker            # Multiple Element Tensors
8732*da0073e9SAndroid Build Coastguard Worker            1,
8733*da0073e9SAndroid Build Coastguard Worker            2,
8734*da0073e9SAndroid Build Coastguard Worker            (5, 1),
8735*da0073e9SAndroid Build Coastguard Worker            (1, 5),
8736*da0073e9SAndroid Build Coastguard Worker            (5, 9, 7, 4),
8737*da0073e9SAndroid Build Coastguard Worker        ]
8738*da0073e9SAndroid Build Coastguard Worker
8739*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
8740*da0073e9SAndroid Build Coastguard Worker            for largest_val in largest_vals:
8741*da0073e9SAndroid Build Coastguard Worker                with self.subTest(shape=shape, largest_val=largest_val):
8742*da0073e9SAndroid Build Coastguard Worker                    self._test_topk(shape, largest_val)
8743*da0073e9SAndroid Build Coastguard Worker
8744*da0073e9SAndroid Build Coastguard Workerclass TestNNMPS(NNTestCase):
8745*da0073e9SAndroid Build Coastguard Worker
8746*da0073e9SAndroid Build Coastguard Worker    def _create_basic_net(self):
8747*da0073e9SAndroid Build Coastguard Worker        class Layer(nn.Module):
8748*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8749*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8750*da0073e9SAndroid Build Coastguard Worker                self.layer_dummy_param = Parameter(torch.empty(3, 5))
8751*da0073e9SAndroid Build Coastguard Worker                self.layer_dummy_buf = Buffer(torch.zeros(1, 3, 3, 7))
8752*da0073e9SAndroid Build Coastguard Worker
8753*da0073e9SAndroid Build Coastguard Worker        class Net(nn.Module):
8754*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
8755*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8756*da0073e9SAndroid Build Coastguard Worker                self.l1 = Layer()
8757*da0073e9SAndroid Build Coastguard Worker                self.dummy_param = Parameter(torch.empty(3, 5))
8758*da0073e9SAndroid Build Coastguard Worker                self.dummy_buf = Buffer(torch.zeros(7, 3, 3, 1))
8759*da0073e9SAndroid Build Coastguard Worker
8760*da0073e9SAndroid Build Coastguard Worker        l = Layer()
8761*da0073e9SAndroid Build Coastguard Worker        n = Net()
8762*da0073e9SAndroid Build Coastguard Worker        s = nn.Sequential(n, n)
8763*da0073e9SAndroid Build Coastguard Worker
8764*da0073e9SAndroid Build Coastguard Worker        return l, n, s
8765*da0073e9SAndroid Build Coastguard Worker
8766*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_(self):
8767*da0073e9SAndroid Build Coastguard Worker        m = self._create_basic_net()[-1]
8768*da0073e9SAndroid Build Coastguard Worker        assert len(list(m.buffers())) > 0, 'invalid test'
8769*da0073e9SAndroid Build Coastguard Worker        assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
8770*da0073e9SAndroid Build Coastguard Worker        assert len(list(m.parameters())) > 0, 'invalid test'
8771*da0073e9SAndroid Build Coastguard Worker        assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
8772*da0073e9SAndroid Build Coastguard Worker        for requires_grad in (False, True):
8773*da0073e9SAndroid Build Coastguard Worker            self.assertIs(m.requires_grad_(requires_grad), m)
8774*da0073e9SAndroid Build Coastguard Worker            for p in m.parameters():
8775*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(p.requires_grad, requires_grad)
8776*da0073e9SAndroid Build Coastguard Worker            for b in m.buffers():
8777*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(b.requires_grad)
8778*da0073e9SAndroid Build Coastguard Worker
8779*da0073e9SAndroid Build Coastguard Worker    def test_module_backcompat(self):
8780*da0073e9SAndroid Build Coastguard Worker        from torch.serialization import SourceChangeWarning
8781*da0073e9SAndroid Build Coastguard Worker        path = download_file('https://download.pytorch.org/test_data/linear.pt')
8782*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings():
8783*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter('ignore', SourceChangeWarning)
8784*da0073e9SAndroid Build Coastguard Worker            m = torch.load(path)
8785*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3, dtype=torch.float)
8786*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(input).size(), (2, 5))
8787*da0073e9SAndroid Build Coastguard Worker
8788*da0073e9SAndroid Build Coastguard Worker    def test_conv_backcompat(self):
8789*da0073e9SAndroid Build Coastguard Worker        from torch.serialization import SourceChangeWarning
8790*da0073e9SAndroid Build Coastguard Worker        # This file was generated by running on PyTorch 1.0.1 on Python 2:
8791*da0073e9SAndroid Build Coastguard Worker        #
8792*da0073e9SAndroid Build Coastguard Worker        #     import torch
8793*da0073e9SAndroid Build Coastguard Worker        #     from torch import nn
8794*da0073e9SAndroid Build Coastguard Worker        #     m = nn.Conv2d(1, 1, 1)
8795*da0073e9SAndroid Build Coastguard Worker        #     torch.save(m, 'legacy_conv2d.pt')
8796*da0073e9SAndroid Build Coastguard Worker        #
8797*da0073e9SAndroid Build Coastguard Worker        # NB: This Pickle also contains some Unicode data!
8798*da0073e9SAndroid Build Coastguard Worker        path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
8799*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings():
8800*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter('ignore', SourceChangeWarning)
8801*da0073e9SAndroid Build Coastguard Worker            m = torch.load(path, encoding='utf-8')
8802*da0073e9SAndroid Build Coastguard Worker        input = torch.randn((1, 1, 1, 1), dtype=torch.float)
8803*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(m(input).size(), (1, 1, 1, 1))
8804*da0073e9SAndroid Build Coastguard Worker
8805*da0073e9SAndroid Build Coastguard Worker    def test_conv_expand(self):
8806*da0073e9SAndroid Build Coastguard Worker        device = 'mps'
8807*da0073e9SAndroid Build Coastguard Worker        input_ = torch.rand(2, 3, 16, 16, device=device)
8808*da0073e9SAndroid Build Coastguard Worker        kernel = torch.rand(1, 1, 3, 11, device=device)
8809*da0073e9SAndroid Build Coastguard Worker        tmp_kernel = kernel.expand(-1, 3, -1, -1)
8810*da0073e9SAndroid Build Coastguard Worker        output = F.conv2d(input_, tmp_kernel, groups=1, padding=0, stride=1)
8811*da0073e9SAndroid Build Coastguard Worker
8812*da0073e9SAndroid Build Coastguard Worker    # The test should not crash
8813*da0073e9SAndroid Build Coastguard Worker    def test_permute(self):
8814*da0073e9SAndroid Build Coastguard Worker        M_cpu = torch.randn(5, 5)
8815*da0073e9SAndroid Build Coastguard Worker        M_mps = M_cpu.to('mps')
8816*da0073e9SAndroid Build Coastguard Worker
8817*da0073e9SAndroid Build Coastguard Worker        output_cpu = M_cpu.permute(1, 0)
8818*da0073e9SAndroid Build Coastguard Worker        output_mps = M_mps.permute(1, 0)
8819*da0073e9SAndroid Build Coastguard Worker
8820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu, output_mps)
8821*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu.size(), output_mps.size())
8822*da0073e9SAndroid Build Coastguard Worker
8823*da0073e9SAndroid Build Coastguard Worker    # Printing of non_contiguous should not crash
8824*da0073e9SAndroid Build Coastguard Worker    def test_print_non_contiguous(self):
8825*da0073e9SAndroid Build Coastguard Worker        print(torch.ones(100, 100, device='mps').nonzero())
8826*da0073e9SAndroid Build Coastguard Worker        print(torch.ones(100, 100, device='mps').nonzero().contiguous())
8827*da0073e9SAndroid Build Coastguard Worker
8828*da0073e9SAndroid Build Coastguard Worker    def test_zero_grad(self):
8829*da0073e9SAndroid Build Coastguard Worker        i = torch.randn(2, 5, requires_grad=True)
8830*da0073e9SAndroid Build Coastguard Worker        module = nn.Linear(5, 5)
8831*da0073e9SAndroid Build Coastguard Worker        for p in module.parameters():
8832*da0073e9SAndroid Build Coastguard Worker            p.requires_grad = False
8833*da0073e9SAndroid Build Coastguard Worker        module.zero_grad()
8834*da0073e9SAndroid Build Coastguard Worker
8835*da0073e9SAndroid Build Coastguard Worker        module.weight.requires_grad = True
8836*da0073e9SAndroid Build Coastguard Worker        module.zero_grad()
8837*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.weight.grad)  # uninitialized grad
8838*da0073e9SAndroid Build Coastguard Worker
8839*da0073e9SAndroid Build Coastguard Worker        module(i).sum().backward()
8840*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(module.weight.grad)
8841*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(module.weight.grad.data.abs().sum(), 0)
8842*da0073e9SAndroid Build Coastguard Worker        module.zero_grad()
8843*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.weight.grad)
8844*da0073e9SAndroid Build Coastguard Worker
8845*da0073e9SAndroid Build Coastguard Worker        module.bias.requires_grad = True
8846*da0073e9SAndroid Build Coastguard Worker        module.zero_grad()
8847*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.weight.grad)
8848*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.bias.grad)
8849*da0073e9SAndroid Build Coastguard Worker        module(i).sum().backward()
8850*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(module.weight.grad)
8851*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(module.bias.grad)
8852*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(module.weight.grad.data.abs().sum(), 0)
8853*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(module.bias.grad.data.abs().sum(), 0)
8854*da0073e9SAndroid Build Coastguard Worker
8855*da0073e9SAndroid Build Coastguard Worker        # Force set to zeros.
8856*da0073e9SAndroid Build Coastguard Worker        module.zero_grad(set_to_none=False)
8857*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
8858*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
8859*da0073e9SAndroid Build Coastguard Worker
8860*da0073e9SAndroid Build Coastguard Worker        module.zero_grad()
8861*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.weight.grad)
8862*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(module.bias.grad)
8863*da0073e9SAndroid Build Coastguard Worker
8864*da0073e9SAndroid Build Coastguard Worker
8865*da0073e9SAndroid Build Coastguard Worker    def test_no_grad(self):
8866*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.bfloat16, torch.float, torch.double]:
8867*da0073e9SAndroid Build Coastguard Worker            module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype)
8868*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(1, 2, 10, 10).to(dtype)
8869*da0073e9SAndroid Build Coastguard Worker            x = input
8870*da0073e9SAndroid Build Coastguard Worker            y = input.clone()
8871*da0073e9SAndroid Build Coastguard Worker
8872*da0073e9SAndroid Build Coastguard Worker            output = module(x)
8873*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(output.requires_grad)
8874*da0073e9SAndroid Build Coastguard Worker            output.backward(torch.ones(1, 5, 10, 10))
8875*da0073e9SAndroid Build Coastguard Worker
8876*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
8877*da0073e9SAndroid Build Coastguard Worker                output2 = module(y)
8878*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(output2.requires_grad)
8879*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))
8880*da0073e9SAndroid Build Coastguard Worker
8881*da0073e9SAndroid Build Coastguard Worker    def test_invalid_conv1d(self):
8882*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.bfloat16, torch.float, torch.double]:
8883*da0073e9SAndroid Build Coastguard Worker            module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype)
8884*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(1, 3, 4).to(dtype)
8885*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError,
8886*da0073e9SAndroid Build Coastguard Worker                                        r'Calculated padded input size per channel: \(4\). ' +
8887*da0073e9SAndroid Build Coastguard Worker                                        r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'):
8888*da0073e9SAndroid Build Coastguard Worker                module(input)
8889*da0073e9SAndroid Build Coastguard Worker
8890*da0073e9SAndroid Build Coastguard Worker            # Negative stride check
8891*da0073e9SAndroid Build Coastguard Worker            module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype)
8892*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(1, 3, 4).to(dtype)
8893*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8894*da0073e9SAndroid Build Coastguard Worker                module(input)
8895*da0073e9SAndroid Build Coastguard Worker
8896*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_discontiguous_weight(self):
8897*da0073e9SAndroid Build Coastguard Worker        # Test for https://github.com/pytorch/pytorch/issues/55781
8898*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(64, 16, 16, 16)
8899*da0073e9SAndroid Build Coastguard Worker        weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2)[:, :, :, ::2]
8900*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(weight.is_contiguous())
8901*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.conv2d(x, weight, None)
8902*da0073e9SAndroid Build Coastguard Worker        if torch.backends.mkldnn.is_available():
8903*da0073e9SAndroid Build Coastguard Worker            # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
8904*da0073e9SAndroid Build Coastguard Worker            with torch.backends.mkldnn.flags(enabled=False):
8905*da0073e9SAndroid Build Coastguard Worker                y_ = torch.nn.functional.conv2d(x, weight, None)
8906*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y, y_)
8907*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.sum(), 4186112.)
8908*da0073e9SAndroid Build Coastguard Worker
8909*da0073e9SAndroid Build Coastguard Worker    def test_invalid_conv2d(self):
8910*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.bfloat16, torch.float, torch.double]:
8911*da0073e9SAndroid Build Coastguard Worker            module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype)
8912*da0073e9SAndroid Build Coastguard Worker            input = torch.empty(1, 1, 4, 4).to(dtype)
8913*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: module(input))
8914*da0073e9SAndroid Build Coastguard Worker
8915*da0073e9SAndroid Build Coastguard Worker            module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True)
8916*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(1, 3, 1, 1)
8917*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError,
8918*da0073e9SAndroid Build Coastguard Worker                                        r'Calculated padded input size per channel: \(1 x 1\). ' +
8919*da0073e9SAndroid Build Coastguard Worker                                        r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'):
8920*da0073e9SAndroid Build Coastguard Worker                module(input)
8921*da0073e9SAndroid Build Coastguard Worker
8922*da0073e9SAndroid Build Coastguard Worker            # Negative stride check
8923*da0073e9SAndroid Build Coastguard Worker            module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype)
8924*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(1, 3, 4, 4).to(dtype)
8925*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8926*da0073e9SAndroid Build Coastguard Worker                module(input)
8927*da0073e9SAndroid Build Coastguard Worker
8928*da0073e9SAndroid Build Coastguard Worker            # Zero stride check
8929*da0073e9SAndroid Build Coastguard Worker            module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype)
8930*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(1, 3, 4, 4).to(dtype)
8931*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8932*da0073e9SAndroid Build Coastguard Worker                module(input)
8933*da0073e9SAndroid Build Coastguard Worker
8934*da0073e9SAndroid Build Coastguard Worker            # Input and weights on different devices
8935*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError,
8936*da0073e9SAndroid Build Coastguard Worker                                   'must be on the same device',
8937*da0073e9SAndroid Build Coastguard Worker                                   lambda: torch.conv2d(torch.rand(1, 3, 32, 32), torch.rand(1, 3, 3, 3, device='mps')))
8938*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(RuntimeError,
8939*da0073e9SAndroid Build Coastguard Worker                                   'Input type \\(MPSFloatType\\) and weight type \\(torch\\.FloatTensor\\) should be the same',
8940*da0073e9SAndroid Build Coastguard Worker                                   lambda: torch.conv2d(torch.rand(1, 3, 32, 32, device='mps'), torch.rand(1, 3, 3, 3)))
8941*da0073e9SAndroid Build Coastguard Worker
8942*da0073e9SAndroid Build Coastguard Worker
8943*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_valid_padding(self, device='mps'):
8944*da0073e9SAndroid Build Coastguard Worker        # Test F.conv2d padding='valid' is the same as no padding
8945*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 1, 10, device=device).to(torch.float)
8946*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, 1, 1, 4, device=device).to(torch.float)
8947*da0073e9SAndroid Build Coastguard Worker
8948*da0073e9SAndroid Build Coastguard Worker        expect = F.conv2d(x, y)
8949*da0073e9SAndroid Build Coastguard Worker        actual = F.conv2d(x, y, padding='valid')
8950*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expect.to('cpu'), actual.to('cpu'))
8951*da0073e9SAndroid Build Coastguard Worker
8952*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_backward_collision(self):
8953*da0073e9SAndroid Build Coastguard Worker        # Test for https://github.com/pytorch/pytorch/issues/112998
8954*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 10, 10, device="mps", requires_grad=True)
8955*da0073e9SAndroid Build Coastguard Worker        m1 = nn.Conv2d(1, 1, 3, stride=2, padding=1).to("mps")
8956*da0073e9SAndroid Build Coastguard Worker        m2 = nn.Conv2d(1, 1, 4, stride=2, padding=1).to("mps")
8957*da0073e9SAndroid Build Coastguard Worker        y1, y2 = m1(x), m2(x)
8958*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1.shape, y2.shape)
8959*da0073e9SAndroid Build Coastguard Worker        y1.sum().backward()
8960*da0073e9SAndroid Build Coastguard Worker        # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
8961*da0073e9SAndroid Build Coastguard Worker        y2.sum().backward()
8962*da0073e9SAndroid Build Coastguard Worker
8963*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
8964*da0073e9SAndroid Build Coastguard Worker    def test_conv3d_backward_collision(self):
8965*da0073e9SAndroid Build Coastguard Worker        # Conv3D is only available from MacOS 13.2 onwards
8966*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, 1, 10, 10, 20, device="mps", requires_grad=True)
8967*da0073e9SAndroid Build Coastguard Worker        m1 = nn.Conv3d(1, 1, 3, stride=2, padding=1).to("mps")
8968*da0073e9SAndroid Build Coastguard Worker        m2 = nn.Conv3d(1, 1, 4, stride=2, padding=1).to("mps")
8969*da0073e9SAndroid Build Coastguard Worker        y1, y2 = m1(x), m2(x)
8970*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1.shape, y2.shape)
8971*da0073e9SAndroid Build Coastguard Worker        y1.sum().backward()
8972*da0073e9SAndroid Build Coastguard Worker        # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
8973*da0073e9SAndroid Build Coastguard Worker        y2.sum().backward()
8974*da0073e9SAndroid Build Coastguard Worker
8975*da0073e9SAndroid Build Coastguard Worker    def test_gemm_permute_transpose(self):
8976*da0073e9SAndroid Build Coastguard Worker        batch_size = 32
8977*da0073e9SAndroid Build Coastguard Worker        n = 20
8978*da0073e9SAndroid Build Coastguard Worker        hidden = 768
8979*da0073e9SAndroid Build Coastguard Worker        num_attention_heads = 12
8980*da0073e9SAndroid Build Coastguard Worker        attention_head_size = hidden // num_attention_heads
8981*da0073e9SAndroid Build Coastguard Worker
8982*da0073e9SAndroid Build Coastguard Worker        def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
8983*da0073e9SAndroid Build Coastguard Worker            new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
8984*da0073e9SAndroid Build Coastguard Worker            x = x.view(new_x_shape)
8985*da0073e9SAndroid Build Coastguard Worker            return x.permute(0, 2, 1, 3)
8986*da0073e9SAndroid Build Coastguard Worker
8987*da0073e9SAndroid Build Coastguard Worker        def attention2(key, *, workaround=False, device):
8988*da0073e9SAndroid Build Coastguard Worker            key = transpose_for_scores(key)
8989*da0073e9SAndroid Build Coastguard Worker            res = key.transpose(-1, -2)
8990*da0073e9SAndroid Build Coastguard Worker            return res
8991*da0073e9SAndroid Build Coastguard Worker
8992*da0073e9SAndroid Build Coastguard Worker        A = torch.randn(batch_size, n, hidden)
8993*da0073e9SAndroid Build Coastguard Worker        A_mps = A.detach().clone().to("mps")
8994*da0073e9SAndroid Build Coastguard Worker
8995*da0073e9SAndroid Build Coastguard Worker        r1 = attention2(A, device="cpu")
8996*da0073e9SAndroid Build Coastguard Worker        r2 = attention2(A_mps, device="mps")
8997*da0073e9SAndroid Build Coastguard Worker
8998*da0073e9SAndroid Build Coastguard Worker        r2_cpu = r2.to("cpu")
8999*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1, r2_cpu)
9000*da0073e9SAndroid Build Coastguard Worker
9001*da0073e9SAndroid Build Coastguard Worker    def test_group_norm_backward(self, device='mps'):
9002*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/88331 for more detail
9003*da0073e9SAndroid Build Coastguard Worker        shape = [1, 4, 16, 16]
9004*da0073e9SAndroid Build Coastguard Worker        x = torch.full(shape, 7.0, device=device)
9005*da0073e9SAndroid Build Coastguard Worker
9006*da0073e9SAndroid Build Coastguard Worker        target = torch.ones((1, 3, 128, 128), device=device)
9007*da0073e9SAndroid Build Coastguard Worker
9008*da0073e9SAndroid Build Coastguard Worker        conv_in = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
9009*da0073e9SAndroid Build Coastguard Worker        conv_out = nn.Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
9010*da0073e9SAndroid Build Coastguard Worker        norm = nn.GroupNorm(32, 128, eps=1e-6, affine=True, device=device)
9011*da0073e9SAndroid Build Coastguard Worker
9012*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
9013*da0073e9SAndroid Build Coastguard Worker            x = x.detach().requires_grad_()
9014*da0073e9SAndroid Build Coastguard Worker            out = 5.5 * x
9015*da0073e9SAndroid Build Coastguard Worker            out = conv_in(out)
9016*da0073e9SAndroid Build Coastguard Worker            out = out + norm(out)
9017*da0073e9SAndroid Build Coastguard Worker            out = out + norm(out)
9018*da0073e9SAndroid Build Coastguard Worker            out = out + norm(out)
9019*da0073e9SAndroid Build Coastguard Worker            out = F.interpolate(out, scale_factor=8.0, mode="nearest")
9020*da0073e9SAndroid Build Coastguard Worker            out = norm(out)
9021*da0073e9SAndroid Build Coastguard Worker            out = conv_out(out)
9022*da0073e9SAndroid Build Coastguard Worker
9023*da0073e9SAndroid Build Coastguard Worker            loss = (out - target).norm(dim=-1).sum()
9024*da0073e9SAndroid Build Coastguard Worker            grad = -torch.autograd.grad(loss, x)[0]
9025*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(grad.detach().isnan().any().item(), 'NaN gradients returned by autograd')
9026*da0073e9SAndroid Build Coastguard Worker
9027*da0073e9SAndroid Build Coastguard Worker
9028*da0073e9SAndroid Build Coastguard Worker    # def test_conv2d_same_padding(self, device='mps'):
9029*da0073e9SAndroid Build Coastguard Worker        # x = torch.rand(1, 1, 10, 11, device=device)
9030*da0073e9SAndroid Build Coastguard Worker        # y = torch.rand(1, 1, 4, 5, device=device)
9031*da0073e9SAndroid Build Coastguard Worker        # expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
9032*da0073e9SAndroid Build Coastguard Worker        # actual = F.conv2d(x, y, padding='same')
9033*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(expect.to('cpu'), actual.to('cpu'))
9034*da0073e9SAndroid Build Coastguard Worker
9035*da0073e9SAndroid Build Coastguard Worker        # # With dilation
9036*da0073e9SAndroid Build Coastguard Worker        # y = torch.rand(1, 1, 3, 4, device=device)
9037*da0073e9SAndroid Build Coastguard Worker        # expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
9038*da0073e9SAndroid Build Coastguard Worker        # actual = F.conv2d(x, y, padding='same', dilation=2)
9039*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(expect, actual)
9040*da0073e9SAndroid Build Coastguard Worker
9041*da0073e9SAndroid Build Coastguard Worker        # # Dilation with asymmetric padding
9042*da0073e9SAndroid Build Coastguard Worker        # y = torch.rand(1, 1, 4, 4, device=device)
9043*da0073e9SAndroid Build Coastguard Worker        # expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
9044*da0073e9SAndroid Build Coastguard Worker        # actual = F.conv2d(x, y, padding='same', dilation=3)
9045*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(expect, actual)
9046*da0073e9SAndroid Build Coastguard Worker
9047*da0073e9SAndroid Build Coastguard Worker
9048*da0073e9SAndroid Build Coastguard Workerclass TestPad(TestCaseMPS):
9049*da0073e9SAndroid Build Coastguard Worker    def test_constant_pad(self):
9050*da0073e9SAndroid Build Coastguard Worker        m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5)
9051*da0073e9SAndroid Build Coastguard Worker        input_cpu = torch.randn(1, 16, 16, 16)
9052*da0073e9SAndroid Build Coastguard Worker        input_mps = input_cpu.detach().clone().to("mps")
9053*da0073e9SAndroid Build Coastguard Worker        r_cpu = m(input_cpu)
9054*da0073e9SAndroid Build Coastguard Worker        r_mps = m(input_mps)
9055*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r_cpu, r_mps.to("cpu"))
9056*da0073e9SAndroid Build Coastguard Worker
9057*da0073e9SAndroid Build Coastguard Worker        # Arbitrary input dimensions
9058*da0073e9SAndroid Build Coastguard Worker        pad = (1, 1, 0, 0, 0, 0)
9059*da0073e9SAndroid Build Coastguard Worker        value = 3.5
9060*da0073e9SAndroid Build Coastguard Worker        input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3))
9061*da0073e9SAndroid Build Coastguard Worker        input_mps = input_cpu.detach().clone().to("mps")
9062*da0073e9SAndroid Build Coastguard Worker        r_cpu = F.pad(input_cpu, pad=pad, value=value)
9063*da0073e9SAndroid Build Coastguard Worker        r_mps = F.pad(input_mps, pad=pad, value=value)
9064*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r_cpu, r_mps.to("cpu"))
9065*da0073e9SAndroid Build Coastguard Worker
9066*da0073e9SAndroid Build Coastguard Worker    def test_circular_pad(self):
9067*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/80856
9068*da0073e9SAndroid Build Coastguard Worker        k_cpu = torch.ones(3, 3, 9, 9)
9069*da0073e9SAndroid Build Coastguard Worker        k_mps = k_cpu.detach().clone().to("mps")
9070*da0073e9SAndroid Build Coastguard Worker
9071*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.rand(1, 3, 32, 32)
9072*da0073e9SAndroid Build Coastguard Worker        x_mps = x_cpu.detach().clone().to("mps")
9073*da0073e9SAndroid Build Coastguard Worker
9074*da0073e9SAndroid Build Coastguard Worker        x_pad_cpu = F.pad(x_cpu, (2, 2, 2, 2), mode='circular')
9075*da0073e9SAndroid Build Coastguard Worker        x_pad_mps = F.pad(x_mps, (2, 2, 2, 2), mode='circular')
9076*da0073e9SAndroid Build Coastguard Worker
9077*da0073e9SAndroid Build Coastguard Worker        y_cpu = F.conv2d(x_pad_cpu, k_cpu)
9078*da0073e9SAndroid Build Coastguard Worker        y_mps = F.conv2d(x_pad_mps, k_mps)
9079*da0073e9SAndroid Build Coastguard Worker
9080*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_cpu, y_mps.cpu())
9081*da0073e9SAndroid Build Coastguard Worker
9082*da0073e9SAndroid Build Coastguard Worker    def test_constant_pad_4d_warning(self):
9083*da0073e9SAndroid Build Coastguard Worker        inputCPU = torch.rand((1, 2, 2, 2, 1, 1))
9084*da0073e9SAndroid Build Coastguard Worker        inputMPS = inputCPU.detach().clone().to('mps')
9085*da0073e9SAndroid Build Coastguard Worker        outputCPU = F.pad(inputCPU, [0, 0, 0, 0, 0, 0, 1, 0])
9086*da0073e9SAndroid Build Coastguard Worker        outputMPS = F.pad(inputMPS, [0, 0, 0, 0, 0, 0, 1, 0])
9087*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outputCPU, outputMPS)
9088*da0073e9SAndroid Build Coastguard Worker
9089*da0073e9SAndroid Build Coastguard Worker    def test_pad(self):
9090*da0073e9SAndroid Build Coastguard Worker        def helper(shape, padding, op, value=0):
9091*da0073e9SAndroid Build Coastguard Worker            inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
9092*da0073e9SAndroid Build Coastguard Worker            inputCPU.retain_grad()
9093*da0073e9SAndroid Build Coastguard Worker            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
9094*da0073e9SAndroid Build Coastguard Worker
9095*da0073e9SAndroid Build Coastguard Worker            if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]):
9096*da0073e9SAndroid Build Coastguard Worker                padCriteria = op(padding, value)
9097*da0073e9SAndroid Build Coastguard Worker            else:
9098*da0073e9SAndroid Build Coastguard Worker                padCriteria = op(padding)
9099*da0073e9SAndroid Build Coastguard Worker            outputCPU = padCriteria(inputCPU)
9100*da0073e9SAndroid Build Coastguard Worker            outputMPS = padCriteria(inputMPS)
9101*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(outputCPU, outputMPS)
9102*da0073e9SAndroid Build Coastguard Worker
9103*da0073e9SAndroid Build Coastguard Worker            # backward pass (chose 0.6 just to have the grad_output != 1)
9104*da0073e9SAndroid Build Coastguard Worker            outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
9105*da0073e9SAndroid Build Coastguard Worker            outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
9106*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(inputCPU.grad, inputMPS.grad)
9107*da0073e9SAndroid Build Coastguard Worker
9108*da0073e9SAndroid Build Coastguard Worker        # 1D Padding
9109*da0073e9SAndroid Build Coastguard Worker        helper((2, 4, 3), 2, nn.ReflectionPad1d)
9110*da0073e9SAndroid Build Coastguard Worker        # verify if a change in shape of input would cause problems with graph caching
9111*da0073e9SAndroid Build Coastguard Worker        helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
9112*da0073e9SAndroid Build Coastguard Worker        # Replication 1D
9113*da0073e9SAndroid Build Coastguard Worker        helper((2, 1, 6), 3, nn.ReplicationPad1d)
9114*da0073e9SAndroid Build Coastguard Worker        # Constant Pad 1D
9115*da0073e9SAndroid Build Coastguard Worker        helper((2, 3, 4), 2, nn.ConstantPad1d)
9116*da0073e9SAndroid Build Coastguard Worker        # Constant Pad 1D with single dimension input
9117*da0073e9SAndroid Build Coastguard Worker        helper((16), (1, 2), nn.ConstantPad1d)
9118*da0073e9SAndroid Build Coastguard Worker
9119*da0073e9SAndroid Build Coastguard Worker        # 2D Padding
9120*da0073e9SAndroid Build Coastguard Worker        helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
9121*da0073e9SAndroid Build Coastguard Worker        # verify if a change in shape of input would cause problems with graph caching
9122*da0073e9SAndroid Build Coastguard Worker        helper((2, 4, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
9123*da0073e9SAndroid Build Coastguard Worker        # this should make the padding (2, 2, 2, 2)
9124*da0073e9SAndroid Build Coastguard Worker        helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
9125*da0073e9SAndroid Build Coastguard Worker        # verify if a change in shape of padding would cause problems with graph caching
9126*da0073e9SAndroid Build Coastguard Worker        helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
9127*da0073e9SAndroid Build Coastguard Worker        # Constant Pad 2D
9128*da0073e9SAndroid Build Coastguard Worker        helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
9129*da0073e9SAndroid Build Coastguard Worker        # input size < pad size
9130*da0073e9SAndroid Build Coastguard Worker        helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d)
9131*da0073e9SAndroid Build Coastguard Worker        # pad dims < input dims
9132*da0073e9SAndroid Build Coastguard Worker        helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d)
9133*da0073e9SAndroid Build Coastguard Worker        # pad dims == input dims
9134*da0073e9SAndroid Build Coastguard Worker        helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
9135*da0073e9SAndroid Build Coastguard Worker        # input.numel() == 0 but output.numel() > 0
9136*da0073e9SAndroid Build Coastguard Worker        helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d)
9137*da0073e9SAndroid Build Coastguard Worker        # pad dims < input dims - 2
9138*da0073e9SAndroid Build Coastguard Worker        helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d)
9139*da0073e9SAndroid Build Coastguard Worker
9140*da0073e9SAndroid Build Coastguard Worker        # 3D Padding
9141*da0073e9SAndroid Build Coastguard Worker        helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
9142*da0073e9SAndroid Build Coastguard Worker        # verify if a change in shape of padding would cause problems with graph caching
9143*da0073e9SAndroid Build Coastguard Worker        helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
9144*da0073e9SAndroid Build Coastguard Worker        # case where input_d == pad_front/back for ReplicationPad3d
9145*da0073e9SAndroid Build Coastguard Worker        helper((3, 4, 5, 6, 7), (1, 2, 3, 4, 5, 6), nn.ReplicationPad3d)
9146*da0073e9SAndroid Build Coastguard Worker        # Constant Pad 3D
9147*da0073e9SAndroid Build Coastguard Worker        helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
9148*da0073e9SAndroid Build Coastguard Worker        # input size < pad size
9149*da0073e9SAndroid Build Coastguard Worker        helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
9150*da0073e9SAndroid Build Coastguard Worker        # check the workaround for the right padding bug in Monterey
9151*da0073e9SAndroid Build Coastguard Worker        helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d)
9152*da0073e9SAndroid Build Coastguard Worker
9153*da0073e9SAndroid Build Coastguard Worker    def test_constant_pad_nd_preserves_memory_format(self):
9154*da0073e9SAndroid Build Coastguard Worker        nchw_tensor = torch.rand((1, 2, 5, 3))
9155*da0073e9SAndroid Build Coastguard Worker        nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
9156*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format))
9157*da0073e9SAndroid Build Coastguard Worker
9158*da0073e9SAndroid Build Coastguard Worker        nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last)
9159*da0073e9SAndroid Build Coastguard Worker        nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
9160*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
9161*da0073e9SAndroid Build Coastguard Worker
9162*da0073e9SAndroid Build Coastguard Worker
9163*da0073e9SAndroid Build Coastguard Workerclass TestLinalgMPS(TestCaseMPS):
9164*da0073e9SAndroid Build Coastguard Worker    def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False):
9165*da0073e9SAndroid Build Coastguard Worker        dtype = t.dtype
9166*da0073e9SAndroid Build Coastguard Worker        numpy_dtype = dtype
9167*da0073e9SAndroid Build Coastguard Worker        alpha = 1.2 if alpha is None else alpha
9168*da0073e9SAndroid Build Coastguard Worker        beta = 0.8 if beta is None else beta
9169*da0073e9SAndroid Build Coastguard Worker        res1 = f(t, m, v, alpha=alpha, beta=beta)
9170*da0073e9SAndroid Build Coastguard Worker        res2 = torch.full_like(res1, math.nan)
9171*da0073e9SAndroid Build Coastguard Worker        if transpose_out:
9172*da0073e9SAndroid Build Coastguard Worker            res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
9173*da0073e9SAndroid Build Coastguard Worker        f(t, m, v, alpha=alpha, beta=beta, out=res2)
9174*da0073e9SAndroid Build Coastguard Worker        res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
9175*da0073e9SAndroid Build Coastguard Worker        if beta != 0:
9176*da0073e9SAndroid Build Coastguard Worker            res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
9177*da0073e9SAndroid Build Coastguard Worker        res3 = torch.from_numpy(res3).to(dtype)
9178*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
9179*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res3)
9180*da0073e9SAndroid Build Coastguard Worker
9181*da0073e9SAndroid Build Coastguard Worker    def test_addmm(self, device="mps", dtype=torch.float32):
9182*da0073e9SAndroid Build Coastguard Worker        M = torch.randn(10, 25, device=device).to(dtype)
9183*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, 50, device=device).to(dtype)
9184*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(50, 25, device=device).to(dtype)
9185*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_addmv(torch.addmm, M, m1, m2)
9186*da0073e9SAndroid Build Coastguard Worker
9187*da0073e9SAndroid Build Coastguard Worker        # Test beta=0, M=nan
9188*da0073e9SAndroid Build Coastguard Worker        M = torch.full((10, 25), math.nan, device=device).to(dtype)
9189*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, 50, device=device).to(dtype)
9190*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(50, 25, device=device).to(dtype)
9191*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0)
9192*da0073e9SAndroid Build Coastguard Worker
9193*da0073e9SAndroid Build Coastguard Worker        # Test transpose
9194*da0073e9SAndroid Build Coastguard Worker        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
9195*da0073e9SAndroid Build Coastguard Worker            def maybe_transpose(cond, m):
9196*da0073e9SAndroid Build Coastguard Worker                if not cond:
9197*da0073e9SAndroid Build Coastguard Worker                    return m
9198*da0073e9SAndroid Build Coastguard Worker                return m.t().clone(memory_format=torch.contiguous_format).t()
9199*da0073e9SAndroid Build Coastguard Worker
9200*da0073e9SAndroid Build Coastguard Worker        M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
9201*da0073e9SAndroid Build Coastguard Worker        m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
9202*da0073e9SAndroid Build Coastguard Worker        m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
9203*da0073e9SAndroid Build Coastguard Worker        self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
9204*da0073e9SAndroid Build Coastguard Worker
9205*da0073e9SAndroid Build Coastguard Worker    def _test_addr(self, f, t, m, v, alpha=None, beta=None):
9206*da0073e9SAndroid Build Coastguard Worker        dtype = t.dtype
9207*da0073e9SAndroid Build Coastguard Worker        numpy_dtype = dtype
9208*da0073e9SAndroid Build Coastguard Worker        alpha = 1.2 if alpha is None else alpha
9209*da0073e9SAndroid Build Coastguard Worker        beta = 0.8 if beta is None else beta
9210*da0073e9SAndroid Build Coastguard Worker        res1 = f(t, m, v, alpha=alpha, beta=beta)
9211*da0073e9SAndroid Build Coastguard Worker        res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy())
9212*da0073e9SAndroid Build Coastguard Worker        if beta != 0:
9213*da0073e9SAndroid Build Coastguard Worker            res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
9214*da0073e9SAndroid Build Coastguard Worker        res2 = torch.from_numpy(res2).to(dtype)
9215*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
9216*da0073e9SAndroid Build Coastguard Worker
9217*da0073e9SAndroid Build Coastguard Worker    def test_addr(self, device="mps", dtype=torch.float32):
9218*da0073e9SAndroid Build Coastguard Worker        M = torch.randn(10, 25, device=device).to(dtype)
9219*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, device=device).to(dtype)
9220*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(25, device=device).to(dtype)
9221*da0073e9SAndroid Build Coastguard Worker        self._test_addr(torch.addr, M, m1, m2)
9222*da0073e9SAndroid Build Coastguard Worker
9223*da0073e9SAndroid Build Coastguard Worker        # Test beta=0, M=nan
9224*da0073e9SAndroid Build Coastguard Worker        M = torch.full((10, 25), math.nan, device=device).to(dtype)
9225*da0073e9SAndroid Build Coastguard Worker        m1 = torch.randn(10, device=device).to(dtype)
9226*da0073e9SAndroid Build Coastguard Worker        m2 = torch.randn(25, device=device).to(dtype)
9227*da0073e9SAndroid Build Coastguard Worker        self._test_addr(torch.addr, M, m1, m2, beta=0)
9228*da0073e9SAndroid Build Coastguard Worker
9229*da0073e9SAndroid Build Coastguard Worker    def test_matrix_rank(self, device="mps", dtype=torch.float32):
9230*da0073e9SAndroid Build Coastguard Worker        matrix_rank = torch.linalg.matrix_rank
9231*da0073e9SAndroid Build Coastguard Worker
9232*da0073e9SAndroid Build Coastguard Worker        def run_test(shape0, shape1, batch):
9233*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
9234*da0073e9SAndroid Build Coastguard Worker            rank_a = matrix_rank(a)
9235*da0073e9SAndroid Build Coastguard Worker
9236*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_a, matrix_rank(a.mH))
9237*da0073e9SAndroid Build Coastguard Worker            aaH = torch.matmul(a, a.mH)
9238*da0073e9SAndroid Build Coastguard Worker            rank_aaH = matrix_rank(aaH)
9239*da0073e9SAndroid Build Coastguard Worker            rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
9240*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_aaH, rank_aaH_hermitian)
9241*da0073e9SAndroid Build Coastguard Worker            aHa = torch.matmul(a.mH, a)
9242*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
9243*da0073e9SAndroid Build Coastguard Worker
9244*da0073e9SAndroid Build Coastguard Worker            # check against NumPy
9245*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
9246*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
9247*da0073e9SAndroid Build Coastguard Worker
9248*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
9249*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
9250*da0073e9SAndroid Build Coastguard Worker
9251*da0073e9SAndroid Build Coastguard Worker            # hermitian flag for NumPy was added in 1.14.0
9252*da0073e9SAndroid Build Coastguard Worker            if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
9253*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(rank_aaH_hermitian,
9254*da0073e9SAndroid Build Coastguard Worker                                 np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
9255*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(matrix_rank(aaH, 0.01, True),
9256*da0073e9SAndroid Build Coastguard Worker                                 np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
9257*da0073e9SAndroid Build Coastguard Worker
9258*da0073e9SAndroid Build Coastguard Worker            # check out= variant
9259*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
9260*da0073e9SAndroid Build Coastguard Worker            ans = matrix_rank(a, out=out)
9261*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, out)
9262*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, rank_a)
9263*da0073e9SAndroid Build Coastguard Worker
9264*da0073e9SAndroid Build Coastguard Worker        shapes = (3, 13)
9265*da0073e9SAndroid Build Coastguard Worker        batches = ((), (0, ), (4, ), (3, 5, ))
9266*da0073e9SAndroid Build Coastguard Worker        for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
9267*da0073e9SAndroid Build Coastguard Worker            # escape only when NotImplementedError of downstream function is raised
9268*da0073e9SAndroid Build Coastguard Worker            # TODO: remove this once the required function is implemented
9269*da0073e9SAndroid Build Coastguard Worker            try:
9270*da0073e9SAndroid Build Coastguard Worker                run_test(shape0, shape1, batch)
9271*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError as e:
9272*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
9273*da0073e9SAndroid Build Coastguard Worker                        NotImplementedError,
9274*da0073e9SAndroid Build Coastguard Worker                        "The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."):
9275*da0073e9SAndroid Build Coastguard Worker                    raise e
9276*da0073e9SAndroid Build Coastguard Worker
9277*da0073e9SAndroid Build Coastguard Worker    def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4):
9278*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
9279*da0073e9SAndroid Build Coastguard Worker
9280*da0073e9SAndroid Build Coastguard Worker        def run_test_main(A, hermitian):
9281*da0073e9SAndroid Build Coastguard Worker            # Testing against definition for pseudo-inverses
9282*da0073e9SAndroid Build Coastguard Worker            A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
9283*da0073e9SAndroid Build Coastguard Worker            np_A = A.cpu().numpy()
9284*da0073e9SAndroid Build Coastguard Worker            np_A_pinv = A_pinv.cpu().numpy()
9285*da0073e9SAndroid Build Coastguard Worker            if A.numel() > 0:
9286*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision)
9287*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision)
9288*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
9289*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
9290*da0073e9SAndroid Build Coastguard Worker            else:
9291*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
9292*da0073e9SAndroid Build Coastguard Worker
9293*da0073e9SAndroid Build Coastguard Worker            # Check out= variant
9294*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(A_pinv)
9295*da0073e9SAndroid Build Coastguard Worker            ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
9296*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, out)
9297*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ans, A_pinv)
9298*da0073e9SAndroid Build Coastguard Worker
9299*da0073e9SAndroid Build Coastguard Worker        def run_test_numpy(A, hermitian):
9300*da0073e9SAndroid Build Coastguard Worker            # Check against NumPy output
9301*da0073e9SAndroid Build Coastguard Worker            # Test float rcond, and specific value for each matrix
9302*da0073e9SAndroid Build Coastguard Worker            rconds = [float(torch.rand(1)), ]
9303*da0073e9SAndroid Build Coastguard Worker            # Test different types of rcond tensor
9304*da0073e9SAndroid Build Coastguard Worker            for rcond_type in MPS_DTYPES:
9305*da0073e9SAndroid Build Coastguard Worker                rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type))
9306*da0073e9SAndroid Build Coastguard Worker            # Test broadcasting of rcond
9307*da0073e9SAndroid Build Coastguard Worker            if A.ndim > 2:
9308*da0073e9SAndroid Build Coastguard Worker                rconds.append(torch.rand(A.shape[-3], device=device))
9309*da0073e9SAndroid Build Coastguard Worker            for rcond in rconds:
9310*da0073e9SAndroid Build Coastguard Worker                actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
9311*da0073e9SAndroid Build Coastguard Worker                torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
9312*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision)
9313*da0073e9SAndroid Build Coastguard Worker                numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
9314*da0073e9SAndroid Build Coastguard Worker                expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
9315*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected, atol=precision, rtol=precision)
9316*da0073e9SAndroid Build Coastguard Worker
9317*da0073e9SAndroid Build Coastguard Worker        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
9318*da0073e9SAndroid Build Coastguard Worker                      (3, 2), (5, 3, 2), (2, 5, 3, 2),  # fat matrices
9319*da0073e9SAndroid Build Coastguard Worker                      (2, 3), (5, 2, 3), (2, 5, 2, 3),  # thin matrices
9320*da0073e9SAndroid Build Coastguard Worker                      (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]:  # zero numel matrices
9321*da0073e9SAndroid Build Coastguard Worker            A = torch.randn(*sizes, dtype=dtype, device=device)
9322*da0073e9SAndroid Build Coastguard Worker            hermitian = False
9323*da0073e9SAndroid Build Coastguard Worker            run_test_main(A, hermitian)
9324*da0073e9SAndroid Build Coastguard Worker            run_test_numpy(A, hermitian)
9325*da0073e9SAndroid Build Coastguard Worker
9326*da0073e9SAndroid Build Coastguard Worker        # Check hermitian = True
9327*da0073e9SAndroid Build Coastguard Worker        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
9328*da0073e9SAndroid Build Coastguard Worker                      (0, 0), (3, 0, 0), ]:  # zero numel square matrices
9329*da0073e9SAndroid Build Coastguard Worker            A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
9330*da0073e9SAndroid Build Coastguard Worker            hermitian = True
9331*da0073e9SAndroid Build Coastguard Worker            # escape only when NotImplementedError of downstream function is raised
9332*da0073e9SAndroid Build Coastguard Worker            # TODO: remove this once the required function is implemented
9333*da0073e9SAndroid Build Coastguard Worker            try:
9334*da0073e9SAndroid Build Coastguard Worker                run_test_main(A, hermitian)
9335*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError as e:
9336*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
9337*da0073e9SAndroid Build Coastguard Worker                        NotImplementedError,
9338*da0073e9SAndroid Build Coastguard Worker                        "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
9339*da0073e9SAndroid Build Coastguard Worker                    raise e
9340*da0073e9SAndroid Build Coastguard Worker            try:
9341*da0073e9SAndroid Build Coastguard Worker                run_test_numpy(A, hermitian)
9342*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError as e:
9343*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
9344*da0073e9SAndroid Build Coastguard Worker                        NotImplementedError,
9345*da0073e9SAndroid Build Coastguard Worker                        "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
9346*da0073e9SAndroid Build Coastguard Worker                    raise e
9347*da0073e9SAndroid Build Coastguard Worker
9348*da0073e9SAndroid Build Coastguard Worker    @parametrize("m", [1, 32, 64])
9349*da0073e9SAndroid Build Coastguard Worker    @parametrize("n", [48, 64])
9350*da0073e9SAndroid Build Coastguard Worker    @parametrize("q_group", [32, 64, 128, 256])
9351*da0073e9SAndroid Build Coastguard Worker    @parametrize("num_groups", [1, 2])
9352*da0073e9SAndroid Build Coastguard Worker    def test__int4_mm(self, m, n, q_group, num_groups):
9353*da0073e9SAndroid Build Coastguard Worker        k = q_group * num_groups
9354*da0073e9SAndroid Build Coastguard Worker        inner_k_tiles = 2
9355*da0073e9SAndroid Build Coastguard Worker
9356*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
9357*da0073e9SAndroid Build Coastguard Worker        a_f32 = torch.rand((m, k), device="mps")
9358*da0073e9SAndroid Build Coastguard Worker        b_f32 = torch.rand((k, n), device="mps")
9359*da0073e9SAndroid Build Coastguard Worker
9360*da0073e9SAndroid Build Coastguard Worker        def convert_weight_to_int4pack(b):
9361*da0073e9SAndroid Build Coastguard Worker            b_int32, b_scales_and_zeros = _group_quantize_tensor(
9362*da0073e9SAndroid Build Coastguard Worker                b.to("cpu"), n_bit=4, q_group_size=q_group
9363*da0073e9SAndroid Build Coastguard Worker            )
9364*da0073e9SAndroid Build Coastguard Worker            b_int32 = b_int32.to("mps")
9365*da0073e9SAndroid Build Coastguard Worker            b_scales_and_zeros = b_scales_and_zeros.to("mps")
9366*da0073e9SAndroid Build Coastguard Worker            b_int4pack = torch._convert_weight_to_int4pack(
9367*da0073e9SAndroid Build Coastguard Worker                b_int32, inner_k_tiles
9368*da0073e9SAndroid Build Coastguard Worker            )
9369*da0073e9SAndroid Build Coastguard Worker
9370*da0073e9SAndroid Build Coastguard Worker            return b_int4pack, b_scales_and_zeros
9371*da0073e9SAndroid Build Coastguard Worker
9372*da0073e9SAndroid Build Coastguard Worker        def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
9373*da0073e9SAndroid Build Coastguard Worker            return torch._weight_int4pack_mm(
9374*da0073e9SAndroid Build Coastguard Worker                a, b_int4pack, q_group, b_scales_and_zeros
9375*da0073e9SAndroid Build Coastguard Worker            )
9376*da0073e9SAndroid Build Coastguard Worker
9377*da0073e9SAndroid Build Coastguard Worker        b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b_f32)
9378*da0073e9SAndroid Build Coastguard Worker
9379*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
9380*da0073e9SAndroid Build Coastguard Worker            a = a_f32.to(dtype=dtype)
9381*da0073e9SAndroid Build Coastguard Worker            b = b_f32.to(dtype=dtype)
9382*da0073e9SAndroid Build Coastguard Worker            b_scales_and_zeros = b_scales_and_zeros_f32.to(dtype=dtype)
9383*da0073e9SAndroid Build Coastguard Worker            ref = torch.mm(a, b)
9384*da0073e9SAndroid Build Coastguard Worker            res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
9385*da0073e9SAndroid Build Coastguard Worker
9386*da0073e9SAndroid Build Coastguard Worker            mean_err = ((res - ref).abs() / ref).mean()
9387*da0073e9SAndroid Build Coastguard Worker            self.assertLess(mean_err, 0.05)
9388*da0073e9SAndroid Build Coastguard Worker
9389*da0073e9SAndroid Build Coastguard Worker    @parametrize("m", [1, 32, 64])
9390*da0073e9SAndroid Build Coastguard Worker    @parametrize("k", [32, 64])
9391*da0073e9SAndroid Build Coastguard Worker    @parametrize("n", [32, 64])
9392*da0073e9SAndroid Build Coastguard Worker    def test__int8_mm(self, m, k, n):
9393*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1)
9394*da0073e9SAndroid Build Coastguard Worker        a_f32 = torch.rand((m, k), device="mps")
9395*da0073e9SAndroid Build Coastguard Worker        b_f32 = torch.rand((n, k), device="mps")
9396*da0073e9SAndroid Build Coastguard Worker
9397*da0073e9SAndroid Build Coastguard Worker        def convert_weight_to_int8pack(b):
9398*da0073e9SAndroid Build Coastguard Worker            b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
9399*da0073e9SAndroid Build Coastguard Worker                b, -128, 127, torch.int8
9400*da0073e9SAndroid Build Coastguard Worker            )
9401*da0073e9SAndroid Build Coastguard Worker            return b_int8pack, b_scales
9402*da0073e9SAndroid Build Coastguard Worker
9403*da0073e9SAndroid Build Coastguard Worker        def weight_int8pack_mm(a, b_int8pack, b_scales):
9404*da0073e9SAndroid Build Coastguard Worker            return torch._weight_int8pack_mm(a, b_int8pack, b_scales)
9405*da0073e9SAndroid Build Coastguard Worker
9406*da0073e9SAndroid Build Coastguard Worker        b_int8pack, b_scales_f32 = convert_weight_to_int8pack(b_f32)
9407*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
9408*da0073e9SAndroid Build Coastguard Worker            a = a_f32.to(dtype=dtype)
9409*da0073e9SAndroid Build Coastguard Worker            b = b_f32.to(dtype=dtype)
9410*da0073e9SAndroid Build Coastguard Worker            b_scales = b_scales_f32.to(dtype=dtype)
9411*da0073e9SAndroid Build Coastguard Worker            res = weight_int8pack_mm(a, b_int8pack, b_scales)
9412*da0073e9SAndroid Build Coastguard Worker            ref = torch.mm(a, b.transpose(0, 1))
9413*da0073e9SAndroid Build Coastguard Worker
9414*da0073e9SAndroid Build Coastguard Worker            mean_err = ((res - ref).abs() / ref).mean()
9415*da0073e9SAndroid Build Coastguard Worker            self.assertLess(mean_err, 0.05)
9416*da0073e9SAndroid Build Coastguard Worker
9417*da0073e9SAndroid Build Coastguard Worker
9418*da0073e9SAndroid Build Coastguard Workerclass TestSDPA(TestCaseMPS):
9419*da0073e9SAndroid Build Coastguard Worker    def _compare_tensors(self, y, ref):
9420*da0073e9SAndroid Build Coastguard Worker        denom = torch.maximum(ref.abs(), torch.tensor([1e-6], device=ref.device, dtype=ref.dtype))
9421*da0073e9SAndroid Build Coastguard Worker        err = ((y - ref).abs() / denom).mean().item()
9422*da0073e9SAndroid Build Coastguard Worker        self.assertLess(err, 0.01)
9423*da0073e9SAndroid Build Coastguard Worker
9424*da0073e9SAndroid Build Coastguard Worker    def _test_sdpa_no_mask(
9425*da0073e9SAndroid Build Coastguard Worker        self,
9426*da0073e9SAndroid Build Coastguard Worker        is_causal: bool,
9427*da0073e9SAndroid Build Coastguard Worker        dtype: torch.dtype,
9428*da0073e9SAndroid Build Coastguard Worker        L: int = 1,
9429*da0073e9SAndroid Build Coastguard Worker        S: int = 72,
9430*da0073e9SAndroid Build Coastguard Worker        NH: int = 32,
9431*da0073e9SAndroid Build Coastguard Worker        HS: int = 128,
9432*da0073e9SAndroid Build Coastguard Worker        requires_grad: bool = False
9433*da0073e9SAndroid Build Coastguard Worker    ):
9434*da0073e9SAndroid Build Coastguard Worker
9435*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1729)
9436*da0073e9SAndroid Build Coastguard Worker        with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
9437*da0073e9SAndroid Build Coastguard Worker            q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad)
9438*da0073e9SAndroid Build Coastguard Worker            k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9439*da0073e9SAndroid Build Coastguard Worker            v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9440*da0073e9SAndroid Build Coastguard Worker            q_cpu = q.cpu().detach().cpu().requires_grad_(requires_grad)
9441*da0073e9SAndroid Build Coastguard Worker            k_cpu = k.cpu()
9442*da0073e9SAndroid Build Coastguard Worker            v_cpu = v.cpu()
9443*da0073e9SAndroid Build Coastguard Worker
9444*da0073e9SAndroid Build Coastguard Worker            y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal)
9445*da0073e9SAndroid Build Coastguard Worker            y_ref = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, dropout_p=0.0, is_causal=is_causal)
9446*da0073e9SAndroid Build Coastguard Worker
9447*da0073e9SAndroid Build Coastguard Worker            self._compare_tensors(y.cpu(), y_ref)
9448*da0073e9SAndroid Build Coastguard Worker
9449*da0073e9SAndroid Build Coastguard Worker            if requires_grad and torch.is_grad_enabled():
9450*da0073e9SAndroid Build Coastguard Worker                y.sum().backward()
9451*da0073e9SAndroid Build Coastguard Worker                y_ref.sum().backward()
9452*da0073e9SAndroid Build Coastguard Worker
9453*da0073e9SAndroid Build Coastguard Worker                self._compare_tensors(q.grad.cpu(), q_cpu.grad)
9454*da0073e9SAndroid Build Coastguard Worker
9455*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_no_mask_no_causal_fp32(self):
9456*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_no_mask(False, torch.float32)
9457*da0073e9SAndroid Build Coastguard Worker
9458*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_no_mask_no_causal_fp16(self):
9459*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_no_mask(False, torch.float16)
9460*da0073e9SAndroid Build Coastguard Worker
9461*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_no_mask_causal_fp32(self):
9462*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_no_mask(True, torch.float32)
9463*da0073e9SAndroid Build Coastguard Worker
9464*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_no_mask_causal_fp16(self):
9465*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_no_mask(True, torch.float16)
9466*da0073e9SAndroid Build Coastguard Worker
9467*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_no_mask_causal_fp16_L7(self):
9468*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_no_mask(True, torch.float16, 7)
9469*da0073e9SAndroid Build Coastguard Worker
9470*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_no_mask_causal_fp16_L7_S17(self):
9471*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_no_mask(True, torch.float16, 7, 17)
9472*da0073e9SAndroid Build Coastguard Worker
9473*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_no_mask_causal_fp16_L7_S17_NH23_HS121(self):
9474*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121)
9475*da0073e9SAndroid Build Coastguard Worker
9476*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_no_mask_no_causal_fp32_grad(self):
9477*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
9478*da0073e9SAndroid Build Coastguard Worker
9479*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
9480*da0073e9SAndroid Build Coastguard Worker            self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
9481*da0073e9SAndroid Build Coastguard Worker
9482*da0073e9SAndroid Build Coastguard Worker    def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128):
9483*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1729)
9484*da0073e9SAndroid Build Coastguard Worker        causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps'))
9485*da0073e9SAndroid Build Coastguard Worker        with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
9486*da0073e9SAndroid Build Coastguard Worker            i = 42
9487*da0073e9SAndroid Build Coastguard Worker
9488*da0073e9SAndroid Build Coastguard Worker            q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps")
9489*da0073e9SAndroid Build Coastguard Worker            k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9490*da0073e9SAndroid Build Coastguard Worker            v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9491*da0073e9SAndroid Build Coastguard Worker
9492*da0073e9SAndroid Build Coastguard Worker            input_pos = torch.tensor([i], dtype=torch.int32, device='mps')
9493*da0073e9SAndroid Build Coastguard Worker            mask = causal_mask[None, None, input_pos]
9494*da0073e9SAndroid Build Coastguard Worker
9495*da0073e9SAndroid Build Coastguard Worker            y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
9496*da0073e9SAndroid Build Coastguard Worker            y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False)
9497*da0073e9SAndroid Build Coastguard Worker
9498*da0073e9SAndroid Build Coastguard Worker            self._compare_tensors(y.cpu(), y_ref)
9499*da0073e9SAndroid Build Coastguard Worker
9500*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_mask_fp32(self):
9501*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_mask(torch.float32)
9502*da0073e9SAndroid Build Coastguard Worker
9503*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_mask_fp16(self):
9504*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_mask(torch.float16)
9505*da0073e9SAndroid Build Coastguard Worker
9506*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_mask_fp16_L6(self):
9507*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_mask(torch.float16, 6)
9508*da0073e9SAndroid Build Coastguard Worker
9509*da0073e9SAndroid Build Coastguard Worker    def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
9510*da0073e9SAndroid Build Coastguard Worker        self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
9511*da0073e9SAndroid Build Coastguard Worker
9512*da0073e9SAndroid Build Coastguard Worker
9513*da0073e9SAndroid Build Coastguard Workerclass TestGatherScatter(TestCaseMPS):
9514*da0073e9SAndroid Build Coastguard Worker    def test_slicing_with_step(self):
9515*da0073e9SAndroid Build Coastguard Worker        # Slicing with step
9516*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/78886
9517*da0073e9SAndroid Build Coastguard Worker        x_mps = torch.zeros(10, dtype=torch.float32, device="mps")
9518*da0073e9SAndroid Build Coastguard Worker        x_mps[::2] = 1.0
9519*da0073e9SAndroid Build Coastguard Worker
9520*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.zeros(10, dtype=torch.float32, device="cpu")
9521*da0073e9SAndroid Build Coastguard Worker        x_cpu[::2] = 1.0
9522*da0073e9SAndroid Build Coastguard Worker
9523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_cpu, x_mps)
9524*da0073e9SAndroid Build Coastguard Worker
9525*da0073e9SAndroid Build Coastguard Worker    def test_cast_gather_scatter(self):
9526*da0073e9SAndroid Build Coastguard Worker        for _ in range(0, 50):
9527*da0073e9SAndroid Build Coastguard Worker            input = np.random.randint(0, 255, size=(5, 5, 4), dtype=np.uint8)
9528*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
9529*da0073e9SAndroid Build Coastguard Worker                s = torch.tensor(input, dtype=torch.uint8, device="mps").unsqueeze(0)
9530*da0073e9SAndroid Build Coastguard Worker                s_cpu = torch.tensor(input, dtype=torch.uint8, device="cpu").unsqueeze(0)
9531*da0073e9SAndroid Build Coastguard Worker                s = s.long()
9532*da0073e9SAndroid Build Coastguard Worker                s_cpu = s_cpu.long()
9533*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.cpu(), s_cpu)
9534*da0073e9SAndroid Build Coastguard Worker
9535*da0073e9SAndroid Build Coastguard Worker                s = s.float()
9536*da0073e9SAndroid Build Coastguard Worker                s_cpu = s_cpu.float()
9537*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.cpu(), s_cpu)
9538*da0073e9SAndroid Build Coastguard Worker
9539*da0073e9SAndroid Build Coastguard Worker                s /= 255
9540*da0073e9SAndroid Build Coastguard Worker                s_cpu /= 255
9541*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.cpu(), s_cpu)
9542*da0073e9SAndroid Build Coastguard Worker
9543*da0073e9SAndroid Build Coastguard Worker    def test_slicing_replace_column(self):
9544*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/78074
9545*da0073e9SAndroid Build Coastguard Worker        def _helper(tensor_data):
9546*da0073e9SAndroid Build Coastguard Worker            x_cpu = torch.tensor(tensor_data)
9547*da0073e9SAndroid Build Coastguard Worker            x_mps = x_cpu.to('mps')
9548*da0073e9SAndroid Build Coastguard Worker
9549*da0073e9SAndroid Build Coastguard Worker            x_cpu[:, 0] = 7
9550*da0073e9SAndroid Build Coastguard Worker            x_mps[:, 0] = 7
9551*da0073e9SAndroid Build Coastguard Worker
9552*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu, x_mps)
9553*da0073e9SAndroid Build Coastguard Worker
9554*da0073e9SAndroid Build Coastguard Worker        _helper([[1, 2, 3], [4, 5, 6]])
9555*da0073e9SAndroid Build Coastguard Worker        _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
9556*da0073e9SAndroid Build Coastguard Worker        _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
9557*da0073e9SAndroid Build Coastguard Worker
9558*da0073e9SAndroid Build Coastguard Worker    def test_inplace_scatter(self):
9559*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/79672
9560*da0073e9SAndroid Build Coastguard Worker        a_mps = torch.ones((2, 2),).to(torch.device("mps"))
9561*da0073e9SAndroid Build Coastguard Worker        b_mps = torch.ones((2, 2),).to(torch.device("mps"))
9562*da0073e9SAndroid Build Coastguard Worker
9563*da0073e9SAndroid Build Coastguard Worker        a_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
9564*da0073e9SAndroid Build Coastguard Worker        b_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
9565*da0073e9SAndroid Build Coastguard Worker
9566*da0073e9SAndroid Build Coastguard Worker        a_mps[:, 0] += b_mps[:, 0]
9567*da0073e9SAndroid Build Coastguard Worker        a_cpu[:, 0] += b_cpu[:, 0]
9568*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_cpu, a_mps)
9569*da0073e9SAndroid Build Coastguard Worker
9570*da0073e9SAndroid Build Coastguard Worker        a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0]
9571*da0073e9SAndroid Build Coastguard Worker        a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0]
9572*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_cpu, a_mps)
9573*da0073e9SAndroid Build Coastguard Worker
9574*da0073e9SAndroid Build Coastguard Worker# These tests were taken from test/test_view_ops.py
9575*da0073e9SAndroid Build Coastguard Worker# They are subset of those tests as currently only this subset is working.
9576*da0073e9SAndroid Build Coastguard Worker# This whole `class` will be removed when we add generic device testing. There
9577*da0073e9SAndroid Build Coastguard Worker# are no additional tests added apart from what is part of test_view_ops.py
9578*da0073e9SAndroid Build Coastguard Workerclass TestViewOpsMPS(TestCaseMPS):
9579*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
9580*da0073e9SAndroid Build Coastguard Worker
9581*da0073e9SAndroid Build Coastguard Worker    def test_permute_slicing(self):
9582*da0073e9SAndroid Build Coastguard Worker        # test the fix for crash reported in
9583*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/94190
9584*da0073e9SAndroid Build Coastguard Worker        cpu_x = (torch.randn([3, 2, 2]).float())
9585*da0073e9SAndroid Build Coastguard Worker        mps_x = cpu_x.detach().clone().to('mps')
9586*da0073e9SAndroid Build Coastguard Worker        cpu_out = cpu_x.permute((2, 0, 1)) * 2.0
9587*da0073e9SAndroid Build Coastguard Worker        mps_out = mps_x.permute((2, 0, 1)) * 2.0
9588*da0073e9SAndroid Build Coastguard Worker        # this print caused a crash prior to fix PR#94259
9589*da0073e9SAndroid Build Coastguard Worker        print(torch.zeros_like(mps_out))
9590*da0073e9SAndroid Build Coastguard Worker        # test the fix for fill_scalar_mps() mentioned in issue #94190
9591*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros_like(cpu_out), torch.zeros_like(mps_out))
9592*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_x[:, 1, :].fill_(1), mps_x[:, 1, :].fill_(1))
9593*da0073e9SAndroid Build Coastguard Worker
9594*da0073e9SAndroid Build Coastguard Worker    def is_view_of(self, base, other):
9595*da0073e9SAndroid Build Coastguard Worker        if (not other._is_view() or
9596*da0073e9SAndroid Build Coastguard Worker                other is base or
9597*da0073e9SAndroid Build Coastguard Worker                other._base is not base or
9598*da0073e9SAndroid Build Coastguard Worker                base.device != other.device):
9599*da0073e9SAndroid Build Coastguard Worker            return False
9600*da0073e9SAndroid Build Coastguard Worker        # Note: only validates storage on native device types
9601*da0073e9SAndroid Build Coastguard Worker        # because some accelerators, like XLA, do not expose storage
9602*da0073e9SAndroid Build Coastguard Worker        if base.device.type == 'mps':
9603*da0073e9SAndroid Build Coastguard Worker            if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
9604*da0073e9SAndroid Build Coastguard Worker                return False
9605*da0073e9SAndroid Build Coastguard Worker
9606*da0073e9SAndroid Build Coastguard Worker        return True
9607*da0073e9SAndroid Build Coastguard Worker
9608*da0073e9SAndroid Build Coastguard Worker    # Returns true if v1 and v2 are views of the same base
9609*da0073e9SAndroid Build Coastguard Worker    def is_view_of_same_base(self, v1, v2):
9610*da0073e9SAndroid Build Coastguard Worker        if (not v1._is_view() or v1 is v2):
9611*da0073e9SAndroid Build Coastguard Worker            return False
9612*da0073e9SAndroid Build Coastguard Worker        return self.is_view_of(v1._base, v2)
9613*da0073e9SAndroid Build Coastguard Worker
9614*da0073e9SAndroid Build Coastguard Worker    # Performs transpose if contiguous=True, else returns the input tensor as is
9615*da0073e9SAndroid Build Coastguard Worker    def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
9616*da0073e9SAndroid Build Coastguard Worker        if contiguous:
9617*da0073e9SAndroid Build Coastguard Worker            return x
9618*da0073e9SAndroid Build Coastguard Worker        else:
9619*da0073e9SAndroid Build Coastguard Worker            return x.transpose(dim0, dim1)
9620*da0073e9SAndroid Build Coastguard Worker
9621*da0073e9SAndroid Build Coastguard Worker    def test_diagonal_view(self, device="mps"):
9622*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 5), device=device)
9623*da0073e9SAndroid Build Coastguard Worker        v = torch.diagonal(t)
9624*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9625*da0073e9SAndroid Build Coastguard Worker
9626*da0073e9SAndroid Build Coastguard Worker        v[0] = 0
9627*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 0], v[0])
9628*da0073e9SAndroid Build Coastguard Worker
9629*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((3, 3, 3), device="mps")
9630*da0073e9SAndroid Build Coastguard Worker        v = torch.diagonal(t, offset=1, dim1=1, dim2=2)
9631*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9632*da0073e9SAndroid Build Coastguard Worker
9633*da0073e9SAndroid Build Coastguard Worker        v[0, 0] = 0
9634*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 0, 1], v[0, 0])
9635*da0073e9SAndroid Build Coastguard Worker
9636*da0073e9SAndroid Build Coastguard Worker    def test_select_view(self, device="mps") -> None:
9637*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 5), device=device)
9638*da0073e9SAndroid Build Coastguard Worker        v = t.select(0, 2)
9639*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9640*da0073e9SAndroid Build Coastguard Worker
9641*da0073e9SAndroid Build Coastguard Worker        v[0] = 0
9642*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[2, 0], v[0])
9643*da0073e9SAndroid Build Coastguard Worker
9644*da0073e9SAndroid Build Coastguard Worker    def test_unbind_view(self, device="mps") -> None:
9645*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros((5, 5), device=device)
9646*da0073e9SAndroid Build Coastguard Worker        tup = torch.unbind(t)
9647*da0073e9SAndroid Build Coastguard Worker
9648*da0073e9SAndroid Build Coastguard Worker        for idx, v in enumerate(tup):
9649*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, v))
9650*da0073e9SAndroid Build Coastguard Worker
9651*da0073e9SAndroid Build Coastguard Worker            v[0] = idx + 1
9652*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[idx, 0], v[0])
9653*da0073e9SAndroid Build Coastguard Worker
9654*da0073e9SAndroid Build Coastguard Worker    def test_expand_view(self, device="mps") -> None:
9655*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 1), device=device)
9656*da0073e9SAndroid Build Coastguard Worker        v = t.expand(5, 5)
9657*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9658*da0073e9SAndroid Build Coastguard Worker
9659*da0073e9SAndroid Build Coastguard Worker        v[2, 2] = 0
9660*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[2, 0], v[2, 2])
9661*da0073e9SAndroid Build Coastguard Worker
9662*da0073e9SAndroid Build Coastguard Worker    def test_expand_as_view(self, device="mps"):
9663*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 1), device=device)
9664*da0073e9SAndroid Build Coastguard Worker        e = torch.empty((5, 5), device=device)
9665*da0073e9SAndroid Build Coastguard Worker        v = t.expand_as(e)
9666*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9667*da0073e9SAndroid Build Coastguard Worker
9668*da0073e9SAndroid Build Coastguard Worker        v[2, 2] = 0
9669*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[2, 0], v[2, 2])
9670*da0073e9SAndroid Build Coastguard Worker
9671*da0073e9SAndroid Build Coastguard Worker    def test_narrow_view(self, device="mps"):
9672*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 5), device=device)
9673*da0073e9SAndroid Build Coastguard Worker        v = torch.narrow(t, 1, 2, 2)
9674*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9675*da0073e9SAndroid Build Coastguard Worker
9676*da0073e9SAndroid Build Coastguard Worker        v[0, 0] = 0
9677*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 2], v[0, 0])
9678*da0073e9SAndroid Build Coastguard Worker
9679*da0073e9SAndroid Build Coastguard Worker    def test_permute_view(self, device="mps") -> None:
9680*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 5), device=device)
9681*da0073e9SAndroid Build Coastguard Worker        v = t.permute(1, 0)
9682*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9683*da0073e9SAndroid Build Coastguard Worker
9684*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
9685*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
9686*da0073e9SAndroid Build Coastguard Worker
9687*da0073e9SAndroid Build Coastguard Worker    def test_transpose_view(self, device="mps"):
9688*da0073e9SAndroid Build Coastguard Worker        for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
9689*da0073e9SAndroid Build Coastguard Worker            t = torch.ones((5, 5), device=device)
9690*da0073e9SAndroid Build Coastguard Worker            v = fn(t, 0, 1)
9691*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, v))
9692*da0073e9SAndroid Build Coastguard Worker
9693*da0073e9SAndroid Build Coastguard Worker            v[0, 1] = 0
9694*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[1, 0], v[0, 1])
9695*da0073e9SAndroid Build Coastguard Worker
9696*da0073e9SAndroid Build Coastguard Worker    def test_transpose_inplace_view(self, device="mps"):
9697*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9698*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
9699*da0073e9SAndroid Build Coastguard Worker        v = v.swapdims_(0, 1)
9700*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9701*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
9702*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
9703*da0073e9SAndroid Build Coastguard Worker
9704*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9705*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
9706*da0073e9SAndroid Build Coastguard Worker        v = v.swapaxes_(0, 1)
9707*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9708*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
9709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
9710*da0073e9SAndroid Build Coastguard Worker
9711*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9712*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
9713*da0073e9SAndroid Build Coastguard Worker        v = v.transpose_(0, 1)
9714*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9715*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
9716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
9717*da0073e9SAndroid Build Coastguard Worker
9718*da0073e9SAndroid Build Coastguard Worker    def test_t_view(self, device="mps"):
9719*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 5), device=device)
9720*da0073e9SAndroid Build Coastguard Worker        v = t.t()
9721*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9722*da0073e9SAndroid Build Coastguard Worker
9723*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
9724*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
9725*da0073e9SAndroid Build Coastguard Worker
9726*da0073e9SAndroid Build Coastguard Worker    def test_inplace_view_add(self):
9727*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/96153
9728*da0073e9SAndroid Build Coastguard Worker        t_mps = torch.ones((2, 6,), device='mps')[1].reshape(2, 3)
9729*da0073e9SAndroid Build Coastguard Worker        t_cpu = torch.ones((2, 6,), device='cpu')[1].reshape(2, 3)
9730*da0073e9SAndroid Build Coastguard Worker        t_mps = t_mps + 1
9731*da0073e9SAndroid Build Coastguard Worker        t_cpu = t_cpu + 1
9732*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t_mps, t_cpu)
9733*da0073e9SAndroid Build Coastguard Worker
9734*da0073e9SAndroid Build Coastguard Worker    def test_t_inplace_view(self, device="mps"):
9735*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9736*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
9737*da0073e9SAndroid Build Coastguard Worker        v = v.t_()
9738*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9739*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
9740*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
9741*da0073e9SAndroid Build Coastguard Worker
9742*da0073e9SAndroid Build Coastguard Worker    def test_T_view(self, device="mps"):
9743*da0073e9SAndroid Build Coastguard Worker        for op in ("T", "H", "mT", "mH"):
9744*da0073e9SAndroid Build Coastguard Worker            t = torch.ones((5, 5), device=device)
9745*da0073e9SAndroid Build Coastguard Worker            v = getattr(t, op)
9746*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, v))
9747*da0073e9SAndroid Build Coastguard Worker
9748*da0073e9SAndroid Build Coastguard Worker            v[0, 1] = 0
9749*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[1, 0], v[0, 1])
9750*da0073e9SAndroid Build Coastguard Worker
9751*da0073e9SAndroid Build Coastguard Worker    def test_unfold_view(self, device="mps"):
9752*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(10, device=device)
9753*da0073e9SAndroid Build Coastguard Worker        v = t.unfold(0, 3, 2)
9754*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9755*da0073e9SAndroid Build Coastguard Worker
9756*da0073e9SAndroid Build Coastguard Worker        v[1, 0] = 0
9757*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[2], v[1, 0])
9758*da0073e9SAndroid Build Coastguard Worker
9759*da0073e9SAndroid Build Coastguard Worker    def test_squeeze_view(self, device="mps"):
9760*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 1, 5, device=device)
9761*da0073e9SAndroid Build Coastguard Worker        v = torch.squeeze(t)
9762*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9763*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
9764*da0073e9SAndroid Build Coastguard Worker        self.assertIs(t, v._base)
9765*da0073e9SAndroid Build Coastguard Worker
9766*da0073e9SAndroid Build Coastguard Worker    def test_squeeze_inplace_view(self, device="mps"):
9767*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9768*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
9769*da0073e9SAndroid Build Coastguard Worker        v = v.squeeze_()
9770*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9771*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
9772*da0073e9SAndroid Build Coastguard Worker        self.assertIs(t, v._base)
9773*da0073e9SAndroid Build Coastguard Worker
9774*da0073e9SAndroid Build Coastguard Worker    def test_unsqueeze_view(self, device="mps"):
9775*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9776*da0073e9SAndroid Build Coastguard Worker        v = torch.unsqueeze(t, 1)
9777*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9778*da0073e9SAndroid Build Coastguard Worker
9779*da0073e9SAndroid Build Coastguard Worker        v[0, 0, 1] = 0
9780*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 1], v[0, 0, 1])
9781*da0073e9SAndroid Build Coastguard Worker
9782*da0073e9SAndroid Build Coastguard Worker    def test_unsqueeze_inplace_view(self, device="mps"):
9783*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9784*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
9785*da0073e9SAndroid Build Coastguard Worker        v = v.unsqueeze_(1)
9786*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9787*da0073e9SAndroid Build Coastguard Worker        v[0, 0, 1] = 0
9788*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 1], v[0, 0, 1])
9789*da0073e9SAndroid Build Coastguard Worker
9790*da0073e9SAndroid Build Coastguard Worker    def test_as_strided_view(self, device="mps"):
9791*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9792*da0073e9SAndroid Build Coastguard Worker        v = torch.as_strided(t, (25,), (1,))
9793*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9794*da0073e9SAndroid Build Coastguard Worker
9795*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
9796*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
9797*da0073e9SAndroid Build Coastguard Worker
9798*da0073e9SAndroid Build Coastguard Worker    def test_as_strided_inplace_view(self, device="mps"):
9799*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9800*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
9801*da0073e9SAndroid Build Coastguard Worker        v = v.as_strided_((25,), (1,))
9802*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9803*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
9804*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
9805*da0073e9SAndroid Build Coastguard Worker
9806*da0073e9SAndroid Build Coastguard Worker    def test_view_view(self, device="mps"):
9807*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9808*da0073e9SAndroid Build Coastguard Worker        v = t.view(25)
9809*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9810*da0073e9SAndroid Build Coastguard Worker
9811*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
9812*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
9813*da0073e9SAndroid Build Coastguard Worker
9814*da0073e9SAndroid Build Coastguard Worker    def test_view_as_view(self, device="mps"):
9815*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9816*da0073e9SAndroid Build Coastguard Worker        e = torch.empty((25,))
9817*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(e)
9818*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9819*da0073e9SAndroid Build Coastguard Worker
9820*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
9821*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
9822*da0073e9SAndroid Build Coastguard Worker
9823*da0073e9SAndroid Build Coastguard Worker    def test_contiguous_self(self, device="mps"):
9824*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9825*da0073e9SAndroid Build Coastguard Worker        s = t.contiguous()
9826*da0073e9SAndroid Build Coastguard Worker        self.assertIs(s, t)
9827*da0073e9SAndroid Build Coastguard Worker
9828*da0073e9SAndroid Build Coastguard Worker    def test_contiguous_nonview(self, device="mps"):
9829*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9830*da0073e9SAndroid Build Coastguard Worker        nv = t.t().contiguous()
9831*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(self.is_view_of(t, nv))
9832*da0073e9SAndroid Build Coastguard Worker
9833*da0073e9SAndroid Build Coastguard Worker        nv[0, 0] = 0
9834*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(t[0, 0], nv[0, 0])
9835*da0073e9SAndroid Build Coastguard Worker
9836*da0073e9SAndroid Build Coastguard Worker    def test_reshape_view(self, device="mps"):
9837*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9838*da0073e9SAndroid Build Coastguard Worker        v = torch.reshape(t, (25,))
9839*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9840*da0073e9SAndroid Build Coastguard Worker
9841*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
9842*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
9843*da0073e9SAndroid Build Coastguard Worker
9844*da0073e9SAndroid Build Coastguard Worker    def test_reshape_as_view(self, device="mps"):
9845*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9846*da0073e9SAndroid Build Coastguard Worker        e = torch.empty((25,), device=device)
9847*da0073e9SAndroid Build Coastguard Worker        v = t.reshape_as(e)
9848*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9849*da0073e9SAndroid Build Coastguard Worker
9850*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
9851*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
9852*da0073e9SAndroid Build Coastguard Worker
9853*da0073e9SAndroid Build Coastguard Worker    def test_reshape_nonview(self, device="mps"):
9854*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9855*da0073e9SAndroid Build Coastguard Worker        nv = torch.reshape(t.t(), (25,))
9856*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(self.is_view_of(t, nv))
9857*da0073e9SAndroid Build Coastguard Worker
9858*da0073e9SAndroid Build Coastguard Worker        nv[6] = 0
9859*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(t[1, 1], nv[6])
9860*da0073e9SAndroid Build Coastguard Worker
9861*da0073e9SAndroid Build Coastguard Worker    def test_flatten_view(self, device="mps"):
9862*da0073e9SAndroid Build Coastguard Worker        def test_writes_propagate(t, v):
9863*da0073e9SAndroid Build Coastguard Worker            idx_t = (0,) * t.ndim
9864*da0073e9SAndroid Build Coastguard Worker            idx_v = (0,) * v.ndim
9865*da0073e9SAndroid Build Coastguard Worker            v[idx_v] = 0
9866*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[idx_t], v[idx_v])
9867*da0073e9SAndroid Build Coastguard Worker
9868*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(1, 2, 3, 4, device=device)
9869*da0073e9SAndroid Build Coastguard Worker        v = t.flatten()
9870*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9871*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v)
9872*da0073e9SAndroid Build Coastguard Worker
9873*da0073e9SAndroid Build Coastguard Worker        # zero-dimensional tensor
9874*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(1, device=device)
9875*da0073e9SAndroid Build Coastguard Worker        v = t.flatten()
9876*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v)
9877*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9878*da0073e9SAndroid Build Coastguard Worker
9879*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
9880*da0073e9SAndroid Build Coastguard Worker        v = t.flatten(0, 1)
9881*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v)
9882*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of_same_base(t, v))
9883*da0073e9SAndroid Build Coastguard Worker
9884*da0073e9SAndroid Build Coastguard Worker        # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
9885*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(720, device=device) \
9886*da0073e9SAndroid Build Coastguard Worker            .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
9887*da0073e9SAndroid Build Coastguard Worker        #               [--1--|---2---|-3-] [--1--|----2---|-3-]
9888*da0073e9SAndroid Build Coastguard Worker        v1 = t.flatten(0, 1)
9889*da0073e9SAndroid Build Coastguard Worker        v2 = v1.flatten(1, 3)
9890*da0073e9SAndroid Build Coastguard Worker        v3 = v2.flatten(2, 2)
9891*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v1)
9892*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of_same_base(t, v1))
9893*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v2)
9894*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of_same_base(t, v2))
9895*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v3)
9896*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of_same_base(t, v3))
9897*da0073e9SAndroid Build Coastguard Worker
9898*da0073e9SAndroid Build Coastguard Worker    def test_flatten_nonview(self, device="mps"):
9899*da0073e9SAndroid Build Coastguard Worker        def assert_is_nonview(t, nv):
9900*da0073e9SAndroid Build Coastguard Worker            idx_t = (0,) * t.ndim
9901*da0073e9SAndroid Build Coastguard Worker            idx_nv = (0,) * nv.ndim
9902*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(nv._is_view())
9903*da0073e9SAndroid Build Coastguard Worker            nv[idx_nv] = 0
9904*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(t[idx_t], nv[idx_nv])
9905*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
9906*da0073e9SAndroid Build Coastguard Worker        nv = t.flatten(1, 3)
9907*da0073e9SAndroid Build Coastguard Worker        assert_is_nonview(t, nv)
9908*da0073e9SAndroid Build Coastguard Worker
9909*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(2, 2, device=device).T
9910*da0073e9SAndroid Build Coastguard Worker        nv = t.flatten()
9911*da0073e9SAndroid Build Coastguard Worker        assert_is_nonview(t, nv)
9912*da0073e9SAndroid Build Coastguard Worker
9913*da0073e9SAndroid Build Coastguard Worker        # flatten returns the original object if start_dim=end_dim
9914*da0073e9SAndroid Build Coastguard Worker        t = t = torch.ones(2, 2, device=device)
9915*da0073e9SAndroid Build Coastguard Worker        nv = t.flatten(1, 1)
9916*da0073e9SAndroid Build Coastguard Worker        self.assertIs(t, nv)
9917*da0073e9SAndroid Build Coastguard Worker
9918*da0073e9SAndroid Build Coastguard Worker    def test_basic_indexing_slice_view(self, device="mps"):
9919*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9920*da0073e9SAndroid Build Coastguard Worker        v = t[:2, :3]
9921*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9922*da0073e9SAndroid Build Coastguard Worker
9923*da0073e9SAndroid Build Coastguard Worker        v[0, 0] = 0
9924*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 0], v[0, 0])
9925*da0073e9SAndroid Build Coastguard Worker
9926*da0073e9SAndroid Build Coastguard Worker    def test_basic_indexing_ellipses_view(self, device="mps"):
9927*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9928*da0073e9SAndroid Build Coastguard Worker        v = t[..., :2]
9929*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9930*da0073e9SAndroid Build Coastguard Worker
9931*da0073e9SAndroid Build Coastguard Worker        v[0, 0] = 0
9932*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 0], v[0, 0])
9933*da0073e9SAndroid Build Coastguard Worker
9934*da0073e9SAndroid Build Coastguard Worker    def test_basic_indexing_newaxis_view(self, device="mps"):
9935*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
9936*da0073e9SAndroid Build Coastguard Worker        v = t[None, :2, 3]
9937*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
9938*da0073e9SAndroid Build Coastguard Worker
9939*da0073e9SAndroid Build Coastguard Worker        v[0, 0] = 0
9940*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 3], v[0, 0])
9941*da0073e9SAndroid Build Coastguard Worker
9942*da0073e9SAndroid Build Coastguard Worker    def test_chunk_view(self, device="mps"):
9943*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros(3, 3, device=device)
9944*da0073e9SAndroid Build Coastguard Worker        l = torch.chunk(t, 3)
9945*da0073e9SAndroid Build Coastguard Worker
9946*da0073e9SAndroid Build Coastguard Worker        for idx, v in enumerate(l):
9947*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, v))
9948*da0073e9SAndroid Build Coastguard Worker
9949*da0073e9SAndroid Build Coastguard Worker            v[0, 0] = idx + 1
9950*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[idx, 0], v[0, 0])
9951*da0073e9SAndroid Build Coastguard Worker
9952*da0073e9SAndroid Build Coastguard Worker    def test_split_view(self, device="mps"):
9953*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros(3, 3, device=device)
9954*da0073e9SAndroid Build Coastguard Worker        l = torch.split(t, [1, 1, 1])
9955*da0073e9SAndroid Build Coastguard Worker
9956*da0073e9SAndroid Build Coastguard Worker        for idx, v in enumerate(l):
9957*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, v))
9958*da0073e9SAndroid Build Coastguard Worker
9959*da0073e9SAndroid Build Coastguard Worker            v[0, 0] = idx + 1
9960*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[idx, 0], v[0, 0])
9961*da0073e9SAndroid Build Coastguard Worker
9962*da0073e9SAndroid Build Coastguard Worker    def test_movedim_view(self, device="mps"):
9963*da0073e9SAndroid Build Coastguard Worker        def run_test(device, op):
9964*da0073e9SAndroid Build Coastguard Worker            t = torch.zeros(3, 3, device=device)
9965*da0073e9SAndroid Build Coastguard Worker            out = op(t)
9966*da0073e9SAndroid Build Coastguard Worker
9967*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, out))
9968*da0073e9SAndroid Build Coastguard Worker
9969*da0073e9SAndroid Build Coastguard Worker            # Randomly change values in output
9970*da0073e9SAndroid Build Coastguard Worker            # and verify that original is changed
9971*da0073e9SAndroid Build Coastguard Worker            # as well.
9972*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
9973*da0073e9SAndroid Build Coastguard Worker                idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2)
9974*da0073e9SAndroid Build Coastguard Worker                out[idx_1, idx_2] = random.random()
9975*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
9976*da0073e9SAndroid Build Coastguard Worker
9977*da0073e9SAndroid Build Coastguard Worker        for fn in [torch.movedim, torch.moveaxis]:
9978*da0073e9SAndroid Build Coastguard Worker            op = partial(fn, source=(0, 1), destination=(1, 0))
9979*da0073e9SAndroid Build Coastguard Worker            run_test(device, op)
9980*da0073e9SAndroid Build Coastguard Worker
9981*da0073e9SAndroid Build Coastguard Worker            op = partial(fn, source=0, destination=1)
9982*da0073e9SAndroid Build Coastguard Worker            run_test(device, op)
9983*da0073e9SAndroid Build Coastguard Worker
9984*da0073e9SAndroid Build Coastguard Worker    # Testing that the generated view_copy kernel and its derivative are implemented correctly
9985*da0073e9SAndroid Build Coastguard Worker    def test_view_copy(self, device="mps"):
9986*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, device=device, requires_grad=True)
9987*da0073e9SAndroid Build Coastguard Worker        a_ref = a.clone().detach().requires_grad_()
9988*da0073e9SAndroid Build Coastguard Worker        a_view = a_ref.view(2, 2)
9989*da0073e9SAndroid Build Coastguard Worker        a_view_copy = torch.view_copy(a, (2, 2))
9990*da0073e9SAndroid Build Coastguard Worker
9991*da0073e9SAndroid Build Coastguard Worker        # view_copy ops don't preserve view relationship
9992*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(a_ref, a_view))
9993*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(self.is_view_of(a, a_view_copy))
9994*da0073e9SAndroid Build Coastguard Worker
9995*da0073e9SAndroid Build Coastguard Worker        a_view_copy.sum().backward()
9996*da0073e9SAndroid Build Coastguard Worker        a_view.sum().backward()
9997*da0073e9SAndroid Build Coastguard Worker
9998*da0073e9SAndroid Build Coastguard Worker        # forward and backward give the same shape + result
9999*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_view_copy, a_view)
10000*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, a_ref.grad)
10001*da0073e9SAndroid Build Coastguard Worker
10002*da0073e9SAndroid Build Coastguard Worker    def test_view_copy_out(self, device="mps"):
10003*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2, device=device)
10004*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(2, device=device)
10005*da0073e9SAndroid Build Coastguard Worker
10006*da0073e9SAndroid Build Coastguard Worker        torch.diagonal_copy(a, out=out)
10007*da0073e9SAndroid Build Coastguard Worker        expected = torch.diagonal_copy(a)
10008*da0073e9SAndroid Build Coastguard Worker
10009*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, out)
10010*da0073e9SAndroid Build Coastguard Worker
10011*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, device=device)
10012*da0073e9SAndroid Build Coastguard Worker        out1 = torch.empty(2, device=device)
10013*da0073e9SAndroid Build Coastguard Worker        out2 = torch.empty(2, device=device)
10014*da0073e9SAndroid Build Coastguard Worker
10015*da0073e9SAndroid Build Coastguard Worker        torch.split_copy(a, 2, out=(out1, out2))
10016*da0073e9SAndroid Build Coastguard Worker        expected1, expected2 = torch.split_copy(a, 2)
10017*da0073e9SAndroid Build Coastguard Worker
10018*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected1, out1)
10019*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected2, out2)
10020*da0073e9SAndroid Build Coastguard Worker
10021*da0073e9SAndroid Build Coastguard Worker    def test_detached_view_copy(self, device="mps"):
10022*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/86052
10023*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(2)
10024*da0073e9SAndroid Build Coastguard Worker        # .detach() makes y not a view, but contig tensor
10025*da0073e9SAndroid Build Coastguard Worker        # with non-zero offset
10026*da0073e9SAndroid Build Coastguard Worker        y = x[1].detach()
10027*da0073e9SAndroid Build Coastguard Worker        z = y.to(device)
10028*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, z.cpu())
10029*da0073e9SAndroid Build Coastguard Worker
10030*da0073e9SAndroid Build Coastguard Worker    def test_empty_reshape(self, device="mps"):
10031*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(0, 6, device=device)
10032*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
10033*da0073e9SAndroid Build Coastguard Worker        # should be viewable -- i.e. data_ptr is the same.
10034*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
10035*da0073e9SAndroid Build Coastguard Worker
10036*da0073e9SAndroid Build Coastguard Worker        # match NumPy semantics -- don't infer the size of dimension with a degree of freedom
10037*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
10038*da0073e9SAndroid Build Coastguard Worker
10039*da0073e9SAndroid Build Coastguard Worker    def test_expand(self, device="mps"):
10040*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(1, 8, 1, device=device)
10041*da0073e9SAndroid Build Coastguard Worker        tensor2 = torch.rand(5, device=device)
10042*da0073e9SAndroid Build Coastguard Worker        template = torch.rand(4, 8, 5, device=device)
10043*da0073e9SAndroid Build Coastguard Worker        target = template.size()
10044*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.expand_as(template).size(), target)
10045*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.expand(4, 8, 5).size(), target)
10046*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.expand(target).size(), target)
10047*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor2.expand_as(template).size(), target)
10048*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
10049*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor2.expand(target).size(), target)
10050*da0073e9SAndroid Build Coastguard Worker
10051*da0073e9SAndroid Build Coastguard Worker        # test double expand
10052*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
10053*da0073e9SAndroid Build Coastguard Worker
10054*da0073e9SAndroid Build Coastguard Worker        # test non-contiguous
10055*da0073e9SAndroid Build Coastguard Worker        noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0]
10056*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(noncontig.is_contiguous())
10057*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1))
10058*da0073e9SAndroid Build Coastguard Worker
10059*da0073e9SAndroid Build Coastguard Worker        # make sure it's compatible with unsqueeze
10060*da0073e9SAndroid Build Coastguard Worker        expanded = tensor2.expand(1, 1, 5)
10061*da0073e9SAndroid Build Coastguard Worker        unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
10062*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expanded, unsqueezed)
10063*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expanded.stride(), unsqueezed.stride())
10064*da0073e9SAndroid Build Coastguard Worker
10065*da0073e9SAndroid Build Coastguard Worker        # test -1 as target size
10066*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
10067*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1))
10068*da0073e9SAndroid Build Coastguard Worker
10069*da0073e9SAndroid Build Coastguard Worker        # test expanding empty to empty
10070*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device))
10071*da0073e9SAndroid Build Coastguard Worker
10072*da0073e9SAndroid Build Coastguard Worker    def test_view_empty(self, device="mps"):
10073*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(0, 6, device=device)
10074*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
10075*da0073e9SAndroid Build Coastguard Worker
10076*da0073e9SAndroid Build Coastguard Worker    def test_reshape(self, device="mps"):
10077*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, device=device)
10078*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
10079*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
10080*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
10081*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
10082*da0073e9SAndroid Build Coastguard Worker
10083*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, 4, device=device)[:, 0, :]
10084*da0073e9SAndroid Build Coastguard Worker        # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
10085*da0073e9SAndroid Build Coastguard Worker        if device != "meta":
10086*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
10087*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
10088*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
10089*da0073e9SAndroid Build Coastguard Worker
10090*da0073e9SAndroid Build Coastguard Worker        s = torch.randn((), device=device)
10091*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
10092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.reshape(-1).shape, (1,))
10093*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: s.reshape(2))
10094*da0073e9SAndroid Build Coastguard Worker
10095*da0073e9SAndroid Build Coastguard Worker        empty = torch.tensor([], device=device)
10096*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty, empty.reshape(-1))
10097*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty, empty.reshape([0]))
10098*da0073e9SAndroid Build Coastguard Worker        # TODO: fix these once we have multi-dimensional empty tensors
10099*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
10100*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
10101*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: empty.reshape(1))
10102*da0073e9SAndroid Build Coastguard Worker
10103*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, device=device)
10104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
10105*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
10106*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device)))
10107*da0073e9SAndroid Build Coastguard Worker
10108*da0073e9SAndroid Build Coastguard Worker    def test_narrow(self, device="mps"):
10109*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
10110*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]]))
10111*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]]))
10112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]]))
10113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]]))
10114*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]]))
10115*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
10116*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]]))
10117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]]))
10118*da0073e9SAndroid Build Coastguard Worker
10119*da0073e9SAndroid Build Coastguard Worker    def test_narrow_tensor(self, device="mps"):
10120*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
10121*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]]))
10122*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
10123*da0073e9SAndroid Build Coastguard Worker            x.narrow(0, torch.tensor(0.), 1)
10124*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
10125*da0073e9SAndroid Build Coastguard Worker            x.narrow(0, torch.tensor([0]), 1)
10126*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
10127*da0073e9SAndroid Build Coastguard Worker            x.narrow(0, torch.tensor([0, 1]), 1)
10128*da0073e9SAndroid Build Coastguard Worker
10129*da0073e9SAndroid Build Coastguard Worker    def test_t(self, device="mps"):
10130*da0073e9SAndroid Build Coastguard Worker        # Test 0D tensors
10131*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(())
10132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.t())
10133*da0073e9SAndroid Build Coastguard Worker        x = x.to_sparse()
10134*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.t())
10135*da0073e9SAndroid Build Coastguard Worker
10136*da0073e9SAndroid Build Coastguard Worker        # Test 1D tensors
10137*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(4)
10138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.t())
10139*da0073e9SAndroid Build Coastguard Worker        x = x.to_sparse()
10140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.t())
10141*da0073e9SAndroid Build Coastguard Worker
10142*da0073e9SAndroid Build Coastguard Worker        # Test 2D tensors
10143*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
10144*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.t(), x.transpose(0, 1))
10145*da0073e9SAndroid Build Coastguard Worker        x = x.to_sparse()
10146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.t(), x.transpose(0, 1))
10147*da0073e9SAndroid Build Coastguard Worker
10148*da0073e9SAndroid Build Coastguard Worker        # Test 3D tensor
10149*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2, 2))
10150*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'):
10151*da0073e9SAndroid Build Coastguard Worker            x.t()
10152*da0073e9SAndroid Build Coastguard Worker        x = x.to_sparse()
10153*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'):
10154*da0073e9SAndroid Build Coastguard Worker            x.t()
10155*da0073e9SAndroid Build Coastguard Worker
10156*da0073e9SAndroid Build Coastguard Worker    def test_split(self, device="mps"):
10157*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(7, 4)
10158*da0073e9SAndroid Build Coastguard Worker        split_size = 3
10159*da0073e9SAndroid Build Coastguard Worker        dim = 0
10160*da0073e9SAndroid Build Coastguard Worker        target_sizes = ([3, 4], [3, 4], [1, 4])
10161*da0073e9SAndroid Build Coastguard Worker        splits = tensor.split(split_size, dim)
10162*da0073e9SAndroid Build Coastguard Worker        start = 0
10163*da0073e9SAndroid Build Coastguard Worker        for target_size, split in zip(target_sizes, splits):
10164*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(split.size(), target_size)
10165*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
10166*da0073e9SAndroid Build Coastguard Worker            start = start + target_size[dim]
10167*da0073e9SAndroid Build Coastguard Worker
10168*da0073e9SAndroid Build Coastguard Worker        # Variable sections split
10169*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(20, 10)
10170*da0073e9SAndroid Build Coastguard Worker        dim = 0
10171*da0073e9SAndroid Build Coastguard Worker        split_sizes = [5, 5, 10]
10172*da0073e9SAndroid Build Coastguard Worker        target_sizes = ([[5, 10], [5, 10], [10, 10]])
10173*da0073e9SAndroid Build Coastguard Worker        splits = tensor.split(split_sizes, dim)
10174*da0073e9SAndroid Build Coastguard Worker        start = 0
10175*da0073e9SAndroid Build Coastguard Worker        for target_size, split in zip(target_sizes, splits):
10176*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(split.size(), target_size)
10177*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
10178*da0073e9SAndroid Build Coastguard Worker            start = start + target_size[dim]
10179*da0073e9SAndroid Build Coastguard Worker
10180*da0073e9SAndroid Build Coastguard Worker        split_sizes = [2, 2, 6]
10181*da0073e9SAndroid Build Coastguard Worker        target_sizes = ([20, 2], [20, 2], [20, 6])
10182*da0073e9SAndroid Build Coastguard Worker        dim = 1
10183*da0073e9SAndroid Build Coastguard Worker        splits = tensor.split(split_sizes, dim)
10184*da0073e9SAndroid Build Coastguard Worker        start = 0
10185*da0073e9SAndroid Build Coastguard Worker        for target_size, split in zip(target_sizes, splits):
10186*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(split.size(), target_size)
10187*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
10188*da0073e9SAndroid Build Coastguard Worker            start = start + target_size[dim]
10189*da0073e9SAndroid Build Coastguard Worker
10190*da0073e9SAndroid Build Coastguard Worker    def test_chunk(self, device="mps"):
10191*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(4, 7)
10192*da0073e9SAndroid Build Coastguard Worker        num_chunks = 3
10193*da0073e9SAndroid Build Coastguard Worker        dim = 1
10194*da0073e9SAndroid Build Coastguard Worker        target_sizes = ([4, 3], [4, 3], [4, 1])
10195*da0073e9SAndroid Build Coastguard Worker        splits = tensor.chunk(num_chunks, dim)
10196*da0073e9SAndroid Build Coastguard Worker        start = 0
10197*da0073e9SAndroid Build Coastguard Worker        for target_size, split in zip(target_sizes, splits):
10198*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(split.size(), target_size)
10199*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split,
10200*da0073e9SAndroid Build Coastguard Worker                             atol=0, rtol=0)
10201*da0073e9SAndroid Build Coastguard Worker            start = start + target_size[dim]
10202*da0073e9SAndroid Build Coastguard Worker
10203*da0073e9SAndroid Build Coastguard Worker        # Invalid chunk sizes
10204*da0073e9SAndroid Build Coastguard Worker        error_regex = 'chunk expects.*greater than 0'
10205*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, error_regex):
10206*da0073e9SAndroid Build Coastguard Worker            tensor.chunk(0)
10207*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, error_regex):
10208*da0073e9SAndroid Build Coastguard Worker            tensor.chunk(-2)
10209*da0073e9SAndroid Build Coastguard Worker
10210*da0073e9SAndroid Build Coastguard Worker    def test_unsqueeze(self, device="mps") -> None:
10211*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 4)
10212*da0073e9SAndroid Build Coastguard Worker        y = x.unsqueeze(1)
10213*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.view(2, 1, 3, 4))
10214*da0073e9SAndroid Build Coastguard Worker        y = x.clone().unsqueeze_(2)
10215*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.view(2, 3, 1, 4))
10216*da0073e9SAndroid Build Coastguard Worker
10217*da0073e9SAndroid Build Coastguard Worker        x = x[:, 1]
10218*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.is_contiguous())
10219*da0073e9SAndroid Build Coastguard Worker        y = x.unsqueeze(1)
10220*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.contiguous().view(2, 1, 4))
10221*da0073e9SAndroid Build Coastguard Worker        y = x.clone().unsqueeze_(2)
10222*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.contiguous().view(2, 4, 1))
10223*da0073e9SAndroid Build Coastguard Worker
10224*da0073e9SAndroid Build Coastguard Worker    # unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
10225*da0073e9SAndroid Build Coastguard Worker    def test_big_transpose(self, device="mps"):
10226*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(456, 789, device=device)
10227*da0073e9SAndroid Build Coastguard Worker        t1 = t.t().contiguous()
10228*da0073e9SAndroid Build Coastguard Worker        t2 = torch.from_numpy(t.cpu().numpy().transpose())
10229*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1, t2)
10230*da0073e9SAndroid Build Coastguard Worker
10231*da0073e9SAndroid Build Coastguard Worker    def test_T(self, device="mps"):
10232*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, 4, device=device)
10233*da0073e9SAndroid Build Coastguard Worker        t1 = a.T
10234*da0073e9SAndroid Build Coastguard Worker        t2 = a.permute(2, 1, 0)
10235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t2, t1)
10236*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(10, device=device)
10237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b, b.T)
10238*da0073e9SAndroid Build Coastguard Worker
10239*da0073e9SAndroid Build Coastguard Worker    def test_transposes(self, device="mps", dtype=torch.float32):
10240*da0073e9SAndroid Build Coastguard Worker        for op in ("T", "H", "mT", "mH", "adjoint"):
10241*da0073e9SAndroid Build Coastguard Worker            shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
10242*da0073e9SAndroid Build Coastguard Worker            for shape in shapes:
10243*da0073e9SAndroid Build Coastguard Worker                a = make_tensor(shape, device=device, dtype=dtype)
10244*da0073e9SAndroid Build Coastguard Worker                t1 = getattr(a, op)
10245*da0073e9SAndroid Build Coastguard Worker                if op == "adjoint":
10246*da0073e9SAndroid Build Coastguard Worker                    t1 = t1()
10247*da0073e9SAndroid Build Coastguard Worker                t2 = a
10248*da0073e9SAndroid Build Coastguard Worker                if a.ndim != 0:
10249*da0073e9SAndroid Build Coastguard Worker                    t2 = t2.transpose(-2, -1)
10250*da0073e9SAndroid Build Coastguard Worker                if op[-1] == "H" or op == "adjoint":
10251*da0073e9SAndroid Build Coastguard Worker                    t2 = t2.conj()
10252*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t2, t1)
10253*da0073e9SAndroid Build Coastguard Worker
10254*da0073e9SAndroid Build Coastguard Worker    def test_transposes_errors(self, device="mps", dtype=torch.float32):
10255*da0073e9SAndroid Build Coastguard Worker        for op in ("H", "mT", "mH", "adjoint"):
10256*da0073e9SAndroid Build Coastguard Worker            shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
10257*da0073e9SAndroid Build Coastguard Worker            for shape in shapes:
10258*da0073e9SAndroid Build Coastguard Worker                a = make_tensor(shape, device=device, dtype=dtype)
10259*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "only supported on matrices"):
10260*da0073e9SAndroid Build Coastguard Worker                    t1 = getattr(a, op)
10261*da0073e9SAndroid Build Coastguard Worker                    if op == "adjoint":
10262*da0073e9SAndroid Build Coastguard Worker                        t1 = t1()
10263*da0073e9SAndroid Build Coastguard Worker
10264*da0073e9SAndroid Build Coastguard Worker    def test_python_types(self, device="mps"):
10265*da0073e9SAndroid Build Coastguard Worker        a1 = torch.randn((1, 2), device=device, dtype=torch.float32)
10266*da0073e9SAndroid Build Coastguard Worker        a2 = torch.randn((1, 2), device=device, dtype=torch.float32)
10267*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a1.dtype, a2.dtype)
10268*da0073e9SAndroid Build Coastguard Worker
10269*da0073e9SAndroid Build Coastguard Worker        b1 = torch.arange(10, 20, dtype=torch.int64, device=device)
10270*da0073e9SAndroid Build Coastguard Worker        b2 = torch.arange(10, 20, dtype=int, device=device)
10271*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b1.dtype, b2.dtype)
10272*da0073e9SAndroid Build Coastguard Worker
10273*da0073e9SAndroid Build Coastguard Worker        c1 = torch.tensor([True, False], dtype=torch.bool, device=device)
10274*da0073e9SAndroid Build Coastguard Worker        c2 = torch.tensor([True, False], dtype=bool, device=device)
10275*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c1.dtype, c2.dtype)
10276*da0073e9SAndroid Build Coastguard Worker
10277*da0073e9SAndroid Build Coastguard Worker    # TODO: is resize best put in test_view_ops?
10278*da0073e9SAndroid Build Coastguard Worker    def test_resize_as_preserves_strides(self, device="mps"):
10279*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2, 3).t()
10280*da0073e9SAndroid Build Coastguard Worker        old_strides = x.stride()
10281*da0073e9SAndroid Build Coastguard Worker        x.resize_as_(x)
10282*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.stride(), old_strides)
10283*da0073e9SAndroid Build Coastguard Worker
10284*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_resize_as(self, device="mps"):
10285*da0073e9SAndroid Build Coastguard Worker        def test_helper(shape, memory_format, device="mps"):
10286*da0073e9SAndroid Build Coastguard Worker            xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format)
10287*da0073e9SAndroid Build Coastguard Worker            flat = torch.randn(xc.numel(), device=device)
10288*da0073e9SAndroid Build Coastguard Worker            flat.resize_as_(xc, memory_format=torch.preserve_format)
10289*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(flat.is_contiguous(memory_format=memory_format))
10290*da0073e9SAndroid Build Coastguard Worker
10291*da0073e9SAndroid Build Coastguard Worker        test_helper((10, 3, 32, 32), torch.channels_last, device="mps")
10292*da0073e9SAndroid Build Coastguard Worker        test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps")
10293*da0073e9SAndroid Build Coastguard Worker
10294*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_resize_(self, device="mps"):
10295*da0073e9SAndroid Build Coastguard Worker        def test_helper(shape, numel, memory_format, device="mps"):
10296*da0073e9SAndroid Build Coastguard Worker            flat = torch.randn(numel, device=device)
10297*da0073e9SAndroid Build Coastguard Worker            flat.resize_(shape, memory_format=memory_format)
10298*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(flat.is_contiguous(memory_format=memory_format))
10299*da0073e9SAndroid Build Coastguard Worker
10300*da0073e9SAndroid Build Coastguard Worker        test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps")
10301*da0073e9SAndroid Build Coastguard Worker        test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps")
10302*da0073e9SAndroid Build Coastguard Worker
10303*da0073e9SAndroid Build Coastguard Worker    # TODO: OpInfo this
10304*da0073e9SAndroid Build Coastguard Worker    def _test_atleast(self, device, torch_fn):
10305*da0073e9SAndroid Build Coastguard Worker        # 0-dim
10306*da0073e9SAndroid Build Coastguard Worker        s = torch.tensor(0.5, dtype=torch.double, requires_grad=True)
10307*da0073e9SAndroid Build Coastguard Worker
10308*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: torch_fn(x), s)
10309*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda x: torch_fn(x), s)
10310*da0073e9SAndroid Build Coastguard Worker
10311*da0073e9SAndroid Build Coastguard Worker        # 1-dim
10312*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4, dtype=torch.double, requires_grad=True)
10313*da0073e9SAndroid Build Coastguard Worker
10314*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: torch_fn(x), a)
10315*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda x: torch_fn(x), a)
10316*da0073e9SAndroid Build Coastguard Worker
10317*da0073e9SAndroid Build Coastguard Worker        # 2,3,4-dim
10318*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(4, 3, dtype=torch.double, requires_grad=True)
10319*da0073e9SAndroid Build Coastguard Worker        c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True)
10320*da0073e9SAndroid Build Coastguard Worker        d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
10321*da0073e9SAndroid Build Coastguard Worker
10322*da0073e9SAndroid Build Coastguard Worker        input_tuple = (s, a, b, c, d)
10323*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
10324*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
10325*da0073e9SAndroid Build Coastguard Worker
10326*da0073e9SAndroid Build Coastguard Worker    def test_atleast_gradient(self, device="mps"):
10327*da0073e9SAndroid Build Coastguard Worker        self._test_atleast(device, torch.atleast_1d)
10328*da0073e9SAndroid Build Coastguard Worker        self._test_atleast(device, torch.atleast_2d)
10329*da0073e9SAndroid Build Coastguard Worker        self._test_atleast(device, torch.atleast_3d)
10330*da0073e9SAndroid Build Coastguard Worker
10331*da0073e9SAndroid Build Coastguard Worker    def test_view(self, device="mps"):
10332*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(15, device=device)
10333*da0073e9SAndroid Build Coastguard Worker        template = torch.rand(3, 5, device=device)
10334*da0073e9SAndroid Build Coastguard Worker        empty = torch.empty(0, device=device)
10335*da0073e9SAndroid Build Coastguard Worker        target = template.size()
10336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view_as(template).size(), target)
10337*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(3, 5).size(), target)
10338*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
10339*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(-1, 5).size(), target)
10340*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(3, -1).size(), target)
10341*da0073e9SAndroid Build Coastguard Worker        tensor_view = tensor.view(5, 3)
10342*da0073e9SAndroid Build Coastguard Worker        tensor_view.fill_(random.uniform(0, 1))
10343*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view_as(empty), empty)
10344*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view(0), empty)
10345*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
10346*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
10347*da0073e9SAndroid Build Coastguard Worker
10348*da0073e9SAndroid Build Coastguard Worker        # test size inference with empty tensors
10349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view(-1).size(), torch.Size([0]))
10350*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
10351*da0073e9SAndroid Build Coastguard Worker
10352*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
10353*da0073e9SAndroid Build Coastguard Worker            empty.view(-1, 0)
10354*da0073e9SAndroid Build Coastguard Worker
10355*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
10356*da0073e9SAndroid Build Coastguard Worker            empty.view(3, 0, -1, 0)
10357*da0073e9SAndroid Build Coastguard Worker
10358*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
10359*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
10360*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
10361*da0073e9SAndroid Build Coastguard Worker
10362*da0073e9SAndroid Build Coastguard Worker    def test_contiguous(self, device="mps"):
10363*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 16, 5, 5, device=device)
10364*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.is_contiguous())
10365*da0073e9SAndroid Build Coastguard Worker        stride = list(x.stride())
10366*da0073e9SAndroid Build Coastguard Worker        stride[0] = 20
10367*da0073e9SAndroid Build Coastguard Worker        # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
10368*da0073e9SAndroid Build Coastguard Worker        x.set_(x.storage(), 0, x.size(), stride)
10369*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.is_contiguous())
10370*da0073e9SAndroid Build Coastguard Worker
10371*da0073e9SAndroid Build Coastguard Worker    def test_resize_mps_dtypes(self, device="mps"):
10372*da0073e9SAndroid Build Coastguard Worker        shape = (2, 2)
10373*da0073e9SAndroid Build Coastguard Worker        for dt in MPS_DTYPES:
10374*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10375*da0073e9SAndroid Build Coastguard Worker            x.resize_(shape)
10376*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(shape, x.shape)
10377*da0073e9SAndroid Build Coastguard Worker
10378*da0073e9SAndroid Build Coastguard Worker    def test_resize_as_mps_dtypes(self, device="mps"):
10379*da0073e9SAndroid Build Coastguard Worker        for dt in MPS_DTYPES:
10380*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10381*da0073e9SAndroid Build Coastguard Worker            y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
10382*da0073e9SAndroid Build Coastguard Worker            x.resize_as_(y)
10383*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.shape, x.shape)
10384*da0073e9SAndroid Build Coastguard Worker
10385*da0073e9SAndroid Build Coastguard Worker    def test_resize_overflow(self, device="mps"):
10386*da0073e9SAndroid Build Coastguard Worker        x = torch.empty((), dtype=torch.float64)
10387*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
10388*da0073e9SAndroid Build Coastguard Worker            x.resize_([2, 4, 2**29, 2**29])
10389*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'overflow'):
10390*da0073e9SAndroid Build Coastguard Worker            x.resize_([8, 8, 2**29, 2**29])
10391*da0073e9SAndroid Build Coastguard Worker
10392*da0073e9SAndroid Build Coastguard Worker    def test_view_all_dtypes_and_devices(self, device="mps"):
10393*da0073e9SAndroid Build Coastguard Worker        for dt in (torch.float, torch.bool):
10394*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10395*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.view(6).shape, [6])
10396*da0073e9SAndroid Build Coastguard Worker
10397*da0073e9SAndroid Build Coastguard Workerclass TestConvolutionMPS(TestCaseMPS):
10398*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_all_strides_paddings(self):
10399*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/82921
10400*da0073e9SAndroid Build Coastguard Worker        def helper(stride, padding):
10401*da0073e9SAndroid Build Coastguard Worker            y_cpu = torch.randn(1, 57, 40)
10402*da0073e9SAndroid Build Coastguard Worker            conv_cpu = nn.Conv1d(57, 20, stride=stride, padding=padding, kernel_size=3, bias=False)
10403*da0073e9SAndroid Build Coastguard Worker            conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10404*da0073e9SAndroid Build Coastguard Worker            x_cpu = conv_cpu(y_cpu)
10405*da0073e9SAndroid Build Coastguard Worker
10406*da0073e9SAndroid Build Coastguard Worker            y_gpu = y_cpu.to(device='mps')
10407*da0073e9SAndroid Build Coastguard Worker            x_gpu = conv_gpu(y_gpu)
10408*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu, x_gpu.cpu())
10409*da0073e9SAndroid Build Coastguard Worker        for stride in range(1, 4):
10410*da0073e9SAndroid Build Coastguard Worker            for padding in range(1, 4):
10411*da0073e9SAndroid Build Coastguard Worker                helper(stride, padding)
10412*da0073e9SAndroid Build Coastguard Worker
10413*da0073e9SAndroid Build Coastguard Worker
10414*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_channels_last(self):
10415*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/81557
10416*da0073e9SAndroid Build Coastguard Worker        model_cpu = torch.nn.Conv1d(1, 128, 3)
10417*da0073e9SAndroid Build Coastguard Worker        a_cpu = torch.arange((128 * 176), dtype=torch.float32)
10418*da0073e9SAndroid Build Coastguard Worker        a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1)
10419*da0073e9SAndroid Build Coastguard Worker        out_cpu = model_cpu(a_cpu)
10420*da0073e9SAndroid Build Coastguard Worker
10421*da0073e9SAndroid Build Coastguard Worker        a_mps = a_cpu.detach().clone().to("mps")
10422*da0073e9SAndroid Build Coastguard Worker        model_mps = model_cpu.to("mps")
10423*da0073e9SAndroid Build Coastguard Worker        out_mps = model_mps(a_mps)
10424*da0073e9SAndroid Build Coastguard Worker
10425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04)
10426*da0073e9SAndroid Build Coastguard Worker
10427*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose_1d_all_strides(self):
10428*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/82711
10429*da0073e9SAndroid Build Coastguard Worker        def helper(stride):
10430*da0073e9SAndroid Build Coastguard Worker            y_cpu = torch.ones(1, 1, 2)
10431*da0073e9SAndroid Build Coastguard Worker            deconv_cpu = nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=1, stride=stride, bias=False, padding=1)
10432*da0073e9SAndroid Build Coastguard Worker            deconv_cpu.weight.data = torch.ones(1, 1, 2)
10433*da0073e9SAndroid Build Coastguard Worker            deconv_gpu = copy.deepcopy(deconv_cpu).to(device='mps')
10434*da0073e9SAndroid Build Coastguard Worker            x_cpu = deconv_cpu(y_cpu)
10435*da0073e9SAndroid Build Coastguard Worker
10436*da0073e9SAndroid Build Coastguard Worker            y_gpu = y_cpu.to(device='mps')
10437*da0073e9SAndroid Build Coastguard Worker            x_gpu = deconv_gpu(y_gpu)
10438*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu, x_gpu.cpu())
10439*da0073e9SAndroid Build Coastguard Worker        [helper(stride) for stride in [1, 2, 3]]
10440*da0073e9SAndroid Build Coastguard Worker
10441*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose_1d_nn_functional(self):
10442*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/82563
10443*da0073e9SAndroid Build Coastguard Worker        tin = torch.rand((1, 512, 1245), dtype=torch.float32)
10444*da0073e9SAndroid Build Coastguard Worker        tparams = torch.rand((512, 256, 16), dtype=torch.float32)
10445*da0073e9SAndroid Build Coastguard Worker        tbias = torch.rand((256), dtype=torch.float32)
10446*da0073e9SAndroid Build Coastguard Worker
10447*da0073e9SAndroid Build Coastguard Worker        device = 'cpu'
10448*da0073e9SAndroid Build Coastguard Worker        tcpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
10449*da0073e9SAndroid Build Coastguard Worker
10450*da0073e9SAndroid Build Coastguard Worker        device = 'mps'
10451*da0073e9SAndroid Build Coastguard Worker        tgpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
10452*da0073e9SAndroid Build Coastguard Worker
10453*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
10454*da0073e9SAndroid Build Coastguard Worker
10455*da0073e9SAndroid Build Coastguard Worker    def test_conv_backward_1d_channels_last(self):
10456*da0073e9SAndroid Build Coastguard Worker        def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
10457*da0073e9SAndroid Build Coastguard Worker            # https://github.com/pytorch/pytorch/issues/84511
10458*da0073e9SAndroid Build Coastguard Worker            conv_cpu = torch.nn.Conv1d(
10459*da0073e9SAndroid Build Coastguard Worker                in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_()
10460*da0073e9SAndroid Build Coastguard Worker            conv_mps = torch.nn.Conv1d(
10461*da0073e9SAndroid Build Coastguard Worker                in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
10462*da0073e9SAndroid Build Coastguard Worker            conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
10463*da0073e9SAndroid Build Coastguard Worker            conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True)
10464*da0073e9SAndroid Build Coastguard Worker
10465*da0073e9SAndroid Build Coastguard Worker
10466*da0073e9SAndroid Build Coastguard Worker            data = torch.rand(shape, dtype=torch.float32)
10467*da0073e9SAndroid Build Coastguard Worker            x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True)
10468*da0073e9SAndroid Build Coastguard Worker            x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True)
10469*da0073e9SAndroid Build Coastguard Worker            res_cpu = conv_cpu(x_cpu)
10470*da0073e9SAndroid Build Coastguard Worker            res_mps = conv_mps(x_mps)
10471*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_cpu, res_mps)
10472*da0073e9SAndroid Build Coastguard Worker            res_cpu = res_cpu.sum().backward()
10473*da0073e9SAndroid Build Coastguard Worker            res_mps = res_mps.sum().backward()
10474*da0073e9SAndroid Build Coastguard Worker
10475*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
10476*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu.grad, x_mps.grad)
10477*da0073e9SAndroid Build Coastguard Worker
10478*da0073e9SAndroid Build Coastguard Worker        helper(shape=(1, 176, 1))
10479*da0073e9SAndroid Build Coastguard Worker        helper(shape=(2, 12, 1))
10480*da0073e9SAndroid Build Coastguard Worker        helper(shape=(3, 176, 1))
10481*da0073e9SAndroid Build Coastguard Worker        helper(shape=(4, 376, 1))
10482*da0073e9SAndroid Build Coastguard Worker        helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1)
10483*da0073e9SAndroid Build Coastguard Worker        helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3)
10484*da0073e9SAndroid Build Coastguard Worker
10485*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_contiguous(self):
10486*da0073e9SAndroid Build Coastguard Worker        model_cpu = torch.nn.Conv1d(1, 128, 3)
10487*da0073e9SAndroid Build Coastguard Worker        a_cpu = torch.ones(128, 1, 176)
10488*da0073e9SAndroid Build Coastguard Worker        out_cpu = model_cpu(a_cpu)
10489*da0073e9SAndroid Build Coastguard Worker
10490*da0073e9SAndroid Build Coastguard Worker        a_mps = a_cpu.detach().clone().to("mps")
10491*da0073e9SAndroid Build Coastguard Worker        model_mps = model_cpu.to("mps")
10492*da0073e9SAndroid Build Coastguard Worker        out_mps = model_mps(a_mps)
10493*da0073e9SAndroid Build Coastguard Worker
10494*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_cpu.shape, out_mps.shape)
10495*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_cpu, out_mps.cpu())
10496*da0073e9SAndroid Build Coastguard Worker
10497*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_all_strides_paddings(self):
10498*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/83180
10499*da0073e9SAndroid Build Coastguard Worker        def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data):
10500*da0073e9SAndroid Build Coastguard Worker            x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_()
10501*da0073e9SAndroid Build Coastguard Worker            x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()
10502*da0073e9SAndroid Build Coastguard Worker
10503*da0073e9SAndroid Build Coastguard Worker            if permute_data:
10504*da0073e9SAndroid Build Coastguard Worker                x_cpu.permute(0, 2, 3, 1)
10505*da0073e9SAndroid Build Coastguard Worker                x_mps.permute(0, 2, 3, 1)
10506*da0073e9SAndroid Build Coastguard Worker
10507*da0073e9SAndroid Build Coastguard Worker            for strideX in range(1, 4):
10508*da0073e9SAndroid Build Coastguard Worker                for strideY in range(1, 4):
10509*da0073e9SAndroid Build Coastguard Worker                    conv_cpu = torch.nn.Conv2d(
10510*da0073e9SAndroid Build Coastguard Worker                        in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_()
10511*da0073e9SAndroid Build Coastguard Worker                    conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_()
10512*da0073e9SAndroid Build Coastguard Worker
10513*da0073e9SAndroid Build Coastguard Worker                    conv_mps = torch.nn.Conv2d(
10514*da0073e9SAndroid Build Coastguard Worker                        in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
10515*da0073e9SAndroid Build Coastguard Worker                    conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10516*da0073e9SAndroid Build Coastguard Worker                    conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10517*da0073e9SAndroid Build Coastguard Worker
10518*da0073e9SAndroid Build Coastguard Worker                    res_cpu = conv_cpu(x_cpu)
10519*da0073e9SAndroid Build Coastguard Worker                    res_mps = conv_mps(x_mps)
10520*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)
10521*da0073e9SAndroid Build Coastguard Worker                    res_cpu = res_cpu.sum().backward()
10522*da0073e9SAndroid Build Coastguard Worker                    res_mps = res_mps.sum().backward()
10523*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
10524*da0073e9SAndroid Build Coastguard Worker
10525*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
10526*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
10527*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(x_cpu.grad, x_mps.grad)
10528*da0073e9SAndroid Build Coastguard Worker
10529*da0073e9SAndroid Build Coastguard Worker        for mem_format_input in [torch.contiguous_format, torch.channels_last]:
10530*da0073e9SAndroid Build Coastguard Worker            for mem_format_weight in [torch.contiguous_format, torch.channels_last]:
10531*da0073e9SAndroid Build Coastguard Worker                for permute_data in [True, False]:
10532*da0073e9SAndroid Build Coastguard Worker                    helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data)
10533*da0073e9SAndroid Build Coastguard Worker                    helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
10534*da0073e9SAndroid Build Coastguard Worker                    helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
10535*da0073e9SAndroid Build Coastguard Worker
10536*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose_2d_strided(self):
10537*da0073e9SAndroid Build Coastguard Worker        def helper(m_cpu, memory_format):
10538*da0073e9SAndroid Build Coastguard Worker            m_mps = copy.deepcopy(m_cpu).requires_grad_()
10539*da0073e9SAndroid Build Coastguard Worker            m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10540*da0073e9SAndroid Build Coastguard Worker            m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10541*da0073e9SAndroid Build Coastguard Worker
10542*da0073e9SAndroid Build Coastguard Worker            input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_()
10543*da0073e9SAndroid Build Coastguard Worker            input_mps = input_cpu.detach().clone().to("mps")
10544*da0073e9SAndroid Build Coastguard Worker
10545*da0073e9SAndroid Build Coastguard Worker            output_cpu = m_cpu(input_cpu)
10546*da0073e9SAndroid Build Coastguard Worker            output_mps = m_mps(input_mps)
10547*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output_cpu, output_mps)
10548*da0073e9SAndroid Build Coastguard Worker
10549*da0073e9SAndroid Build Coastguard Worker        for mem_format_input in [torch.contiguous_format, torch.channels_last]:
10550*da0073e9SAndroid Build Coastguard Worker            # With square kernels and equal stride
10551*da0073e9SAndroid Build Coastguard Worker            helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input)
10552*da0073e9SAndroid Build Coastguard Worker
10553*da0073e9SAndroid Build Coastguard Worker            # non-square kernels and unequal stride and with padding
10554*da0073e9SAndroid Build Coastguard Worker            helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input)
10555*da0073e9SAndroid Build Coastguard Worker
10556*da0073e9SAndroid Build Coastguard Worker    def test_conv_transpose_2d_specified_output(self):
10557*da0073e9SAndroid Build Coastguard Worker        input_cpu = torch.randn(1, 16, 12, 12)
10558*da0073e9SAndroid Build Coastguard Worker        input_mps = input_cpu.detach().clone().to("mps")
10559*da0073e9SAndroid Build Coastguard Worker
10560*da0073e9SAndroid Build Coastguard Worker        downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1)
10561*da0073e9SAndroid Build Coastguard Worker        downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
10562*da0073e9SAndroid Build Coastguard Worker        downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10563*da0073e9SAndroid Build Coastguard Worker        downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10564*da0073e9SAndroid Build Coastguard Worker
10565*da0073e9SAndroid Build Coastguard Worker        upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
10566*da0073e9SAndroid Build Coastguard Worker        upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
10567*da0073e9SAndroid Build Coastguard Worker        upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10568*da0073e9SAndroid Build Coastguard Worker        upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10569*da0073e9SAndroid Build Coastguard Worker
10570*da0073e9SAndroid Build Coastguard Worker        h_cpu = downsample_cpu(input_cpu)
10571*da0073e9SAndroid Build Coastguard Worker        h_mps = downsample_mps(input_mps)
10572*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(h_cpu, h_mps)
10573*da0073e9SAndroid Build Coastguard Worker
10574*da0073e9SAndroid Build Coastguard Worker        size_cpu = h_cpu.size()
10575*da0073e9SAndroid Build Coastguard Worker        size_mps = h_mps.size()
10576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(size_cpu, size_mps)
10577*da0073e9SAndroid Build Coastguard Worker
10578*da0073e9SAndroid Build Coastguard Worker        output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size())
10579*da0073e9SAndroid Build Coastguard Worker        output_mps = upsample_mps(h_mps, output_size=input_mps.size())
10580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu, output_mps)
10581*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output_cpu.size(), output_mps.size())
10582*da0073e9SAndroid Build Coastguard Worker
10583*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_single_stride(self):
10584*da0073e9SAndroid Build Coastguard Worker        y_cpu = torch.randn(2, 2, 3, 6)
10585*da0073e9SAndroid Build Coastguard Worker        y_gpu = y_cpu.to(device='mps')
10586*da0073e9SAndroid Build Coastguard Worker        for stride in range(1, 4):
10587*da0073e9SAndroid Build Coastguard Worker            conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=stride)
10588*da0073e9SAndroid Build Coastguard Worker            conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10589*da0073e9SAndroid Build Coastguard Worker            x_cpu = conv_cpu(y_cpu)
10590*da0073e9SAndroid Build Coastguard Worker            x_gpu = conv_gpu(y_gpu)
10591*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
10592*da0073e9SAndroid Build Coastguard Worker
10593*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
10594*da0073e9SAndroid Build Coastguard Worker    def test_conv3d_single_stride(self):
10595*da0073e9SAndroid Build Coastguard Worker        # Conv3d is only available from MacOS 13.2 onwards
10596*da0073e9SAndroid Build Coastguard Worker        y_cpu = torch.randn(2, 2, 3, 6)
10597*da0073e9SAndroid Build Coastguard Worker        y_gpu = y_cpu.to(device='mps')
10598*da0073e9SAndroid Build Coastguard Worker        for stride in range(1, 4):
10599*da0073e9SAndroid Build Coastguard Worker            conv_cpu = torch.nn.Conv3d(in_channels=2, out_channels=2, kernel_size=2, stride=stride)
10600*da0073e9SAndroid Build Coastguard Worker            conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10601*da0073e9SAndroid Build Coastguard Worker            x_cpu = conv_cpu(y_cpu)
10602*da0073e9SAndroid Build Coastguard Worker            x_gpu = conv_gpu(y_gpu)
10603*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
10604*da0073e9SAndroid Build Coastguard Worker
10605*da0073e9SAndroid Build Coastguard Worker    def test_grid_sample(self):
10606*da0073e9SAndroid Build Coastguard Worker        def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad):
10607*da0073e9SAndroid Build Coastguard Worker            def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners):
10608*da0073e9SAndroid Build Coastguard Worker                for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]:
10609*da0073e9SAndroid Build Coastguard Worker                    # grid_dim_contig_order specifies the dimension order that can
10610*da0073e9SAndroid Build Coastguard Worker                    # make grid to be contiguous.
10611*da0073e9SAndroid Build Coastguard Worker                    # i.e., grid.permute(grid_dim_contig_order) is contiguous.
10612*da0073e9SAndroid Build Coastguard Worker                    # e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be
10613*da0073e9SAndroid Build Coastguard Worker                    #       initialized with contiguous tensor of shape [N, 2, H, W]
10614*da0073e9SAndroid Build Coastguard Worker                    #       and permuted to [N, H, W, 2] afterwards.
10615*da0073e9SAndroid Build Coastguard Worker                    grid_shape = [N, H, W, 2]
10616*da0073e9SAndroid Build Coastguard Worker                    grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order]
10617*da0073e9SAndroid Build Coastguard Worker                    grid_fwd_permute = [None, None, None, None]
10618*da0073e9SAndroid Build Coastguard Worker                    for i, d in enumerate(grid_dim_contig_order):
10619*da0073e9SAndroid Build Coastguard Worker                        grid_fwd_permute[d] = i
10620*da0073e9SAndroid Build Coastguard Worker
10621*da0073e9SAndroid Build Coastguard Worker                    def get_grid(device='cpu', data=None):
10622*da0073e9SAndroid Build Coastguard Worker                        if data is not None:
10623*da0073e9SAndroid Build Coastguard Worker                            assert list(data.shape) == grid_shape
10624*da0073e9SAndroid Build Coastguard Worker                            data = data.permute(grid_dim_contig_order).to(device)
10625*da0073e9SAndroid Build Coastguard Worker                        else:
10626*da0073e9SAndroid Build Coastguard Worker                            data = torch.randn(grid_init_shape, device=device)
10627*da0073e9SAndroid Build Coastguard Worker                        grid = data.permute(grid_fwd_permute)
10628*da0073e9SAndroid Build Coastguard Worker                        assert grid.permute(grid_dim_contig_order).is_contiguous()
10629*da0073e9SAndroid Build Coastguard Worker                        return grid
10630*da0073e9SAndroid Build Coastguard Worker
10631*da0073e9SAndroid Build Coastguard Worker                    input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad)
10632*da0073e9SAndroid Build Coastguard Worker                    grid_cpu = get_grid().requires_grad_()
10633*da0073e9SAndroid Build Coastguard Worker                    out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
10634*da0073e9SAndroid Build Coastguard Worker                                            align_corners=align_corners)
10635*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out_cpu.size(), torch.Size([N, C, H, W]))
10636*da0073e9SAndroid Build Coastguard Worker
10637*da0073e9SAndroid Build Coastguard Worker                    gradients = torch.randn_like(out_cpu)
10638*da0073e9SAndroid Build Coastguard Worker                    out_cpu.backward(gradients)
10639*da0073e9SAndroid Build Coastguard Worker
10640*da0073e9SAndroid Build Coastguard Worker
10641*da0073e9SAndroid Build Coastguard Worker                    # Compare against unvectorized CPU fallback
10642*da0073e9SAndroid Build Coastguard Worker
10643*da0073e9SAndroid Build Coastguard Worker                    # NOTE [ grid_sample CPU fallback ]
10644*da0073e9SAndroid Build Coastguard Worker                    # grid_sample uses AVX for 2d images, but that requires 32-bit indexing for
10645*da0073e9SAndroid Build Coastguard Worker                    # 32-bit floats. So we also have a fallback that is used only for float tensors
10646*da0073e9SAndroid Build Coastguard Worker                    # requiring 64-bit indexing. That requires too much memory to run on CI, so we
10647*da0073e9SAndroid Build Coastguard Worker                    # also export the fallback and test it here to ensure feature parity with
10648*da0073e9SAndroid Build Coastguard Worker                    # the vectorized version.
10649*da0073e9SAndroid Build Coastguard Worker                    input_fallback = input_cpu.float().detach_().requires_grad_()
10650*da0073e9SAndroid Build Coastguard Worker                    grid_fallback = grid_cpu.float().detach_().requires_grad_()
10651*da0073e9SAndroid Build Coastguard Worker                    out_fallback = torch._grid_sampler_2d_cpu_fallback(
10652*da0073e9SAndroid Build Coastguard Worker                        input_fallback, grid_fallback,
10653*da0073e9SAndroid Build Coastguard Worker                        F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
10654*da0073e9SAndroid Build Coastguard Worker                        F.GRID_SAMPLE_PADDING_MODES[padding_mode],
10655*da0073e9SAndroid Build Coastguard Worker                        align_corners)
10656*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5)
10657*da0073e9SAndroid Build Coastguard Worker
10658*da0073e9SAndroid Build Coastguard Worker                    out_fallback.backward(gradients.float())
10659*da0073e9SAndroid Build Coastguard Worker                    if input_requires_grad:
10660*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5)
10661*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5)
10662*da0073e9SAndroid Build Coastguard Worker
10663*da0073e9SAndroid Build Coastguard Worker                    input_mps = input_cpu.detach().transpose(0, 1).to("mps").transpose(0, 1).requires_grad_(input_requires_grad)
10664*da0073e9SAndroid Build Coastguard Worker                    grid_mps = get_grid('mps', grid_cpu.detach()).requires_grad_()
10665*da0073e9SAndroid Build Coastguard Worker                    out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
10666*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out_cpu, out_mps)
10667*da0073e9SAndroid Build Coastguard Worker                    out_mps.backward(gradients.to("mps"))
10668*da0073e9SAndroid Build Coastguard Worker                    if input_requires_grad:
10669*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(input_cpu.grad, input_mps.grad)
10670*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grid_cpu.grad, grid_mps.grad, atol=5e-5, rtol=0)
10671*da0073e9SAndroid Build Coastguard Worker
10672*da0073e9SAndroid Build Coastguard Worker                    # check that zero-dimensional input strides don't error out
10673*da0073e9SAndroid Build Coastguard Worker                    base_input = torch.randn(N, C, 1, IW)
10674*da0073e9SAndroid Build Coastguard Worker                    input_cpu = base_input.expand_as(input_mps).requires_grad_(input_requires_grad)
10675*da0073e9SAndroid Build Coastguard Worker                    out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
10676*da0073e9SAndroid Build Coastguard Worker                                            align_corners=align_corners)
10677*da0073e9SAndroid Build Coastguard Worker
10678*da0073e9SAndroid Build Coastguard Worker                    input_mps = base_input.to("mps").expand_as(input_mps).requires_grad_(input_requires_grad)
10679*da0073e9SAndroid Build Coastguard Worker                    out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
10680*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(out_cpu, out_mps)
10681*da0073e9SAndroid Build Coastguard Worker
10682*da0073e9SAndroid Build Coastguard Worker            # test same size output
10683*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners)
10684*da0073e9SAndroid Build Coastguard Worker
10685*da0073e9SAndroid Build Coastguard Worker            # test larger output
10686*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 8)
10687*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 8)
10688*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 8)
10689*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 8)
10690*da0073e9SAndroid Build Coastguard Worker            H = random.randint(IH + 1, 12)
10691*da0073e9SAndroid Build Coastguard Worker            W = random.randint(IW + 1, 12)
10692*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10693*da0073e9SAndroid Build Coastguard Worker
10694*da0073e9SAndroid Build Coastguard Worker            # test smaller output
10695*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 8)
10696*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 8)
10697*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 8)
10698*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 8)
10699*da0073e9SAndroid Build Coastguard Worker            H = random.randint(2, IH)
10700*da0073e9SAndroid Build Coastguard Worker            W = random.randint(2, IW)
10701*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10702*da0073e9SAndroid Build Coastguard Worker
10703*da0073e9SAndroid Build Coastguard Worker            # test 1x1 inpput
10704*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 8)
10705*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 8)
10706*da0073e9SAndroid Build Coastguard Worker            IH = 1
10707*da0073e9SAndroid Build Coastguard Worker            IW = 1
10708*da0073e9SAndroid Build Coastguard Worker            H = random.randint(2, 5)
10709*da0073e9SAndroid Build Coastguard Worker            W = random.randint(2, 5)
10710*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10711*da0073e9SAndroid Build Coastguard Worker
10712*da0073e9SAndroid Build Coastguard Worker            # testing empty grid
10713*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 8)
10714*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 8)
10715*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 8)
10716*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 8)
10717*da0073e9SAndroid Build Coastguard Worker            W = random.randint(3, IW + 2)
10718*da0073e9SAndroid Build Coastguard Worker            test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners)
10719*da0073e9SAndroid Build Coastguard Worker
10720*da0073e9SAndroid Build Coastguard Worker            # testing empty channel
10721*da0073e9SAndroid Build Coastguard Worker            N = random.randint(2, 8)
10722*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 8)
10723*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 8)
10724*da0073e9SAndroid Build Coastguard Worker            H = random.randint(3, IH + 2)
10725*da0073e9SAndroid Build Coastguard Worker            W = random.randint(3, IW + 2)
10726*da0073e9SAndroid Build Coastguard Worker            test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners)
10727*da0073e9SAndroid Build Coastguard Worker
10728*da0073e9SAndroid Build Coastguard Worker            # testing empty batch
10729*da0073e9SAndroid Build Coastguard Worker            C = random.randint(2, 8)
10730*da0073e9SAndroid Build Coastguard Worker            IH = random.randint(2, 8)
10731*da0073e9SAndroid Build Coastguard Worker            IW = random.randint(2, 8)
10732*da0073e9SAndroid Build Coastguard Worker            H = random.randint(3, IH + 2)
10733*da0073e9SAndroid Build Coastguard Worker            W = random.randint(3, IW + 2)
10734*da0073e9SAndroid Build Coastguard Worker            test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners)
10735*da0073e9SAndroid Build Coastguard Worker
10736*da0073e9SAndroid Build Coastguard Worker        for mode in ('bilinear', 'nearest'):
10737*da0073e9SAndroid Build Coastguard Worker            for padding_mode in ('zeros', 'reflection'):
10738*da0073e9SAndroid Build Coastguard Worker                for align_corners in (True, False):
10739*da0073e9SAndroid Build Coastguard Worker                    # test known input
10740*da0073e9SAndroid Build Coastguard Worker                    input = torch.arange(1., 11, device="mps").view(1, 1, 2, 5)
10741*da0073e9SAndroid Build Coastguard Worker                    grid = torch.tensor(
10742*da0073e9SAndroid Build Coastguard Worker                        [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],
10743*da0073e9SAndroid Build Coastguard Worker                         [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]], device="mps").view(1, 2, 5, 2)
10744*da0073e9SAndroid Build Coastguard Worker                    if mode == 'bilinear':
10745*da0073e9SAndroid Build Coastguard Worker                        if padding_mode == 'zeros':
10746*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
10747*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10748*da0073e9SAndroid Build Coastguard Worker                                    [[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000],
10749*da0073e9SAndroid Build Coastguard Worker                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]], device="mps").view(1, 1, 2, 5)
10750*da0073e9SAndroid Build Coastguard Worker                            else:
10751*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10752*da0073e9SAndroid Build Coastguard Worker                                    [[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250],
10753*da0073e9SAndroid Build Coastguard Worker                                     [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]], device="mps").view(1, 1, 2, 5)
10754*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'border':
10755*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
10756*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10757*da0073e9SAndroid Build Coastguard Worker                                    [[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000],
10758*da0073e9SAndroid Build Coastguard Worker                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]], device="mps").view(1, 1, 2, 5)
10759*da0073e9SAndroid Build Coastguard Worker                            else:
10760*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10761*da0073e9SAndroid Build Coastguard Worker                                    [[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
10762*da0073e9SAndroid Build Coastguard Worker                                     [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]], device="mps").view(1, 1, 2, 5)
10763*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'reflection':
10764*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
10765*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10766*da0073e9SAndroid Build Coastguard Worker                                    [[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000],
10767*da0073e9SAndroid Build Coastguard Worker                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]], device="mps").view(1, 1, 2, 5)
10768*da0073e9SAndroid Build Coastguard Worker                            else:
10769*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10770*da0073e9SAndroid Build Coastguard Worker                                    [[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
10771*da0073e9SAndroid Build Coastguard Worker                                     [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]], device="mps").view(1, 1, 2, 5)
10772*da0073e9SAndroid Build Coastguard Worker                        else:
10773*da0073e9SAndroid Build Coastguard Worker                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
10774*da0073e9SAndroid Build Coastguard Worker                    elif mode == 'nearest':
10775*da0073e9SAndroid Build Coastguard Worker                        if padding_mode == 'zeros':
10776*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
10777*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10778*da0073e9SAndroid Build Coastguard Worker                                    [[0., 8., 5., 7., 9.],
10779*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
10780*da0073e9SAndroid Build Coastguard Worker                            else:
10781*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10782*da0073e9SAndroid Build Coastguard Worker                                    [[0., 8., 5., 7., 0.],
10783*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
10784*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'border':
10785*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
10786*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10787*da0073e9SAndroid Build Coastguard Worker                                    [[1., 8., 5., 7., 9.],
10788*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
10789*da0073e9SAndroid Build Coastguard Worker                            else:
10790*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10791*da0073e9SAndroid Build Coastguard Worker                                    [[1., 8., 5., 7., 9.],
10792*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
10793*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'reflection':
10794*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
10795*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10796*da0073e9SAndroid Build Coastguard Worker                                    [[1., 8., 5., 7., 9.],
10797*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
10798*da0073e9SAndroid Build Coastguard Worker                            else:
10799*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10800*da0073e9SAndroid Build Coastguard Worker                                    [[1., 8., 5., 7., 9.],
10801*da0073e9SAndroid Build Coastguard Worker                                     [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
10802*da0073e9SAndroid Build Coastguard Worker                        else:
10803*da0073e9SAndroid Build Coastguard Worker                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
10804*da0073e9SAndroid Build Coastguard Worker                    elif mode == 'bicubic':
10805*da0073e9SAndroid Build Coastguard Worker                        if padding_mode == 'zeros':
10806*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
10807*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10808*da0073e9SAndroid Build Coastguard Worker                                    [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000],
10809*da0073e9SAndroid Build Coastguard Worker                                     [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]], device="mps").view(1, 1, 2, 5)
10810*da0073e9SAndroid Build Coastguard Worker                            else:
10811*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10812*da0073e9SAndroid Build Coastguard Worker                                    [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264],
10813*da0073e9SAndroid Build Coastguard Worker                                     [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]], device="mps").view(1, 1, 2, 5)
10814*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'border':
10815*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
10816*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10817*da0073e9SAndroid Build Coastguard Worker                                    [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000],
10818*da0073e9SAndroid Build Coastguard Worker                                     [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]], device="mps").view(1, 1, 2, 5)
10819*da0073e9SAndroid Build Coastguard Worker                            else:
10820*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10821*da0073e9SAndroid Build Coastguard Worker                                    [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781],
10822*da0073e9SAndroid Build Coastguard Worker                                     [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]], device="mps").view(1, 1, 2, 5)
10823*da0073e9SAndroid Build Coastguard Worker                        elif padding_mode == 'reflection':
10824*da0073e9SAndroid Build Coastguard Worker                            if align_corners:
10825*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10826*da0073e9SAndroid Build Coastguard Worker                                    [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000],
10827*da0073e9SAndroid Build Coastguard Worker                                     [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]], device="mps").view(1, 1, 2, 5)
10828*da0073e9SAndroid Build Coastguard Worker                            else:
10829*da0073e9SAndroid Build Coastguard Worker                                groundtruth = torch.tensor(
10830*da0073e9SAndroid Build Coastguard Worker                                    [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531],
10831*da0073e9SAndroid Build Coastguard Worker                                     [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]], device="mps").view(1, 1, 2, 5)
10832*da0073e9SAndroid Build Coastguard Worker                        else:
10833*da0073e9SAndroid Build Coastguard Worker                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
10834*da0073e9SAndroid Build Coastguard Worker
10835*da0073e9SAndroid Build Coastguard Worker                    else:
10836*da0073e9SAndroid Build Coastguard Worker                        raise AssertionError(f"missing groundtruth test for interpolation mode '{mode}'")
10837*da0073e9SAndroid Build Coastguard Worker                    output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
10838*da0073e9SAndroid Build Coastguard Worker                                           align_corners=align_corners)
10839*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(output, groundtruth, atol=1e-5, rtol=0,
10840*da0073e9SAndroid Build Coastguard Worker                                     msg=f"groundtruth comparison failed for mode={mode}, "
10841*da0073e9SAndroid Build Coastguard Worker                                     f"padding_mode={padding_mode}")
10842*da0073e9SAndroid Build Coastguard Worker
10843*da0073e9SAndroid Build Coastguard Workerclass TestAdvancedIndexing(TestCaseMPS):
10844*da0073e9SAndroid Build Coastguard Worker    supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
10845*da0073e9SAndroid Build Coastguard Worker    supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
10846*da0073e9SAndroid Build Coastguard Worker
10847*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(product_version < 14.0, "Skipped on macOS < 14")
10848*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_no_warning(self):
10849*da0073e9SAndroid Build Coastguard Worker        device = "mps"
10850*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((2, 2), device=device)
10851*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
10852*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
10853*da0073e9SAndroid Build Coastguard Worker            torch.nonzero(t)
10854*da0073e9SAndroid Build Coastguard Worker            t.nonzero()
10855*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 0)
10856*da0073e9SAndroid Build Coastguard Worker
10857*da0073e9SAndroid Build Coastguard Worker    def test_nonzero(self):
10858*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
10859*da0073e9SAndroid Build Coastguard Worker            device = "mps"
10860*da0073e9SAndroid Build Coastguard Worker            shapes = [
10861*da0073e9SAndroid Build Coastguard Worker                torch.Size((12,)),
10862*da0073e9SAndroid Build Coastguard Worker                torch.Size((12, 1)),
10863*da0073e9SAndroid Build Coastguard Worker                torch.Size((1, 12)),
10864*da0073e9SAndroid Build Coastguard Worker                torch.Size((6, 2)),
10865*da0073e9SAndroid Build Coastguard Worker                torch.Size((3, 2, 2)),
10866*da0073e9SAndroid Build Coastguard Worker                torch.Size((5, 5, 5)),
10867*da0073e9SAndroid Build Coastguard Worker            ]
10868*da0073e9SAndroid Build Coastguard Worker
10869*da0073e9SAndroid Build Coastguard Worker            def gen_nontrivial_input(shape, dtype, device):
10870*da0073e9SAndroid Build Coastguard Worker                if dtype != torch.bfloat16:
10871*da0073e9SAndroid Build Coastguard Worker                    return torch.randint(2, shape, device=device, dtype=dtype)
10872*da0073e9SAndroid Build Coastguard Worker                else:
10873*da0073e9SAndroid Build Coastguard Worker                    # windows does not work for bfloat16 randing
10874*da0073e9SAndroid Build Coastguard Worker                    return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype)
10875*da0073e9SAndroid Build Coastguard Worker
10876*da0073e9SAndroid Build Coastguard Worker            for shape in shapes:
10877*da0073e9SAndroid Build Coastguard Worker                tensor = gen_nontrivial_input(shape, dtype, device)
10878*da0073e9SAndroid Build Coastguard Worker                dst1 = torch.nonzero(tensor, as_tuple=False)
10879*da0073e9SAndroid Build Coastguard Worker                dst2 = tensor.nonzero(as_tuple=False)
10880*da0073e9SAndroid Build Coastguard Worker                dst3 = torch.empty([], dtype=torch.long, device=device)
10881*da0073e9SAndroid Build Coastguard Worker                dst3 = dst3.resize_(0)
10882*da0073e9SAndroid Build Coastguard Worker                torch.nonzero(tensor, out=dst3)
10883*da0073e9SAndroid Build Coastguard Worker                np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy()
10884*da0073e9SAndroid Build Coastguard Worker                np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
10885*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
10886*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
10887*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
10888*da0073e9SAndroid Build Coastguard Worker                tup1 = torch.nonzero(tensor, as_tuple=True)
10889*da0073e9SAndroid Build Coastguard Worker                tup2 = tensor.nonzero(as_tuple=True)
10890*da0073e9SAndroid Build Coastguard Worker                tup1 = torch.stack(tup1).t().cpu()
10891*da0073e9SAndroid Build Coastguard Worker                tup2 = torch.stack(tup2).t().cpu()
10892*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tup1, np_result, atol=0, rtol=0)
10893*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tup2, np_result, atol=0, rtol=0)
10894*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in self.supported_dtypes]
10895*da0073e9SAndroid Build Coastguard Worker
10896*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_astuple_out(self):
10897*da0073e9SAndroid Build Coastguard Worker        device = "mps"
10898*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((3, 3, 3), device=device)
10899*da0073e9SAndroid Build Coastguard Worker        out = torch.empty([], dtype=torch.long, device=device)
10900*da0073e9SAndroid Build Coastguard Worker        out = out.resize_(0)
10901*da0073e9SAndroid Build Coastguard Worker
10902*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
10903*da0073e9SAndroid Build Coastguard Worker            torch.nonzero(t, as_tuple=True, out=out)
10904*da0073e9SAndroid Build Coastguard Worker
10905*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
10906*da0073e9SAndroid Build Coastguard Worker
10907*da0073e9SAndroid Build Coastguard Worker        # Verifies that JIT script cannot handle the as_tuple kwarg
10908*da0073e9SAndroid Build Coastguard Worker        # See Issue https://github.com/pytorch/pytorch/issues/45499.
10909*da0073e9SAndroid Build Coastguard Worker        def _foo(t):
10910*da0073e9SAndroid Build Coastguard Worker            tuple_result = torch.nonzero(t, as_tuple=True)
10911*da0073e9SAndroid Build Coastguard Worker            nontuple_result = torch.nonzero(t, as_tuple=False)
10912*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(nontuple_result)
10913*da0073e9SAndroid Build Coastguard Worker            torch.nonzero(t, as_tuple=False, out=out)
10914*da0073e9SAndroid Build Coastguard Worker            return tuple_result, nontuple_result, out
10915*da0073e9SAndroid Build Coastguard Worker
10916*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
10917*da0073e9SAndroid Build Coastguard Worker            scripted_foo = torch.jit.script(_foo)
10918*da0073e9SAndroid Build Coastguard Worker
10919*da0073e9SAndroid Build Coastguard Worker        # Verifies that JIT tracing works fine
10920*da0073e9SAndroid Build Coastguard Worker        traced_foo = torch.jit.trace(_foo, t)
10921*da0073e9SAndroid Build Coastguard Worker        traced_tuple, traced_nontuple, traced_out = traced_foo(t)
10922*da0073e9SAndroid Build Coastguard Worker        expected_tuple = torch.nonzero(t, as_tuple=True)
10923*da0073e9SAndroid Build Coastguard Worker        expected_nontuple = torch.nonzero(t)
10924*da0073e9SAndroid Build Coastguard Worker
10925*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_tuple, expected_tuple)
10926*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_nontuple, expected_nontuple)
10927*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_out, expected_nontuple)
10928*da0073e9SAndroid Build Coastguard Worker
10929*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_discontiguous(self):
10930*da0073e9SAndroid Build Coastguard Worker        device = "mps"
10931*da0073e9SAndroid Build Coastguard Worker        shape = (4, 4)
10932*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randint(2, shape, device=device)
10933*da0073e9SAndroid Build Coastguard Worker        tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
10934*da0073e9SAndroid Build Coastguard Worker        dst1 = tensor.nonzero(as_tuple=False)
10935*da0073e9SAndroid Build Coastguard Worker        dst2 = tensor_nc.nonzero(as_tuple=False)
10936*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst1, dst2, atol=0, rtol=0)
10937*da0073e9SAndroid Build Coastguard Worker        dst3 = torch.empty_like(dst1)
10938*da0073e9SAndroid Build Coastguard Worker        data_ptr = dst3.data_ptr()
10939*da0073e9SAndroid Build Coastguard Worker        # expect dst3 storage to be reused
10940*da0073e9SAndroid Build Coastguard Worker        torch.nonzero(tensor, out=dst3)
10941*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data_ptr, dst3.data_ptr())
10942*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst1, dst3, atol=0, rtol=0)
10943*da0073e9SAndroid Build Coastguard Worker        # discontiguous out
10944*da0073e9SAndroid Build Coastguard Worker        dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2]
10945*da0073e9SAndroid Build Coastguard Worker        data_ptr = dst4.data_ptr()
10946*da0073e9SAndroid Build Coastguard Worker        strides = dst4.stride()
10947*da0073e9SAndroid Build Coastguard Worker        torch.nonzero(tensor, out=dst4)
10948*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data_ptr, dst4.data_ptr())
10949*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst1, dst4, atol=0, rtol=0)
10950*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(strides, dst4.stride())
10951*da0073e9SAndroid Build Coastguard Worker
10952*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_non_diff(self):
10953*da0073e9SAndroid Build Coastguard Worker        device = "mps"
10954*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, requires_grad=True, device=device)
10955*da0073e9SAndroid Build Coastguard Worker        nz = x.nonzero()
10956*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(nz.requires_grad)
10957*da0073e9SAndroid Build Coastguard Worker
10958*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_multi_threading(self):
10959*da0073e9SAndroid Build Coastguard Worker        # Test that MPS doesn't crash if nonzero called concurrently
10960*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/100285
10961*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 3, device="mps")
10962*da0073e9SAndroid Build Coastguard Worker        t1 = threading.Thread(target=torch.nonzero, args=(x,))
10963*da0073e9SAndroid Build Coastguard Worker        t2 = threading.Thread(target=torch.nonzero, args=(x,))
10964*da0073e9SAndroid Build Coastguard Worker        t1.start()
10965*da0073e9SAndroid Build Coastguard Worker        t2.start()
10966*da0073e9SAndroid Build Coastguard Worker
10967*da0073e9SAndroid Build Coastguard Worker    def test_sliced_view_cast(self):
10968*da0073e9SAndroid Build Coastguard Worker        # This used to crash on MacOS Sequoia
10969*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/137800
10970*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(16, 16, device='mps', dtype=torch.float16)
10971*da0073e9SAndroid Build Coastguard Worker        y = x[:, 0:2].view(torch.float32) + 1
10972*da0073e9SAndroid Build Coastguard Worker
10973*da0073e9SAndroid Build Coastguard Worker    def test_masked_select(self):
10974*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4)
10975*da0073e9SAndroid Build Coastguard Worker        x_mps = x.to("mps")
10976*da0073e9SAndroid Build Coastguard Worker        mask = x.ge(0.5)
10977*da0073e9SAndroid Build Coastguard Worker        mask_mps = x_mps.ge(0.5)
10978*da0073e9SAndroid Build Coastguard Worker
10979*da0073e9SAndroid Build Coastguard Worker        res = torch.masked_select(x, mask)
10980*da0073e9SAndroid Build Coastguard Worker        res_mps = torch.masked_select(x_mps, mask_mps)
10981*da0073e9SAndroid Build Coastguard Worker
10982*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, res_mps)
10983*da0073e9SAndroid Build Coastguard Worker
10984*da0073e9SAndroid Build Coastguard Worker    # examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm
10985*da0073e9SAndroid Build Coastguard Worker    def test_indexing_get(self):
10986*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
10987*da0073e9SAndroid Build Coastguard Worker            x_cpu = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype)
10988*da0073e9SAndroid Build Coastguard Worker            x_mps = x_cpu.detach().clone().to("mps")
10989*da0073e9SAndroid Build Coastguard Worker
10990*da0073e9SAndroid Build Coastguard Worker            y_cpu = x_cpu[[0, 1, 2], [0, 1, 0]]
10991*da0073e9SAndroid Build Coastguard Worker            y_mps = x_mps[[0, 1, 2], [0, 1, 0]]
10992*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_cpu, y_mps, str(dtype))
10993*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in self.supported_dtypes]
10994*da0073e9SAndroid Build Coastguard Worker
10995*da0073e9SAndroid Build Coastguard Worker    def test_indexing_select_corners(self):
10996*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
10997*da0073e9SAndroid Build Coastguard Worker            x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
10998*da0073e9SAndroid Build Coastguard Worker            x_mps = x_cpu.detach().clone().to("mps")
10999*da0073e9SAndroid Build Coastguard Worker
11000*da0073e9SAndroid Build Coastguard Worker            rows_cpu = torch.tensor([[0, 0], [3, 3]])
11001*da0073e9SAndroid Build Coastguard Worker            rows_mps = rows_cpu.detach().clone().to("mps")
11002*da0073e9SAndroid Build Coastguard Worker
11003*da0073e9SAndroid Build Coastguard Worker            cols_cpu = torch.tensor([[0, 2], [0, 2]])
11004*da0073e9SAndroid Build Coastguard Worker            cols_mps = cols_cpu.detach().clone().to("mps")
11005*da0073e9SAndroid Build Coastguard Worker
11006*da0073e9SAndroid Build Coastguard Worker            res_cpu = x_cpu[rows_cpu, cols_cpu]
11007*da0073e9SAndroid Build Coastguard Worker            res_mps = x_mps[rows_mps, cols_mps]
11008*da0073e9SAndroid Build Coastguard Worker
11009*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_cpu, res_mps, str(dtype))
11010*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in self.supported_dtypes]
11011*da0073e9SAndroid Build Coastguard Worker
11012*da0073e9SAndroid Build Coastguard Worker    # FIXME: uint8 fails for this testcase, needs further debugging
11013*da0073e9SAndroid Build Coastguard Worker    def test_slicing_using_advanced_index_for_column(self):
11014*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
11015*da0073e9SAndroid Build Coastguard Worker            x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
11016*da0073e9SAndroid Build Coastguard Worker            x_mps = x_cpu.detach().clone().to("mps")
11017*da0073e9SAndroid Build Coastguard Worker
11018*da0073e9SAndroid Build Coastguard Worker            z_cpu = x_cpu[1:4, 1:3]
11019*da0073e9SAndroid Build Coastguard Worker            z_mps = x_mps[1:4, 1:3]
11020*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(z_cpu, z_mps, str(dtype))
11021*da0073e9SAndroid Build Coastguard Worker
11022*da0073e9SAndroid Build Coastguard Worker            # using advanced index for column
11023*da0073e9SAndroid Build Coastguard Worker            y_cpu = x_cpu[1:4, [1, 2]]
11024*da0073e9SAndroid Build Coastguard Worker            y_mps = x_mps[1:4, [1, 2]]
11025*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_cpu, y_mps, str(dtype))
11026*da0073e9SAndroid Build Coastguard Worker        # FIXME: use supported_dtypes once uint8 is fixed
11027*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]]
11028*da0073e9SAndroid Build Coastguard Worker
11029*da0073e9SAndroid Build Coastguard Worker    def test_boolean_array_indexing(self):
11030*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
11031*da0073e9SAndroid Build Coastguard Worker            x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
11032*da0073e9SAndroid Build Coastguard Worker            x_mps = x_cpu.detach().clone().to("mps")
11033*da0073e9SAndroid Build Coastguard Worker
11034*da0073e9SAndroid Build Coastguard Worker            res_cpu = x_cpu[x_cpu > 5]
11035*da0073e9SAndroid Build Coastguard Worker            res_mps = x_mps[x_mps > 5]
11036*da0073e9SAndroid Build Coastguard Worker
11037*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res_cpu, res_mps, str(dtype))
11038*da0073e9SAndroid Build Coastguard Worker        for dtype in self.supported_dtypes:
11039*da0073e9SAndroid Build Coastguard Worker            # MPS support binary op with uint8 natively starting from macOS 13.0
11040*da0073e9SAndroid Build Coastguard Worker            if product_version < 13.0 and dtype == torch.uint8:
11041*da0073e9SAndroid Build Coastguard Worker                continue
11042*da0073e9SAndroid Build Coastguard Worker            helper(dtype)
11043*da0073e9SAndroid Build Coastguard Worker
11044*da0073e9SAndroid Build Coastguard Worker    def test_advanced_indexing_3D_get(self):
11045*da0073e9SAndroid Build Coastguard Worker        def helper(x_cpu):
11046*da0073e9SAndroid Build Coastguard Worker            x_mps = x_cpu.detach().clone().to("mps")
11047*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu[[1, 2], 3, :], x_mps[[1, 2], 3, :])
11048*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu[[0, 2], :, :], x_mps[[0, 2], :, :])
11049*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu[:, [1, 0], [1]], x_mps[:, [1, 0], [1]])
11050*da0073e9SAndroid Build Coastguard Worker
11051*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
11052*da0073e9SAndroid Build Coastguard Worker                               [0.5, 0.6, 0.7, 0.8],
11053*da0073e9SAndroid Build Coastguard Worker                               [0.9, 1.0, 1.1, 1.2],
11054*da0073e9SAndroid Build Coastguard Worker                               [1.3, 1.4, 1.5, 1.6]],
11055*da0073e9SAndroid Build Coastguard Worker
11056*da0073e9SAndroid Build Coastguard Worker                              [[2.0, 2.1, 2.2, 2.3],
11057*da0073e9SAndroid Build Coastguard Worker                               [2.4, 2.5, 2.6, 2.7],
11058*da0073e9SAndroid Build Coastguard Worker                               [2.8, 2.9, 3.0, 3.1],
11059*da0073e9SAndroid Build Coastguard Worker                               [3.2, 3.3, 3.4, 3.5]],
11060*da0073e9SAndroid Build Coastguard Worker
11061*da0073e9SAndroid Build Coastguard Worker                              [[4.0, 4.1, 4.2, 4.3],
11062*da0073e9SAndroid Build Coastguard Worker                               [4.4, 4.5, 4.6, 4.7],
11063*da0073e9SAndroid Build Coastguard Worker                               [4.8, 4.9, 5.0, 5.1],
11064*da0073e9SAndroid Build Coastguard Worker                               [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
11065*da0073e9SAndroid Build Coastguard Worker        helper(x_cpu)
11066*da0073e9SAndroid Build Coastguard Worker        for idx in range(len(self.supported_np_dtypes)):
11067*da0073e9SAndroid Build Coastguard Worker            # torch.randn / torch.rand don't work with all dtypes
11068*da0073e9SAndroid Build Coastguard Worker            # Generate input data for all dtypes on Numpy them move to torch
11069*da0073e9SAndroid Build Coastguard Worker            input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
11070*da0073e9SAndroid Build Coastguard Worker            inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
11071*da0073e9SAndroid Build Coastguard Worker
11072*da0073e9SAndroid Build Coastguard Worker            helper(inputCPU)
11073*da0073e9SAndroid Build Coastguard Worker
11074*da0073e9SAndroid Build Coastguard Worker    def test_advanced_indexing_3D_put(self):
11075*da0073e9SAndroid Build Coastguard Worker        def helper(x_cpu):
11076*da0073e9SAndroid Build Coastguard Worker            dtype = x_cpu.dtype
11077*da0073e9SAndroid Build Coastguard Worker            x_mps = x_cpu.detach().clone().to("mps")
11078*da0073e9SAndroid Build Coastguard Worker
11079*da0073e9SAndroid Build Coastguard Worker            out_tensor_cpu = torch.tensor([88, 99], dtype=dtype, device="cpu")
11080*da0073e9SAndroid Build Coastguard Worker            out_tensor_cpu_view = out_tensor_cpu[1:]
11081*da0073e9SAndroid Build Coastguard Worker
11082*da0073e9SAndroid Build Coastguard Worker            out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps")
11083*da0073e9SAndroid Build Coastguard Worker            out_tensor_mps_view = out_tensor_mps[1:]
11084*da0073e9SAndroid Build Coastguard Worker
11085*da0073e9SAndroid Build Coastguard Worker            x_cpu[[1, 2], 3, :] = out_tensor_cpu_view
11086*da0073e9SAndroid Build Coastguard Worker            x_mps[[1, 2], 3, :] = out_tensor_mps_view
11087*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu, x_mps)
11088*da0073e9SAndroid Build Coastguard Worker
11089*da0073e9SAndroid Build Coastguard Worker            x_cpu[[0, 2], :, :] = out_tensor_cpu_view
11090*da0073e9SAndroid Build Coastguard Worker            x_mps[[0, 2], :, :] = out_tensor_mps_view
11091*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu, x_mps)
11092*da0073e9SAndroid Build Coastguard Worker
11093*da0073e9SAndroid Build Coastguard Worker            x_cpu[:, [1, 0], [1]] = out_tensor_cpu_view
11094*da0073e9SAndroid Build Coastguard Worker            x_mps[:, [1, 0], [1]] = out_tensor_mps_view
11095*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cpu, x_mps)
11096*da0073e9SAndroid Build Coastguard Worker
11097*da0073e9SAndroid Build Coastguard Worker        x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
11098*da0073e9SAndroid Build Coastguard Worker                               [0.5, 0.6, 0.7, 0.8],
11099*da0073e9SAndroid Build Coastguard Worker                               [0.9, 1.0, 1.1, 1.2],
11100*da0073e9SAndroid Build Coastguard Worker                               [1.3, 1.4, 1.5, 1.6]],
11101*da0073e9SAndroid Build Coastguard Worker
11102*da0073e9SAndroid Build Coastguard Worker                              [[2.0, 2.1, 2.2, 2.3],
11103*da0073e9SAndroid Build Coastguard Worker                               [2.4, 2.5, 2.6, 2.7],
11104*da0073e9SAndroid Build Coastguard Worker                               [2.8, 2.9, 3.0, 3.1],
11105*da0073e9SAndroid Build Coastguard Worker                               [3.2, 3.3, 3.4, 3.5]],
11106*da0073e9SAndroid Build Coastguard Worker
11107*da0073e9SAndroid Build Coastguard Worker                              [[4.0, 4.1, 4.2, 4.3],
11108*da0073e9SAndroid Build Coastguard Worker                               [4.4, 4.5, 4.6, 4.7],
11109*da0073e9SAndroid Build Coastguard Worker                               [4.8, 4.9, 5.0, 5.1],
11110*da0073e9SAndroid Build Coastguard Worker                               [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
11111*da0073e9SAndroid Build Coastguard Worker        helper(x_cpu)
11112*da0073e9SAndroid Build Coastguard Worker        for idx in range(len(self.supported_np_dtypes)):
11113*da0073e9SAndroid Build Coastguard Worker            # torch.randn / torch.rand don't work with all dtypes
11114*da0073e9SAndroid Build Coastguard Worker            # Generate input data for all dtypes on Numpy them move to torch
11115*da0073e9SAndroid Build Coastguard Worker            input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
11116*da0073e9SAndroid Build Coastguard Worker            inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
11117*da0073e9SAndroid Build Coastguard Worker
11118*da0073e9SAndroid Build Coastguard Worker            helper(inputCPU)
11119*da0073e9SAndroid Build Coastguard Worker
11120*da0073e9SAndroid Build Coastguard Worker    def test_index_put_with_view_indices(self):
11121*da0073e9SAndroid Build Coastguard Worker        def helper(dtype):
11122*da0073e9SAndroid Build Coastguard Worker            target_cpu = torch.zeros([5, 3], device="cpu", dtype=dtype)
11123*da0073e9SAndroid Build Coastguard Worker            target_mps = torch.zeros([5, 3], device="mps", dtype=dtype)
11124*da0073e9SAndroid Build Coastguard Worker
11125*da0073e9SAndroid Build Coastguard Worker            indices_cpu = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cpu")
11126*da0073e9SAndroid Build Coastguard Worker            indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps")
11127*da0073e9SAndroid Build Coastguard Worker
11128*da0073e9SAndroid Build Coastguard Worker            value_cpu = torch.ones(indices_cpu.shape[0], device="cpu", dtype=dtype)
11129*da0073e9SAndroid Build Coastguard Worker            value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype)
11130*da0073e9SAndroid Build Coastguard Worker
11131*da0073e9SAndroid Build Coastguard Worker            target_cpu.index_put_(tuple(indices_cpu.t()), value_cpu, accumulate=True)
11132*da0073e9SAndroid Build Coastguard Worker            target_mps.index_put_(tuple(indices_mps.t()), value_mps, accumulate=True)
11133*da0073e9SAndroid Build Coastguard Worker
11134*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(target_cpu, target_mps)
11135*da0073e9SAndroid Build Coastguard Worker
11136*da0073e9SAndroid Build Coastguard Worker        [helper(dtype) for dtype in [torch.int32, torch.float]]
11137*da0073e9SAndroid Build Coastguard Worker
11138*da0073e9SAndroid Build Coastguard Worker    # tests from 'test_indexing.py'
11139*da0073e9SAndroid Build Coastguard Worker    def test_advancedindex_big(self, device="mps"):
11140*da0073e9SAndroid Build Coastguard Worker        reference = torch.arange(0, 123344, dtype=torch.int, device=device)
11141*da0073e9SAndroid Build Coastguard Worker
11142*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
11143*da0073e9SAndroid Build Coastguard Worker                         torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
11144*da0073e9SAndroid Build Coastguard Worker
11145*da0073e9SAndroid Build Coastguard Worker    def test_set_item_to_scalar_tensor(self, device="mps"):
11146*da0073e9SAndroid Build Coastguard Worker        m = random.randint(1, 10)
11147*da0073e9SAndroid Build Coastguard Worker        n = random.randint(1, 10)
11148*da0073e9SAndroid Build Coastguard Worker        z = torch.randn([m, n], device=device)
11149*da0073e9SAndroid Build Coastguard Worker        a = 1.0
11150*da0073e9SAndroid Build Coastguard Worker        w = torch.tensor(a, requires_grad=True, device=device)
11151*da0073e9SAndroid Build Coastguard Worker        z[:, 0] = w
11152*da0073e9SAndroid Build Coastguard Worker        z.sum().backward()
11153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(w.grad, m * a)
11154*da0073e9SAndroid Build Coastguard Worker
11155*da0073e9SAndroid Build Coastguard Worker    def test_single_int(self, device="mps"):
11156*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
11157*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[4].shape, (7, 3))
11158*da0073e9SAndroid Build Coastguard Worker
11159*da0073e9SAndroid Build Coastguard Worker    def test_multiple_int(self, device="mps"):
11160*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
11161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[4].shape, (7, 3))
11162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[4, :, 1].shape, (7,))
11163*da0073e9SAndroid Build Coastguard Worker
11164*da0073e9SAndroid Build Coastguard Worker    def test_none(self, device="mps"):
11165*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
11166*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[None].shape, (1, 5, 7, 3))
11167*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
11168*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
11169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
11170*da0073e9SAndroid Build Coastguard Worker
11171*da0073e9SAndroid Build Coastguard Worker    def test_step(self, device="mps"):
11172*da0073e9SAndroid Build Coastguard Worker        v = torch.arange(10, device=device)
11173*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[::1], v)
11174*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
11175*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
11176*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[::11].tolist(), [0])
11177*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
11178*da0073e9SAndroid Build Coastguard Worker
11179*da0073e9SAndroid Build Coastguard Worker    def test_step_assignment(self, device="mps"):
11180*da0073e9SAndroid Build Coastguard Worker        v = torch.zeros(4, 4, device=device)
11181*da0073e9SAndroid Build Coastguard Worker        v[0, 1::2] = torch.tensor([3., 4.], device=device)
11182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
11183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[1:].sum(), 0)
11184*da0073e9SAndroid Build Coastguard Worker
11185*da0073e9SAndroid Build Coastguard Worker    def test_bool_indices(self, device="mps"):
11186*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
11187*da0073e9SAndroid Build Coastguard Worker        boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device)
11188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[boolIndices].shape, (3, 7, 3))
11189*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
11190*da0073e9SAndroid Build Coastguard Worker
11191*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
11192*da0073e9SAndroid Build Coastguard Worker        boolIndices = torch.tensor([True, False, False], dtype=torch.bool, device=device)
11193*da0073e9SAndroid Build Coastguard Worker        uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
11194*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
11195*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
11196*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v[boolIndices], v[uint8Indices])
11197*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device))
11198*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 2)
11199*da0073e9SAndroid Build Coastguard Worker
11200*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
11201*da0073e9SAndroid Build Coastguard Worker    def test_bool_indices_accumulate(self, device="mps"):
11202*da0073e9SAndroid Build Coastguard Worker        mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
11203*da0073e9SAndroid Build Coastguard Worker        mask = mask > 0
11204*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(size=(10, 10), device=device)
11205*da0073e9SAndroid Build Coastguard Worker        y.index_put_((mask, ), y[mask], accumulate=True)
11206*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, torch.ones(size=(10, 10), device=device))
11207*da0073e9SAndroid Build Coastguard Worker
11208*da0073e9SAndroid Build Coastguard Worker    def test_multiple_bool_indices(self, device="mps"):
11209*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
11210*da0073e9SAndroid Build Coastguard Worker        # note: these broadcast together and are transposed to the first dim
11211*da0073e9SAndroid Build Coastguard Worker        mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
11212*da0073e9SAndroid Build Coastguard Worker        mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
11213*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
11214*da0073e9SAndroid Build Coastguard Worker
11215*da0073e9SAndroid Build Coastguard Worker    def test_byte_mask(self, device="mps"):
11216*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
11217*da0073e9SAndroid Build Coastguard Worker        mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
11218*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
11219*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v[mask].shape, (3, 7, 3))
11220*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
11221*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 2)
11222*da0073e9SAndroid Build Coastguard Worker
11223*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([1.], device=device)
11224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[v == 0], torch.tensor([], device=device))
11225*da0073e9SAndroid Build Coastguard Worker
11226*da0073e9SAndroid Build Coastguard Worker    def test_byte_mask_accumulate(self, device="mps"):
11227*da0073e9SAndroid Build Coastguard Worker        mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
11228*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(size=(10, 10), device=device)
11229*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
11230*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
11231*da0073e9SAndroid Build Coastguard Worker            y.index_put_((mask, ), y[mask], accumulate=True)
11232*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, torch.ones(size=(10, 10), device=device))
11233*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 2)
11234*da0073e9SAndroid Build Coastguard Worker
11235*da0073e9SAndroid Build Coastguard Worker    def test_index_put_accumulate_expanded_values(self, device="mps"):
11236*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros((5, 2))
11237*da0073e9SAndroid Build Coastguard Worker        t_dev = t.to(device)
11238*da0073e9SAndroid Build Coastguard Worker        indices = [
11239*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 1, 2, 3]),
11240*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, ]),
11241*da0073e9SAndroid Build Coastguard Worker        ]
11242*da0073e9SAndroid Build Coastguard Worker        indices_dev = [i.to(device) for i in indices]
11243*da0073e9SAndroid Build Coastguard Worker        values0d = torch.tensor(1.0)
11244*da0073e9SAndroid Build Coastguard Worker        values1d = torch.tensor([1.0, ])
11245*da0073e9SAndroid Build Coastguard Worker
11246*da0073e9SAndroid Build Coastguard Worker        out_mps = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
11247*da0073e9SAndroid Build Coastguard Worker        out_cpu = t.index_put_(indices, values0d, accumulate=True)
11248*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_mps.cpu(), out_cpu)
11249*da0073e9SAndroid Build Coastguard Worker
11250*da0073e9SAndroid Build Coastguard Worker        out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
11251*da0073e9SAndroid Build Coastguard Worker        out_cpu = t.index_put_(indices, values1d, accumulate=True)
11252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_mps.cpu(), out_cpu)
11253*da0073e9SAndroid Build Coastguard Worker
11254*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros(4, 3, 2)
11255*da0073e9SAndroid Build Coastguard Worker        t_dev = t.to(device)
11256*da0073e9SAndroid Build Coastguard Worker
11257*da0073e9SAndroid Build Coastguard Worker        indices = [
11258*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, ]),
11259*da0073e9SAndroid Build Coastguard Worker            torch.arange(3)[:, None],
11260*da0073e9SAndroid Build Coastguard Worker            torch.arange(2)[None, :],
11261*da0073e9SAndroid Build Coastguard Worker        ]
11262*da0073e9SAndroid Build Coastguard Worker        indices_dev = [i.to(device) for i in indices]
11263*da0073e9SAndroid Build Coastguard Worker        values1d = torch.tensor([-1.0, -2.0])
11264*da0073e9SAndroid Build Coastguard Worker        values2d = torch.tensor([[-1.0, -2.0], ])
11265*da0073e9SAndroid Build Coastguard Worker
11266*da0073e9SAndroid Build Coastguard Worker        out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
11267*da0073e9SAndroid Build Coastguard Worker        out_cpu = t.index_put_(indices, values1d, accumulate=True)
11268*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_mps.cpu(), out_cpu)
11269*da0073e9SAndroid Build Coastguard Worker
11270*da0073e9SAndroid Build Coastguard Worker        out_mps = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
11271*da0073e9SAndroid Build Coastguard Worker        out_cpu = t.index_put_(indices, values2d, accumulate=True)
11272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_mps.cpu(), out_cpu)
11273*da0073e9SAndroid Build Coastguard Worker
11274*da0073e9SAndroid Build Coastguard Worker    def test_index_put_accumulate_non_contiguous(self, device="mps"):
11275*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros((5, 2, 2))
11276*da0073e9SAndroid Build Coastguard Worker        t_dev = t.to(device)
11277*da0073e9SAndroid Build Coastguard Worker        t1 = t_dev[:, 0, :]
11278*da0073e9SAndroid Build Coastguard Worker        t2 = t[:, 0, :]
11279*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t1.is_contiguous())
11280*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t2.is_contiguous())
11281*da0073e9SAndroid Build Coastguard Worker
11282*da0073e9SAndroid Build Coastguard Worker        indices = [torch.tensor([0, 1]), ]
11283*da0073e9SAndroid Build Coastguard Worker        indices_dev = [i.to(device) for i in indices]
11284*da0073e9SAndroid Build Coastguard Worker        value = torch.randn(2, 2)
11285*da0073e9SAndroid Build Coastguard Worker        out_mps = t1.index_put_(indices_dev, value.to(device), accumulate=True)
11286*da0073e9SAndroid Build Coastguard Worker        out_cpu = t2.index_put_(indices, value, accumulate=True)
11287*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t1.is_contiguous())
11288*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t2.is_contiguous())
11289*da0073e9SAndroid Build Coastguard Worker
11290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_mps.cpu(), out_cpu)
11291*da0073e9SAndroid Build Coastguard Worker
11292*da0073e9SAndroid Build Coastguard Worker    def test_index_put_accumulate_with_optional_tensors(self, device="mps"):
11293*da0073e9SAndroid Build Coastguard Worker        # TODO: replace with a better solution.
11294*da0073e9SAndroid Build Coastguard Worker        # Currently, here using torchscript to put None into indices.
11295*da0073e9SAndroid Build Coastguard Worker        # on C++ it gives indices as a list of 2 optional tensors: first is null and
11296*da0073e9SAndroid Build Coastguard Worker        # the second is a valid tensor.
11297*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
11298*da0073e9SAndroid Build Coastguard Worker        def func(x, i, v):
11299*da0073e9SAndroid Build Coastguard Worker            idx = [None, i]
11300*da0073e9SAndroid Build Coastguard Worker            x.index_put_(idx, v, accumulate=True)
11301*da0073e9SAndroid Build Coastguard Worker            return x
11302*da0073e9SAndroid Build Coastguard Worker
11303*da0073e9SAndroid Build Coastguard Worker        n = 4
11304*da0073e9SAndroid Build Coastguard Worker        t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
11305*da0073e9SAndroid Build Coastguard Worker        t_dev = t.to(device)
11306*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor([1, 0])
11307*da0073e9SAndroid Build Coastguard Worker        indices_dev = indices.to(device)
11308*da0073e9SAndroid Build Coastguard Worker        value0d = torch.tensor(10.0)
11309*da0073e9SAndroid Build Coastguard Worker        value1d = torch.tensor([1.0, 2.0])
11310*da0073e9SAndroid Build Coastguard Worker
11311*da0073e9SAndroid Build Coastguard Worker        out_mps = func(t_dev, indices_dev, value0d.to("mps"))
11312*da0073e9SAndroid Build Coastguard Worker        out_cpu = func(t, indices, value0d)
11313*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_mps.cpu(), out_cpu)
11314*da0073e9SAndroid Build Coastguard Worker
11315*da0073e9SAndroid Build Coastguard Worker        out_mps = func(t_dev, indices_dev, value1d.to("mps"))
11316*da0073e9SAndroid Build Coastguard Worker        out_cpu = func(t, indices, value1d)
11317*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_mps.cpu(), out_cpu)
11318*da0073e9SAndroid Build Coastguard Worker
11319*da0073e9SAndroid Build Coastguard Worker    def test_index_put_accumulate_duplicate_indices(self, device="mps"):
11320*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 128):
11321*da0073e9SAndroid Build Coastguard Worker            # generate indices by random walk, this will create indices with
11322*da0073e9SAndroid Build Coastguard Worker            # lots of duplicates interleaved with each other
11323*da0073e9SAndroid Build Coastguard Worker            delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1)
11324*da0073e9SAndroid Build Coastguard Worker
11325*da0073e9SAndroid Build Coastguard Worker            indices = delta.cumsum(0).long().to("mps")
11326*da0073e9SAndroid Build Coastguard Worker
11327*da0073e9SAndroid Build Coastguard Worker            # abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
11328*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device)
11329*da0073e9SAndroid Build Coastguard Worker            values = torch.randn(indices.size(0), device=device)
11330*da0073e9SAndroid Build Coastguard Worker            output = input.index_put((indices,), values, accumulate=True)
11331*da0073e9SAndroid Build Coastguard Worker
11332*da0073e9SAndroid Build Coastguard Worker            input_list = input.tolist()
11333*da0073e9SAndroid Build Coastguard Worker            indices_list = indices.tolist()
11334*da0073e9SAndroid Build Coastguard Worker            values_list = values.tolist()
11335*da0073e9SAndroid Build Coastguard Worker            for i, v in zip(indices_list, values_list):
11336*da0073e9SAndroid Build Coastguard Worker                input_list[i] += v
11337*da0073e9SAndroid Build Coastguard Worker
11338*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output, input_list)
11339*da0073e9SAndroid Build Coastguard Worker
11340*da0073e9SAndroid Build Coastguard Worker    def test_index_put_deterministic(self, device="mps"):
11341*da0073e9SAndroid Build Coastguard Worker        def helper(dtype, accumulate, deterministic, num_tests=128):
11342*da0073e9SAndroid Build Coastguard Worker            acc_expected = torch.tensor([233, 187, 360], device=device, dtype=dtype)
11343*da0073e9SAndroid Build Coastguard Worker            non_acc_expected = torch.tensor([38, 37, 39], device=device, dtype=dtype)
11344*da0073e9SAndroid Build Coastguard Worker            t_idx = torch.tensor(
11345*da0073e9SAndroid Build Coastguard Worker                [0, 0, 0, 0, 2, 2, 1, 0, 2, 1, 0, 1, 2, 1, 0, 2, 2, 2, 2, 2,
11346*da0073e9SAndroid Build Coastguard Worker                 0, 0, 2, 1, 2, 1, 0, 0, 2, 0, 2, 1, 1, 2, 2, 0, 2, 1, 0, 2]
11347*da0073e9SAndroid Build Coastguard Worker            )
11348*da0073e9SAndroid Build Coastguard Worker            for _ in range(num_tests):
11349*da0073e9SAndroid Build Coastguard Worker                try:
11350*da0073e9SAndroid Build Coastguard Worker                    torch.use_deterministic_algorithms(deterministic)
11351*da0073e9SAndroid Build Coastguard Worker                    t = torch.zeros(3, dtype=dtype, device=device)
11352*da0073e9SAndroid Build Coastguard Worker                    t.index_put_((t_idx,), torch.arange(len(t_idx), device=device, dtype=dtype), accumulate=accumulate)
11353*da0073e9SAndroid Build Coastguard Worker                    if accumulate:
11354*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(t, acc_expected)
11355*da0073e9SAndroid Build Coastguard Worker                    else:
11356*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(t, non_acc_expected)
11357*da0073e9SAndroid Build Coastguard Worker                finally:
11358*da0073e9SAndroid Build Coastguard Worker                    torch.use_deterministic_algorithms(False)
11359*da0073e9SAndroid Build Coastguard Worker
11360*da0073e9SAndroid Build Coastguard Worker        for accumulate, deterministic in product((False, True), (False, True)):
11361*da0073e9SAndroid Build Coastguard Worker            dtype = torch.float if accumulate else torch.long
11362*da0073e9SAndroid Build Coastguard Worker            if not accumulate and not deterministic:
11363*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(AssertionError, "Tensor-likes are not equal!"):
11364*da0073e9SAndroid Build Coastguard Worker                    helper(dtype, accumulate, deterministic)
11365*da0073e9SAndroid Build Coastguard Worker            else:
11366*da0073e9SAndroid Build Coastguard Worker                helper(dtype, accumulate, deterministic)
11367*da0073e9SAndroid Build Coastguard Worker
11368*da0073e9SAndroid Build Coastguard Worker    def test_multiple_byte_mask(self, device="mps"):
11369*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
11370*da0073e9SAndroid Build Coastguard Worker        # note: these broadcast together and are transposed to the first dim
11371*da0073e9SAndroid Build Coastguard Worker        mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
11372*da0073e9SAndroid Build Coastguard Worker        mask2 = torch.ByteTensor([1, 1, 1]).to(device)
11373*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
11374*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
11375*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
11376*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 2)
11377*da0073e9SAndroid Build Coastguard Worker
11378*da0073e9SAndroid Build Coastguard Worker    def test_byte_mask2d(self, device="mps"):
11379*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
11380*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(5, 7, device=device)
11381*da0073e9SAndroid Build Coastguard Worker        num_ones = (c > 0).sum()
11382*da0073e9SAndroid Build Coastguard Worker        r = v[c > 0]
11383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r.shape, (num_ones, 3))
11384*da0073e9SAndroid Build Coastguard Worker
11385*da0073e9SAndroid Build Coastguard Worker    def test_jit_indexing(self, device="mps"):
11386*da0073e9SAndroid Build Coastguard Worker        def fn1(x):
11387*da0073e9SAndroid Build Coastguard Worker            x[x < 50] = 1.0
11388*da0073e9SAndroid Build Coastguard Worker            return x
11389*da0073e9SAndroid Build Coastguard Worker
11390*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
11391*da0073e9SAndroid Build Coastguard Worker            x[0:50] = 1.0
11392*da0073e9SAndroid Build Coastguard Worker            return x
11393*da0073e9SAndroid Build Coastguard Worker
11394*da0073e9SAndroid Build Coastguard Worker        scripted_fn1 = torch.jit.script(fn1)
11395*da0073e9SAndroid Build Coastguard Worker        scripted_fn2 = torch.jit.script(fn2)
11396*da0073e9SAndroid Build Coastguard Worker        data = torch.arange(100, device=device, dtype=torch.float)
11397*da0073e9SAndroid Build Coastguard Worker        out = scripted_fn1(data.detach().clone())
11398*da0073e9SAndroid Build Coastguard Worker        ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
11399*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref)
11400*da0073e9SAndroid Build Coastguard Worker        out = scripted_fn2(data.detach().clone())
11401*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref)
11402*da0073e9SAndroid Build Coastguard Worker
11403*da0073e9SAndroid Build Coastguard Worker    def test_int_indices(self, device="mps"):
11404*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 7, 3, device=device)
11405*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
11406*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
11407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
11408*da0073e9SAndroid Build Coastguard Worker
11409*da0073e9SAndroid Build Coastguard Worker    def test_index_put_src_datatype(self):
11410*da0073e9SAndroid Build Coastguard Worker        def helper(device, dtype):
11411*da0073e9SAndroid Build Coastguard Worker            src = torch.ones(3, 2, 4, device=device, dtype=dtype)
11412*da0073e9SAndroid Build Coastguard Worker            vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
11413*da0073e9SAndroid Build Coastguard Worker            indices = (torch.tensor([0, 2, 1]),)
11414*da0073e9SAndroid Build Coastguard Worker            res = src.index_put_(indices, vals, accumulate=True)
11415*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.shape, src.shape)
11416*da0073e9SAndroid Build Coastguard Worker        [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]]
11417*da0073e9SAndroid Build Coastguard Worker
11418*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
11419*da0073e9SAndroid Build Coastguard Worker    def test_index_src_datatype(self):
11420*da0073e9SAndroid Build Coastguard Worker        def helper(device, dtype):
11421*da0073e9SAndroid Build Coastguard Worker            orig_dtype = dtype
11422*da0073e9SAndroid Build Coastguard Worker            if dtype is torch.bool:
11423*da0073e9SAndroid Build Coastguard Worker                dtype = torch.uint8
11424*da0073e9SAndroid Build Coastguard Worker
11425*da0073e9SAndroid Build Coastguard Worker            src = torch.ones(3, 2, 4, device=device, dtype=dtype)
11426*da0073e9SAndroid Build Coastguard Worker            if orig_dtype is torch.bool:
11427*da0073e9SAndroid Build Coastguard Worker                src = src == 1
11428*da0073e9SAndroid Build Coastguard Worker            # test index
11429*da0073e9SAndroid Build Coastguard Worker            res = src[[0, 2, 1], :, :]
11430*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.shape, src.shape)
11431*da0073e9SAndroid Build Coastguard Worker            # test index_put, no accum
11432*da0073e9SAndroid Build Coastguard Worker            src[[0, 2, 1], :, :] = res
11433*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.shape, src.shape)
11434*da0073e9SAndroid Build Coastguard Worker        [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.bool]]
11435*da0073e9SAndroid Build Coastguard Worker
11436*da0073e9SAndroid Build Coastguard Worker    def test_int_indices2d(self, device="mps"):
11437*da0073e9SAndroid Build Coastguard Worker        # From the NumPy indexing example
11438*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 12, device=device).view(4, 3)
11439*da0073e9SAndroid Build Coastguard Worker        rows = torch.tensor([[0, 0], [3, 3]], device=device)
11440*da0073e9SAndroid Build Coastguard Worker        columns = torch.tensor([[0, 2], [0, 2]], device=device)
11441*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
11442*da0073e9SAndroid Build Coastguard Worker
11443*da0073e9SAndroid Build Coastguard Worker    def test_int_indices_broadcast(self, device="mps"):
11444*da0073e9SAndroid Build Coastguard Worker        # From the NumPy indexing example
11445*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 12, device=device).view(4, 3)
11446*da0073e9SAndroid Build Coastguard Worker        rows = torch.tensor([0, 3], device=device)
11447*da0073e9SAndroid Build Coastguard Worker        columns = torch.tensor([0, 2], device=device)
11448*da0073e9SAndroid Build Coastguard Worker        result = x[rows[:, None], columns]
11449*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
11450*da0073e9SAndroid Build Coastguard Worker
11451*da0073e9SAndroid Build Coastguard Worker    def test_empty_index(self, device="mps"):
11452*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 12, device=device).view(4, 3)
11453*da0073e9SAndroid Build Coastguard Worker        idx = torch.tensor([], dtype=torch.long, device=device)
11454*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[idx].numel(), 0)
11455*da0073e9SAndroid Build Coastguard Worker
11456*da0073e9SAndroid Build Coastguard Worker        # empty assignment should have no effect but not throw an exception
11457*da0073e9SAndroid Build Coastguard Worker        y = x.clone()
11458*da0073e9SAndroid Build Coastguard Worker        y[idx] = -1
11459*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, y)
11460*da0073e9SAndroid Build Coastguard Worker
11461*da0073e9SAndroid Build Coastguard Worker        mask = torch.zeros(4, 3, device=device).bool()
11462*da0073e9SAndroid Build Coastguard Worker        y[mask] = -1
11463*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, y)
11464*da0073e9SAndroid Build Coastguard Worker
11465*da0073e9SAndroid Build Coastguard Worker    def test_empty_ndim_index(self, device="mps"):
11466*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, device=device)
11467*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
11468*da0073e9SAndroid Build Coastguard Worker
11469*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 4, 5, device=device)
11470*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.empty(2, 0, 6, 4, 5, device=device),
11471*da0073e9SAndroid Build Coastguard Worker                         x[:, torch.empty(0, 6, dtype=torch.int64, device=device)])
11472*da0073e9SAndroid Build Coastguard Worker
11473*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(10, 0, device=device)
11474*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[[1, 2]].shape, (2, 0))
11475*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[[], []].shape, (0,))
11476*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, 'for dimension with size 0'):
11477*da0073e9SAndroid Build Coastguard Worker            x[:, [0, 1]]
11478*da0073e9SAndroid Build Coastguard Worker
11479*da0073e9SAndroid Build Coastguard Worker    def test_empty_ndim_index_bool(self, device="mps"):
11480*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, device=device)
11481*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
11482*da0073e9SAndroid Build Coastguard Worker
11483*da0073e9SAndroid Build Coastguard Worker    def test_empty_slice(self, device="mps"):
11484*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 4, 5, device=device)
11485*da0073e9SAndroid Build Coastguard Worker        y = x[:, :, :, 1]
11486*da0073e9SAndroid Build Coastguard Worker        z = y[:, 1:1, :]
11487*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((2, 0, 4), z.shape)
11488*da0073e9SAndroid Build Coastguard Worker        # this isn't technically necessary, but matches NumPy stride calculations.
11489*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((60, 20, 5), z.stride())
11490*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.is_contiguous())
11491*da0073e9SAndroid Build Coastguard Worker
11492*da0073e9SAndroid Build Coastguard Worker    def test_index_getitem_copy_bools_slices(self, device="mps"):
11493*da0073e9SAndroid Build Coastguard Worker        true = torch.tensor(1, dtype=torch.uint8, device=device)
11494*da0073e9SAndroid Build Coastguard Worker        false = torch.tensor(0, dtype=torch.uint8, device=device)
11495*da0073e9SAndroid Build Coastguard Worker
11496*da0073e9SAndroid Build Coastguard Worker        tensors = [torch.randn(2, 3, device=device), torch.tensor(3., device=device)]
11497*da0073e9SAndroid Build Coastguard Worker
11498*da0073e9SAndroid Build Coastguard Worker        for a in tensors:
11499*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
11500*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty(0, *a.shape), a[False])
11501*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
11502*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty(0, *a.shape), a[false])
11503*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.data_ptr(), a[None].data_ptr())
11504*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.data_ptr(), a[...].data_ptr())
11505*da0073e9SAndroid Build Coastguard Worker
11506*da0073e9SAndroid Build Coastguard Worker    def test_index_setitem_bools_slices(self, device="mps"):
11507*da0073e9SAndroid Build Coastguard Worker        true = torch.tensor(1, dtype=torch.uint8, device=device)
11508*da0073e9SAndroid Build Coastguard Worker        false = torch.tensor(0, dtype=torch.uint8, device=device)
11509*da0073e9SAndroid Build Coastguard Worker
11510*da0073e9SAndroid Build Coastguard Worker        tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
11511*da0073e9SAndroid Build Coastguard Worker
11512*da0073e9SAndroid Build Coastguard Worker        for a in tensors:
11513*da0073e9SAndroid Build Coastguard Worker            # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
11514*da0073e9SAndroid Build Coastguard Worker            # (some of these ops already prefix a 1 to the size)
11515*da0073e9SAndroid Build Coastguard Worker            neg_ones = torch.ones_like(a) * -1
11516*da0073e9SAndroid Build Coastguard Worker            neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
11517*da0073e9SAndroid Build Coastguard Worker            a[True] = neg_ones_expanded
11518*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones)
11519*da0073e9SAndroid Build Coastguard Worker            a[False] = 5
11520*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones)
11521*da0073e9SAndroid Build Coastguard Worker            a[true] = neg_ones_expanded * 2
11522*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones * 2)
11523*da0073e9SAndroid Build Coastguard Worker            a[false] = 5
11524*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones * 2)
11525*da0073e9SAndroid Build Coastguard Worker            a[None] = neg_ones_expanded * 3
11526*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones * 3)
11527*da0073e9SAndroid Build Coastguard Worker            a[...] = neg_ones_expanded * 4
11528*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, neg_ones * 4)
11529*da0073e9SAndroid Build Coastguard Worker            if a.dim() == 0:
11530*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(IndexError):
11531*da0073e9SAndroid Build Coastguard Worker                    a[:] = neg_ones_expanded * 5
11532*da0073e9SAndroid Build Coastguard Worker
11533*da0073e9SAndroid Build Coastguard Worker    def test_index_scalar_with_bool_mask(self, device="mps"):
11534*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1, device=device)
11535*da0073e9SAndroid Build Coastguard Worker        uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
11536*da0073e9SAndroid Build Coastguard Worker        boolMask = torch.tensor(True, dtype=torch.bool, device=device)
11537*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[uintMask], a[boolMask])
11538*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
11539*da0073e9SAndroid Build Coastguard Worker
11540*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(True, dtype=torch.bool, device=device)
11541*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[uintMask], a[boolMask])
11542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
11543*da0073e9SAndroid Build Coastguard Worker
11544*da0073e9SAndroid Build Coastguard Worker    def test_setitem_expansion_error(self, device="mps"):
11545*da0073e9SAndroid Build Coastguard Worker        true = torch.tensor(True, device=device)
11546*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, device=device)
11547*da0073e9SAndroid Build Coastguard Worker        # check prefix with  non-1s doesn't work
11548*da0073e9SAndroid Build Coastguard Worker        a_expanded = a.expand(torch.Size([5, 1]) + a.size())
11549*da0073e9SAndroid Build Coastguard Worker        # NumPy: ValueError
11550*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
11551*da0073e9SAndroid Build Coastguard Worker            a[True] = a_expanded
11552*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
11553*da0073e9SAndroid Build Coastguard Worker            a[true] = a_expanded
11554*da0073e9SAndroid Build Coastguard Worker
11555*da0073e9SAndroid Build Coastguard Worker    def test_getitem_scalars(self, device="mps"):
11556*da0073e9SAndroid Build Coastguard Worker        zero = torch.tensor(0, dtype=torch.int64, device=device)
11557*da0073e9SAndroid Build Coastguard Worker        one = torch.tensor(1, dtype=torch.int64, device=device)
11558*da0073e9SAndroid Build Coastguard Worker
11559*da0073e9SAndroid Build Coastguard Worker        # non-scalar indexed with scalars
11560*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, device=device)
11561*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0], a[zero])
11562*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0][1], a[zero][one])
11563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, 1], a[zero, one])
11564*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, one], a[zero, 1])
11565*da0073e9SAndroid Build Coastguard Worker
11566*da0073e9SAndroid Build Coastguard Worker        # indexing by a scalar should slice (not copy)
11567*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
11568*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
11569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
11570*da0073e9SAndroid Build Coastguard Worker
11571*da0073e9SAndroid Build Coastguard Worker        # scalar indexed with scalar
11572*da0073e9SAndroid Build Coastguard Worker        r = torch.randn((), device=device)
11573*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
11574*da0073e9SAndroid Build Coastguard Worker            r[:]
11575*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
11576*da0073e9SAndroid Build Coastguard Worker            r[zero]
11577*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, r[...])
11578*da0073e9SAndroid Build Coastguard Worker
11579*da0073e9SAndroid Build Coastguard Worker    def test_setitem_scalars(self, device="mps"):
11580*da0073e9SAndroid Build Coastguard Worker        zero = torch.tensor(0, dtype=torch.int64)
11581*da0073e9SAndroid Build Coastguard Worker
11582*da0073e9SAndroid Build Coastguard Worker        # non-scalar indexed with scalars
11583*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, device=device)
11584*da0073e9SAndroid Build Coastguard Worker        a_set_with_number = a.clone()
11585*da0073e9SAndroid Build Coastguard Worker        a_set_with_scalar = a.clone()
11586*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, device=device)
11587*da0073e9SAndroid Build Coastguard Worker
11588*da0073e9SAndroid Build Coastguard Worker        a_set_with_number[0] = b
11589*da0073e9SAndroid Build Coastguard Worker        a_set_with_scalar[zero] = b
11590*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_set_with_number, a_set_with_scalar)
11591*da0073e9SAndroid Build Coastguard Worker        a[1, zero] = 7.7
11592*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(7.7, a[1, 0])
11593*da0073e9SAndroid Build Coastguard Worker
11594*da0073e9SAndroid Build Coastguard Worker        # scalar indexed with scalars
11595*da0073e9SAndroid Build Coastguard Worker        r = torch.randn((), device=device)
11596*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
11597*da0073e9SAndroid Build Coastguard Worker            r[:] = 8.8
11598*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
11599*da0073e9SAndroid Build Coastguard Worker            r[zero] = 8.8
11600*da0073e9SAndroid Build Coastguard Worker        r[...] = 9.9
11601*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(9.9, r)
11602*da0073e9SAndroid Build Coastguard Worker
11603*da0073e9SAndroid Build Coastguard Worker    def test_basic_advanced_combined(self, device="mps"):
11604*da0073e9SAndroid Build Coastguard Worker        # From the NumPy indexing example
11605*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 12, device=device).view(4, 3)
11606*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
11607*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
11608*da0073e9SAndroid Build Coastguard Worker
11609*da0073e9SAndroid Build Coastguard Worker        # Check that it is a copy
11610*da0073e9SAndroid Build Coastguard Worker        unmodified = x.clone()
11611*da0073e9SAndroid Build Coastguard Worker        x[1:2, [1, 2]].zero_()
11612*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, unmodified)
11613*da0073e9SAndroid Build Coastguard Worker
11614*da0073e9SAndroid Build Coastguard Worker        # But assignment should modify the original
11615*da0073e9SAndroid Build Coastguard Worker        unmodified = x.clone()
11616*da0073e9SAndroid Build Coastguard Worker        x[1:2, [1, 2]] = 0
11617*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(x, unmodified)
11618*da0073e9SAndroid Build Coastguard Worker
11619*da0073e9SAndroid Build Coastguard Worker    def test_int_assignment(self, device="mps"):
11620*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 4, device=device).view(2, 2)
11621*da0073e9SAndroid Build Coastguard Worker        x[1] = 5
11622*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
11623*da0073e9SAndroid Build Coastguard Worker
11624*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 4, device=device).view(2, 2)
11625*da0073e9SAndroid Build Coastguard Worker        x[1] = torch.arange(5, 7, device=device)
11626*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
11627*da0073e9SAndroid Build Coastguard Worker
11628*da0073e9SAndroid Build Coastguard Worker    def test_byte_tensor_assignment(self, device="mps"):
11629*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0., 16, device=device).view(4, 4)
11630*da0073e9SAndroid Build Coastguard Worker        b = torch.ByteTensor([True, False, True, False]).to(device)
11631*da0073e9SAndroid Build Coastguard Worker        value = torch.tensor([3., 4., 5., 6.], device=device)
11632*da0073e9SAndroid Build Coastguard Worker
11633*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
11634*da0073e9SAndroid Build Coastguard Worker            x[b] = value
11635*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
11636*da0073e9SAndroid Build Coastguard Worker
11637*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[0], value)
11638*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[1], torch.arange(4., 8, device=device))
11639*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[2], value)
11640*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[3], torch.arange(12., 16, device=device))
11641*da0073e9SAndroid Build Coastguard Worker
11642*da0073e9SAndroid Build Coastguard Worker    def test_variable_slicing(self, device="mps"):
11643*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 16, device=device).view(4, 4)
11644*da0073e9SAndroid Build Coastguard Worker        indices = torch.IntTensor([0, 1]).to(device)
11645*da0073e9SAndroid Build Coastguard Worker        i, j = indices
11646*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[i:j], x[0:1])
11647*da0073e9SAndroid Build Coastguard Worker
11648*da0073e9SAndroid Build Coastguard Worker    def test_ellipsis_tensor(self, device="mps"):
11649*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 9, device=device).view(3, 3)
11650*da0073e9SAndroid Build Coastguard Worker        idx = torch.tensor([0, 2], device=device)
11651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[..., idx].tolist(), [[0, 2],
11652*da0073e9SAndroid Build Coastguard Worker                                                [3, 5],
11653*da0073e9SAndroid Build Coastguard Worker                                                [6, 8]])
11654*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2],
11655*da0073e9SAndroid Build Coastguard Worker                                                [6, 7, 8]])
11656*da0073e9SAndroid Build Coastguard Worker
11657*da0073e9SAndroid Build Coastguard Worker    def test_invalid_index(self, device="mps"):
11658*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 16, device=device).view(4, 4)
11659*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
11660*da0073e9SAndroid Build Coastguard Worker
11661*da0073e9SAndroid Build Coastguard Worker    def test_out_of_bound_index(self, device="mps"):
11662*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, 100, device=device).view(2, 5, 10)
11663*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5])
11664*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5])
11665*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10',
11666*da0073e9SAndroid Build Coastguard Worker                               lambda: x[0, 1, 15])
11667*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10',
11668*da0073e9SAndroid Build Coastguard Worker                               lambda: x[:, :, 12])
11669*da0073e9SAndroid Build Coastguard Worker
11670*da0073e9SAndroid Build Coastguard Worker    def test_zero_dim_index(self, device="mps"):
11671*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(10, device=device)
11672*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.item())
11673*da0073e9SAndroid Build Coastguard Worker
11674*da0073e9SAndroid Build Coastguard Worker        def runner():
11675*da0073e9SAndroid Build Coastguard Worker            print(x[0])
11676*da0073e9SAndroid Build Coastguard Worker            return x[0]
11677*da0073e9SAndroid Build Coastguard Worker
11678*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(IndexError, 'invalid index', runner)
11679*da0073e9SAndroid Build Coastguard Worker
11680*da0073e9SAndroid Build Coastguard Worker    def test_cpu_indices(self, device="mps"):
11681*da0073e9SAndroid Build Coastguard Worker        idx = torch.tensor([0, 1])
11682*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(2, device=device)
11683*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(10, device=device)
11684*da0073e9SAndroid Build Coastguard Worker        x[idx] = b  # index_put_
11685*da0073e9SAndroid Build Coastguard Worker        ref = torch.ones(10, device=device)
11686*da0073e9SAndroid Build Coastguard Worker        ref[:2] = 0
11687*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, ref, atol=0, rtol=0)
11688*da0073e9SAndroid Build Coastguard Worker        out = x[idx]  # index
11689*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
11690*da0073e9SAndroid Build Coastguard Worker
11691*da0073e9SAndroid Build Coastguard Worker    def test_nextafter(self, device="mps"):
11692*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float16, torch.float32]:
11693*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([1, -1, 0, 0, 2, -2], device=device, dtype=dtype)
11694*da0073e9SAndroid Build Coastguard Worker            y = torch.tensor([2, -2, -1, 1, -3, 3], device=device, dtype=dtype)
11695*da0073e9SAndroid Build Coastguard Worker            na = torch.nextafter(x, y)
11696*da0073e9SAndroid Build Coastguard Worker            na_cpu = torch.nextafter(x.cpu(), y.cpu())
11697*da0073e9SAndroid Build Coastguard Worker            na_ge_x_mps = na.cpu() > x.cpu()
11698*da0073e9SAndroid Build Coastguard Worker            # greater is broken on MPS, see https://github.com/pytorch/pytorch/issues/125051
11699*da0073e9SAndroid Build Coastguard Worker            na_ge_x_cpu = na_cpu > x.cpu()
11700*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(na_ge_x_mps, na_ge_x_cpu)
11701*da0073e9SAndroid Build Coastguard Worker
11702*da0073e9SAndroid Build Coastguard Worker
11703*da0073e9SAndroid Build Coastguard Workerclass TestRNNMPS(TestCaseMPS):
11704*da0073e9SAndroid Build Coastguard Worker    def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False,
11705*da0073e9SAndroid Build Coastguard Worker                     seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False):
11706*da0073e9SAndroid Build Coastguard Worker        rnn = nn.LSTM(
11707*da0073e9SAndroid Build Coastguard Worker            input_size=input_size,
11708*da0073e9SAndroid Build Coastguard Worker            hidden_size=hidden_size,
11709*da0073e9SAndroid Build Coastguard Worker            num_layers=num_layers,
11710*da0073e9SAndroid Build Coastguard Worker            bias=bias,
11711*da0073e9SAndroid Build Coastguard Worker            bidirectional=bidirectional,
11712*da0073e9SAndroid Build Coastguard Worker            batch_first=batch_first,
11713*da0073e9SAndroid Build Coastguard Worker            device="cpu"
11714*da0073e9SAndroid Build Coastguard Worker        )
11715*da0073e9SAndroid Build Coastguard Worker        bidirectional_mul = 2 if bidirectional else 1
11716*da0073e9SAndroid Build Coastguard Worker
11717*da0073e9SAndroid Build Coastguard Worker        if batch_first:
11718*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward)
11719*da0073e9SAndroid Build Coastguard Worker            hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11720*da0073e9SAndroid Build Coastguard Worker                             requires_grad=backward)
11721*da0073e9SAndroid Build Coastguard Worker            cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11722*da0073e9SAndroid Build Coastguard Worker                             requires_grad=backward)
11723*da0073e9SAndroid Build Coastguard Worker        else:
11724*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward)
11725*da0073e9SAndroid Build Coastguard Worker            hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11726*da0073e9SAndroid Build Coastguard Worker                             requires_grad=backward)
11727*da0073e9SAndroid Build Coastguard Worker            cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11728*da0073e9SAndroid Build Coastguard Worker                             requires_grad=backward)
11729*da0073e9SAndroid Build Coastguard Worker
11730*da0073e9SAndroid Build Coastguard Worker        cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
11731*da0073e9SAndroid Build Coastguard Worker
11732*da0073e9SAndroid Build Coastguard Worker        rnn = rnn.to(device)
11733*da0073e9SAndroid Build Coastguard Worker        input = input.to(device)
11734*da0073e9SAndroid Build Coastguard Worker        hx = hx.to(device)
11735*da0073e9SAndroid Build Coastguard Worker        cx = cx.to(device)
11736*da0073e9SAndroid Build Coastguard Worker        output, (hn, cn) = rnn(input, (hx, cx))
11737*da0073e9SAndroid Build Coastguard Worker
11738*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_output, output)
11739*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_hn, hn)
11740*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_cn, cn)
11741*da0073e9SAndroid Build Coastguard Worker
11742*da0073e9SAndroid Build Coastguard Worker        def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True):
11743*da0073e9SAndroid Build Coastguard Worker            rnn = rnn.to(device)
11744*da0073e9SAndroid Build Coastguard Worker            inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
11745*da0073e9SAndroid Build Coastguard Worker
11746*da0073e9SAndroid Build Coastguard Worker            output, (hx_out, cx_out) = rnn(inp, (hx, cx))
11747*da0073e9SAndroid Build Coastguard Worker            assert output_grad_presented or states_grad_presented, "At least some outputs must be used"
11748*da0073e9SAndroid Build Coastguard Worker
11749*da0073e9SAndroid Build Coastguard Worker            f = 0
11750*da0073e9SAndroid Build Coastguard Worker            if output_grad_presented:
11751*da0073e9SAndroid Build Coastguard Worker                f = f + 3 * output.sum()
11752*da0073e9SAndroid Build Coastguard Worker            if states_grad_presented:
11753*da0073e9SAndroid Build Coastguard Worker                f = f + (hx_out * cx_out).sum()
11754*da0073e9SAndroid Build Coastguard Worker
11755*da0073e9SAndroid Build Coastguard Worker            param_names, params = zip(*rnn.named_parameters())
11756*da0073e9SAndroid Build Coastguard Worker            param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
11757*da0073e9SAndroid Build Coastguard Worker
11758*da0073e9SAndroid Build Coastguard Worker            input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
11759*da0073e9SAndroid Build Coastguard Worker            return output, param_grads, input_grad, hx_grad, cx_grad
11760*da0073e9SAndroid Build Coastguard Worker
11761*da0073e9SAndroid Build Coastguard Worker        if backward:
11762*da0073e9SAndroid Build Coastguard Worker            grad_cases = [
11763*da0073e9SAndroid Build Coastguard Worker                dict(output_grad_presented=True, states_grad_presented=True),
11764*da0073e9SAndroid Build Coastguard Worker                dict(output_grad_presented=False, states_grad_presented=True),
11765*da0073e9SAndroid Build Coastguard Worker                dict(output_grad_presented=True, states_grad_presented=False),
11766*da0073e9SAndroid Build Coastguard Worker            ]
11767*da0073e9SAndroid Build Coastguard Worker
11768*da0073e9SAndroid Build Coastguard Worker            for grad_case in grad_cases:
11769*da0073e9SAndroid Build Coastguard Worker                cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
11770*da0073e9SAndroid Build Coastguard Worker                    get_backward_results(rnn, "cpu", input, hx, cx, **grad_case)
11771*da0073e9SAndroid Build Coastguard Worker                mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
11772*da0073e9SAndroid Build Coastguard Worker                    get_backward_results(rnn, device, input, hx, cx, **grad_case)
11773*da0073e9SAndroid Build Coastguard Worker
11774*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cpu_hx_grad, mps_hx_grad)
11775*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cpu_cx_grad, mps_cx_grad)
11776*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cpu_output, mps_output)
11777*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cpu_input_grad, mps_input_grad)
11778*da0073e9SAndroid Build Coastguard Worker                for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
11779*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(cpu_weight_grad, mps_weight_grad,
11780*da0073e9SAndroid Build Coastguard Worker                                     f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
11781*da0073e9SAndroid Build Coastguard Worker
11782*da0073e9SAndroid Build Coastguard Worker    LSTM_TEST_CASES = [
11783*da0073e9SAndroid Build Coastguard Worker        {},  # default
11784*da0073e9SAndroid Build Coastguard Worker        dict(batch_first=True),
11785*da0073e9SAndroid Build Coastguard Worker        dict(bias=False),
11786*da0073e9SAndroid Build Coastguard Worker        dict(bidirectional=True),
11787*da0073e9SAndroid Build Coastguard Worker        dict(batch_first=True, bias=False),
11788*da0073e9SAndroid Build Coastguard Worker        dict(bidirectional=True, bias=False),
11789*da0073e9SAndroid Build Coastguard Worker        dict(bidirectional=True, batch_first=True),
11790*da0073e9SAndroid Build Coastguard Worker        dict(bidirectional=True, batch_first=True, bias=False)
11791*da0073e9SAndroid Build Coastguard Worker    ]
11792*da0073e9SAndroid Build Coastguard Worker
11793*da0073e9SAndroid Build Coastguard Worker    def test_lstm_forward(self, device="mps", dtype=torch.float32):
11794*da0073e9SAndroid Build Coastguard Worker        for num_layers in [1, 2, 5]:
11795*da0073e9SAndroid Build Coastguard Worker            for test_options in self.LSTM_TEST_CASES:
11796*da0073e9SAndroid Build Coastguard Worker                self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options)
11797*da0073e9SAndroid Build Coastguard Worker
11798*da0073e9SAndroid Build Coastguard Worker    def test_lstm_backward(self, device="mps", dtype=torch.float32):
11799*da0073e9SAndroid Build Coastguard Worker        for num_layers in [1, 2, 5]:
11800*da0073e9SAndroid Build Coastguard Worker            for test_options in self.LSTM_TEST_CASES:
11801*da0073e9SAndroid Build Coastguard Worker                self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options)
11802*da0073e9SAndroid Build Coastguard Worker
11803*da0073e9SAndroid Build Coastguard Worker    def test_RNN_cell_no_broadcasting(self):
11804*da0073e9SAndroid Build Coastguard Worker        def test(cell_module, input, hx, input_size, hidden_size):
11805*da0073e9SAndroid Build Coastguard Worker            cell = cell_module(input_size, hidden_size, device='mps')
11806*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: cell(input, hx))
11807*da0073e9SAndroid Build Coastguard Worker
11808*da0073e9SAndroid Build Coastguard Worker        def test_all(hidden_size, bad_hx, good_hx, input_size, input):
11809*da0073e9SAndroid Build Coastguard Worker            test(nn.RNNCell, input, bad_hx, input_size, hidden_size)
11810*da0073e9SAndroid Build Coastguard Worker            test(nn.GRUCell, input, bad_hx, input_size, hidden_size)
11811*da0073e9SAndroid Build Coastguard Worker            test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size)
11812*da0073e9SAndroid Build Coastguard Worker            test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size)
11813*da0073e9SAndroid Build Coastguard Worker
11814*da0073e9SAndroid Build Coastguard Worker        hidden_size = 20
11815*da0073e9SAndroid Build Coastguard Worker        input_size = 10
11816*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, input_size, device='mps')
11817*da0073e9SAndroid Build Coastguard Worker        bad_hx = torch.randn(1, hidden_size, device='mps')
11818*da0073e9SAndroid Build Coastguard Worker        good_hx = torch.randn(3, hidden_size, device='mps')
11819*da0073e9SAndroid Build Coastguard Worker
11820*da0073e9SAndroid Build Coastguard Worker        # Test hidden/input batch size broadcasting
11821*da0073e9SAndroid Build Coastguard Worker        test_all(hidden_size, bad_hx, good_hx, input_size, input)
11822*da0073e9SAndroid Build Coastguard Worker
11823*da0073e9SAndroid Build Coastguard Worker        # Test hx's hidden_size vs module's hidden_size broadcasting
11824*da0073e9SAndroid Build Coastguard Worker        bad_hx = torch.randn(3, 1)
11825*da0073e9SAndroid Build Coastguard Worker        test_all(hidden_size, bad_hx, good_hx, input_size, input)
11826*da0073e9SAndroid Build Coastguard Worker
11827*da0073e9SAndroid Build Coastguard Worker        # Test input's input_size vs module's input_size broadcasting
11828*da0073e9SAndroid Build Coastguard Worker        bad_input = torch.randn(3, 1)
11829*da0073e9SAndroid Build Coastguard Worker        test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
11830*da0073e9SAndroid Build Coastguard Worker
11831*da0073e9SAndroid Build Coastguard Worker    def test_LSTM_cell(self):
11832*da0073e9SAndroid Build Coastguard Worker        # this is just a smoke test; these modules are implemented through
11833*da0073e9SAndroid Build Coastguard Worker        # autograd so no Jacobian test is needed
11834*da0073e9SAndroid Build Coastguard Worker        for bias in (True, False):
11835*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(3, 10, device='mps')
11836*da0073e9SAndroid Build Coastguard Worker            hx = torch.randn(3, 20, device='mps')
11837*da0073e9SAndroid Build Coastguard Worker            cx = torch.randn(3, 20, device='mps')
11838*da0073e9SAndroid Build Coastguard Worker            lstm = nn.LSTMCell(10, 20, bias=bias, device='mps')
11839*da0073e9SAndroid Build Coastguard Worker            for _ in range(6):
11840*da0073e9SAndroid Build Coastguard Worker                hx, cx = lstm(input, (hx, cx))
11841*da0073e9SAndroid Build Coastguard Worker
11842*da0073e9SAndroid Build Coastguard Worker            (hx + cx).sum().backward()
11843*da0073e9SAndroid Build Coastguard Worker
11844*da0073e9SAndroid Build Coastguard Worker    def test_LSTM_cell_forward_input_size(self):
11845*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 11, device='mps')
11846*da0073e9SAndroid Build Coastguard Worker        hx = torch.randn(3, 20, device='mps')
11847*da0073e9SAndroid Build Coastguard Worker        cx = torch.randn(3, 20, device='mps')
11848*da0073e9SAndroid Build Coastguard Worker        lstm = nn.LSTMCell(10, 20, device='mps')
11849*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
11850*da0073e9SAndroid Build Coastguard Worker
11851*da0073e9SAndroid Build Coastguard Worker    def test_LSTM_cell_forward_hidden_size(self):
11852*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(3, 10, device='mps')
11853*da0073e9SAndroid Build Coastguard Worker        hx = torch.randn(3, 21, device='mps')
11854*da0073e9SAndroid Build Coastguard Worker        cx = torch.randn(3, 20, device='mps')
11855*da0073e9SAndroid Build Coastguard Worker        lstm = nn.LSTMCell(10, 20, device='mps')
11856*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
11857*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(Exception, lambda: lstm(input, (cx, hx)))
11858*da0073e9SAndroid Build Coastguard Worker
11859*da0073e9SAndroid Build Coastguard Worker
11860*da0073e9SAndroid Build Coastguard Workerclass TestFallbackWarning(TestCase):
11861*da0073e9SAndroid Build Coastguard Worker    # TODO: Remove once test_testing.py is running on MPS devices
11862*da0073e9SAndroid Build Coastguard Worker    def test_no_warning_on_import(self):
11863*da0073e9SAndroid Build Coastguard Worker        out = subprocess.check_output(
11864*da0073e9SAndroid Build Coastguard Worker            [sys.executable, "-W", "always", "-c", "import torch"],
11865*da0073e9SAndroid Build Coastguard Worker            stderr=subprocess.STDOUT,
11866*da0073e9SAndroid Build Coastguard Worker            # On Windows, opening the subprocess with the default CWD makes `import torch`
11867*da0073e9SAndroid Build Coastguard Worker            # fail, so just set CWD to this script's directory
11868*da0073e9SAndroid Build Coastguard Worker            cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
11869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, "")
11870*da0073e9SAndroid Build Coastguard Worker
11871*da0073e9SAndroid Build Coastguard Worker    def _get_not_implemented_op(self):
11872*da0073e9SAndroid Build Coastguard Worker        # This can be changed once we actually implement 'lcm'
11873*da0073e9SAndroid Build Coastguard Worker        # Should return fn, args, kwargs, string_version
11874*da0073e9SAndroid Build Coastguard Worker        return (torch.lcm,
11875*da0073e9SAndroid Build Coastguard Worker                [torch.tensor([1], device='mps'), torch.tensor([2], device='mps')], {},
11876*da0073e9SAndroid Build Coastguard Worker                "torch.lcm(torch.tensor([1], device='mps'), torch.tensor([2], device='mps'))")
11877*da0073e9SAndroid Build Coastguard Worker
11878*da0073e9SAndroid Build Coastguard Worker    def test_error_on_not_implemented(self):
11879*da0073e9SAndroid Build Coastguard Worker        fn, args, kwargs, _ = self._get_not_implemented_op()
11880*da0073e9SAndroid Build Coastguard Worker
11881*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "not currently implemented for the MPS device"):
11882*da0073e9SAndroid Build Coastguard Worker            fn(*args, **kwargs)
11883*da0073e9SAndroid Build Coastguard Worker
11884*da0073e9SAndroid Build Coastguard Worker    def test_warn_on_not_implemented_with_fallback(self):
11885*da0073e9SAndroid Build Coastguard Worker        _, _, _, op = self._get_not_implemented_op()
11886*da0073e9SAndroid Build Coastguard Worker        script = f"""
11887*da0073e9SAndroid Build Coastguard Workerimport os
11888*da0073e9SAndroid Build Coastguard Worker# MUST happen before pytorch's import
11889*da0073e9SAndroid Build Coastguard Workeros.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
11890*da0073e9SAndroid Build Coastguard Workerimport warnings
11891*da0073e9SAndroid Build Coastguard Worker
11892*da0073e9SAndroid Build Coastguard Workerwith warnings.catch_warnings(record=True) as w:
11893*da0073e9SAndroid Build Coastguard Worker    import torch
11894*da0073e9SAndroid Build Coastguard Worker
11895*da0073e9SAndroid Build Coastguard Workerif len(w) > 0:
11896*da0073e9SAndroid Build Coastguard Worker    print(w)
11897*da0073e9SAndroid Build Coastguard Worker    exit(1)
11898*da0073e9SAndroid Build Coastguard Worker
11899*da0073e9SAndroid Build Coastguard Worker# This should run just fine and raise warning about perf
11900*da0073e9SAndroid Build Coastguard Workerwith warnings.catch_warnings(record=True) as w:
11901*da0073e9SAndroid Build Coastguard Worker    {op}
11902*da0073e9SAndroid Build Coastguard Worker
11903*da0073e9SAndroid Build Coastguard Workerif len(w) != 1:
11904*da0073e9SAndroid Build Coastguard Worker    print(w)
11905*da0073e9SAndroid Build Coastguard Worker    exit(2)
11906*da0073e9SAndroid Build Coastguard Worker"""
11907*da0073e9SAndroid Build Coastguard Worker        try:
11908*da0073e9SAndroid Build Coastguard Worker            subprocess.check_output(
11909*da0073e9SAndroid Build Coastguard Worker                [sys.executable, '-W', 'always', '-c', script],
11910*da0073e9SAndroid Build Coastguard Worker                stderr=subprocess.STDOUT,
11911*da0073e9SAndroid Build Coastguard Worker                # On Windows, opening the subprocess with the default CWD makes `import torch`
11912*da0073e9SAndroid Build Coastguard Worker                # fail, so just set CWD to this script's directory
11913*da0073e9SAndroid Build Coastguard Worker                cwd=os.path.dirname(os.path.realpath(__file__)),)
11914*da0073e9SAndroid Build Coastguard Worker        except subprocess.CalledProcessError as e:
11915*da0073e9SAndroid Build Coastguard Worker            if e.returncode == 1:
11916*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set." +
11917*da0073e9SAndroid Build Coastguard Worker                                       e.output.decode("utf-8"))
11918*da0073e9SAndroid Build Coastguard Worker            elif e.returncode == 2:
11919*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(False, "There wasn't exactly one warning when running not implemented op with "
11920*da0073e9SAndroid Build Coastguard Worker                                f"PYTORCH_ENABLE_MPS_FALLBACK set. {e.output}")
11921*da0073e9SAndroid Build Coastguard Worker            else:
11922*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set. " +
11923*da0073e9SAndroid Build Coastguard Worker                                       e.output.decode("utf-8"))
11924*da0073e9SAndroid Build Coastguard Worker
11925*da0073e9SAndroid Build Coastguard Workerclass TestNoRegression(TestCase):
11926*da0073e9SAndroid Build Coastguard Worker    def test_assert_close(self):
11927*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, device="mps")
11928*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(1, device="mps")
11929*da0073e9SAndroid Build Coastguard Worker        inf = a / b
11930*da0073e9SAndroid Build Coastguard Worker        nan = b / b
11931*da0073e9SAndroid Build Coastguard Worker
11932*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
11933*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(a, inf)
11934*da0073e9SAndroid Build Coastguard Worker
11935*da0073e9SAndroid Build Coastguard Worker        # TODO: The NaN test is failing when all the tests in test_mps are run
11936*da0073e9SAndroid Build Coastguard Worker        # together but passes when run separately. There seems to be memory
11937*da0073e9SAndroid Build Coastguard Worker        # corruption which needs to be fixed for this test to be enabled.
11938*da0073e9SAndroid Build Coastguard Worker        # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
11939*da0073e9SAndroid Build Coastguard Worker            # torch.testing.assert_close(a, nan)
11940*da0073e9SAndroid Build Coastguard Worker
11941*da0073e9SAndroid Build Coastguard Worker    def test_double_error(self):
11942*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
11943*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(2, dtype=torch.float64, device="mps")
11944*da0073e9SAndroid Build Coastguard Worker
11945*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, device="mps")
11946*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
11947*da0073e9SAndroid Build Coastguard Worker            a = a.double()
11948*da0073e9SAndroid Build Coastguard Worker
11949*da0073e9SAndroid Build Coastguard Worker    def test_legacy_constructor(self):
11950*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, device="mps")
11951*da0073e9SAndroid Build Coastguard Worker
11952*da0073e9SAndroid Build Coastguard Worker        b = a.new(1)
11953*da0073e9SAndroid Build Coastguard Worker
11954*da0073e9SAndroid Build Coastguard Worker    def test_serialization_map_location(self):
11955*da0073e9SAndroid Build Coastguard Worker
11956*da0073e9SAndroid Build Coastguard Worker        # Ensures that cpu Tensor can be loaded on mps
11957*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
11958*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(2)
11959*da0073e9SAndroid Build Coastguard Worker            torch.save(x, f)
11960*da0073e9SAndroid Build Coastguard Worker
11961*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
11962*da0073e9SAndroid Build Coastguard Worker            x2 = torch.load(f, map_location="mps")
11963*da0073e9SAndroid Build Coastguard Worker
11964*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x2)
11965*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x2.device.type, "mps")
11966*da0073e9SAndroid Build Coastguard Worker
11967*da0073e9SAndroid Build Coastguard Worker        # Ensures that mps Tensors can be loaded on mps
11968*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
11969*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(2, device="mps")
11970*da0073e9SAndroid Build Coastguard Worker            torch.save(x, f)
11971*da0073e9SAndroid Build Coastguard Worker
11972*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
11973*da0073e9SAndroid Build Coastguard Worker            x2 = torch.load(f)
11974*da0073e9SAndroid Build Coastguard Worker
11975*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x2)
11976*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x2.device.type, "mps")
11977*da0073e9SAndroid Build Coastguard Worker
11978*da0073e9SAndroid Build Coastguard Worker        # Ensures that mps Tensors can be loaded on cpu
11979*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
11980*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(2, device="mps")
11981*da0073e9SAndroid Build Coastguard Worker            torch.save(x, f)
11982*da0073e9SAndroid Build Coastguard Worker
11983*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
11984*da0073e9SAndroid Build Coastguard Worker            x2 = torch.load(f, map_location="cpu")
11985*da0073e9SAndroid Build Coastguard Worker
11986*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x2)
11987*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x2.device.type, "cpu")
11988*da0073e9SAndroid Build Coastguard Worker
11989*da0073e9SAndroid Build Coastguard Worker        # Ensures that `mps:0` Tensors can be loaded on mps
11990*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as f:
11991*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(2, device="mps:0")
11992*da0073e9SAndroid Build Coastguard Worker            torch.save(x, f)
11993*da0073e9SAndroid Build Coastguard Worker
11994*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
11995*da0073e9SAndroid Build Coastguard Worker            x2 = torch.load(f, map_location="mps:0")
11996*da0073e9SAndroid Build Coastguard Worker
11997*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x2)
11998*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x2.device.type, "mps")
11999*da0073e9SAndroid Build Coastguard Worker
12000*da0073e9SAndroid Build Coastguard Worker
12001*da0073e9SAndroid Build Coastguard WorkerMPS_DTYPES = get_all_dtypes()
12002*da0073e9SAndroid Build Coastguard Workerfor t in [torch.double, torch.cdouble, torch.cfloat, torch.bfloat16]:
12003*da0073e9SAndroid Build Coastguard Worker    del MPS_DTYPES[MPS_DTYPES.index(t)]
12004*da0073e9SAndroid Build Coastguard Worker
12005*da0073e9SAndroid Build Coastguard WorkerMPS_GRAD_DTYPES = [torch.float32, torch.float16]
12006*da0073e9SAndroid Build Coastguard Worker
12007*da0073e9SAndroid Build Coastguard Worker
12008*da0073e9SAndroid Build Coastguard Workerclass TestConsistency(TestCaseMPS):
12009*da0073e9SAndroid Build Coastguard Worker    # TODO: This is only used while some ops are being added.
12010*da0073e9SAndroid Build Coastguard Worker    # This list should contain all ops and dtypes eventually
12011*da0073e9SAndroid Build Coastguard Worker    # This can be generated automatically in the `new_mps_allowlist.txt` file
12012*da0073e9SAndroid Build Coastguard Worker    # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU`
12013*da0073e9SAndroid Build Coastguard Worker    # You most likely do NOT want to modify this manually
12014*da0073e9SAndroid Build Coastguard Worker
12015*da0073e9SAndroid Build Coastguard Worker    FP16_LOW_PRECISION_LIST = {
12016*da0073e9SAndroid Build Coastguard Worker        'add', 'sub', 'div', 'addcdiv',
12017*da0073e9SAndroid Build Coastguard Worker        '__rdiv__', '__rmul__',
12018*da0073e9SAndroid Build Coastguard Worker        'nn.functional.huber_loss',
12019*da0073e9SAndroid Build Coastguard Worker        'true_divide', 'kron',
12020*da0073e9SAndroid Build Coastguard Worker        'gradient', 'var', 'std', 'std_mean', 'ldexp',
12021*da0073e9SAndroid Build Coastguard Worker        'linalg.vector_norm', 'lerp',
12022*da0073e9SAndroid Build Coastguard Worker        'addr', 'var_mean',
12023*da0073e9SAndroid Build Coastguard Worker        'var_mean_unbiased',
12024*da0073e9SAndroid Build Coastguard Worker        'acosh', 'asinh', 'asin',
12025*da0073e9SAndroid Build Coastguard Worker        'masked.std',
12026*da0073e9SAndroid Build Coastguard Worker        'nn.functional.normalize',
12027*da0073e9SAndroid Build Coastguard Worker        'nn.functional.triplet_margin_loss',
12028*da0073e9SAndroid Build Coastguard Worker        'nn.functional.triplet_margin_with_distance_loss',
12029*da0073e9SAndroid Build Coastguard Worker        'nn.functional.batch_norm',
12030*da0073e9SAndroid Build Coastguard Worker        'nn.functional.instance_norm',
12031*da0073e9SAndroid Build Coastguard Worker        'round', 'xlogy', 'addcmul',
12032*da0073e9SAndroid Build Coastguard Worker        'nn.functional.cross_entropy',
12033*da0073e9SAndroid Build Coastguard Worker        'nn.functional.binary_cross_entropy',
12034*da0073e9SAndroid Build Coastguard Worker        'nn.functional.nll_loss',
12035*da0073e9SAndroid Build Coastguard Worker        'nn.functional.max_pool2d',
12036*da0073e9SAndroid Build Coastguard Worker        'nn.functional.gelu',
12037*da0073e9SAndroid Build Coastguard Worker        'nn.functional.glu',
12038*da0073e9SAndroid Build Coastguard Worker        '_native_batch_norm_legit',
12039*da0073e9SAndroid Build Coastguard Worker        '_batch_norm_with_update',
12040*da0073e9SAndroid Build Coastguard Worker        'native_batch_norm',
12041*da0073e9SAndroid Build Coastguard Worker        'softmax',
12042*da0073e9SAndroid Build Coastguard Worker        '_softmax_backward_data',
12043*da0073e9SAndroid Build Coastguard Worker        'log_softmax',
12044*da0073e9SAndroid Build Coastguard Worker        'masked.softmax',
12045*da0073e9SAndroid Build Coastguard Worker        'masked.log_softmax',
12046*da0073e9SAndroid Build Coastguard Worker        'masked.softmin',
12047*da0073e9SAndroid Build Coastguard Worker        'nn.functional.kl_div',
12048*da0073e9SAndroid Build Coastguard Worker        'nn.functional.softmin',
12049*da0073e9SAndroid Build Coastguard Worker        'cross', 'linalg.cross',
12050*da0073e9SAndroid Build Coastguard Worker        'prod', 'masked.prod',
12051*da0073e9SAndroid Build Coastguard Worker        'nextafter',
12052*da0073e9SAndroid Build Coastguard Worker        'native_layer_norm',
12053*da0073e9SAndroid Build Coastguard Worker        'nn.functional.layer_norm',
12054*da0073e9SAndroid Build Coastguard Worker        'nn.functional.interpolate',
12055*da0073e9SAndroid Build Coastguard Worker        'nn.functional.upsample_bilinear',
12056*da0073e9SAndroid Build Coastguard Worker        'nn.functional.upsample_nearest',
12057*da0073e9SAndroid Build Coastguard Worker
12058*da0073e9SAndroid Build Coastguard Worker        # for macOS 12
12059*da0073e9SAndroid Build Coastguard Worker        'masked.normalize', 'masked.sum', 'masked.var',
12060*da0073e9SAndroid Build Coastguard Worker        'outer',
12061*da0073e9SAndroid Build Coastguard Worker        'sum_to_size', 'sum',
12062*da0073e9SAndroid Build Coastguard Worker        'mul',
12063*da0073e9SAndroid Build Coastguard Worker        'nansum', 'nanmean',
12064*da0073e9SAndroid Build Coastguard Worker        'norm',
12065*da0073e9SAndroid Build Coastguard Worker    }
12066*da0073e9SAndroid Build Coastguard Worker
12067*da0073e9SAndroid Build Coastguard Worker    FP32_LOW_PRECISION_LIST = {
12068*da0073e9SAndroid Build Coastguard Worker        # conv2d and conv_transpose2d results have a very small
12069*da0073e9SAndroid Build Coastguard Worker        # difference compared to CPU/CUDA, so we use lower precision on FP32
12070*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv2d',
12071*da0073e9SAndroid Build Coastguard Worker        'nn.functional.conv_transpose2d',
12072*da0073e9SAndroid Build Coastguard Worker        'matmul', '__rmatmul__',
12073*da0073e9SAndroid Build Coastguard Worker        'linalg.multi_dot',
12074*da0073e9SAndroid Build Coastguard Worker        'addbmm',
12075*da0073e9SAndroid Build Coastguard Worker    }
12076*da0073e9SAndroid Build Coastguard Worker
12077*da0073e9SAndroid Build Coastguard Worker    def _compute_tolerances(self, op, dtype):
12078*da0073e9SAndroid Build Coastguard Worker        if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype in [torch.float32, torch.complex64]:
12079*da0073e9SAndroid Build Coastguard Worker            return (1e-4, 3e-5)
12080*da0073e9SAndroid Build Coastguard Worker
12081*da0073e9SAndroid Build Coastguard Worker        if op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16:
12082*da0073e9SAndroid Build Coastguard Worker            return (1e-2, 1e-2)
12083*da0073e9SAndroid Build Coastguard Worker
12084*da0073e9SAndroid Build Coastguard Worker        if op.name in ['nn.functional.conv_transpose1d',
12085*da0073e9SAndroid Build Coastguard Worker                       'nn.functional.conv_transpose2d',
12086*da0073e9SAndroid Build Coastguard Worker                       'nn.functional.conv_transpose3d',
12087*da0073e9SAndroid Build Coastguard Worker                       '__rmatmul__', 'addbmm', 'addmv',
12088*da0073e9SAndroid Build Coastguard Worker                       'baddbmm', 'cov', 'matmul', 'mv'] and dtype == torch.float16:
12089*da0073e9SAndroid Build Coastguard Worker            return (5e-2, 5e-2)
12090*da0073e9SAndroid Build Coastguard Worker        if op.name == "masked.mean":
12091*da0073e9SAndroid Build Coastguard Worker            return (7e-4, 2e-3)
12092*da0073e9SAndroid Build Coastguard Worker        if op.name == "native_layer_norm":
12093*da0073e9SAndroid Build Coastguard Worker            return (1e-4, 1.3e-5)
12094*da0073e9SAndroid Build Coastguard Worker        if op.name in ["pow", "__rpow__"] and product_version < 13.3:
12095*da0073e9SAndroid Build Coastguard Worker            # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
12096*da0073e9SAndroid Build Coastguard Worker            # fixed in macOS 13.3+
12097*da0073e9SAndroid Build Coastguard Worker            return (1e-6, 2e-3 if dtype == torch.float16 else 4e-6)
12098*da0073e9SAndroid Build Coastguard Worker        if op.name == "nn.functional.interpolate":
12099*da0073e9SAndroid Build Coastguard Worker            return (1e-3, 1e-4)
12100*da0073e9SAndroid Build Coastguard Worker        if op.name in ['fft.rfftn', 'fft.hfftn', 'fft.hfft2', 'fft.fft', 'fft.fftn', 'fft.rfft']:
12101*da0073e9SAndroid Build Coastguard Worker            # TODO: Investigate why this is needed
12102*da0073e9SAndroid Build Coastguard Worker            # See https://github.com/pytorch/pytorch/issues/120237
12103*da0073e9SAndroid Build Coastguard Worker            return (3e-5, 3e-5)
12104*da0073e9SAndroid Build Coastguard Worker        return (None, None)
12105*da0073e9SAndroid Build Coastguard Worker
12106*da0073e9SAndroid Build Coastguard Worker    # Used for accept mode only
12107*da0073e9SAndroid Build Coastguard Worker    NEW_ALLOW_LIST = defaultdict(list)
12108*da0073e9SAndroid Build Coastguard Worker    NEW_ALLOW_LIST_GRAD = defaultdict(list)
12109*da0073e9SAndroid Build Coastguard Worker
12110*da0073e9SAndroid Build Coastguard Worker    @ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES + [torch.complex64])
12111*da0073e9SAndroid Build Coastguard Worker    def test_output_match(self, device, dtype, op):
12112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device, "cpu")
12113*da0073e9SAndroid Build Coastguard Worker
12114*da0073e9SAndroid Build Coastguard Worker        def get_samples():
12115*da0073e9SAndroid Build Coastguard Worker            return op.sample_inputs(
12116*da0073e9SAndroid Build Coastguard Worker                device,
12117*da0073e9SAndroid Build Coastguard Worker                dtype,
12118*da0073e9SAndroid Build Coastguard Worker                requires_grad=(dtype.is_floating_point or dtype.is_complex),
12119*da0073e9SAndroid Build Coastguard Worker                # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12120*da0073e9SAndroid Build Coastguard Worker                set_seed=False,
12121*da0073e9SAndroid Build Coastguard Worker            )
12122*da0073e9SAndroid Build Coastguard Worker        cpu_samples = get_samples()
12123*da0073e9SAndroid Build Coastguard Worker
12124*da0073e9SAndroid Build Coastguard Worker        for cpu_sample in cpu_samples:
12125*da0073e9SAndroid Build Coastguard Worker            #
12126*da0073e9SAndroid Build Coastguard Worker            # Forward check
12127*da0073e9SAndroid Build Coastguard Worker            #
12128*da0073e9SAndroid Build Coastguard Worker            mps_sample = cpu_sample.transform(
12129*da0073e9SAndroid Build Coastguard Worker                lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
12130*da0073e9SAndroid Build Coastguard Worker
12131*da0073e9SAndroid Build Coastguard Worker            cpu_args = [cpu_sample.input] + list(cpu_sample.args)
12132*da0073e9SAndroid Build Coastguard Worker            cpu_kwargs = cpu_sample.kwargs
12133*da0073e9SAndroid Build Coastguard Worker            mps_args = [mps_sample.input] + list(mps_sample.args)
12134*da0073e9SAndroid Build Coastguard Worker            mps_kwargs = mps_sample.kwargs
12135*da0073e9SAndroid Build Coastguard Worker
12136*da0073e9SAndroid Build Coastguard Worker            # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
12137*da0073e9SAndroid Build Coastguard Worker            if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
12138*da0073e9SAndroid Build Coastguard Worker                mps_args[1] = cpu_args[1]
12139*da0073e9SAndroid Build Coastguard Worker
12140*da0073e9SAndroid Build Coastguard Worker            cpu_out = op(*cpu_args, **cpu_kwargs)
12141*da0073e9SAndroid Build Coastguard Worker            mps_out = op(*mps_args, **mps_kwargs)
12142*da0073e9SAndroid Build Coastguard Worker
12143*da0073e9SAndroid Build Coastguard Worker            atol, rtol = self._compute_tolerances(op, dtype)
12144*da0073e9SAndroid Build Coastguard Worker            if op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8:
12145*da0073e9SAndroid Build Coastguard Worker                atol = 1.0
12146*da0073e9SAndroid Build Coastguard Worker                rtol = 0.0
12147*da0073e9SAndroid Build Coastguard Worker
12148*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
12149*da0073e9SAndroid Build Coastguard Worker
12150*da0073e9SAndroid Build Coastguard Worker
12151*da0073e9SAndroid Build Coastguard Worker    @ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
12152*da0073e9SAndroid Build Coastguard Worker    def test_output_grad_match(self, device, dtype, op):
12153*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device, "cpu")
12154*da0073e9SAndroid Build Coastguard Worker
12155*da0073e9SAndroid Build Coastguard Worker        def get_samples():
12156*da0073e9SAndroid Build Coastguard Worker            return op.sample_inputs(
12157*da0073e9SAndroid Build Coastguard Worker                device,
12158*da0073e9SAndroid Build Coastguard Worker                dtype,
12159*da0073e9SAndroid Build Coastguard Worker                requires_grad=(dtype.is_floating_point or dtype.is_complex),
12160*da0073e9SAndroid Build Coastguard Worker                # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12161*da0073e9SAndroid Build Coastguard Worker                set_seed=False,
12162*da0073e9SAndroid Build Coastguard Worker            )
12163*da0073e9SAndroid Build Coastguard Worker        cpu_samples = get_samples()
12164*da0073e9SAndroid Build Coastguard Worker
12165*da0073e9SAndroid Build Coastguard Worker        for cpu_sample in cpu_samples:
12166*da0073e9SAndroid Build Coastguard Worker            #
12167*da0073e9SAndroid Build Coastguard Worker            # Forward check
12168*da0073e9SAndroid Build Coastguard Worker            #
12169*da0073e9SAndroid Build Coastguard Worker            forward_failed = False
12170*da0073e9SAndroid Build Coastguard Worker            mps_sample = cpu_sample.transform(
12171*da0073e9SAndroid Build Coastguard Worker                lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
12172*da0073e9SAndroid Build Coastguard Worker
12173*da0073e9SAndroid Build Coastguard Worker            cpu_args = [cpu_sample.input] + list(cpu_sample.args)
12174*da0073e9SAndroid Build Coastguard Worker            cpu_kwargs = cpu_sample.kwargs
12175*da0073e9SAndroid Build Coastguard Worker            mps_args = [mps_sample.input] + list(mps_sample.args)
12176*da0073e9SAndroid Build Coastguard Worker            mps_kwargs = mps_sample.kwargs
12177*da0073e9SAndroid Build Coastguard Worker
12178*da0073e9SAndroid Build Coastguard Worker            # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
12179*da0073e9SAndroid Build Coastguard Worker            if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
12180*da0073e9SAndroid Build Coastguard Worker                mps_args[1] = cpu_args[1]
12181*da0073e9SAndroid Build Coastguard Worker
12182*da0073e9SAndroid Build Coastguard Worker            cpu_out = op(*cpu_args, **cpu_kwargs)
12183*da0073e9SAndroid Build Coastguard Worker            mps_out = op(*mps_args, **mps_kwargs)
12184*da0073e9SAndroid Build Coastguard Worker
12185*da0073e9SAndroid Build Coastguard Worker            if op.name == "unique" and cpu_kwargs["sorted"] is False:
12186*da0073e9SAndroid Build Coastguard Worker                continue
12187*da0073e9SAndroid Build Coastguard Worker
12188*da0073e9SAndroid Build Coastguard Worker            atol, rtol = self._compute_tolerances(op, dtype)
12189*da0073e9SAndroid Build Coastguard Worker            if op.name in ["renorm", "norm", "linalg.norm"] and dtype == torch.float16:
12190*da0073e9SAndroid Build Coastguard Worker                atol = 7e-4
12191*da0073e9SAndroid Build Coastguard Worker                rtol = 1.5e-3
12192*da0073e9SAndroid Build Coastguard Worker
12193*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
12194*da0073e9SAndroid Build Coastguard Worker
12195*da0073e9SAndroid Build Coastguard Worker            #
12196*da0073e9SAndroid Build Coastguard Worker            # Backward check
12197*da0073e9SAndroid Build Coastguard Worker            #
12198*da0073e9SAndroid Build Coastguard Worker            if forward_failed:
12199*da0073e9SAndroid Build Coastguard Worker                # We would've failed immediately anyway, but this error is clearer
12200*da0073e9SAndroid Build Coastguard Worker                # We error instead of continuing so that all_backward_pass would not be True
12201*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Forward pass already failed")
12202*da0073e9SAndroid Build Coastguard Worker
12203*da0073e9SAndroid Build Coastguard Worker            cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out)
12204*da0073e9SAndroid Build Coastguard Worker            mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out)
12205*da0073e9SAndroid Build Coastguard Worker
12206*da0073e9SAndroid Build Coastguard Worker            def req_grad(t):
12207*da0073e9SAndroid Build Coastguard Worker                return isinstance(t, torch.Tensor) and t.requires_grad
12208*da0073e9SAndroid Build Coastguard Worker
12209*da0073e9SAndroid Build Coastguard Worker            diff_cpu_out = tuple(t for t in cpu_out if req_grad(t))
12210*da0073e9SAndroid Build Coastguard Worker            diff_mps_out = tuple(t for t in mps_out if req_grad(t))
12211*da0073e9SAndroid Build Coastguard Worker            diff_cpu_arg = tuple(t for t in pytree.tree_leaves((cpu_args, cpu_kwargs)) if req_grad(t))
12212*da0073e9SAndroid Build Coastguard Worker            diff_mps_arg = tuple(t for t in pytree.tree_leaves((mps_args, mps_kwargs)) if req_grad(t))
12213*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(diff_cpu_out), len(diff_mps_out))
12214*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(diff_cpu_arg), len(diff_mps_arg))
12215*da0073e9SAndroid Build Coastguard Worker
12216*da0073e9SAndroid Build Coastguard Worker            if len(diff_cpu_out) == 0:
12217*da0073e9SAndroid Build Coastguard Worker                continue
12218*da0073e9SAndroid Build Coastguard Worker            # rand_like does not work with certain dtypes, so cast to double and cast back
12219*da0073e9SAndroid Build Coastguard Worker            cpu_grad_outputs = tuple(torch.rand_like(t, dtype=torch.double).to(dtype=t.dtype) for t in diff_cpu_out)
12220*da0073e9SAndroid Build Coastguard Worker            mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs)
12221*da0073e9SAndroid Build Coastguard Worker
12222*da0073e9SAndroid Build Coastguard Worker            # Compare computed gradients with cpu given random grad_output vector
12223*da0073e9SAndroid Build Coastguard Worker            # Sometimes when the derivative is 0, we just don't bother creating the graph
12224*da0073e9SAndroid Build Coastguard Worker            # allow_unused is needed in those cases.
12225*da0073e9SAndroid Build Coastguard Worker            cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
12226*da0073e9SAndroid Build Coastguard Worker            mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)
12227*da0073e9SAndroid Build Coastguard Worker
12228*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
12229*da0073e9SAndroid Build Coastguard Worker
12230*da0073e9SAndroid Build Coastguard Worker
12231*da0073e9SAndroid Build Coastguard Workerclass TestErrorInputs(TestCase):
12232*da0073e9SAndroid Build Coastguard Worker    _ignore_not_implemented_error = True
12233*da0073e9SAndroid Build Coastguard Worker
12234*da0073e9SAndroid Build Coastguard Worker    @ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none)
12235*da0073e9SAndroid Build Coastguard Worker    def test_error_inputs(self, device, op):
12236*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device, "mps:0")
12237*da0073e9SAndroid Build Coastguard Worker
12238*da0073e9SAndroid Build Coastguard Worker        # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12239*da0073e9SAndroid Build Coastguard Worker        mps_samples = op.error_inputs(device, set_seed=False)
12240*da0073e9SAndroid Build Coastguard Worker
12241*da0073e9SAndroid Build Coastguard Worker        for mps_sample in mps_samples:
12242*da0073e9SAndroid Build Coastguard Worker            mps_sample_input = mps_sample.sample_input
12243*da0073e9SAndroid Build Coastguard Worker            error_type = mps_sample.error_type
12244*da0073e9SAndroid Build Coastguard Worker            error_regex = mps_sample.error_regex
12245*da0073e9SAndroid Build Coastguard Worker
12246*da0073e9SAndroid Build Coastguard Worker            mps_args = [mps_sample_input.input] + list(mps_sample_input.args)
12247*da0073e9SAndroid Build Coastguard Worker            mps_kwargs = mps_sample_input.kwargs
12248*da0073e9SAndroid Build Coastguard Worker
12249*da0073e9SAndroid Build Coastguard Worker            # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
12250*da0073e9SAndroid Build Coastguard Worker            if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
12251*da0073e9SAndroid Build Coastguard Worker                mps_args[1] = mps_args[1].cpu()
12252*da0073e9SAndroid Build Coastguard Worker
12253*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(error_type, error_regex):
12254*da0073e9SAndroid Build Coastguard Worker                op(*mps_args, **mps_kwargs)
12255*da0073e9SAndroid Build Coastguard Worker
12256*da0073e9SAndroid Build Coastguard Workerclass TestComplex(TestCase):
12257*da0073e9SAndroid Build Coastguard Worker    def test_tensor_scalar_binops(self):
12258*da0073e9SAndroid Build Coastguard Worker        # Regression test for https://github.com/pytorch/pytorch/issues/119088
12259*da0073e9SAndroid Build Coastguard Worker        def to_cpu(x):
12260*da0073e9SAndroid Build Coastguard Worker            return x.cpu() if isinstance(x, torch.Tensor) else x
12261*da0073e9SAndroid Build Coastguard Worker
12262*da0073e9SAndroid Build Coastguard Worker        # Allocate tensors on mps
12263*da0073e9SAndroid Build Coastguard Worker        with torch.device("mps"):
12264*da0073e9SAndroid Build Coastguard Worker            inputs = [torch.rand(2, dtype=dtype) for dtype in [torch.float, torch.half, torch.cfloat]]
12265*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all(x.device.type == "mps" for x in inputs))
12266*da0073e9SAndroid Build Coastguard Worker        # Add scalars
12267*da0073e9SAndroid Build Coastguard Worker        inputs.extend([7, 3.14, 2 + 3j, torch.tensor(4 + 5j, dtype=torch.chalf)])
12268*da0073e9SAndroid Build Coastguard Worker
12269*da0073e9SAndroid Build Coastguard Worker        # Iterate over all permutations of types(int, float, complex, half) and ops (excluding div)
12270*da0073e9SAndroid Build Coastguard Worker        for x, y in itertools.product(inputs, inputs):
12271*da0073e9SAndroid Build Coastguard Worker            for op_name in ["__add__", "__sub__", "__mul__"]:
12272*da0073e9SAndroid Build Coastguard Worker                x_cpu, y_cpu = map(to_cpu, (x, y))
12273*da0073e9SAndroid Build Coastguard Worker                res = getattr(x, op_name)(y)
12274*da0073e9SAndroid Build Coastguard Worker                res_cpu = getattr(x_cpu, op_name)(y_cpu)
12275*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(to_cpu(res), res_cpu, f"{op_name}({x}, {y}) produces different results {res} vs {res_cpu}")
12276*da0073e9SAndroid Build Coastguard Worker
12277*da0073e9SAndroid Build Coastguard Worker
12278*da0073e9SAndroid Build Coastguard Worker# Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS
12279*da0073e9SAndroid Build Coastguard Worker@skipIfSlowGradcheckEnv
12280*da0073e9SAndroid Build Coastguard Workerclass TestCommon(TestCase):
12281*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
12282*da0073e9SAndroid Build Coastguard Worker
12283*da0073e9SAndroid Build Coastguard Worker    # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI
12284*da0073e9SAndroid Build Coastguard Worker    @classmethod
12285*da0073e9SAndroid Build Coastguard Worker    def tearDownClass(cls):
12286*da0073e9SAndroid Build Coastguard Worker        super().tearDownClass()
12287*da0073e9SAndroid Build Coastguard Worker
12288*da0073e9SAndroid Build Coastguard Worker        if IS_CI:
12289*da0073e9SAndroid Build Coastguard Worker            err_msg = (
12290*da0073e9SAndroid Build Coastguard Worker                "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
12291*da0073e9SAndroid Build Coastguard Worker                "This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
12292*da0073e9SAndroid Build Coastguard Worker            )
12293*da0073e9SAndroid Build Coastguard Worker            # Assure no opinfo entry has dynamic_dtypes
12294*da0073e9SAndroid Build Coastguard Worker            filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db))
12295*da0073e9SAndroid Build Coastguard Worker            for op in filtered_ops:
12296*da0073e9SAndroid Build Coastguard Worker                fmt_str = opinfo.utils.str_format_dynamic_dtype(op)
12297*da0073e9SAndroid Build Coastguard Worker                err_msg += "\n" + fmt_str
12298*da0073e9SAndroid Build Coastguard Worker
12299*da0073e9SAndroid Build Coastguard Worker            assert len(filtered_ops) == 0, err_msg
12300*da0073e9SAndroid Build Coastguard Worker
12301*da0073e9SAndroid Build Coastguard Worker    # This is the MPS equivalent of `test_numpy_ref` from `test_ops.py`. It lives over here while
12302*da0073e9SAndroid Build Coastguard Worker    # MPS still requires some fairly heavy special casing in the test framework.
12303*da0073e9SAndroid Build Coastguard Worker    # When MPS becomes more consistent, this can probably be merged with that test using
12304*da0073e9SAndroid Build Coastguard Worker    # `@dtypesIfMPS(torch.float32)`, but for now, the assertions themselves need to be loosened
12305*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
12306*da0073e9SAndroid Build Coastguard Worker    # MPS only supports float32
12307*da0073e9SAndroid Build Coastguard Worker    @ops(_ref_test_ops, allowed_dtypes=(torch.float32,))
12308*da0073e9SAndroid Build Coastguard Worker    def test_numpy_ref_mps(self, device, dtype, op):
12309*da0073e9SAndroid Build Coastguard Worker        # Unlike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS
12310*da0073e9SAndroid Build Coastguard Worker        # does not support float64 Tensors.
12311*da0073e9SAndroid Build Coastguard Worker        # A few ops are currently broken on their reference inputs, but not their sample inputs. These should
12312*da0073e9SAndroid Build Coastguard Worker        # get patched up and this workaround removed.
12313*da0073e9SAndroid Build Coastguard Worker        broken_on_ref_inputs = op.name in ('where',)
12314*da0073e9SAndroid Build Coastguard Worker
12315*da0073e9SAndroid Build Coastguard Worker        # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12316*da0073e9SAndroid Build Coastguard Worker        inputs = (
12317*da0073e9SAndroid Build Coastguard Worker            op.reference_inputs(device, dtype, set_seed=False) if not broken_on_ref_inputs
12318*da0073e9SAndroid Build Coastguard Worker            else op.sample_inputs(device, dtype, set_seed=False)
12319*da0073e9SAndroid Build Coastguard Worker        )
12320*da0073e9SAndroid Build Coastguard Worker        for sample_input in inputs:
12321*da0073e9SAndroid Build Coastguard Worker            self.compare_with_reference(op, op.ref, sample_input)
12322*da0073e9SAndroid Build Coastguard Worker
12323*da0073e9SAndroid Build Coastguard Worker    @dtypes(*get_all_dtypes())
12324*da0073e9SAndroid Build Coastguard Worker    def test_tensor_creation(self, device, dtype):
12325*da0073e9SAndroid Build Coastguard Worker        def ones(device):
12326*da0073e9SAndroid Build Coastguard Worker            return torch.ones((2, 2), dtype=dtype, device=device)
12327*da0073e9SAndroid Build Coastguard Worker        if dtype not in MPS_DTYPES + ([torch.bfloat16, torch.complex64] if product_version > 14.0 else [torch.complex64]):
12328*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(TypeError):
12329*da0073e9SAndroid Build Coastguard Worker                ones(device)
12330*da0073e9SAndroid Build Coastguard Worker        else:
12331*da0073e9SAndroid Build Coastguard Worker            mps_tensor = ones(device)
12332*da0073e9SAndroid Build Coastguard Worker            cpu_tensor = ones("cpu")
12333*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mps_tensor.cpu(), cpu_tensor)
12334*da0073e9SAndroid Build Coastguard Worker
12335*da0073e9SAndroid Build Coastguard Worker
12336*da0073e9SAndroid Build Coastguard Worker# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
12337*da0073e9SAndroid Build Coastguard Worker# This requires mps to be properly registered in the device generic test framework which is not the
12338*da0073e9SAndroid Build Coastguard Worker# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
12339*da0073e9SAndroid Build Coastguard Worker# to achieve this.
12340*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
12341*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
12342*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
12343*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
12344*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestMPS)
12345*da0073e9SAndroid Build Coastguard Worker
12346*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
12347*da0073e9SAndroid Build Coastguard Worker    run_tests()
12348