xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/MPSDevice.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2//  Copyright (c) 2023 Apple Inc. All rights reserved.
3//  Provided subject to the LICENSE file in the top level directory.
4//
5
6#include "MPSDevice.h"
7#include <executorch/runtime/platform/assert.h>
8#include <memory>
9#include <mutex>
10
11namespace executorch {
12namespace backends {
13namespace mps {
14namespace delegate {
15
16using executorch::runtime::Error;
17
18static std::unique_ptr<MPSDevice> mps_device;
19static std::once_flag mpsdev_init;
20
21static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device, bool macOS13Plus) {
22  // MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
23  // host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+)
24  MTLLanguageVersion languageVersion = MTLLanguageVersion2_3;
25#if defined(__MAC_13_0)
26  if (macOS13Plus) {
27    languageVersion = MTLLanguageVersion3_0;
28  }
29#endif
30
31  ET_CHECK_MSG([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
32  return languageVersion;
33}
34
35MPSDevice::~MPSDevice() {
36  [_mtl_device release];
37  _mtl_device = nil;
38}
39
40MPSDevice::MPSDevice(): _mtl_device(nil) {
41  @autoreleasepool {
42#if TARGET_OS_IPHONE
43    _mtl_device = MTLCreateSystemDefaultDevice();
44#else
45    NSArray* devices = MTLCopyAllDevices();
46    for (unsigned long i = 0 ; i < [devices count] ; i++) {
47      id<MTLDevice>  device = devices[i];
48      if(![device isLowPower]) { // exclude Intel GPUs
49        _mtl_device = [device retain];
50        break;
51      }
52    }
53#endif
54  }
55  // MPS TODO: Replace with `ET_CHECK_OR_RETURN_ERROR` and propagate back the error.
56  ET_CHECK(_mtl_device != nil);
57}
58
59MPSDevice* MPSDevice::getInstance() {
60  std::call_once(mpsdev_init, [] {
61      mps_device = std::unique_ptr<MPSDevice>(new MPSDevice());
62  });
63  return mps_device.get();
64}
65
66bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
67  id mpsCD = NSClassFromString(@"MPSGraph");
68  static auto compileOptions = [[[MTLCompileOptions alloc] init] autorelease];
69  static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:
70                                                                                                axis:name:)] == YES;
71  static bool _macos_13_1_plus =
72      [mpsCD instancesRespondToSelector:@selector
73             (sampleGridWithSourceTensor:
74                        coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode
75                                        :samplingMode:constantValue:name:)] == YES;
76  static bool _macos_13_2_plus =
77      [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
78  static bool _macos_13_3_plus = [compileOptions respondsToSelector:@selector(maxTotalThreadsPerThreadgroup)] == YES;
79
80  static bool _macos_14_0_plus = [mpsCD instancesRespondToSelector:@selector(conjugateWithTensor:name:)] == YES;
81  static bool _macos_15_0_plus = [mpsCD instancesRespondToSelector:@selector(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:)] == YES;
82  switch (version) {
83    case MacOSVersion::MACOS_VER_13_0_PLUS:
84      return _macos_13_0_plus;
85    case MacOSVersion::MACOS_VER_13_1_PLUS:
86      return _macos_13_1_plus;
87    case MacOSVersion::MACOS_VER_13_2_PLUS:
88      return _macos_13_2_plus;
89    case MacOSVersion::MACOS_VER_13_3_PLUS:
90      return _macos_13_3_plus;
91    case MacOSVersion::MACOS_VER_14_0_PLUS:
92      return _macos_14_0_plus;
93    case MacOSVersion::MACOS_VER_15_0_PLUS:
94      return _macos_15_0_plus;
95    default:
96      return false;
97  }
98}
99
100const char* getLibraryCString(LibraryType libraryType) {
101  switch (libraryType) {
102    case LibraryType::INDEXING_KERNELS:
103      return "TODO";
104    default:
105      ET_CHECK_MSG(false, "Unhandled library type!");
106  }
107}
108
109Error
110MPSDevice::compileLibrary(LibraryType libraryType) {
111  Error err = Error::Ok;
112  NSError* error = nil;
113  MTLCompileOptions* options = [MTLCompileOptions new];
114  [options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
115  [options setFastMathEnabled:YES];
116  id<MTLLibrary> lib =
117      [_mtl_device newLibraryWithSource:[NSString stringWithCString:getLibraryCString(libraryType)
118                                                           encoding:NSASCIIStringEncoding]
119                                options:options
120                                  error:&error];
121
122  ET_CHECK_OR_RETURN_ERROR(
123    lib != nil,
124    Internal,
125    "Failed to create indexing library, error: %s", [[error description] UTF8String]
126  );
127
128  _m_library_cache[libraryType] = lib;
129  return err;
130}
131
132Error
133MPSDevice::compilePSO(LibraryType libraryType, const char* kernelName) {
134  Error err = Error::Ok;
135  if (_m_library_cache.find(libraryType) == _m_library_cache.end()) {
136    ET_LOG(Debug, "Compiling library type: %d", libraryType);
137    err = compileLibrary(libraryType);
138    ET_CHECK_OR_RETURN_ERROR(
139      err == Error::Ok,
140      Internal,
141      "An error occured occured while compiling library %d", libraryType
142    );
143  }
144  if (_m_pso_cache.find(kernelName) == _m_pso_cache.end()) {
145    ET_LOG(Debug, "Compiling kernel: %s", kernelName);
146    // err = compilePSO(libraryType, kernelName);
147  }
148  return err;
149}
150
151bool is_macos_13_or_newer(MacOSVersion version) {
152  return MPSDevice::getInstance()->isMacOS13Plus(version);
153}
154
155} // namespace delegate
156} // namespace mps
157} // namespace backends
158} // namespace executorch
159