xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.h>
2
3#import <CoreML/CoreML.h>
4
5@implementation PTMCoreMLExecutor {
6  NSArray *_featureNames;
7  PTMCoreMLFeatureProvider *_inputProvider;
8}
9
10- (instancetype)initWithFeatureNames:(NSArray<NSString *> *)featureNames {
11  if (self = [super init]) {
12    _featureNames = featureNames;
13    NSSet<NSString *> *featureNamesSet = [NSSet setWithArray:featureNames];
14    _inputProvider = [[PTMCoreMLFeatureProvider alloc] initWithFeatureNames:featureNamesSet];
15  }
16  return self;
17}
18
19- (void)setInputs:(c10::impl::GenericList)inputs {
20  [_inputProvider clearInputTensors];
21
22  int input_count = 0;
23  for (int i = 0; i < inputs.size(); ++i) {
24    at::IValue val = inputs.get(i);
25    if (val.isTuple()) {
26      auto& tuples = val.toTupleRef().elements();
27      for (auto& ival : tuples) {
28        [_inputProvider setInputTensor:ival.toTensor() forFeatureName:_featureNames[input_count]];
29        input_count++;
30      }
31    } else {
32      [_inputProvider setInputTensor:val.toTensor() forFeatureName:_featureNames[input_count]];
33      input_count++;
34    }
35  }
36}
37
38- (id<MLFeatureProvider>)forward:(NSError **)error {
39  return [self.model predictionFromFeatures:_inputProvider error:error];
40}
41
42@end
43