xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/IndexingOps.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14
15MPSGraphTensor* indexSelect(
16  MPSGraphTensor* inputTensor,
17  int64_t dim,
18  MPSGraphTensor* indexTensor,
19  MPSGraph* mpsGraph) {
20
21  MPSGraphTensor* castIndexTensor = indexTensor;
22  if(castIndexTensor.dataType != MPSDataTypeInt32) {
23    castIndexTensor = [mpsGraph castTensor:indexTensor
24                                     toType:MPSDataTypeInt32
25                                       name:@"castTensor"];
26  }
27
28  return  [mpsGraph gatherWithUpdatesTensor:inputTensor
29                              indicesTensor:castIndexTensor
30                                       axis:dim
31                            batchDimensions:0
32                                       name:@"indexSelect"];
33}
34
35Error
36MPSGraphBuilder::mpsIndexSelectOp(NodePtr nodePtr) {
37  auto graphNode = nodePtr->mpsnode_union_as_MPSIndexSelect();
38  ET_LOG(
39    Debug, "%s: %d -> %d",
40    __FUNCTION__,
41    graphNode->input1_id(),
42    graphNode->output_id()
43  );
44
45  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
46  MPSGraphTensor* indexTensor = getMPSGraphTensor(graphNode->index_id());
47  MPSGraphTensor* castIndexTensor = indexTensor;
48  if(castIndexTensor.dataType != MPSDataTypeInt32) {
49    castIndexTensor = [_mpsGraph castTensor:indexTensor
50                                     toType:MPSDataTypeInt32
51                                       name:@"castTensor"];
52  }
53
54  _idToMPSGraphTensor[graphNode->output_id()] =
55    [_mpsGraph gatherWithUpdatesTensor:inputTensor
56                         indicesTensor:castIndexTensor
57                                  axis:graphNode->dim()
58                       batchDimensions:0
59                                  name:nil];
60  return Error::Ok;
61}
62
63Error
64MPSGraphBuilder::mpsEmbeddingOp(NodePtr nodePtr) {
65  auto graphNode = nodePtr->mpsnode_union_as_MPSEmbedding();
66  ET_LOG(
67    Debug, "%s: (%d, %d) -> %d",
68    __FUNCTION__,
69    graphNode->input1_id(),
70    graphNode->input2_id(),
71    graphNode->output_id()
72  );
73
74
75  MPSGraphTensor* weightTensor = getMPSGraphTensor(graphNode->input1_id());
76  MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->input2_id());
77  int padding_idx = graphNode->padding_idx();
78
79  if (padding_idx != -1) {
80    MPSGraphTensor* constantTensor = [_mpsGraph constantWithScalar:padding_idx
81                                                             shape:@[@1]
82                                                          dataType:indicesTensor.dataType];
83
84    MPSGraphTensor* notEqualTensor = [_mpsGraph notEqualWithPrimaryTensor:indicesTensor
85                                                          secondaryTensor:constantTensor
86                                                                     name:nil];
87    MPSGraphTensor* condition = [_mpsGraph expandDimsOfTensor:notEqualTensor
88                                                  axis:-1
89                                                  name:@"unsqueeze"];
90    MPSGraphTensor* valTensor = indexSelect(weightTensor, 0, indicesTensor, _mpsGraph);
91    MPSGraphTensor* zeroTensor = [_mpsGraph constantWithScalar:0
92                                                         shape:valTensor.shape
93                                                      dataType:valTensor.dataType];
94    _idToMPSGraphTensor[graphNode->output_id()] =
95      [_mpsGraph selectWithPredicateTensor:condition
96                       truePredicateTensor:valTensor
97                      falsePredicateTensor:zeroTensor
98                                       name:nil];
99  } else {
100    _idToMPSGraphTensor[graphNode->output_id()] = indexSelect(
101      getMPSGraphTensor(graphNode->input1_id()),
102      0,
103      getMPSGraphTensor(graphNode->input2_id()),
104      _mpsGraph
105    );
106  }
107
108  return Error::Ok;
109}
110
111Error
112MPSGraphBuilder::mpsIndexTensorOp(NodePtr nodePtr) {
113  Error err = Error::Ok;
114  auto graphNode = nodePtr->mpsnode_union_as_MPSIndexTensor();
115  ET_LOG(
116    Debug, "%s: %d -> %d",
117    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
118  );
119
120  if (_metal_kernel) {
121    err = MPSDevice::getInstance()->compilePSO(LibraryType::INDEXING_KERNELS, "index_select");
122    ET_CHECK_MSG(false, "Metal kernel path not yet implemented\n");
123  } else {
124    int validIndices = 0;
125    int numIndices = graphNode->indices_id()->size();
126    int axis = -1;
127    int indexId = -1;
128    for (int i = 0; i < numIndices; i++) {
129      int32_t index_id = graphNode->indices_id()->Get(i);
130      if (index_id == -1) {
131        continue;
132      }
133      validIndices++;
134      axis = i;
135      indexId = index_id;
136    }
137    ET_LOG(Debug, "index.Tensor with %d indices (axis = %d)", validIndices, axis);
138    ET_CHECK(validIndices > 0);
139
140    if (validIndices == 1) {
141      MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->input1_id());
142      MPSGraphTensor* indexTensor = getMPSGraphTensor(indexId);
143      _idToMPSGraphTensor[graphNode->output_id()] =
144        [_mpsGraph gatherWithUpdatesTensor:updatesTensor indicesTensor:indexTensor axis:axis batchDimensions:0 name:nil];
145    } else {
146      ET_CHECK_MSG(false, "Not yet implemented");
147    }
148  }
149
150  return err;
151}
152
153Error
154MPSGraphBuilder::mpsIndexPutOp(NodePtr nodePtr) {
155  Error err = Error::Ok;
156  auto graphNode = nodePtr->mpsnode_union_as_MPSIndexPut();
157  ET_LOG(
158    Debug, "%s: %d -> %d",
159    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
160  );
161
162  if (_metal_kernel) {
163    err = MPSDevice::getInstance()->compilePSO(LibraryType::INDEXING_KERNELS, "index_put");
164    ET_CHECK_MSG(false, "Metal kernel path not yet implemented\n");
165  } else {
166    int validIndices = 0;
167    int numIndices = graphNode->indices_id()->size();
168    int axis = -1;
169    int indexId = -1;
170    for (int i = 0; i < numIndices; i++) {
171      int32_t index_id = graphNode->indices_id()->Get(i);
172      if (index_id == -1) {
173        continue;
174      }
175      validIndices++;
176      axis = i;
177      indexId = index_id;
178    }
179    ET_LOG(Debug, "index_put with %d indices (axis = %d)", validIndices, axis);
180    ET_CHECK(validIndices > 0);
181
182    if (validIndices == 1) {
183      MPSGraphTensor* dataTensor = getMPSGraphTensor(graphNode->input1_id());
184      MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->values_id());
185      MPSGraphTensor* indicesTensor = getMPSGraphTensor(indexId);
186      if (graphNode->values_shape()->size() != 0) {
187        updatesTensor = [_mpsGraph broadcastTensor:updatesTensor
188                                       toShape:getMPSShape(graphNode->values_shape())
189                                            name:nil];
190      }
191
192      _idToMPSGraphTensor[graphNode->output_id()] =
193        [_mpsGraph scatterWithDataTensor:dataTensor
194                           updatesTensor:updatesTensor
195                           indicesTensor:indicesTensor
196                                    axis:axis
197                                    mode:MPSGraphScatterModeSet
198                                  name:nil];
199    } else {
200      ET_CHECK_MSG(false, "Not yet implemented");
201    }
202  }
203
204  return err;
205}
206
207Error
208MPSGraphBuilder::mpsScatterOp(NodePtr nodePtr) {
209  auto graphNode = nodePtr->mpsnode_union_as_MPSScatter();
210  ET_LOG(
211    Debug, "%s %d: %d",
212    __FUNCTION__, graphNode->input1_id(), graphNode->output_id()
213  );
214
215  int64_t dim = graphNode->dim();
216  MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
217  MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->idx_id());
218  MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->src_id());
219
220  _idToMPSGraphTensor[graphNode->output_id()] =
221    [_mpsGraph scatterAlongAxis:dim
222                 withDataTensor:inputTensor
223                  updatesTensor:updatesTensor
224                  indicesTensor:indicesTensor
225                           mode:MPSGraphScatterModeSet
226                           name:nil];
227  return Error::Ok;
228}
229
230
231} // namespace delegate
232} // namespace mps
233} // namespace backends
234} // namespace executorch
235