1#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h> 2 3#if TARGET_OS_IPHONE 4#import <UIKit/UIKit.h> 5#endif 6 7@implementation PTMCoreMLCompiler 8 9static NSString *gCacheDirectory = @""; 10static NSString *gCompiledModelExtension = @"mlmodelc"; 11static NSString *gVersionExtension = @"version"; 12 13+ (void)setCacheDirectory:(const std::string&)dir { 14 gCacheDirectory = [NSString stringWithCString:dir.c_str() encoding:NSUTF8StringEncoding]; 15} 16 17+ (nonnull NSString *)cacheDirectory { 18 BOOL isSet = gCacheDirectory.length != 0; 19 BOOL isWriteable = isSet && [[NSFileManager defaultManager] isWritableFileAtPath:gCacheDirectory]; 20 if (!isSet || !isWriteable) { 21 // set the default directory to tmp 22 gCacheDirectory = NSTemporaryDirectory(); 23 } 24 return gCacheDirectory; 25} 26 27+ (BOOL)compileModel:(const std::string&)modelSpecs modelID:(const std::string&)modelID { 28 NSString *modelName = [NSString stringWithCString:modelID.c_str() encoding:NSUTF8StringEncoding]; 29 NSString *modelPath = [NSTemporaryDirectory() stringByAppendingPathComponent:modelName]; 30 NSURL *compiledURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gCompiledModelExtension]; 31 BOOL compiledModelIsCached = [[NSFileManager defaultManager] fileExistsAtPath:compiledURL.path]; 32 33#if TARGET_OS_IPHONE 34 NSURL *versionURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gVersionExtension]; 35 NSString *compilationOS = [NSString stringWithContentsOfFile:versionURL.path encoding:NSUTF8StringEncoding error:nil]; 36 NSString *currentOS = [UIDevice currentDevice].systemVersion; 37 BOOL wasCachedOnThisOS = [currentOS isEqualToString:compilationOS]; 38#else 39 BOOL wasCachedOnThisOS = NO; 40#endif 41 42 if (compiledModelIsCached && wasCachedOnThisOS) { 43 return YES; 44 } 45 46 if (!wasCachedOnThisOS) { 47 [PTMCoreMLCompiler _cleanupCachedModel:modelName]; 48 } 49 50 BOOL writeSuccess = [PTMCoreMLCompiler _writeModelSpecs:modelSpecs toPath:modelPath]; 51 if (!writeSuccess) { 52 return NO; 53 } 54 55 return [PTMCoreMLCompiler _compileModel:modelName atPath:modelPath]; 56} 57 58+ (nullable MLModel*)loadModel:(const std::string)modelID backend:(const std::string)backend allowLowPrecision:(BOOL)allowLowPrecision error:(NSError**)error { 59 NSString *modelName = [NSString stringWithCString:modelID.c_str() encoding:NSUTF8StringEncoding]; 60 NSURL *modelURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gCompiledModelExtension]; 61 62 MLModel *model; 63 if (@available(iOS 12.0, macOS 10.14, *)) { 64 MLModelConfiguration* config = [[MLModelConfiguration alloc] init]; 65 MLComputeUnits computeUnits = MLComputeUnitsCPUOnly; 66 if (backend == "cpuAndGPU") { 67 computeUnits = MLComputeUnitsCPUAndGPU; 68 } else if (backend == "all") { 69 computeUnits = MLComputeUnitsAll; 70 } 71 config.computeUnits = computeUnits; 72 config.allowLowPrecisionAccumulationOnGPU = allowLowPrecision; 73 model = [MLModel modelWithContentsOfURL:modelURL configuration:config error:error]; 74 } else { 75 model = [MLModel modelWithContentsOfURL:modelURL error:error]; 76 } 77 78 if (error && *error) { 79 [PTMCoreMLCompiler _cleanupCachedModel:modelName]; 80 return nil; 81 } 82 83 return model; 84} 85 86+ (BOOL)_writeModelSpecs:(const std::string&)modelSpecs toPath:(NSString *)modelPath { 87 // Note that the serialized protobuf binary contains bytes not text. 88 // https://developers.google.com/protocol-buffers/docs/pythontutorial#parsing-and-serialization 89 NSData* data = [NSData dataWithBytes:modelSpecs.c_str() length:modelSpecs.length()]; 90 return [data writeToFile:modelPath atomically:YES]; 91} 92 93+ (BOOL)_compileModel:(NSString *)modelName atPath:(NSString *)modelPath { 94 NSError *error; 95 NSURL *modelURL = [NSURL fileURLWithPath:modelPath]; 96 NSURL *temporaryURL = [MLModel compileModelAtURL:modelURL error:&error]; 97 98 // After the compiled model has been created, the original specs can be cleared to save cache space. 99 [[NSFileManager defaultManager] removeItemAtPath:modelPath error:nil]; 100 101 if (error) { 102 return NO; // Model could not be compiled 103 } 104 105 NSURL *compiledURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gCompiledModelExtension]; 106 if (![compiledURL isEqual:temporaryURL]) { 107 [[NSFileManager defaultManager] removeItemAtURL:compiledURL error:nil]; 108 [[NSFileManager defaultManager] moveItemAtURL:temporaryURL toURL:compiledURL error:&error]; 109 } 110 111 if (error) { 112 return NO; // Model could not be saved in cache 113 } 114 115#if TARGET_OS_IPHONE 116 NSURL *versionURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gVersionExtension]; 117 NSString *currentOSVer = [UIDevice currentDevice].systemVersion; 118 [currentOSVer writeToFile:versionURL.path atomically:YES encoding:NSUTF8StringEncoding error:NULL]; 119#endif 120 121 return YES; 122} 123 124+ (void)_cleanupCachedModel:(NSString *)modelName { 125 NSURL *modelURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gCompiledModelExtension]; 126 NSURL *versionURL = [PTMCoreMLCompiler _cacheURLForModel:modelName extension:gVersionExtension]; 127 [[NSFileManager defaultManager] removeItemAtPath:modelURL.path error:nil]; 128 [[NSFileManager defaultManager] removeItemAtPath:versionURL.path error:nil]; 129} 130 131+ (NSURL *)_cacheURLForModel:(NSString *)modelID extension:(NSString *)pathExtension { 132 NSString *filename = [modelID stringByAppendingPathExtension:pathExtension]; 133 NSString *filePath = [[PTMCoreMLCompiler cacheDirectory] stringByAppendingPathComponent:filename]; 134 return [NSURL fileURLWithPath:filePath isDirectory:NO]; 135} 136 137@end 138