1// Copyright © 2022 Apple Inc. 2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 3#include <ATen/TensorIterator.h> 4#include <ATen/mps/MPSAllocatorInterface.h> 5#include <ATen/mps/MPSProfiler.h> 6#include <ATen/native/mps/MPSGraphSequoiaOps.h> 7#include <ATen/native/mps/MPSGraphSonomaOps.h> 8#include <ATen/native/mps/MPSGraphVenturaOps.h> 9#include <ATen/native/mps/OperationUtils.h> 10#include <fmt/format.h> 11 12#ifndef AT_PER_OPERATOR_HEADERS 13#include <ATen/Functions.h> 14#include <ATen/NativeFunctions.h> 15#else 16#include <ATen/ops/scalar_tensor.h> 17#endif 18 19namespace at::native::mps { 20 21void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) { 22 __block std::optional<std::exception_ptr> block_exception; 23 dispatch_sync(queue, ^() { 24 try { 25 block(); 26 } catch (...) { 27 block_exception = std::current_exception(); 28 } 29 }); 30 if (block_exception) { 31 std::rethrow_exception(*block_exception); 32 } 33} 34 35/** 36 * Computes distance from lowest to highest element offset in given tensor. 37 */ 38size_t compute_storage_numel_distance(const at::Tensor& t) { 39 size_t rc = 1; 40 if (t.numel() == 0) { 41 return 0; 42 } 43 for (const auto i : c10::irange(t.dim())) { 44 assert(t.size(i) > 0); 45 rc += (t.size(i) - 1) * t.stride(i); 46 } 47 return rc; 48} 49 50void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) { 51 mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE); 52} 53 54static inline void checkSupportsComplex() { 55 TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer."); 56} 57 58static inline void checkSupportsBFloat16() { 59 TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS), 60 "MPS bfloat16 type is supported on MacOS 14.0 or newer."); 61} 62 63MPSDataType getMPSDataType(ScalarType scalar_type) { 64 switch (scalar_type) { 65 case ScalarType::Float: 66 return MPSDataTypeFloat32; 67 case ScalarType::Half: 68 return MPSDataTypeFloat16; 69 case ScalarType::BFloat16: 70 checkSupportsBFloat16(); 71 return MPSDataTypeBFloat16; 72 case ScalarType::Int: 73 return MPSDataTypeInt32; 74 case ScalarType::Long: 75 return MPSDataTypeInt64; 76 case ScalarType::Short: 77 return MPSDataTypeInt16; 78 case ScalarType::Char: 79 return MPSDataTypeInt8; 80 case ScalarType::Byte: 81 return MPSDataTypeUInt8; 82 case ScalarType::Bool: 83 return MPSDataTypeBool; 84 case ScalarType::Double: 85 TORCH_CHECK_TYPE(false, 86 "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. " 87 "Please use float32 instead.") 88 case ScalarType::ComplexHalf: 89 checkSupportsComplex(); 90 return MPSDataTypeComplexFloat16; 91 case ScalarType::ComplexFloat: 92 checkSupportsComplex(); 93 return MPSDataTypeComplexFloat32; 94 default: 95 TORCH_CHECK_TYPE( 96 false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.") 97 } 98} 99 100// #issue 104398441 sortWithTensor and argsortWithTensor has support of 101// Int32, Half and Float32 types. These utilities are to help cast to these 102// types. 103MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, 104 MPSGraphTensor* inputTensor, 105 const Tensor& input, 106 bool includesInt64) { 107 MPSDataType dataType = getMPSDataType(input.scalar_type()); 108 bool condition = 109 (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); 110 if (includesInt64) { 111 condition = condition && (dataType != MPSDataTypeInt64); 112 } 113 if (condition) { 114 dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; 115 return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"]; 116 } 117 return inputTensor; 118} 119 120// #issue 104398441 sortWithTensor and argsortWithTensor has support of 121// Int32, Half and Float32 types. These utilities are to help cast from these 122// types. 123MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, 124 MPSGraphTensor* inputTensor, 125 const Tensor& input, 126 bool includesInt64) { 127 MPSDataType dataType = getMPSDataType(input.scalar_type()); 128 bool condition = 129 (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16); 130 if (includesInt64) { 131 condition = condition && (dataType != MPSDataTypeInt64); 132 } 133 if (condition) { 134 inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"]; 135 } 136 return inputTensor; 137} 138 139MPSDataType getMPSScalarType(ScalarType scalar_type) { 140 switch (scalar_type) { 141 // This is an intentional fallthrough supporting Double for Scalar 142 // types as they are casted to Float32 currently. 143 case ScalarType::Double: 144 case ScalarType::Float: 145 return MPSDataTypeFloat32; 146 case ScalarType::Half: 147 return MPSDataTypeFloat16; 148 case ScalarType::BFloat16: 149 checkSupportsBFloat16(); 150 return MPSDataTypeBFloat16; 151 case ScalarType::Int: 152 return MPSDataTypeInt32; 153 case ScalarType::Long: 154 return MPSDataTypeInt64; 155 case ScalarType::Short: 156 return MPSDataTypeInt16; 157 case ScalarType::Char: 158 return MPSDataTypeInt8; 159 case ScalarType::Byte: 160 return MPSDataTypeUInt8; 161 case ScalarType::Bool: 162 return MPSDataTypeBool; 163 case ScalarType::ComplexHalf: 164 checkSupportsComplex(); 165 return MPSDataTypeComplexFloat16; 166 // This is an intentional fallthrough supporting ComplexDouble for Scalar 167 // types as they are casted to Complex64 currently. 168 case ScalarType::ComplexDouble: 169 case ScalarType::ComplexFloat: 170 checkSupportsComplex(); 171 return MPSDataTypeComplexFloat32; 172 default: 173 TORCH_CHECK_TYPE( 174 false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.") 175 } 176} 177 178// use short_name to avoid getting extra long cached graph keys with ops such as cat_out(), etc. 179std::string getMPSTypeString(ScalarType scalar_type, bool short_name) { 180 switch (scalar_type) { 181 case ScalarType::Double: 182 case ScalarType::Float: 183 return short_name ? "f32" : "Float32"; 184 case ScalarType::Half: 185 return short_name ? "f16" : "Float16"; 186 case ScalarType::BFloat16: 187 return short_name ? "bf16" : "BFloat16"; 188 case ScalarType::Int: 189 return short_name ? "i32" : "Int32"; 190 case ScalarType::Long: 191 return short_name ? "i64" : "Int64"; 192 case ScalarType::Short: 193 return short_name ? "i16" : "Int16"; 194 case ScalarType::Char: 195 return short_name ? "i8" : "Int8"; 196 case ScalarType::Byte: 197 return short_name ? "u8" : "UInt8"; 198 case ScalarType::Bool: 199 return short_name ? "b8" : "Bool"; 200 case ScalarType::ComplexHalf: 201 return short_name ? "c16" : "ComplexFloat16"; 202 case ScalarType::ComplexFloat: 203 return short_name ? "c32" : "ComplexFloat32"; 204 default: 205 return "Undefined"; 206 } 207} 208 209std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type) { 210 switch (scalar_type) { 211 case ScalarType::Float: 212 return "float"; 213 case ScalarType::Half: 214 return "half"; 215 case ScalarType::BFloat16: 216 checkSupportsBFloat16(); 217 return "bfloat"; 218 case ScalarType::Int: 219 return "int"; 220 case ScalarType::Long: 221 return "long"; 222 case ScalarType::Short: 223 return "short"; 224 case ScalarType::Char: 225 return "char"; 226 case ScalarType::Byte: 227 return "uchar"; 228 case ScalarType::Bool: 229 return "bool"; 230 default: 231 TORCH_CHECK(false, "Undefined type ", scalar_type); 232 return "Undefined"; 233 } 234} 235 236static NSArray<NSNumber*>* getTensorAxes(int64_t ndim) { 237 auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim]; 238 for (const auto i : c10::irange(ndim)) { 239 axes[i] = [NSNumber numberWithInteger:i]; 240 } 241 return axes; 242} 243 244NSArray<NSNumber*>* getTensorAxes(const Tensor& t) { 245 return getTensorAxes(t.dim()); 246} 247 248static NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes) { 249 return getTensorAxes(sizes.size()); 250} 251 252NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim) { 253 if (dim.has_value() && !dim.value().empty()) { 254 IntArrayRef dimValues = dim.value(); 255 int ndim = dimValues.size(); 256 auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim]; 257 for (const auto i : c10::irange(ndim)) { 258 axes[i] = [NSNumber numberWithInteger:dimValues[i]]; 259 } 260 261 return axes; 262 } 263 264 return getTensorAxes(sizes); 265} 266 267std::string getMPSShapeString(MPSShape* shape) { 268 std::string str; 269 for (NSNumber* elem in shape) { 270 str += std::to_string(elem.unsignedLongValue) + ","; 271 } 272 return str; 273} 274 275std::string getArrayRefString(const IntArrayRef s) { 276 std::stringstream ss; 277 std::copy(s.begin(), s.end(), std::ostream_iterator<int>(ss, ",")); 278 return ss.str(); 279} 280 281std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) { 282 std::string str; 283 // The key format per tensor would look like ":Float32[1,1,1,10]:" 284 for (const Tensor& tensor : tensors) { 285 str += ":"; 286 if (tensor.defined()) { 287 str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "["; 288 // if tensor is a scalar 289 if (tensor.dim() == 0) { 290 str += "Scalar"; 291 } else { 292 if (exclude_shape) { 293 str += "[-1]"; 294 } else { 295 str += 296 std::string([[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","].UTF8String); 297 } 298 } 299 str += "]"; 300 } else { 301 str += "Undefined"; 302 } 303 } 304 return str; 305} 306 307Tensor getTensorView(const Tensor& t, MPSShape* shape) { 308 std::vector<int64_t> res; 309 res.reserve([shape count]); 310 for (NSNumber* elem in shape) { 311 res.push_back(elem.longLongValue); 312 } 313 IntArrayRef r = IntArrayRef(res); 314 return t.view(res); 315} 316 317MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format) { 318 return getMPSShape(t.sizes(), memory_format); 319} 320 321MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format) { 322 if (memory_format == MemoryFormat::ChannelsLast) { 323 TORCH_INTERNAL_ASSERT(sizes.size() == 4, "ChannelsLast memory format must have 4 dimensions!"); 324 const NSUInteger N = sizes[0]; 325 const NSUInteger C = sizes[1]; 326 const NSUInteger H = sizes[2]; 327 const NSUInteger W = sizes[3]; 328 return @[ @(N), @(H), @(W), @(C) ]; 329 } 330 const int sz = sizes.size(); 331 const int sz_ = (sz > 0) ? sz : 1; 332 333 std::vector<NSNumber*> numbers(sz_); 334 335 for (int i = 0; i < sz_; i++) { 336 NSInteger sz_i = (i < sz) ? sizes[i] : 1; 337 NSNumber* number = [NSNumber numberWithInteger:sz_i]; 338 numbers[i] = number; 339 } 340 return [NSArray arrayWithObjects:numbers.data() count:numbers.size()]; 341} 342 343void printTensorNDArray(const Tensor& t) { 344 if (!t.is_mps()) 345 return; 346 if (t.numel() == 0) 347 return; 348 // Get shape and data type 349 auto selfShape = getMPSShape(t); 350 auto selfDType = getMPSDataType(t.scalar_type()); 351 352 // Initialize data 353 id<MTLBuffer> selfBuf = getMTLBufferStorage(t); 354 MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape 355 dataType:selfDType] autorelease]; 356 C10_CLANG_DIAGNOSTIC_PUSH() 357#if C10_CLANG_HAS_WARNING("-Wobjc-method-access") 358 C10_CLANG_DIAGNOSTIC_IGNORE("-Wobjc-method-access") 359#endif 360 [tdata printNDArray]; 361 C10_CLANG_DIAGNOSTIC_POP() 362} 363 364MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape* shape, MPSDataType mpsType) { 365 id<MTLBuffer> buffer = getMTLBufferStorage(tensor); 366 MPSGraphTensorData* tmpGraphTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer 367 shape:shape 368 dataType:mpsType] autorelease]; 369 370 return [tmpGraphTensorData mpsndarray]; 371} 372 373static std::vector<int64_t> getSortedStrides(const IntArrayRef& s) { 374 std::vector<int64_t> idx(s.size()); 375 iota(idx.begin(), idx.end(), 0); 376 sort(idx.begin(), idx.end(), [&s](size_t i1, size_t i2) { return s[i1] > s[i2]; }); 377 378 return idx; 379} 380 381static std::vector<int64_t> inversePermutation(const std::vector<int64_t>& permuteOrder) { 382 auto size = permuteOrder.size(); 383 std::vector<int64_t> inversePerm(permuteOrder.size()); 384 385 for (int i = 0; i < size; i++) { 386 inversePerm[permuteOrder[i]] = i; 387 } 388 return inversePerm; 389} 390 391static MPSNDArray* permuteNDArray(MPSNDArray* inArray, const std::vector<int64_t>& permuteOrder_) { 392 auto permuteOrder = inversePermutation(permuteOrder_); 393 NSUInteger srcRank = [inArray numberOfDimensions]; 394 if (srcRank != permuteOrder.size()) { 395 TORCH_INTERNAL_ASSERT(false); 396 return nil; 397 } 398 std::vector<NSUInteger> dimensionOrder(srcRank); 399 std::iota(std::begin(dimensionOrder), std::end(dimensionOrder), 0); 400 MPSNDArrayDescriptor* desc = [inArray descriptor]; 401 402 for (int64_t i = srcRank - 1; i >= 0; i--) { 403 NSUInteger axis = permuteOrder[i]; 404 auto axisIter = std::find(dimensionOrder.begin(), dimensionOrder.end(), axis); 405 NSUInteger axis1 = srcRank - i - 1; 406 NSUInteger axis2 = dimensionOrder.end() - axisIter - 1; 407 iter_swap(dimensionOrder.begin() + i, axisIter); 408 if (axis1 != axis2) { 409 [desc transposeDimension:axis1 withDimension:axis2]; 410 } 411 } 412 C10_CLANG_DIAGNOSTIC_PUSH() 413#if C10_CLANG_HAS_WARNING("-Wnonnull") 414 C10_CLANG_DIAGNOSTIC_IGNORE("-Wnonnull") 415#endif 416 MPSNDArray* result = [inArray arrayViewWithCommandBuffer:nil descriptor:desc aliasing:MPSAliasingStrategyShallAlias]; 417 C10_CLANG_DIAGNOSTIC_POP() 418 419 TORCH_INTERNAL_ASSERT(result != nil); 420 return result; 421} 422 423MPSNDArray* getMPSNDArray(const at::Tensor& t, MPSShape* sizes, MPSShape* strides) { 424 id<MTLBuffer> srcBuf = getMTLBufferStorage(t); 425 426 MPSDataType mpsDataType = getMPSDataType(t.scalar_type()); 427 MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:mpsDataType shape:sizes]; 428 srcTensorDesc.preferPackedRows = YES; 429 MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf 430 offset:t.storage_offset() * t.element_size() 431 descriptor:srcTensorDesc] autorelease]; 432 if (strides != nil) { 433 srcNDArray = [srcNDArray arrayViewWithShape:sizes strides:strides]; 434 } 435 return srcNDArray; 436} 437 438MPSNDArray* getMPSNDArray(const at::Tensor& t, const IntArrayRef& sizes, const IntArrayRef& strides) { 439 return getMPSNDArray(t, getMPSShape(sizes.empty() ? t.sizes() : sizes), strides.empty() ? nil : getMPSShape(strides)); 440} 441 442static MPSNDArray* getStridedMPSNDArray(const at::Tensor& src, MPSNDArray* srcNDArray) { 443 auto strides = src.strides(); 444 auto sizes = src.sizes(); 445 auto nStrides = strides.size(); 446 auto nonZeroStrides = src.strides(); 447 int64_t crtNonZeroStride = 1; 448 bool hasZeroStrides = false; 449 auto sortedStridesIndices = getSortedStrides(nonZeroStrides); 450 451 NSMutableArray<NSNumber*>* sortedStridesShape = [NSMutableArray arrayWithCapacity:nStrides]; 452 NSMutableArray<NSNumber*>* sortedMPSShape = [NSMutableArray arrayWithCapacity:nStrides]; 453 for (const auto i : c10::irange(nStrides)) { 454 sortedStridesShape[i] = [NSNumber numberWithInteger:nonZeroStrides[sortedStridesIndices[i]]]; 455 sortedMPSShape[i] = [NSNumber numberWithInteger:sizes[sortedStridesIndices[i]]]; 456 } 457 MPSShape* originalSortedMPSShape = sortedMPSShape; 458 MPSShape* originalSortedStridesShape = sortedStridesShape; 459 bool hasNonZeroStrides = nStrides == 0 ? false : nonZeroStrides[sortedStridesIndices[nStrides - 1]] != 1; 460 if (hasNonZeroStrides) { 461 originalSortedMPSShape = [sortedMPSShape copy]; 462 originalSortedStridesShape = [sortedStridesShape copy]; 463 [sortedStridesShape addObject:[NSNumber numberWithInteger:1]]; 464 [sortedMPSShape addObject:[NSNumber numberWithInteger:1]]; 465 } 466 if (nStrides == 0) { 467 originalSortedMPSShape = getMPSShape(src); 468 originalSortedStridesShape = getMPSShape(src.strides()); 469 } 470 471 srcNDArray = [srcNDArray arrayViewWithShape:sortedMPSShape strides:sortedStridesShape]; 472 if (hasNonZeroStrides) { 473 MPSNDArrayIdentity* identity = 474 [[[MPSNDArrayIdentity alloc] initWithDevice:MPSDevice::getInstance()->device()] autorelease]; 475 srcNDArray = [identity reshapeWithCommandBuffer:nil 476 sourceArray:srcNDArray 477 shape:originalSortedMPSShape 478 destinationArray:nil]; 479 } 480 TORCH_INTERNAL_ASSERT(srcNDArray); 481 482 srcNDArray = permuteNDArray(srcNDArray, sortedStridesIndices); 483 TORCH_INTERNAL_ASSERT(srcNDArray); 484 485 return srcNDArray; 486} 487 488Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray) { 489 _placeholder = mpsGraphTensor; 490 _value = [[[MPSGraphTensorData alloc] initWithMPSNDArray:mpsNDArray] autorelease]; 491} 492 493Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, 494 const Tensor& src, 495 MPSShape* mpsShape_, 496 bool gatherTensorData, 497 MPSDataType dataType, 498 bool useMPSStridedAPI) 499 : _tensor(src) { 500 TORCH_CHECK(src.is_mps(), "Placeholder storage has not been allocated on MPS device!"); 501 // extract the pointer to MTLBuffer from the Tensor's storage 502 id<MTLBuffer> srcBuf = getMTLBufferStorage(src); 503 504 static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); 505 // Use gather kernel to solve strides for macOS < 15.0 506 // Starting with macOS 15.0, MPS supports native strides direclty in the kernels 507 if (!is_macOS_15_0_or_newer || !useMPSStridedAPI) { 508 if ((!src.is_contiguous() || src.storage_offset()) && gatherTensorData) { 509 Tensor emptyShell = Tensor(); 510 // use "_tensor" from Placeholder to retain view's output during its usage in other ops 511 _tensor = gatherViewTensor(src, emptyShell); 512 if (!_tensor.has_storage()) { 513 // if we cannot gather, we make the tensor contiguous implicitly, and keep 514 // it in placeholder to be able to retrieve it when we return from constructor 515 _tensor = src.clone(MemoryFormat::Contiguous); 516 } 517 srcBuf = getMTLBufferStorage(_tensor); 518 } 519 } 520 521 // tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero. 522 // if buffer size is zero in here, it's not a user error. It could be a missing check for 523 // tensor.numel() == 0 in our internal implementations of ops. 524 TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!"); 525 if (dataType == MPSDataTypeInvalid) { 526 const auto scalar_type = _tensor.scalar_type(); 527 dataType = _tensor.dim() == 0 ? getMPSScalarType(scalar_type) : getMPSDataType(scalar_type); 528 } 529 530 // Tensor is contiguous and has no storage offset. 531 // Wrap it directly inside MPSGraphTensorData 532 if ((_tensor.is_contiguous() && !_tensor.storage_offset()) || !useMPSStridedAPI || !is_macOS_15_0_or_newer) { 533 _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf 534 shape:mpsShape_ ? mpsShape_ : getMPSShape(_tensor) 535 dataType:dataType] autorelease]; 536 } else { 537 IntArrayRef view_shape; 538 if (mpsShape_) { 539 _tensor = getTensorView(src, mpsShape_); 540 } 541 542 MPSShape* mpsShape = getMPSShape(_tensor); 543 MPSShape* mpsStrides = getMPSShape(_tensor.strides()); 544 545 auto storage_numel = src.storage().nbytes() / src.element_size(); 546 MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType 547 shape:@[ @(storage_numel) ]]; 548 srcTensorDesc.preferPackedRows = YES; 549 MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf 550 offset:src.storage_offset() * src.element_size() 551 descriptor:srcTensorDesc] autorelease]; 552 TORCH_INTERNAL_ASSERT(srcNDArray); 553 if (src.dim() != 0) { 554 srcNDArray = getStridedMPSNDArray(_tensor, srcNDArray); 555 } else { 556 bool needsReshape = false; 557 NSMutableArray* mpsExpandedShape = nil; 558 NSMutableArray* mpsExpandedStrides = nil; 559 560 if (src.dim() > 0 && src.stride(-1) != 1) { 561 needsReshape = true; 562 mpsExpandedShape = [NSMutableArray arrayWithArray:mpsShape]; 563 mpsExpandedStrides = [NSMutableArray arrayWithArray:mpsStrides]; 564 [mpsExpandedShape addObject:@1]; 565 [mpsExpandedStrides addObject:@1]; 566 } 567 srcNDArray = [srcNDArray arrayViewWithShape:needsReshape ? mpsExpandedShape : getMPSShape(src) 568 strides:needsReshape ? mpsExpandedStrides : getMPSShape(src.strides())]; 569 TORCH_INTERNAL_ASSERT(srcNDArray); 570 571 if (needsReshape) { 572 MPSNDArrayIdentity* identity = 573 [[[MPSNDArrayIdentity alloc] initWithDevice:MPSDevice::getInstance()->device()] autorelease]; 574 srcNDArray = [identity reshapeWithCommandBuffer:nil sourceArray:srcNDArray shape:mpsShape destinationArray:nil]; 575 } 576 TORCH_INTERNAL_ASSERT(srcNDArray); 577 } 578 _value = [[[MPSGraphTensorData alloc] initWithMPSNDArray:srcNDArray] autorelease]; 579 } 580 581 TORCH_INTERNAL_ASSERT(_value); 582 _placeholder = mpsGraphTensor; 583} 584 585MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor) { 586 auto mpsShape = getMPSShape(tensor); 587 auto dataType = getMPSDataType(tensor.scalar_type()); 588 589 MPSGraphTensorData* result = nil; 590 if (tensor.numel() > 0) { 591 id<MTLBuffer> buf = getMTLBufferStorage(tensor); 592 result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buf shape:mpsShape dataType:dataType] autorelease]; 593 } else { 594 // create empty NDArray 595 MPSNDArrayDescriptor* desc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:mpsShape]; 596 MPSNDArray* emptyArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() descriptor:desc] autorelease]; 597 result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:emptyArray] autorelease]; 598 } 599 TORCH_INTERNAL_ASSERT(result); 600 return result; 601} 602 603MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) { 604 switch (type) { 605 case ScalarType::Double: 606 case ScalarType::Float: 607 return {.value.f = scalar.to<float>(), .size = sizeof(float), .type = type}; 608 case ScalarType::Half: 609 return {.value.h = scalar.to<at::Half>(), .size = sizeof(short), .type = type}; 610 case ScalarType::BFloat16: 611 return {.value.bf16 = scalar.to<at::BFloat16>(), .size = sizeof(short), .type = type}; 612 case ScalarType::Long: 613 return {.value.i = scalar.to<int64_t>(), .size = sizeof(int64_t), .type = type}; 614 case ScalarType::Int: 615 return {.value.i = scalar.to<int32_t>(), .size = sizeof(int32_t), .type = type}; 616 case ScalarType::Short: 617 return {.value.i = scalar.to<int16_t>(), .size = sizeof(int16_t), .type = type}; 618 case ScalarType::Char: 619 return {.value.i = scalar.to<int8_t>(), .size = sizeof(int8_t), .type = type}; 620 case ScalarType::Byte: 621 return {.value.i = scalar.to<uint8_t>(), .size = sizeof(uint8_t), .type = type}; 622 case ScalarType::Bool: 623 return {.value.b = scalar.to<bool>(), .size = sizeof(bool), .type = type}; 624 case ScalarType::ComplexHalf: 625 return {.value.ch = scalar.to<c10::complex<at::Half>>(), .size = sizeof(int32_t), .type = type}; 626 case ScalarType::ComplexFloat: 627 case ScalarType::ComplexDouble: 628 return {.value.cf = scalar.to<c10::complex<float>>(), .size = sizeof(int64_t), .type = type}; 629 default: 630 TORCH_INTERNAL_ASSERT(false, "Unsupported scalar type '", type, "' on MPS backend."); 631 } 632} 633 634MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar) { 635 MPSGraphTensorData* result = nullptr; 636 // Scalar pools are only supported on devices with unified memory 637 if (mpsStream->device().hasUnifiedMemory) { 638 scalar.buffer = getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size); 639 result = [[[MPSGraphTensorData alloc] initWithMTLBuffer:scalar.getMTLBuffer() 640 shape:@[ @1 ] 641 dataType:getMPSScalarType(scalar.type)] autorelease]; 642 } else { 643 MPSNDArrayDescriptor* tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:getMPSScalarType(scalar.type) 644 shape:@[ @1 ]]; 645 MPSNDArray* tensorNDArray = [[[MPSNDArray alloc] initWithDevice:mpsStream->device() 646 descriptor:tensorDesc] autorelease]; 647 [tensorNDArray writeBytes:&scalar.value strideBytes:nil]; 648 result = [[[MPSGraphTensorData alloc] initWithMPSNDArray:tensorNDArray] autorelease]; 649 } 650 return result; 651} 652 653void resize_tensor(Tensor* output) { 654 output->resize_(output->sizes()); 655} 656 657Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device) { 658 // Copied and modified from aten/stc/ATen/ScalarOps.h 659 // as MPS doesn't support float64 tensor. 660 Tensor tensor; 661 if (scalar.isFloatingPoint()) { 662 tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kFloat)); 663 } else if (scalar.isBoolean()) { 664 tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kBool)); 665 } else if (scalar.isComplex()) { 666 tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexDouble)); 667 } else { 668 TORCH_INTERNAL_ASSERT(scalar.isIntegral(false)); 669 tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kLong)); 670 } 671 tensor.unsafeGetTensorImpl()->set_wrapped_number(true); 672 return tensor; 673} 674 675MPSGraph* make_mps_graph() { 676 MPSGraph* mpsGraph = [[MPSGraph new] autorelease]; 677 return mpsGraph; 678} 679 680MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) { 681 return [mpsGraph placeholderWithShape:nil dataType:dataType name:nil]; 682} 683 684MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape) { 685 return [mpsGraph placeholderWithShape:mpsShape dataType:dataType name:nil]; 686} 687 688MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const Tensor& tensor) { 689 return [mpsGraph placeholderWithShape:getMPSShape(tensor) dataType:getMPSScalarType(tensor.scalar_type()) name:nil]; 690} 691 692MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) { 693 return [mpsGraph placeholderWithShape:@[ @1 ] dataType:dataType name:nil]; 694} 695 696MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar) { 697 return [mpsGraph placeholderWithShape:@[ @1 ] dataType:getMPSScalarType(scalar.type()) name:nil]; 698} 699 700// this is meant to suppress the availability warning on castTensor 701// we pass ScalarType instead of MPSDataType to handle MPSDataTypeBoolean's availability too 702MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) { 703 if ([tensor dataType] == toType) { 704 return tensor; 705 } 706 return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"]; 707} 708 709MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType) { 710 return [mpsGraph castTensor:tensor toType:getMPSScalarType(toType) name:@"castTensor"]; 711} 712 713MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor) { 714 TORCH_INTERNAL_ASSERT(tensor.shape.count == 4, "Tensor must have 4 dimensions!"); 715 return [mpsGraph transposeTensor:[mpsGraph transposeTensor:tensor dimension:3 withDimension:2 name:nil] 716 dimension:2 717 withDimension:1 718 name:nil]; 719} 720 721string get_mem_format_string(c10::MemoryFormat memory_format) { 722 string mem_format_key; 723 switch (memory_format) { 724 case at::MemoryFormat::Contiguous: 725 mem_format_key = "Contiguous"; 726 break; 727 case at::MemoryFormat::ChannelsLast: 728 mem_format_key = "ChannelsLast"; 729 break; 730 default: 731 TORCH_CHECK(false, "Invalid memory format", memory_format); 732 } 733 734 return mem_format_key; 735} 736 737MPSGraphCache* MPSGraphCache::_instance_cache = nullptr; 738 739void MPSGraphCache::profileCachedGraph(const CacheEntry& cacheEntry) const { 740 auto& profiler = getMPSProfiler(); 741 if (profiler.isOperationProfilingEnabled()) { 742 std::string graphKey = cacheEntry.key_; 743 // for interval-based signpost tracing, we begin the interval here to be able 744 // to measure the time it takes to compile the graphs (if graph newly created), 745 // and also the time potentially spent on gather/scatter of graph's input tensors 746 profiler.beginProfileKernel(cacheEntry.cachedGraph_->graph(), graphKey, true); 747 } 748} 749 750class MPSGraphCacheCallback : public IMpsAllocatorCallback { 751 public: 752 MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) {} 753 754 void executeMPSAllocatorCallback(void* ptr, EventType event) override {} 755 756 private: 757 MPSGraphCache* graph_cache; 758}; 759 760REGISTER_MPS_ALLOCATOR_CALLBACK("mps_graph_cache_callback", MPSGraphCacheCallback); 761 762id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder, 763 const TensorIteratorBase& iter, 764 bool use_64bit_index) { 765 constexpr uint32_t nOffsets = 3; 766 uint32_t numThreads = iter.numel(); 767 const uint32_t nDim = iter.ndim(); 768 const IntArrayRef& iterShape = iter.shape(); 769 std::vector<uint32_t> iterShapeData(iterShape.size()); 770 std::vector<std::array<uint32_t, nOffsets>> strides(nDim); 771 TORCH_INTERNAL_ASSERT(iter.ntensors() >= nOffsets); 772 TORCH_CHECK(use_64bit_index || iter.can_use_32bit_indexing(), "Can't be indexed using 32-bit iterator"); 773 774 for (const auto i : c10::irange(iterShape.size())) { 775 iterShapeData[i] = static_cast<uint32_t>(iterShape[i]); 776 } 777 778 for (const auto i : c10::irange(nDim)) { 779 for (const auto offset : c10::irange(nOffsets)) { 780 strides[i][offset] = static_cast<uint32_t>(iter.strides(offset)[i]); 781 } 782 } 783 784 id<MTLComputePipelineState> kernelDataOffsetsPSO = MPSDevice::getInstance()->metalIndexingPSO( 785 use_64bit_index ? "kernel_index_offsets_64" : "kernel_index_offsets_32"); 786 const auto elementSize = use_64bit_index ? sizeof(simd_ulong3) : sizeof(simd_uint3); 787 id<MTLBuffer> kernelDataOffsets = (id<MTLBuffer>)getIMPSAllocator()->allocate(numThreads * elementSize).get(); 788 789 [commandEncoder setComputePipelineState:kernelDataOffsetsPSO]; 790 [commandEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0]; 791 [commandEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1]; 792 [commandEncoder setBytes:iterShapeData.data() length:sizeof(uint32_t) * iterShape.size() atIndex:2]; 793 [commandEncoder setBytes:&nDim length:sizeof(uint32_t) atIndex:3]; 794 795 mtl_dispatch1DJob(commandEncoder, kernelDataOffsetsPSO, numThreads); 796 797 return kernelDataOffsets; 798} 799 800id<MTLLibrary> MetalShaderLibrary::getLibrary() { 801 if (C10_UNLIKELY(!library)) { 802 TORCH_INTERNAL_ASSERT(nparams == 0); 803 library = compileLibrary(shaderSource); 804 } 805 return library; 806} 807 808id<MTLLibrary> MetalShaderLibrary::getLibrary(const std::initializer_list<std::string>& params) { 809 TORCH_INTERNAL_ASSERT(nparams == params.size()); 810 std::string key = ""; 811 for (auto p : params) { 812 key += ":" + p; 813 } 814 auto lib = libMap[key]; 815 if (lib) { 816 return lib; 817 } 818 auto it = params.begin(); 819 switch (nparams) { 820 case 1: 821 lib = compileLibrary(fmt::format(shaderSource, *it)); 822 break; 823 case 2: { 824 auto& first = *it++; 825 auto& second = *it; 826 lib = compileLibrary(fmt::format(shaderSource, first, second)); 827 break; 828 } 829 case 3: { 830 auto& first = *it++; 831 auto& second = *it++; 832 auto& third = *it; 833 lib = compileLibrary(fmt::format(shaderSource, first, second, third)); 834 break; 835 } 836 default: 837 TORCH_INTERNAL_ASSERT(false, "Unsupported number of paramaters ", nparams); 838 } 839 return libMap[key] = lib; 840} 841 842id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) { 843 static const char* fast_math = std::getenv("PYTORCH_MPS_FAST_MATH"); 844 NSError* error = nil; 845 MTLCompileOptions* options = compile_options; 846 if (!options) { 847 options = [[MTLCompileOptions new] autorelease]; 848 [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1 849 : MTLLanguageVersion2_3]; 850 [options setFastMathEnabled:(!fast_math || std::stoi(fast_math) == 0) ? NO : YES]; 851 } 852 853 const auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding]; 854 auto device = MPSDevice::getInstance()->device(); 855 library = [device newLibraryWithSource:str options:options error:&error]; 856 TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]); 857 return library; 858} 859 860std::pair<id<MTLComputePipelineState>, id<MTLFunction>> MetalShaderLibrary::getLibraryPipelineState( 861 id<MTLLibrary> lib, 862 const std::string& fname) { 863 const auto key = fmt::format("{}:{}", reinterpret_cast<void*>(lib), fname); 864 auto found_cpl = cplMap.find(key); 865 if (found_cpl != cplMap.end()) { 866 return found_cpl->second; 867 } 868 869 NSError* error = nil; 870 id<MTLFunction> func = [lib newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]]; 871 TORCH_CHECK(func, "Failed to create function state object for: ", fname); 872 auto cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error]; 873 TORCH_CHECK(cpl, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); 874 875 cplMap[key] = std::make_pair(cpl, func); 876 return cplMap[key]; 877} 878 879} // namespace at::native::mps 880