xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/ShapeOps.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14
15Error
16MPSGraphBuilder::mpsPermuteOp(NodePtr nodePtr) {
17  auto graphNode = nodePtr->mpsnode_union_as_MPSPermute();
18  ET_LOG(
19    Debug, "%s: %d -> %d",
20    __FUNCTION__,
21    graphNode->input1_id(),
22    graphNode->output_id()
23  );
24
25  NSMutableArray<NSNumber*>* permutation = [NSMutableArray array];
26  for(int64_t i = 0; i < graphNode->num_dims(); i++) {
27    [permutation addObject:[NSNumber numberWithInteger:graphNode->perm()->Get(i)]];
28  }
29  MPSGraphTensor* outputTensor = permuteTensor(
30    _mpsGraph, getMPSGraphTensor(graphNode->input1_id()), permutation
31  );
32  _idToMPSGraphTensor[graphNode->output_id()] = outputTensor;
33
34  return Error::Ok;
35}
36
37Error
38MPSGraphBuilder::mpsViewOp(NodePtr nodePtr) {
39  auto graphNode = nodePtr->mpsnode_union_as_MPSView();
40  ET_LOG(
41    Debug, "%s: %d -> %d",
42    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
43  );
44
45  _idToMPSGraphTensor[graphNode->output_id()] =
46    [_mpsGraph reshapeTensor:getMPSGraphTensor(graphNode->input1_id())
47                  withShape:getMPSShape(graphNode->shape())
48                       name:@"view_copy"];
49
50  return Error::Ok;
51}
52
53Error
54MPSGraphBuilder::mpsExpandOp(NodePtr nodePtr) {
55  auto graphNode = nodePtr->mpsnode_union_as_MPSExpand();
56  ET_LOG(
57    Debug, "%s: %d -> %d",
58    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
59  );
60
61  NSMutableArray<NSNumber*>* shape = [NSMutableArray array];
62  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
63
64  // In torch, -1 is passed for dimensions which are to stay the same size
65  for (int32_t i = 0; i < inputTensor.shape.count; i++) {
66    int expandDimVal = graphNode->shape()->Get(i);
67    if (expandDimVal == -1) {
68      [shape addObject:inputTensor.shape[i]];
69    } else {
70      [shape addObject:[NSNumber numberWithInteger:expandDimVal]];
71    }
72  }
73
74  _idToMPSGraphTensor[graphNode->output_id()] =
75    [_mpsGraph broadcastTensor:inputTensor
76                       toShape:shape
77                          name:@"expand_copy"];
78
79  return Error::Ok;
80}
81
82Error
83MPSGraphBuilder::mpsCatOp(NodePtr nodePtr) {
84  auto graphNode = nodePtr->mpsnode_union_as_MPSCat();
85  ET_LOG(
86    Debug, "%s: %d",
87    __FUNCTION__, graphNode->output_id()
88  );
89
90  NSMutableArray<MPSGraphTensor*>* inputTensors = [NSMutableArray arrayWithCapacity:graphNode->input_ids()->size()];;
91  for (auto id : *graphNode->input_ids()) {
92    MPSGraphTensor* catTensor = getMPSGraphTensor(id);
93    if (catTensor != nil)
94      [inputTensors addObject:catTensor];
95  }
96  _idToMPSGraphTensor[graphNode->output_id()] =
97    [_mpsGraph concatTensors:inputTensors
98                   dimension:graphNode->dim()
99                        name:@"cat"];
100
101  return Error::Ok;
102}
103
104Error
105MPSGraphBuilder::mpsSqueezeOp(NodePtr nodePtr) {
106  auto graphNode = nodePtr->mpsnode_union_as_MPSSqueeze();
107  ET_LOG(
108    Debug, "%s: %d",
109    __FUNCTION__, graphNode->output_id()
110  );
111
112  _idToMPSGraphTensor[graphNode->output_id()] =
113    [_mpsGraph squeezeTensor:getMPSGraphTensor(graphNode->input1_id())
114                        axes:getMPSShape(graphNode->dims())
115                        name:@"squeeze"];
116
117  return Error::Ok;
118}
119
120Error
121MPSGraphBuilder::mpsUnsqueezeOp(NodePtr nodePtr) {
122  auto graphNode = nodePtr->mpsnode_union_as_MPSUnsqueeze();
123  ET_LOG(
124    Debug, "%s: %d -> %d",
125    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
126  );
127
128  _idToMPSGraphTensor[graphNode->output_id()] =
129    [_mpsGraph expandDimsOfTensor:getMPSGraphTensor(graphNode->input1_id())
130                             axis:graphNode->dim()
131                             name:@"unsqueeze"];
132
133  return Error::Ok;
134}
135
136Error
137MPSGraphBuilder::mpsSelectOp(NodePtr nodePtr) {
138  auto graphNode = nodePtr->mpsnode_union_as_MPSSelect();
139  ET_LOG(
140    Debug, "%s: %d -> %d",
141    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
142  );
143
144  MPSGraphTensor* slicedTensor = [_mpsGraph sliceTensor:getMPSGraphTensor(graphNode->input1_id())
145                                              dimension:graphNode->dim()
146                                                  start:graphNode->index()
147                                                 length:1
148                                                   name:@"slice"];
149  _idToMPSGraphTensor[graphNode->output_id()] =
150    [_mpsGraph squeezeTensor:slicedTensor
151                        axis:graphNode->dim()
152                        name:@"slice/squeezed"];
153
154  return Error::Ok;
155}
156
157Error
158MPSGraphBuilder::mpsPixelShuffleOp(NodePtr nodePtr) {
159  auto graphNode = nodePtr->mpsnode_union_as_MPSPixelShuffle();
160  ET_LOG(
161    Debug, "%s: %d -> %d",
162    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
163  );
164
165  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
166  const int ndims = inputTensor.shape.count;
167  MPSGraphTensor* outputTensor = nil;
168  int32_t upscaleFactor = graphNode->upscale_factor();
169
170  ET_CHECK_OR_RETURN_ERROR(
171    ndims >= 3, Internal,  "pixel_shuffle requires tensor with at least 3 dimensions.");
172  if (upscaleFactor == 1) {
173    // TODO: move this to AOT
174    outputTensor = inputTensor;
175  } else {
176    ET_CHECK_OR_RETURN_ERROR(
177      inputTensor.shape[ndims - 3].intValue % (upscaleFactor * upscaleFactor) == 0,
178      Internal,
179      "pixel_shuffle channels must be divisible by upscale factor squared.");
180
181    outputTensor = [_mpsGraph depthToSpace2DTensor:inputTensor
182                                         widthAxis:ndims - 1
183                                        heightAxis:ndims - 2
184                                         depthAxis:ndims - 3
185                                         blockSize:upscaleFactor
186                              usePixelShuffleOrder:true
187                                             name:@"pixel_shuffle"];
188  }
189
190  _idToMPSGraphTensor[graphNode->output_id()] = outputTensor;
191  return Error::Ok;
192}
193
194Error
195MPSGraphBuilder::mpsSliceOp(NodePtr nodePtr) {
196  auto graphNode = nodePtr->mpsnode_union_as_MPSSlice();
197  ET_LOG(
198    Debug, "%s %d: %d",
199    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
200  );
201
202  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
203  int64_t dim = graphNode->dim();
204
205  // Define input arrays as required by MPSGraph API
206  NSMutableArray<NSNumber*>* start_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count];
207  NSMutableArray<NSNumber*>* end_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count];
208  NSMutableArray<NSNumber*>* step_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count];
209  // Step needs to be set to one for all other dims
210  for (int i = 0; i < inputTensor.shape.count; i++) {
211    step_arr[i] = @1;
212    end_arr[i] = inputTensor.shape[i];
213    start_arr[i] = @0;
214  }
215
216  start_arr[dim] = [NSNumber numberWithInteger:graphNode->start()];
217  end_arr[dim] = [NSNumber numberWithInteger:graphNode->end()];
218  step_arr[dim] = [NSNumber numberWithInteger:graphNode->step()];
219
220  _idToMPSGraphTensor[graphNode->output_id()] =
221    [_mpsGraph sliceTensor:inputTensor
222                   starts:start_arr
223                     ends:end_arr
224                  strides:step_arr
225                     name:@"strided_slice"];
226  return Error::Ok;
227}
228
229Error
230MPSGraphBuilder::mpsSplitWithSizesOp(NodePtr nodePtr) {
231  auto graphNode = nodePtr->mpsnode_union_as_MPSSplitWithSizes();
232  ET_LOG(
233    Debug, "%s: %d -> len(output)=%d",
234    __FUNCTION__, graphNode->input1_id(), graphNode->output_ids()->size()
235  );
236
237  std::vector<MPSGraphTensor*> splitResults;
238  NSArray<MPSGraphTensor*>* mpsGraphResults;
239
240  mpsGraphResults = [_mpsGraph splitTensor:getMPSGraphTensor(graphNode->input1_id())
241                                splitSizes:getMPSShape(graphNode->split_sizes())
242                                      axis:graphNode->dim()
243                                      name:@"split_size"];
244
245  int crtIdx = 0;
246  for (auto outId : *graphNode->output_ids()) {
247    _idToMPSGraphTensor[outId] = mpsGraphResults[crtIdx++];
248  }
249
250  return Error::Ok;
251}
252
253Error
254MPSGraphBuilder::mpsCastOp(NodePtr nodePtr) {
255  auto graphNode = nodePtr->mpsnode_union_as_MPSCast();
256  ET_LOG(
257    Debug, "%s: %d -> %d",
258    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
259  );
260
261
262  _idToMPSGraphTensor[graphNode->output_id()] = castMPSTensor(
263    _mpsGraph, getMPSGraphTensor(graphNode->input1_id()), getMPSDataType(graphNode->dtype()));
264
265  return Error::Ok;
266}
267
268} // namespace delegate
269} // namespace mps
270} // namespace backends
271} // namespace executorch
272