1//
2//  ETCoreMLDefaultModelExecutor.m
3//  executorchcoreml_tests
4//
5//  Created by Gyan Sinha on 2/25/24.
6//
7
8#import <ETCoreMLAsset.h>
9#import <ETCoreMLDefaultModelExecutor.h>
10#import <ETCoreMLLogging.h>
11#import <ETCoreMLModel.h>
12
13@implementation ETCoreMLDefaultModelExecutor
14
15- (instancetype)initWithModel:(ETCoreMLModel *)model {
16    self = [super init];
17    if (self) {
18        _model = model;
19    }
20
21    return self;
22}
23
24- (nullable NSArray<MLMultiArray *> *)executeModelWithInputs:(id<MLFeatureProvider>)inputs
25                                           predictionOptions:(MLPredictionOptions *)predictionOptions
26                                              loggingOptions:(const executorchcoreml::ModelLoggingOptions& __unused)loggingOptions
27                                                 eventLogger:(const executorchcoreml::ModelEventLogger* _Nullable __unused)eventLogger
28                                                       error:(NSError * __autoreleasing *)error {
29    if (self.ignoreOutputBackings) {
30        predictionOptions.outputBackings = @{};
31    }
32
33    id<MLFeatureProvider> outputs = [self.model predictionFromFeatures:inputs
34                                                               options:predictionOptions
35                                                                 error:error];
36    if (!outputs) {
37        return nil;
38    }
39
40    NSOrderedSet<NSString*>* orderedOutputNames = self.model.orderedOutputNames;
41    NSMutableArray<MLMultiArray *> *result = [NSMutableArray arrayWithCapacity:orderedOutputNames.count];
42    for (NSString *outputName in orderedOutputNames) {
43        MLFeatureValue *featureValue = [outputs featureValueForName:outputName];
44        if (!featureValue.multiArrayValue) {
45            ETCoreMLLogErrorAndSetNSError(error,
46                                          ETCoreMLErrorBrokenModel,
47                                          "%@: Model is broken, expected multiarray for output=%@.",
48                                          NSStringFromClass(self.class),
49                                          outputName);
50            return nil;
51        }
52
53        [result addObject:featureValue.multiArrayValue];
54    }
55
56    return result;
57}
58
59@end
60