xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/ReduceOps.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14
15Error
16MPSGraphBuilder::mpsMeanOp(NodePtr nodePtr) {
17  auto graphNode = nodePtr->mpsnode_union_as_MPSMean();
18  ET_LOG(
19    Debug, "%s: %d -> %d",
20    __FUNCTION__,
21    graphNode->input1_id(),
22    graphNode->output_id()
23  );
24
25  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
26
27  //MPSGraph wants negative axes to be converted to positive
28  const int inputDims = [inputTensor.shape count];
29
30  NSMutableArray<NSNumber*>* dimArray = [NSMutableArray array];
31  for(int64_t i = 0; i < graphNode->num_dims(); i++) {
32    int32_t dim = graphNode->dims()->Get(i);
33    if (dim < 0) {
34      dim = inputDims + dim;
35    }
36    [dimArray addObject:[NSNumber numberWithInt:dim]];
37  }
38
39  // Reverting back to get the ordering back to slowest axis first as MPSGraph expects
40  dimArray = [[[dimArray reverseObjectEnumerator] allObjects] mutableCopy];
41
42  MPSGraphTensor* meanTensor = [_mpsGraph meanOfTensor:inputTensor
43                                                  axes:dimArray
44                                                  name:@"Mean"];
45  if (!graphNode->keep_dims()) {
46    meanTensor = [_mpsGraph squeezeTensor:meanTensor
47                                     axes:dimArray
48                                     name:@"Mean/squeezed"];
49  }
50
51  _idToMPSGraphTensor[graphNode->output_id()] = meanTensor;
52  return Error::Ok;
53}
54
55
56} // namespace delegate
57} // namespace mps
58} // namespace backends
59} // namespace executorch
60