xref: /aosp_15_r20/external/executorch/backends/apple/coreml/runtime/delegate/MLModel_Prewarm.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2// MLModel+Prewarm.mm
3//
4// Copyright © 2024 Apple Inc. All rights reserved.
5//
6// Please refer to the license found in the LICENSE file in the root directory of the source tree.
7
8#import <MLModel_Prewarm.h>
9
10#import <algorithm>
11
12@interface MLMultiArray (Prewarm)
13
14+ (nullable MLMultiArray *)zeroedMultiArrayWithShape:(NSArray<NSNumber *> *)shape
15                                            dataType:(MLMultiArrayDataType)dataType
16                                               error:(NSError * __autoreleasing *)error;
17
18@end
19
20
21@implementation MLMultiArray (Prewarm)
22
23+ (MLMultiArray *)zeroedMultiArrayWithShape:(NSArray<NSNumber *> *)shape
24                                   dataType:(MLMultiArrayDataType)dataType
25                                      error:(NSError * __autoreleasing *)error {
26    MLMultiArray *multiArray = [[MLMultiArray alloc] initWithShape:shape dataType:dataType error:error];
27    if (!multiArray) {
28        return nil;
29    }
30
31    [multiArray getMutableBytesWithHandler:^(void *mutableBytes, NSInteger size, NSArray<NSNumber *> * __unused strides) {
32        uint8_t *start = reinterpret_cast<uint8_t *>(mutableBytes);
33        uint8_t *end = start + size;
34        std::fill(start, end, uint8_t(0));
35    }];
36
37    return multiArray;
38}
39
40@end
41
42namespace {
43
44id<MLFeatureProvider> _Nullable get_zeroed_inputs(MLModel *model, NSError * __autoreleasing *error) {
45    NSMutableDictionary<NSString *, MLFeatureValue *> *inputs = [NSMutableDictionary dictionary];
46    for (MLFeatureDescription *feature_desc in model.modelDescription.inputDescriptionsByName.allValues) {
47        switch (feature_desc.type) {
48            case MLFeatureTypeMultiArray: {
49                MLMultiArrayConstraint *constraint = feature_desc.multiArrayConstraint;
50                MLMultiArray *array = [MLMultiArray zeroedMultiArrayWithShape:constraint.shape
51                                                                     dataType:constraint.dataType
52                                                                        error:error];
53                MLFeatureValue *feature = (array != nil) ? [MLFeatureValue featureValueWithMultiArray:array] : nil;
54                if (!feature) {
55                    return nil;
56                }
57                inputs[feature_desc.name] = feature;
58                break;
59            }
60
61            default: {
62                return nil;
63            }
64        }
65    }
66
67    return [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:error];
68}
69
70} //namespace
71
72@implementation MLModel (Prewarm)
73
74- (BOOL)prewarmUsingState:(nullable id)state error:(NSError * __autoreleasing *)error {
75    @autoreleasepool {
76        id<MLFeatureProvider> inputs = ::get_zeroed_inputs(self, error);
77        if (!inputs) {
78            return NO;
79        }
80
81
82        id<MLFeatureProvider> outputs = nil;
83        if (state != nil) {
84#if MODEL_STATE_IS_SUPPORTED
85            if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, *)) {
86                outputs = [self predictionFromFeatures:inputs usingState:(MLState *)state error:error];
87                return outputs != nil;
88            }
89#endif
90        }
91
92        outputs = [self predictionFromFeatures:inputs error:error];
93        return outputs != nil;
94    }
95}
96
97
98@end
99