1// 2// Copyright (c) 2023 Apple Inc. All rights reserved. 3// Provided subject to the LICENSE file in the top level directory. 4// 5 6#include "MPSDevice.h" 7#include <executorch/runtime/platform/assert.h> 8#include <memory> 9#include <mutex> 10 11namespace executorch { 12namespace backends { 13namespace mps { 14namespace delegate { 15 16using executorch::runtime::Error; 17 18static std::unique_ptr<MPSDevice> mps_device; 19static std::once_flag mpsdev_init; 20 21static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device, bool macOS13Plus) { 22 // MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants) 23 // host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+) 24 MTLLanguageVersion languageVersion = MTLLanguageVersion2_3; 25#if defined(__MAC_13_0) 26 if (macOS13Plus) { 27 languageVersion = MTLLanguageVersion3_0; 28 } 29#endif 30 31 ET_CHECK_MSG([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2"); 32 return languageVersion; 33} 34 35MPSDevice::~MPSDevice() { 36 [_mtl_device release]; 37 _mtl_device = nil; 38} 39 40MPSDevice::MPSDevice(): _mtl_device(nil) { 41 @autoreleasepool { 42#if TARGET_OS_IPHONE 43 _mtl_device = MTLCreateSystemDefaultDevice(); 44#else 45 NSArray* devices = MTLCopyAllDevices(); 46 for (unsigned long i = 0 ; i < [devices count] ; i++) { 47 id<MTLDevice> device = devices[i]; 48 if(![device isLowPower]) { // exclude Intel GPUs 49 _mtl_device = [device retain]; 50 break; 51 } 52 } 53#endif 54 } 55 // MPS TODO: Replace with `ET_CHECK_OR_RETURN_ERROR` and propagate back the error. 56 ET_CHECK(_mtl_device != nil); 57} 58 59MPSDevice* MPSDevice::getInstance() { 60 std::call_once(mpsdev_init, [] { 61 mps_device = std::unique_ptr<MPSDevice>(new MPSDevice()); 62 }); 63 return mps_device.get(); 64} 65 66bool MPSDevice::isMacOS13Plus(MacOSVersion version) const { 67 id mpsCD = NSClassFromString(@"MPSGraph"); 68 static auto compileOptions = [[[MTLCompileOptions alloc] init] autorelease]; 69 static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor: 70 axis:name:)] == YES; 71 static bool _macos_13_1_plus = 72 [mpsCD instancesRespondToSelector:@selector 73 (sampleGridWithSourceTensor: 74 coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode 75 :samplingMode:constantValue:name:)] == YES; 76 static bool _macos_13_2_plus = 77 [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES; 78 static bool _macos_13_3_plus = [compileOptions respondsToSelector:@selector(maxTotalThreadsPerThreadgroup)] == YES; 79 80 static bool _macos_14_0_plus = [mpsCD instancesRespondToSelector:@selector(conjugateWithTensor:name:)] == YES; 81 static bool _macos_15_0_plus = [mpsCD instancesRespondToSelector:@selector(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:)] == YES; 82 switch (version) { 83 case MacOSVersion::MACOS_VER_13_0_PLUS: 84 return _macos_13_0_plus; 85 case MacOSVersion::MACOS_VER_13_1_PLUS: 86 return _macos_13_1_plus; 87 case MacOSVersion::MACOS_VER_13_2_PLUS: 88 return _macos_13_2_plus; 89 case MacOSVersion::MACOS_VER_13_3_PLUS: 90 return _macos_13_3_plus; 91 case MacOSVersion::MACOS_VER_14_0_PLUS: 92 return _macos_14_0_plus; 93 case MacOSVersion::MACOS_VER_15_0_PLUS: 94 return _macos_15_0_plus; 95 default: 96 return false; 97 } 98} 99 100const char* getLibraryCString(LibraryType libraryType) { 101 switch (libraryType) { 102 case LibraryType::INDEXING_KERNELS: 103 return "TODO"; 104 default: 105 ET_CHECK_MSG(false, "Unhandled library type!"); 106 } 107} 108 109Error 110MPSDevice::compileLibrary(LibraryType libraryType) { 111 Error err = Error::Ok; 112 NSError* error = nil; 113 MTLCompileOptions* options = [MTLCompileOptions new]; 114 [options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))]; 115 [options setFastMathEnabled:YES]; 116 id<MTLLibrary> lib = 117 [_mtl_device newLibraryWithSource:[NSString stringWithCString:getLibraryCString(libraryType) 118 encoding:NSASCIIStringEncoding] 119 options:options 120 error:&error]; 121 122 ET_CHECK_OR_RETURN_ERROR( 123 lib != nil, 124 Internal, 125 "Failed to create indexing library, error: %s", [[error description] UTF8String] 126 ); 127 128 _m_library_cache[libraryType] = lib; 129 return err; 130} 131 132Error 133MPSDevice::compilePSO(LibraryType libraryType, const char* kernelName) { 134 Error err = Error::Ok; 135 if (_m_library_cache.find(libraryType) == _m_library_cache.end()) { 136 ET_LOG(Debug, "Compiling library type: %d", libraryType); 137 err = compileLibrary(libraryType); 138 ET_CHECK_OR_RETURN_ERROR( 139 err == Error::Ok, 140 Internal, 141 "An error occured occured while compiling library %d", libraryType 142 ); 143 } 144 if (_m_pso_cache.find(kernelName) == _m_pso_cache.end()) { 145 ET_LOG(Debug, "Compiling kernel: %s", kernelName); 146 // err = compilePSO(libraryType, kernelName); 147 } 148 return err; 149} 150 151bool is_macos_13_or_newer(MacOSVersion version) { 152 return MPSDevice::getInstance()->isMacOS13Plus(version); 153} 154 155} // namespace delegate 156} // namespace mps 157} // namespace backends 158} // namespace executorch 159