xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h>
2
3#if TARGET_OS_IPHONE
4#import <UIKit/UIKit.h>
5#endif
6
7@implementation PTMCoreMLCompiler
8
9static NSString *gCacheDirectory = @"";
10static NSString *gCompiledModelExtension = @"mlmodelc";
11static NSString *gVersionExtension = @"version";
12
13+ (void)setCacheDirectory:(const std::string&)dir {
14  gCacheDirectory = [NSString stringWithCString:dir.c_str() encoding:NSUTF8StringEncoding];
15}
16
17+ (nonnull NSString *)cacheDirectory {
18  BOOL isSet = gCacheDirectory.length != 0;
19  BOOL isWriteable = isSet && [[NSFileManager defaultManager] isWritableFileAtPath:gCacheDirectory];
20  if (!isSet || !isWriteable) {
21    // set the default directory to tmp
22    gCacheDirectory = NSTemporaryDirectory();
23  }
24  return gCacheDirectory;
25}
26
27+ (BOOL)compileModel:(const std::string&)modelSpecs modelID:(const std::string&)modelID {
28  NSString *modelName = [NSString stringWithCString:modelID.c_str() encoding:NSUTF8StringEncoding];
29  NSString *modelPath = [NSTemporaryDirectory() stringByAppendingPathComponent:modelName];
30  NSURL *compiledURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gCompiledModelExtension];
31  BOOL compiledModelIsCached = [[NSFileManager defaultManager] fileExistsAtPath:compiledURL.path];
32
33#if TARGET_OS_IPHONE
34  NSURL *versionURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gVersionExtension];
35  NSString *compilationOS = [NSString stringWithContentsOfFile:versionURL.path encoding:NSUTF8StringEncoding error:nil];
36  NSString *currentOS = [UIDevice currentDevice].systemVersion;
37  BOOL wasCachedOnThisOS = [currentOS isEqualToString:compilationOS];
38#else
39  BOOL wasCachedOnThisOS = NO;
40#endif
41
42  if (compiledModelIsCached && wasCachedOnThisOS) {
43    return YES;
44  }
45
46  if (!wasCachedOnThisOS) {
47    [PTMCoreMLCompiler _cleanupCachedModel:modelName];
48  }
49
50  BOOL writeSuccess = [PTMCoreMLCompiler _writeModelSpecs:modelSpecs toPath:modelPath];
51  if (!writeSuccess) {
52    return NO;
53  }
54
55  return [PTMCoreMLCompiler _compileModel:modelName atPath:modelPath];
56}
57
58+ (nullable MLModel*)loadModel:(const std::string)modelID backend:(const std::string)backend allowLowPrecision:(BOOL)allowLowPrecision error:(NSError**)error {
59  NSString *modelName = [NSString stringWithCString:modelID.c_str() encoding:NSUTF8StringEncoding];
60  NSURL *modelURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gCompiledModelExtension];
61
62  MLModel *model;
63  if (@available(iOS 12.0, macOS 10.14, *)) {
64    MLModelConfiguration* config = [[MLModelConfiguration alloc] init];
65    MLComputeUnits computeUnits = MLComputeUnitsCPUOnly;
66    if (backend == "cpuAndGPU") {
67      computeUnits = MLComputeUnitsCPUAndGPU;
68    } else if (backend == "all") {
69      computeUnits = MLComputeUnitsAll;
70    }
71    config.computeUnits = computeUnits;
72    config.allowLowPrecisionAccumulationOnGPU = allowLowPrecision;
73    model = [MLModel modelWithContentsOfURL:modelURL configuration:config error:error];
74  } else {
75    model = [MLModel modelWithContentsOfURL:modelURL error:error];
76  }
77
78  if (error && *error) {
79    [PTMCoreMLCompiler _cleanupCachedModel:modelName];
80    return nil;
81  }
82
83  return model;
84}
85
86+ (BOOL)_writeModelSpecs:(const std::string&)modelSpecs toPath:(NSString *)modelPath {
87  // Note that the serialized protobuf binary contains bytes not text.
88  // https://developers.google.com/protocol-buffers/docs/pythontutorial#parsing-and-serialization
89  NSData* data = [NSData dataWithBytes:modelSpecs.c_str() length:modelSpecs.length()];
90  return [data writeToFile:modelPath atomically:YES];
91}
92
93+ (BOOL)_compileModel:(NSString *)modelName atPath:(NSString *)modelPath {
94  NSError *error;
95  NSURL *modelURL = [NSURL fileURLWithPath:modelPath];
96  NSURL *temporaryURL = [MLModel compileModelAtURL:modelURL error:&error];
97
98  // After the compiled model has been created, the original specs can be cleared to save cache space.
99  [[NSFileManager defaultManager] removeItemAtPath:modelPath error:nil];
100
101  if (error) {
102    return NO; // Model could not be compiled
103  }
104
105  NSURL *compiledURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gCompiledModelExtension];
106  if (![compiledURL isEqual:temporaryURL]) {
107    [[NSFileManager defaultManager] removeItemAtURL:compiledURL error:nil];
108    [[NSFileManager defaultManager] moveItemAtURL:temporaryURL toURL:compiledURL error:&error];
109  }
110
111  if (error) {
112    return NO; // Model could not be saved in cache
113  }
114
115#if TARGET_OS_IPHONE
116  NSURL *versionURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gVersionExtension];
117  NSString *currentOSVer = [UIDevice currentDevice].systemVersion;
118  [currentOSVer writeToFile:versionURL.path atomically:YES encoding:NSUTF8StringEncoding error:NULL];
119#endif
120
121  return YES;
122}
123
124+ (void)_cleanupCachedModel:(NSString *)modelName {
125  NSURL *modelURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gCompiledModelExtension];
126  NSURL *versionURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gVersionExtension];
127  [[NSFileManager defaultManager] removeItemAtPath:modelURL.path error:nil];
128  [[NSFileManager defaultManager] removeItemAtPath:versionURL.path error:nil];
129}
130
131+ (NSURL *)_cacheURLForModel:(NSString *)modelID extension:(NSString *)pathExtension {
132  NSString *filename = [modelID stringByAppendingPathExtension:pathExtension];
133  NSString *filePath = [[PTMCoreMLCompiler cacheDirectory] stringByAppendingPathComponent:filename];
134  return [NSURL fileURLWithPath:filePath isDirectory:NO];
135}
136
137@end
138