xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/OperationUtils.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1//  Copyright © 2022 Apple Inc.
2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3#include <ATen/TensorIterator.h>
4#include <ATen/mps/MPSAllocatorInterface.h>
5#include <ATen/mps/MPSProfiler.h>
6#include <ATen/native/mps/MPSGraphSequoiaOps.h>
7#include <ATen/native/mps/MPSGraphSonomaOps.h>
8#include <ATen/native/mps/MPSGraphVenturaOps.h>
9#include <ATen/native/mps/OperationUtils.h>
10#include <fmt/format.h>
11
12#ifndef AT_PER_OPERATOR_HEADERS
13#include <ATen/Functions.h>
14#include <ATen/NativeFunctions.h>
15#else
16#include <ATen/ops/scalar_tensor.h>
17#endif
18
19namespace at::native::mps {
20
21void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
22  __block std::optional<std::exception_ptr> block_exception;
23  dispatch_sync(queue, ^() {
24    try {
25      block();
26    } catch (...) {
27      block_exception = std::current_exception();
28    }
29  });
30  if (block_exception) {
31    std::rethrow_exception(*block_exception);
32  }
33}
34
35/**
36 * Computes distance from lowest to highest element offset in given tensor.
37 */
38size_t compute_storage_numel_distance(const at::Tensor& t) {
39  size_t rc = 1;
40  if (t.numel() == 0) {
41    return 0;
42  }
43  for (const auto i : c10::irange(t.dim())) {
44    assert(t.size(i) > 0);
45    rc += (t.size(i) - 1) * t.stride(i);
46  }
47  return rc;
48}
49
50void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) {
51  mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE);
52}
53
54static inline void checkSupportsComplex() {
55  TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer.");
56}
57
58static inline void checkSupportsBFloat16() {
59  TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
60                   "MPS bfloat16 type is supported on MacOS 14.0 or newer.");
61}
62
63MPSDataType getMPSDataType(ScalarType scalar_type) {
64  switch (scalar_type) {
65    case ScalarType::Float:
66      return MPSDataTypeFloat32;
67    case ScalarType::Half:
68      return MPSDataTypeFloat16;
69    case ScalarType::BFloat16:
70      checkSupportsBFloat16();
71      return MPSDataTypeBFloat16;
72    case ScalarType::Int:
73      return MPSDataTypeInt32;
74    case ScalarType::Long:
75      return MPSDataTypeInt64;
76    case ScalarType::Short:
77      return MPSDataTypeInt16;
78    case ScalarType::Char:
79      return MPSDataTypeInt8;
80    case ScalarType::Byte:
81      return MPSDataTypeUInt8;
82    case ScalarType::Bool:
83      return MPSDataTypeBool;
84    case ScalarType::Double:
85      TORCH_CHECK_TYPE(false,
86                       "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
87                       "Please use float32 instead.")
88    case ScalarType::ComplexHalf:
89      checkSupportsComplex();
90      return MPSDataTypeComplexFloat16;
91    case ScalarType::ComplexFloat:
92      checkSupportsComplex();
93      return MPSDataTypeComplexFloat32;
94    default:
95      TORCH_CHECK_TYPE(
96          false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
97  }
98}
99
100// #issue 104398441 sortWithTensor and argsortWithTensor has support of
101// Int32, Half and Float32 types. These utilities are to help cast to these
102// types.
103MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
104                               MPSGraphTensor* inputTensor,
105                               const Tensor& input,
106                               bool includesInt64) {
107  MPSDataType dataType = getMPSDataType(input.scalar_type());
108  bool condition =
109      (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
110  if (includesInt64) {
111    condition = condition && (dataType != MPSDataTypeInt64);
112  }
113  if (condition) {
114    dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
115    return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
116  }
117  return inputTensor;
118}
119
120// #issue 104398441 sortWithTensor and argsortWithTensor has support of
121// Int32, Half and Float32 types. These utilities are to help cast from these
122// types.
123MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
124                                 MPSGraphTensor* inputTensor,
125                                 const Tensor& input,
126                                 bool includesInt64) {
127  MPSDataType dataType = getMPSDataType(input.scalar_type());
128  bool condition =
129      (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
130  if (includesInt64) {
131    condition = condition && (dataType != MPSDataTypeInt64);
132  }
133  if (condition) {
134    inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
135  }
136  return inputTensor;
137}
138
139MPSDataType getMPSScalarType(ScalarType scalar_type) {
140  switch (scalar_type) {
141    // This is an intentional fallthrough supporting Double for Scalar
142    // types as they are casted to Float32 currently.
143    case ScalarType::Double:
144    case ScalarType::Float:
145      return MPSDataTypeFloat32;
146    case ScalarType::Half:
147      return MPSDataTypeFloat16;
148    case ScalarType::BFloat16:
149      checkSupportsBFloat16();
150      return MPSDataTypeBFloat16;
151    case ScalarType::Int:
152      return MPSDataTypeInt32;
153    case ScalarType::Long:
154      return MPSDataTypeInt64;
155    case ScalarType::Short:
156      return MPSDataTypeInt16;
157    case ScalarType::Char:
158      return MPSDataTypeInt8;
159    case ScalarType::Byte:
160      return MPSDataTypeUInt8;
161    case ScalarType::Bool:
162      return MPSDataTypeBool;
163    case ScalarType::ComplexHalf:
164      checkSupportsComplex();
165      return MPSDataTypeComplexFloat16;
166    // This is an intentional fallthrough supporting ComplexDouble for Scalar
167    // types as they are casted to Complex64 currently.
168    case ScalarType::ComplexDouble:
169    case ScalarType::ComplexFloat:
170      checkSupportsComplex();
171      return MPSDataTypeComplexFloat32;
172    default:
173      TORCH_CHECK_TYPE(
174          false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.")
175  }
176}
177
178// use short_name to avoid getting extra long cached graph keys with ops such as cat_out(), etc.
179std::string getMPSTypeString(ScalarType scalar_type, bool short_name) {
180  switch (scalar_type) {
181    case ScalarType::Double:
182    case ScalarType::Float:
183      return short_name ? "f32" : "Float32";
184    case ScalarType::Half:
185      return short_name ? "f16" : "Float16";
186    case ScalarType::BFloat16:
187      return short_name ? "bf16" : "BFloat16";
188    case ScalarType::Int:
189      return short_name ? "i32" : "Int32";
190    case ScalarType::Long:
191      return short_name ? "i64" : "Int64";
192    case ScalarType::Short:
193      return short_name ? "i16" : "Int16";
194    case ScalarType::Char:
195      return short_name ? "i8" : "Int8";
196    case ScalarType::Byte:
197      return short_name ? "u8" : "UInt8";
198    case ScalarType::Bool:
199      return short_name ? "b8" : "Bool";
200    case ScalarType::ComplexHalf:
201      return short_name ? "c16" : "ComplexFloat16";
202    case ScalarType::ComplexFloat:
203      return short_name ? "c32" : "ComplexFloat32";
204    default:
205      return "Undefined";
206  }
207}
208
209std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type) {
210  switch (scalar_type) {
211    case ScalarType::Float:
212      return "float";
213    case ScalarType::Half:
214      return "half";
215    case ScalarType::BFloat16:
216      checkSupportsBFloat16();
217      return "bfloat";
218    case ScalarType::Int:
219      return "int";
220    case ScalarType::Long:
221      return "long";
222    case ScalarType::Short:
223      return "short";
224    case ScalarType::Char:
225      return "char";
226    case ScalarType::Byte:
227      return "uchar";
228    case ScalarType::Bool:
229      return "bool";
230    default:
231      TORCH_CHECK(false, "Undefined type ", scalar_type);
232      return "Undefined";
233  }
234}
235
236static NSArray<NSNumber*>* getTensorAxes(int64_t ndim) {
237  auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
238  for (const auto i : c10::irange(ndim)) {
239    axes[i] = [NSNumber numberWithInteger:i];
240  }
241  return axes;
242}
243
244NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
245  return getTensorAxes(t.dim());
246}
247
248static NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes) {
249  return getTensorAxes(sizes.size());
250}
251
252NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim) {
253  if (dim.has_value() && !dim.value().empty()) {
254    IntArrayRef dimValues = dim.value();
255    int ndim = dimValues.size();
256    auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
257    for (const auto i : c10::irange(ndim)) {
258      axes[i] = [NSNumber numberWithInteger:dimValues[i]];
259    }
260
261    return axes;
262  }
263
264  return getTensorAxes(sizes);
265}
266
267std::string getMPSShapeString(MPSShape* shape) {
268  std::string str;
269  for (NSNumber* elem in shape) {
270    str += std::to_string(elem.unsignedLongValue) + ",";
271  }
272  return str;
273}
274
275std::string getArrayRefString(const IntArrayRef s) {
276  std::stringstream ss;
277  std::copy(s.begin(), s.end(), std::ostream_iterator<int>(ss, ","));
278  return ss.str();
279}
280
281std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
282  std::string str;
283  // The key format per tensor would look like ":Float32[1,1,1,10]:"
284  for (const Tensor& tensor : tensors) {
285    str += ":";
286    if (tensor.defined()) {
287      str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "[";
288      // if tensor is a scalar
289      if (tensor.dim() == 0) {
290        str += "Scalar";
291      } else {
292        if (exclude_shape) {
293          str += "[-1]";
294        } else {
295          str +=
296              std::string([[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","].UTF8String);
297        }
298      }
299      str += "]";
300    } else {
301      str += "Undefined";
302    }
303  }
304  return str;
305}
306
307Tensor getTensorView(const Tensor& t, MPSShape* shape) {
308  std::vector<int64_t> res;
309  res.reserve([shape count]);
310  for (NSNumber* elem in shape) {
311    res.push_back(elem.longLongValue);
312  }
313  IntArrayRef r = IntArrayRef(res);
314  return t.view(res);
315}
316
317MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format) {
318  return getMPSShape(t.sizes(), memory_format);
319}
320
321MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format) {
322  if (memory_format == MemoryFormat::ChannelsLast) {
323    TORCH_INTERNAL_ASSERT(sizes.size() == 4, "ChannelsLast memory format must have 4 dimensions!");
324    const NSUInteger N = sizes[0];
325    const NSUInteger C = sizes[1];
326    const NSUInteger H = sizes[2];
327    const NSUInteger W = sizes[3];
328    return @[ @(N), @(H), @(W), @(C) ];
329  }
330  const int sz = sizes.size();
331  const int sz_ = (sz > 0) ? sz : 1;
332
333  std::vector<NSNumber*> numbers(sz_);
334
335  for (int i = 0; i < sz_; i++) {
336    NSInteger sz_i = (i < sz) ? sizes[i] : 1;
337    NSNumber* number = [NSNumber numberWithInteger:sz_i];
338    numbers[i] = number;
339  }
340  return [NSArray arrayWithObjects:numbers.data() count:numbers.size()];
341}
342
343void printTensorNDArray(const Tensor& t) {
344  if (!t.is_mps())
345    return;
346  if (t.numel() == 0)
347    return;
348  // Get shape and data type
349  auto selfShape = getMPSShape(t);
350  auto selfDType = getMPSDataType(t.scalar_type());
351
352  // Initialize data
353  id<MTLBuffer> selfBuf = getMTLBufferStorage(t);
354  MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape
355                                                                    dataType:selfDType] autorelease];
356  C10_CLANG_DIAGNOSTIC_PUSH()
357#if C10_CLANG_HAS_WARNING("-Wobjc-method-access")
358  C10_CLANG_DIAGNOSTIC_IGNORE("-Wobjc-method-access")
359#endif
360  [tdata printNDArray];
361  C10_CLANG_DIAGNOSTIC_POP()
362}
363
364MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape* shape, MPSDataType mpsType) {
365  id<MTLBuffer> buffer = getMTLBufferStorage(tensor);
366  MPSGraphTensorData* tmpGraphTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer
367                                                                                    shape:shape
368                                                                                 dataType:mpsType] autorelease];
369
370  return [tmpGraphTensorData mpsndarray];
371}
372
373static std::vector<int64_t> getSortedStrides(const IntArrayRef& s) {
374  std::vector<int64_t> idx(s.size());
375  iota(idx.begin(), idx.end(), 0);
376  sort(idx.begin(), idx.end(), [&s](size_t i1, size_t i2) { return s[i1] > s[i2]; });
377
378  return idx;
379}
380
381static std::vector<int64_t> inversePermutation(const std::vector<int64_t>& permuteOrder) {
382  auto size = permuteOrder.size();
383  std::vector<int64_t> inversePerm(permuteOrder.size());
384
385  for (int i = 0; i < size; i++) {
386    inversePerm[permuteOrder[i]] = i;
387  }
388  return inversePerm;
389}
390
391static MPSNDArray* permuteNDArray(MPSNDArray* inArray, const std::vector<int64_t>& permuteOrder_) {
392  auto permuteOrder = inversePermutation(permuteOrder_);
393  NSUInteger srcRank = [inArray numberOfDimensions];
394  if (srcRank != permuteOrder.size()) {
395    TORCH_INTERNAL_ASSERT(false);
396    return nil;
397  }
398  std::vector<NSUInteger> dimensionOrder(srcRank);
399  std::iota(std::begin(dimensionOrder), std::end(dimensionOrder), 0);
400  MPSNDArrayDescriptor* desc = [inArray descriptor];
401
402  for (int64_t i = srcRank - 1; i >= 0; i--) {
403    NSUInteger axis = permuteOrder[i];
404    auto axisIter = std::find(dimensionOrder.begin(), dimensionOrder.end(), axis);
405    NSUInteger axis1 = srcRank - i - 1;
406    NSUInteger axis2 = dimensionOrder.end() - axisIter - 1;
407    iter_swap(dimensionOrder.begin() + i, axisIter);
408    if (axis1 != axis2) {
409      [desc transposeDimension:axis1 withDimension:axis2];
410    }
411  }
412  C10_CLANG_DIAGNOSTIC_PUSH()
413#if C10_CLANG_HAS_WARNING("-Wnonnull")
414  C10_CLANG_DIAGNOSTIC_IGNORE("-Wnonnull")
415#endif
416  MPSNDArray* result = [inArray arrayViewWithCommandBuffer:nil descriptor:desc aliasing:MPSAliasingStrategyShallAlias];
417  C10_CLANG_DIAGNOSTIC_POP()
418
419  TORCH_INTERNAL_ASSERT(result != nil);
420  return result;
421}
422
423MPSNDArray* getMPSNDArray(const at::Tensor& t, MPSShape* sizes, MPSShape* strides) {
424  id<MTLBuffer> srcBuf = getMTLBufferStorage(t);
425
426  MPSDataType mpsDataType = getMPSDataType(t.scalar_type());
427  MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:mpsDataType shape:sizes];
428  srcTensorDesc.preferPackedRows = YES;
429  MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf
430                                                        offset:t.storage_offset() * t.element_size()
431                                                    descriptor:srcTensorDesc] autorelease];
432  if (strides != nil) {
433    srcNDArray = [srcNDArray arrayViewWithShape:sizes strides:strides];
434  }
435  return srcNDArray;
436}
437
438MPSNDArray* getMPSNDArray(const at::Tensor& t, const IntArrayRef& sizes, const IntArrayRef& strides) {
439  return getMPSNDArray(t, getMPSShape(sizes.empty() ? t.sizes() : sizes), strides.empty() ? nil : getMPSShape(strides));
440}
441
442static MPSNDArray* getStridedMPSNDArray(const at::Tensor& src, MPSNDArray* srcNDArray) {
443  auto strides = src.strides();
444  auto sizes = src.sizes();
445  auto nStrides = strides.size();
446  auto nonZeroStrides = src.strides();
447  int64_t crtNonZeroStride = 1;
448  bool hasZeroStrides = false;
449  auto sortedStridesIndices = getSortedStrides(nonZeroStrides);
450
451  NSMutableArray<NSNumber*>* sortedStridesShape = [NSMutableArray arrayWithCapacity:nStrides];
452  NSMutableArray<NSNumber*>* sortedMPSShape = [NSMutableArray arrayWithCapacity:nStrides];
453  for (const auto i : c10::irange(nStrides)) {
454    sortedStridesShape[i] = [NSNumber numberWithInteger:nonZeroStrides[sortedStridesIndices[i]]];
455    sortedMPSShape[i] = [NSNumber numberWithInteger:sizes[sortedStridesIndices[i]]];
456  }
457  MPSShape* originalSortedMPSShape = sortedMPSShape;
458  MPSShape* originalSortedStridesShape = sortedStridesShape;
459  bool hasNonZeroStrides = nStrides == 0 ? false : nonZeroStrides[sortedStridesIndices[nStrides - 1]] != 1;
460  if (hasNonZeroStrides) {
461    originalSortedMPSShape = [sortedMPSShape copy];
462    originalSortedStridesShape = [sortedStridesShape copy];
463    [sortedStridesShape addObject:[NSNumber numberWithInteger:1]];
464    [sortedMPSShape addObject:[NSNumber numberWithInteger:1]];
465  }
466  if (nStrides == 0) {
467    originalSortedMPSShape = getMPSShape(src);
468    originalSortedStridesShape = getMPSShape(src.strides());
469  }
470
471  srcNDArray = [srcNDArray arrayViewWithShape:sortedMPSShape strides:sortedStridesShape];
472  if (hasNonZeroStrides) {
473    MPSNDArrayIdentity* identity =
474        [[[MPSNDArrayIdentity alloc] initWithDevice:MPSDevice::getInstance()->device()] autorelease];
475    srcNDArray = [identity reshapeWithCommandBuffer:nil
476                                        sourceArray:srcNDArray
477                                              shape:originalSortedMPSShape
478                                   destinationArray:nil];
479  }
480  TORCH_INTERNAL_ASSERT(srcNDArray);
481
482  srcNDArray = permuteNDArray(srcNDArray, sortedStridesIndices);
483  TORCH_INTERNAL_ASSERT(srcNDArray);
484
485  return srcNDArray;
486}
487
488Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray) {
489  _placeholder = mpsGraphTensor;
490  _value = [[[MPSGraphTensorData alloc] initWithMPSNDArray:mpsNDArray] autorelease];
491}
492
493Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
494                         const Tensor& src,
495                         MPSShape* mpsShape_,
496                         bool gatherTensorData,
497                         MPSDataType dataType,
498                         bool useMPSStridedAPI)
499    : _tensor(src) {
500  TORCH_CHECK(src.is_mps(), "Placeholder storage has not been allocated on MPS device!");
501  // extract the pointer to MTLBuffer from the Tensor's storage
502  id<MTLBuffer> srcBuf = getMTLBufferStorage(src);
503
504  static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
505  // Use gather kernel to solve strides for macOS < 15.0
506  // Starting with macOS 15.0, MPS supports native strides direclty in the kernels
507  if (!is_macOS_15_0_or_newer || !useMPSStridedAPI) {
508    if ((!src.is_contiguous() || src.storage_offset()) && gatherTensorData) {
509      Tensor emptyShell = Tensor();
510      // use "_tensor" from Placeholder to retain view's output during its usage in other ops
511      _tensor = gatherViewTensor(src, emptyShell);
512      if (!_tensor.has_storage()) {
513        // if we cannot gather, we make the tensor contiguous implicitly, and keep
514        // it in placeholder to be able to retrieve it when we return from constructor
515        _tensor = src.clone(MemoryFormat::Contiguous);
516      }
517      srcBuf = getMTLBufferStorage(_tensor);
518    }
519  }
520
521  // tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero.
522  // if buffer size is zero in here, it's not a user error. It could be a missing check for
523  // tensor.numel() == 0 in our internal implementations of ops.
524  TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!");
525  if (dataType == MPSDataTypeInvalid) {
526    const auto scalar_type = _tensor.scalar_type();
527    dataType = _tensor.dim() == 0 ? getMPSScalarType(scalar_type) : getMPSDataType(scalar_type);
528  }
529
530  // Tensor is contiguous and has no storage offset.
531  // Wrap it directly inside MPSGraphTensorData
532  if ((_tensor.is_contiguous() && !_tensor.storage_offset()) || !useMPSStridedAPI || !is_macOS_15_0_or_newer) {
533    _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
534                                                      shape:mpsShape_ ? mpsShape_ : getMPSShape(_tensor)
535                                                   dataType:dataType] autorelease];
536  } else {
537    IntArrayRef view_shape;
538    if (mpsShape_) {
539      _tensor = getTensorView(src, mpsShape_);
540    }
541
542    MPSShape* mpsShape = getMPSShape(_tensor);
543    MPSShape* mpsStrides = getMPSShape(_tensor.strides());
544
545    auto storage_numel = src.storage().nbytes() / src.element_size();
546    MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType
547                                                                                 shape:@[ @(storage_numel) ]];
548    srcTensorDesc.preferPackedRows = YES;
549    MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf
550                                                          offset:src.storage_offset() * src.element_size()
551                                                      descriptor:srcTensorDesc] autorelease];
552    TORCH_INTERNAL_ASSERT(srcNDArray);
553    if (src.dim() != 0) {
554      srcNDArray = getStridedMPSNDArray(_tensor, srcNDArray);
555    } else {
556      bool needsReshape = false;
557      NSMutableArray* mpsExpandedShape = nil;
558      NSMutableArray* mpsExpandedStrides = nil;
559
560      if (src.dim() > 0 && src.stride(-1) != 1) {
561        needsReshape = true;
562        mpsExpandedShape = [NSMutableArray arrayWithArray:mpsShape];
563        mpsExpandedStrides = [NSMutableArray arrayWithArray:mpsStrides];
564        [mpsExpandedShape addObject:@1];
565        [mpsExpandedStrides addObject:@1];
566      }
567      srcNDArray = [srcNDArray arrayViewWithShape:needsReshape ? mpsExpandedShape : getMPSShape(src)
568                                          strides:needsReshape ? mpsExpandedStrides : getMPSShape(src.strides())];
569      TORCH_INTERNAL_ASSERT(srcNDArray);
570
571      if (needsReshape) {
572        MPSNDArrayIdentity* identity =
573            [[[MPSNDArrayIdentity alloc] initWithDevice:MPSDevice::getInstance()->device()] autorelease];
574        srcNDArray = [identity reshapeWithCommandBuffer:nil sourceArray:srcNDArray shape:mpsShape destinationArray:nil];
575      }
576      TORCH_INTERNAL_ASSERT(srcNDArray);
577    }
578    _value = [[[MPSGraphTensorData alloc] initWithMPSNDArray:srcNDArray] autorelease];
579  }
580
581  TORCH_INTERNAL_ASSERT(_value);
582  _placeholder = mpsGraphTensor;
583}
584
585MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor) {
586  auto mpsShape = getMPSShape(tensor);
587  auto dataType = getMPSDataType(tensor.scalar_type());
588
589  MPSGraphTensorData* result = nil;
590  if (tensor.numel() > 0) {
591    id<MTLBuffer> buf = getMTLBufferStorage(tensor);
592    result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buf shape:mpsShape dataType:dataType] autorelease];
593  } else {
594    // create empty NDArray
595    MPSNDArrayDescriptor* desc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:mpsShape];
596    MPSNDArray* emptyArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() descriptor:desc] autorelease];
597    result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:emptyArray] autorelease];
598  }
599  TORCH_INTERNAL_ASSERT(result);
600  return result;
601}
602
603MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) {
604  switch (type) {
605    case ScalarType::Double:
606    case ScalarType::Float:
607      return {.value.f = scalar.to<float>(), .size = sizeof(float), .type = type};
608    case ScalarType::Half:
609      return {.value.h = scalar.to<at::Half>(), .size = sizeof(short), .type = type};
610    case ScalarType::BFloat16:
611      return {.value.bf16 = scalar.to<at::BFloat16>(), .size = sizeof(short), .type = type};
612    case ScalarType::Long:
613      return {.value.i = scalar.to<int64_t>(), .size = sizeof(int64_t), .type = type};
614    case ScalarType::Int:
615      return {.value.i = scalar.to<int32_t>(), .size = sizeof(int32_t), .type = type};
616    case ScalarType::Short:
617      return {.value.i = scalar.to<int16_t>(), .size = sizeof(int16_t), .type = type};
618    case ScalarType::Char:
619      return {.value.i = scalar.to<int8_t>(), .size = sizeof(int8_t), .type = type};
620    case ScalarType::Byte:
621      return {.value.i = scalar.to<uint8_t>(), .size = sizeof(uint8_t), .type = type};
622    case ScalarType::Bool:
623      return {.value.b = scalar.to<bool>(), .size = sizeof(bool), .type = type};
624    case ScalarType::ComplexHalf:
625      return {.value.ch = scalar.to<c10::complex<at::Half>>(), .size = sizeof(int32_t), .type = type};
626    case ScalarType::ComplexFloat:
627    case ScalarType::ComplexDouble:
628      return {.value.cf = scalar.to<c10::complex<float>>(), .size = sizeof(int64_t), .type = type};
629    default:
630      TORCH_INTERNAL_ASSERT(false, "Unsupported scalar type '", type, "' on MPS backend.");
631  }
632}
633
634MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar) {
635  MPSGraphTensorData* result = nullptr;
636  // Scalar pools are only supported on devices with unified memory
637  if (mpsStream->device().hasUnifiedMemory) {
638    scalar.buffer = getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size);
639    result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:scalar.getMTLBuffer()
640                                                      shape:@[ @1 ]
641                                                   dataType:getMPSScalarType(scalar.type)] autorelease];
642  } else {
643    MPSNDArrayDescriptor* tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:getMPSScalarType(scalar.type)
644                                                                              shape:@[ @1 ]];
645    MPSNDArray* tensorNDArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device()
646                                                         descriptor:tensorDesc] autorelease];
647    [tensorNDArray writeBytes:&scalar.value strideBytes:nil];
648    result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:tensorNDArray] autorelease];
649  }
650  return result;
651}
652
653void resize_tensor(Tensor* output) {
654  output->resize_(output->sizes());
655}
656
657Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device) {
658  // Copied and modified from aten/stc/ATen/ScalarOps.h
659  // as MPS doesn't support float64 tensor.
660  Tensor tensor;
661  if (scalar.isFloatingPoint()) {
662    tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kFloat));
663  } else if (scalar.isBoolean()) {
664    tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kBool));
665  } else if (scalar.isComplex()) {
666    tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexDouble));
667  } else {
668    TORCH_INTERNAL_ASSERT(scalar.isIntegral(false));
669    tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kLong));
670  }
671  tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
672  return tensor;
673}
674
675MPSGraph* make_mps_graph() {
676  MPSGraph* mpsGraph = [[MPSGraph new] autorelease];
677  return mpsGraph;
678}
679
680MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) {
681  return [mpsGraph placeholderWithShape:nil dataType:dataType name:nil];
682}
683
684MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape) {
685  return [mpsGraph placeholderWithShape:mpsShape dataType:dataType name:nil];
686}
687
688MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const Tensor& tensor) {
689  return [mpsGraph placeholderWithShape:getMPSShape(tensor) dataType:getMPSScalarType(tensor.scalar_type()) name:nil];
690}
691
692MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) {
693  return [mpsGraph placeholderWithShape:@[ @1 ] dataType:dataType name:nil];
694}
695
696MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar) {
697  return [mpsGraph placeholderWithShape:@[ @1 ] dataType:getMPSScalarType(scalar.type()) name:nil];
698}
699
700// this is meant to suppress the availability warning on castTensor
701// we pass ScalarType instead of MPSDataType to handle MPSDataTypeBoolean's availability too
702MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) {
703  if ([tensor dataType] == toType) {
704    return tensor;
705  }
706  return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"];
707}
708
709MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType) {
710  return [mpsGraph castTensor:tensor toType:getMPSScalarType(toType) name:@"castTensor"];
711}
712
713MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor) {
714  TORCH_INTERNAL_ASSERT(tensor.shape.count == 4, "Tensor must have 4 dimensions!");
715  return [mpsGraph transposeTensor:[mpsGraph transposeTensor:tensor dimension:3 withDimension:2 name:nil]
716                         dimension:2
717                     withDimension:1
718                              name:nil];
719}
720
721string get_mem_format_string(c10::MemoryFormat memory_format) {
722  string mem_format_key;
723  switch (memory_format) {
724    case at::MemoryFormat::Contiguous:
725      mem_format_key = "Contiguous";
726      break;
727    case at::MemoryFormat::ChannelsLast:
728      mem_format_key = "ChannelsLast";
729      break;
730    default:
731      TORCH_CHECK(false, "Invalid memory format", memory_format);
732  }
733
734  return mem_format_key;
735}
736
737MPSGraphCache* MPSGraphCache::_instance_cache = nullptr;
738
739void MPSGraphCache::profileCachedGraph(const CacheEntry& cacheEntry) const {
740  auto& profiler = getMPSProfiler();
741  if (profiler.isOperationProfilingEnabled()) {
742    std::string graphKey = cacheEntry.key_;
743    // for interval-based signpost tracing, we begin the interval here to be able
744    // to measure the time it takes to compile the graphs (if graph newly created),
745    // and also the time potentially spent on gather/scatter of graph's input tensors
746    profiler.beginProfileKernel(cacheEntry.cachedGraph_->graph(), graphKey, true);
747  }
748}
749
750class MPSGraphCacheCallback : public IMpsAllocatorCallback {
751 public:
752  MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) {}
753
754  void executeMPSAllocatorCallback(void* ptr, EventType event) override {}
755
756 private:
757  MPSGraphCache* graph_cache;
758};
759
760REGISTER_MPS_ALLOCATOR_CALLBACK("mps_graph_cache_callback", MPSGraphCacheCallback);
761
762id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder,
763                                        const TensorIteratorBase& iter,
764                                        bool use_64bit_index) {
765  constexpr uint32_t nOffsets = 3;
766  uint32_t numThreads = iter.numel();
767  const uint32_t nDim = iter.ndim();
768  const IntArrayRef& iterShape = iter.shape();
769  std::vector<uint32_t> iterShapeData(iterShape.size());
770  std::vector<std::array<uint32_t, nOffsets>> strides(nDim);
771  TORCH_INTERNAL_ASSERT(iter.ntensors() >= nOffsets);
772  TORCH_CHECK(use_64bit_index || iter.can_use_32bit_indexing(), "Can't be indexed using 32-bit iterator");
773
774  for (const auto i : c10::irange(iterShape.size())) {
775    iterShapeData[i] = static_cast<uint32_t>(iterShape[i]);
776  }
777
778  for (const auto i : c10::irange(nDim)) {
779    for (const auto offset : c10::irange(nOffsets)) {
780      strides[i][offset] = static_cast<uint32_t>(iter.strides(offset)[i]);
781    }
782  }
783
784  id<MTLComputePipelineState> kernelDataOffsetsPSO = MPSDevice::getInstance()->metalIndexingPSO(
785      use_64bit_index ? "kernel_index_offsets_64" : "kernel_index_offsets_32");
786  const auto elementSize = use_64bit_index ? sizeof(simd_ulong3) : sizeof(simd_uint3);
787  id<MTLBuffer> kernelDataOffsets = (id<MTLBuffer>)getIMPSAllocator()->allocate(numThreads * elementSize).get();
788
789  [commandEncoder setComputePipelineState:kernelDataOffsetsPSO];
790  [commandEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
791  [commandEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1];
792  [commandEncoder setBytes:iterShapeData.data() length:sizeof(uint32_t) * iterShape.size() atIndex:2];
793  [commandEncoder setBytes:&nDim length:sizeof(uint32_t) atIndex:3];
794
795  mtl_dispatch1DJob(commandEncoder, kernelDataOffsetsPSO, numThreads);
796
797  return kernelDataOffsets;
798}
799
800id<MTLLibrary> MetalShaderLibrary::getLibrary() {
801  if (C10_UNLIKELY(!library)) {
802    TORCH_INTERNAL_ASSERT(nparams == 0);
803    library = compileLibrary(shaderSource);
804  }
805  return library;
806}
807
808id<MTLLibrary> MetalShaderLibrary::getLibrary(const std::initializer_list<std::string>& params) {
809  TORCH_INTERNAL_ASSERT(nparams == params.size());
810  std::string key = "";
811  for (auto p : params) {
812    key += ":" + p;
813  }
814  auto lib = libMap[key];
815  if (lib) {
816    return lib;
817  }
818  auto it = params.begin();
819  switch (nparams) {
820    case 1:
821      lib = compileLibrary(fmt::format(shaderSource, *it));
822      break;
823    case 2: {
824      auto& first = *it++;
825      auto& second = *it;
826      lib = compileLibrary(fmt::format(shaderSource, first, second));
827      break;
828    }
829    case 3: {
830      auto& first = *it++;
831      auto& second = *it++;
832      auto& third = *it;
833      lib = compileLibrary(fmt::format(shaderSource, first, second, third));
834      break;
835    }
836    default:
837      TORCH_INTERNAL_ASSERT(false, "Unsupported number of paramaters ", nparams);
838  }
839  return libMap[key] = lib;
840}
841
842id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) {
843  static const char* fast_math = std::getenv("PYTORCH_MPS_FAST_MATH");
844  NSError* error = nil;
845  MTLCompileOptions* options = compile_options;
846  if (!options) {
847    options = [[MTLCompileOptions new] autorelease];
848    [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
849                                                                                        : MTLLanguageVersion2_3];
850    [options setFastMathEnabled:(!fast_math || std::stoi(fast_math) == 0) ? NO : YES];
851  }
852
853  const auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding];
854  auto device = MPSDevice::getInstance()->device();
855  library = [device newLibraryWithSource:str options:options error:&error];
856  TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
857  return library;
858}
859
860std::pair<id<MTLComputePipelineState>, id<MTLFunction>> MetalShaderLibrary::getLibraryPipelineState(
861    id<MTLLibrary> lib,
862    const std::string& fname) {
863  const auto key = fmt::format("{}:{}", reinterpret_cast<void*>(lib), fname);
864  auto found_cpl = cplMap.find(key);
865  if (found_cpl != cplMap.end()) {
866    return found_cpl->second;
867  }
868
869  NSError* error = nil;
870  id<MTLFunction> func = [lib newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
871  TORCH_CHECK(func, "Failed to create function state object for: ", fname);
872  auto cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error];
873  TORCH_CHECK(cpl, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
874
875  cplMap[key] = std::make_pair(cpl, func);
876  return cplMap[key];
877}
878
879} // namespace at::native::mps
880