1// 2// MLModel+Prewarm.mm 3// 4// Copyright © 2024 Apple Inc. All rights reserved. 5// 6// Please refer to the license found in the LICENSE file in the root directory of the source tree. 7 8#import <MLModel_Prewarm.h> 9 10#import <algorithm> 11 12@interface MLMultiArray (Prewarm) 13 14+ (nullable MLMultiArray *)zeroedMultiArrayWithShape:(NSArray<NSNumber *> *)shape 15 dataType:(MLMultiArrayDataType)dataType 16 error:(NSError * __autoreleasing *)error; 17 18@end 19 20 21@implementation MLMultiArray (Prewarm) 22 23+ (MLMultiArray *)zeroedMultiArrayWithShape:(NSArray<NSNumber *> *)shape 24 dataType:(MLMultiArrayDataType)dataType 25 error:(NSError * __autoreleasing *)error { 26 MLMultiArray *multiArray = [[MLMultiArray alloc] initWithShape:shape dataType:dataType error:error]; 27 if (!multiArray) { 28 return nil; 29 } 30 31 [multiArray getMutableBytesWithHandler:^(void *mutableBytes, NSInteger size, NSArray<NSNumber *> * __unused strides) { 32 uint8_t *start = reinterpret_cast<uint8_t *>(mutableBytes); 33 uint8_t *end = start + size; 34 std::fill(start, end, uint8_t(0)); 35 }]; 36 37 return multiArray; 38} 39 40@end 41 42namespace { 43 44id<MLFeatureProvider> _Nullable get_zeroed_inputs(MLModel *model, NSError * __autoreleasing *error) { 45 NSMutableDictionary<NSString *, MLFeatureValue *> *inputs = [NSMutableDictionary dictionary]; 46 for (MLFeatureDescription *feature_desc in model.modelDescription.inputDescriptionsByName.allValues) { 47 switch (feature_desc.type) { 48 case MLFeatureTypeMultiArray: { 49 MLMultiArrayConstraint *constraint = feature_desc.multiArrayConstraint; 50 MLMultiArray *array = [MLMultiArray zeroedMultiArrayWithShape:constraint.shape 51 dataType:constraint.dataType 52 error:error]; 53 MLFeatureValue *feature = (array != nil) ? [MLFeatureValue featureValueWithMultiArray:array] : nil; 54 if (!feature) { 55 return nil; 56 } 57 inputs[feature_desc.name] = feature; 58 break; 59 } 60 61 default: { 62 return nil; 63 } 64 } 65 } 66 67 return [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:error]; 68} 69 70} //namespace 71 72@implementation MLModel (Prewarm) 73 74- (BOOL)prewarmUsingState:(nullable id)state error:(NSError * __autoreleasing *)error { 75 @autoreleasepool { 76 id<MLFeatureProvider> inputs = ::get_zeroed_inputs(self, error); 77 if (!inputs) { 78 return NO; 79 } 80 81 82 id<MLFeatureProvider> outputs = nil; 83 if (state != nil) { 84#if MODEL_STATE_IS_SUPPORTED 85 if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, *)) { 86 outputs = [self predictionFromFeatures:inputs usingState:(MLState *)state error:error]; 87 return outputs != nil; 88 } 89#endif 90 } 91 92 outputs = [self predictionFromFeatures:inputs error:error]; 93 return outputs != nil; 94 } 95} 96 97 98@end 99