1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8 9namespace executorch { 10namespace backends { 11namespace mps { 12namespace delegate { 13 14 15MPSGraphTensor* indexSelect( 16 MPSGraphTensor* inputTensor, 17 int64_t dim, 18 MPSGraphTensor* indexTensor, 19 MPSGraph* mpsGraph) { 20 21 MPSGraphTensor* castIndexTensor = indexTensor; 22 if(castIndexTensor.dataType != MPSDataTypeInt32) { 23 castIndexTensor = [mpsGraph castTensor:indexTensor 24 toType:MPSDataTypeInt32 25 name:@"castTensor"]; 26 } 27 28 return [mpsGraph gatherWithUpdatesTensor:inputTensor 29 indicesTensor:castIndexTensor 30 axis:dim 31 batchDimensions:0 32 name:@"indexSelect"]; 33} 34 35Error 36MPSGraphBuilder::mpsIndexSelectOp(NodePtr nodePtr) { 37 auto graphNode = nodePtr->mpsnode_union_as_MPSIndexSelect(); 38 ET_LOG( 39 Debug, "%s: %d -> %d", 40 __FUNCTION__, 41 graphNode->input1_id(), 42 graphNode->output_id() 43 ); 44 45 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); 46 MPSGraphTensor* indexTensor = getMPSGraphTensor(graphNode->index_id()); 47 MPSGraphTensor* castIndexTensor = indexTensor; 48 if(castIndexTensor.dataType != MPSDataTypeInt32) { 49 castIndexTensor = [_mpsGraph castTensor:indexTensor 50 toType:MPSDataTypeInt32 51 name:@"castTensor"]; 52 } 53 54 _idToMPSGraphTensor[graphNode->output_id()] = 55 [_mpsGraph gatherWithUpdatesTensor:inputTensor 56 indicesTensor:castIndexTensor 57 axis:graphNode->dim() 58 batchDimensions:0 59 name:nil]; 60 return Error::Ok; 61} 62 63Error 64MPSGraphBuilder::mpsEmbeddingOp(NodePtr nodePtr) { 65 auto graphNode = nodePtr->mpsnode_union_as_MPSEmbedding(); 66 ET_LOG( 67 Debug, "%s: (%d, %d) -> %d", 68 __FUNCTION__, 69 graphNode->input1_id(), 70 graphNode->input2_id(), 71 graphNode->output_id() 72 ); 73 74 75 MPSGraphTensor* weightTensor = getMPSGraphTensor(graphNode->input1_id()); 76 MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->input2_id()); 77 int padding_idx = graphNode->padding_idx(); 78 79 if (padding_idx != -1) { 80 MPSGraphTensor* constantTensor = [_mpsGraph constantWithScalar:padding_idx 81 shape:@[@1] 82 dataType:indicesTensor.dataType]; 83 84 MPSGraphTensor* notEqualTensor = [_mpsGraph notEqualWithPrimaryTensor:indicesTensor 85 secondaryTensor:constantTensor 86 name:nil]; 87 MPSGraphTensor* condition = [_mpsGraph expandDimsOfTensor:notEqualTensor 88 axis:-1 89 name:@"unsqueeze"]; 90 MPSGraphTensor* valTensor = indexSelect(weightTensor, 0, indicesTensor, _mpsGraph); 91 MPSGraphTensor* zeroTensor = [_mpsGraph constantWithScalar:0 92 shape:valTensor.shape 93 dataType:valTensor.dataType]; 94 _idToMPSGraphTensor[graphNode->output_id()] = 95 [_mpsGraph selectWithPredicateTensor:condition 96 truePredicateTensor:valTensor 97 falsePredicateTensor:zeroTensor 98 name:nil]; 99 } else { 100 _idToMPSGraphTensor[graphNode->output_id()] = indexSelect( 101 getMPSGraphTensor(graphNode->input1_id()), 102 0, 103 getMPSGraphTensor(graphNode->input2_id()), 104 _mpsGraph 105 ); 106 } 107 108 return Error::Ok; 109} 110 111Error 112MPSGraphBuilder::mpsIndexTensorOp(NodePtr nodePtr) { 113 Error err = Error::Ok; 114 auto graphNode = nodePtr->mpsnode_union_as_MPSIndexTensor(); 115 ET_LOG( 116 Debug, "%s: %d -> %d", 117 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 118 ); 119 120 if (_metal_kernel) { 121 err = MPSDevice::getInstance()->compilePSO(LibraryType::INDEXING_KERNELS, "index_select"); 122 ET_CHECK_MSG(false, "Metal kernel path not yet implemented\n"); 123 } else { 124 int validIndices = 0; 125 int numIndices = graphNode->indices_id()->size(); 126 int axis = -1; 127 int indexId = -1; 128 for (int i = 0; i < numIndices; i++) { 129 int32_t index_id = graphNode->indices_id()->Get(i); 130 if (index_id == -1) { 131 continue; 132 } 133 validIndices++; 134 axis = i; 135 indexId = index_id; 136 } 137 ET_LOG(Debug, "index.Tensor with %d indices (axis = %d)", validIndices, axis); 138 ET_CHECK(validIndices > 0); 139 140 if (validIndices == 1) { 141 MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->input1_id()); 142 MPSGraphTensor* indexTensor = getMPSGraphTensor(indexId); 143 _idToMPSGraphTensor[graphNode->output_id()] = 144 [_mpsGraph gatherWithUpdatesTensor:updatesTensor indicesTensor:indexTensor axis:axis batchDimensions:0 name:nil]; 145 } else { 146 ET_CHECK_MSG(false, "Not yet implemented"); 147 } 148 } 149 150 return err; 151} 152 153Error 154MPSGraphBuilder::mpsIndexPutOp(NodePtr nodePtr) { 155 Error err = Error::Ok; 156 auto graphNode = nodePtr->mpsnode_union_as_MPSIndexPut(); 157 ET_LOG( 158 Debug, "%s: %d -> %d", 159 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 160 ); 161 162 if (_metal_kernel) { 163 err = MPSDevice::getInstance()->compilePSO(LibraryType::INDEXING_KERNELS, "index_put"); 164 ET_CHECK_MSG(false, "Metal kernel path not yet implemented\n"); 165 } else { 166 int validIndices = 0; 167 int numIndices = graphNode->indices_id()->size(); 168 int axis = -1; 169 int indexId = -1; 170 for (int i = 0; i < numIndices; i++) { 171 int32_t index_id = graphNode->indices_id()->Get(i); 172 if (index_id == -1) { 173 continue; 174 } 175 validIndices++; 176 axis = i; 177 indexId = index_id; 178 } 179 ET_LOG(Debug, "index_put with %d indices (axis = %d)", validIndices, axis); 180 ET_CHECK(validIndices > 0); 181 182 if (validIndices == 1) { 183 MPSGraphTensor* dataTensor = getMPSGraphTensor(graphNode->input1_id()); 184 MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->values_id()); 185 MPSGraphTensor* indicesTensor = getMPSGraphTensor(indexId); 186 if (graphNode->values_shape()->size() != 0) { 187 updatesTensor = [_mpsGraph broadcastTensor:updatesTensor 188 toShape:getMPSShape(graphNode->values_shape()) 189 name:nil]; 190 } 191 192 _idToMPSGraphTensor[graphNode->output_id()] = 193 [_mpsGraph scatterWithDataTensor:dataTensor 194 updatesTensor:updatesTensor 195 indicesTensor:indicesTensor 196 axis:axis 197 mode:MPSGraphScatterModeSet 198 name:nil]; 199 } else { 200 ET_CHECK_MSG(false, "Not yet implemented"); 201 } 202 } 203 204 return err; 205} 206 207Error 208MPSGraphBuilder::mpsScatterOp(NodePtr nodePtr) { 209 auto graphNode = nodePtr->mpsnode_union_as_MPSScatter(); 210 ET_LOG( 211 Debug, "%s %d: %d", 212 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 213 ); 214 215 int64_t dim = graphNode->dim(); 216 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); 217 MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->idx_id()); 218 MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->src_id()); 219 220 _idToMPSGraphTensor[graphNode->output_id()] = 221 [_mpsGraph scatterAlongAxis:dim 222 withDataTensor:inputTensor 223 updatesTensor:updatesTensor 224 indicesTensor:indicesTensor 225 mode:MPSGraphScatterModeSet 226 name:nil]; 227 return Error::Ok; 228} 229 230 231} // namespace delegate 232} // namespace mps 233} // namespace backends 234} // namespace executorch 235