1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8 9namespace executorch { 10namespace backends { 11namespace mps { 12namespace delegate { 13 14 15Error 16MPSGraphBuilder::mpsPermuteOp(NodePtr nodePtr) { 17 auto graphNode = nodePtr->mpsnode_union_as_MPSPermute(); 18 ET_LOG( 19 Debug, "%s: %d -> %d", 20 __FUNCTION__, 21 graphNode->input1_id(), 22 graphNode->output_id() 23 ); 24 25 NSMutableArray<NSNumber*>* permutation = [NSMutableArray array]; 26 for(int64_t i = 0; i < graphNode->num_dims(); i++) { 27 [permutation addObject:[NSNumber numberWithInteger:graphNode->perm()->Get(i)]]; 28 } 29 MPSGraphTensor* outputTensor = permuteTensor( 30 _mpsGraph, getMPSGraphTensor(graphNode->input1_id()), permutation 31 ); 32 _idToMPSGraphTensor[graphNode->output_id()] = outputTensor; 33 34 return Error::Ok; 35} 36 37Error 38MPSGraphBuilder::mpsViewOp(NodePtr nodePtr) { 39 auto graphNode = nodePtr->mpsnode_union_as_MPSView(); 40 ET_LOG( 41 Debug, "%s: %d -> %d", 42 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 43 ); 44 45 _idToMPSGraphTensor[graphNode->output_id()] = 46 [_mpsGraph reshapeTensor:getMPSGraphTensor(graphNode->input1_id()) 47 withShape:getMPSShape(graphNode->shape()) 48 name:@"view_copy"]; 49 50 return Error::Ok; 51} 52 53Error 54MPSGraphBuilder::mpsExpandOp(NodePtr nodePtr) { 55 auto graphNode = nodePtr->mpsnode_union_as_MPSExpand(); 56 ET_LOG( 57 Debug, "%s: %d -> %d", 58 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 59 ); 60 61 NSMutableArray<NSNumber*>* shape = [NSMutableArray array]; 62 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); 63 64 // In torch, -1 is passed for dimensions which are to stay the same size 65 for (int32_t i = 0; i < inputTensor.shape.count; i++) { 66 int expandDimVal = graphNode->shape()->Get(i); 67 if (expandDimVal == -1) { 68 [shape addObject:inputTensor.shape[i]]; 69 } else { 70 [shape addObject:[NSNumber numberWithInteger:expandDimVal]]; 71 } 72 } 73 74 _idToMPSGraphTensor[graphNode->output_id()] = 75 [_mpsGraph broadcastTensor:inputTensor 76 toShape:shape 77 name:@"expand_copy"]; 78 79 return Error::Ok; 80} 81 82Error 83MPSGraphBuilder::mpsCatOp(NodePtr nodePtr) { 84 auto graphNode = nodePtr->mpsnode_union_as_MPSCat(); 85 ET_LOG( 86 Debug, "%s: %d", 87 __FUNCTION__, graphNode->output_id() 88 ); 89 90 NSMutableArray<MPSGraphTensor*>* inputTensors = [NSMutableArray arrayWithCapacity:graphNode->input_ids()->size()];; 91 for (auto id : *graphNode->input_ids()) { 92 MPSGraphTensor* catTensor = getMPSGraphTensor(id); 93 if (catTensor != nil) 94 [inputTensors addObject:catTensor]; 95 } 96 _idToMPSGraphTensor[graphNode->output_id()] = 97 [_mpsGraph concatTensors:inputTensors 98 dimension:graphNode->dim() 99 name:@"cat"]; 100 101 return Error::Ok; 102} 103 104Error 105MPSGraphBuilder::mpsSqueezeOp(NodePtr nodePtr) { 106 auto graphNode = nodePtr->mpsnode_union_as_MPSSqueeze(); 107 ET_LOG( 108 Debug, "%s: %d", 109 __FUNCTION__, graphNode->output_id() 110 ); 111 112 _idToMPSGraphTensor[graphNode->output_id()] = 113 [_mpsGraph squeezeTensor:getMPSGraphTensor(graphNode->input1_id()) 114 axes:getMPSShape(graphNode->dims()) 115 name:@"squeeze"]; 116 117 return Error::Ok; 118} 119 120Error 121MPSGraphBuilder::mpsUnsqueezeOp(NodePtr nodePtr) { 122 auto graphNode = nodePtr->mpsnode_union_as_MPSUnsqueeze(); 123 ET_LOG( 124 Debug, "%s: %d -> %d", 125 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 126 ); 127 128 _idToMPSGraphTensor[graphNode->output_id()] = 129 [_mpsGraph expandDimsOfTensor:getMPSGraphTensor(graphNode->input1_id()) 130 axis:graphNode->dim() 131 name:@"unsqueeze"]; 132 133 return Error::Ok; 134} 135 136Error 137MPSGraphBuilder::mpsSelectOp(NodePtr nodePtr) { 138 auto graphNode = nodePtr->mpsnode_union_as_MPSSelect(); 139 ET_LOG( 140 Debug, "%s: %d -> %d", 141 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 142 ); 143 144 MPSGraphTensor* slicedTensor = [_mpsGraph sliceTensor:getMPSGraphTensor(graphNode->input1_id()) 145 dimension:graphNode->dim() 146 start:graphNode->index() 147 length:1 148 name:@"slice"]; 149 _idToMPSGraphTensor[graphNode->output_id()] = 150 [_mpsGraph squeezeTensor:slicedTensor 151 axis:graphNode->dim() 152 name:@"slice/squeezed"]; 153 154 return Error::Ok; 155} 156 157Error 158MPSGraphBuilder::mpsPixelShuffleOp(NodePtr nodePtr) { 159 auto graphNode = nodePtr->mpsnode_union_as_MPSPixelShuffle(); 160 ET_LOG( 161 Debug, "%s: %d -> %d", 162 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 163 ); 164 165 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); 166 const int ndims = inputTensor.shape.count; 167 MPSGraphTensor* outputTensor = nil; 168 int32_t upscaleFactor = graphNode->upscale_factor(); 169 170 ET_CHECK_OR_RETURN_ERROR( 171 ndims >= 3, Internal, "pixel_shuffle requires tensor with at least 3 dimensions."); 172 if (upscaleFactor == 1) { 173 // TODO: move this to AOT 174 outputTensor = inputTensor; 175 } else { 176 ET_CHECK_OR_RETURN_ERROR( 177 inputTensor.shape[ndims - 3].intValue % (upscaleFactor * upscaleFactor) == 0, 178 Internal, 179 "pixel_shuffle channels must be divisible by upscale factor squared."); 180 181 outputTensor = [_mpsGraph depthToSpace2DTensor:inputTensor 182 widthAxis:ndims - 1 183 heightAxis:ndims - 2 184 depthAxis:ndims - 3 185 blockSize:upscaleFactor 186 usePixelShuffleOrder:true 187 name:@"pixel_shuffle"]; 188 } 189 190 _idToMPSGraphTensor[graphNode->output_id()] = outputTensor; 191 return Error::Ok; 192} 193 194Error 195MPSGraphBuilder::mpsSliceOp(NodePtr nodePtr) { 196 auto graphNode = nodePtr->mpsnode_union_as_MPSSlice(); 197 ET_LOG( 198 Debug, "%s %d: %d", 199 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 200 ); 201 202 MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); 203 int64_t dim = graphNode->dim(); 204 205 // Define input arrays as required by MPSGraph API 206 NSMutableArray<NSNumber*>* start_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count]; 207 NSMutableArray<NSNumber*>* end_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count]; 208 NSMutableArray<NSNumber*>* step_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count]; 209 // Step needs to be set to one for all other dims 210 for (int i = 0; i < inputTensor.shape.count; i++) { 211 step_arr[i] = @1; 212 end_arr[i] = inputTensor.shape[i]; 213 start_arr[i] = @0; 214 } 215 216 start_arr[dim] = [NSNumber numberWithInteger:graphNode->start()]; 217 end_arr[dim] = [NSNumber numberWithInteger:graphNode->end()]; 218 step_arr[dim] = [NSNumber numberWithInteger:graphNode->step()]; 219 220 _idToMPSGraphTensor[graphNode->output_id()] = 221 [_mpsGraph sliceTensor:inputTensor 222 starts:start_arr 223 ends:end_arr 224 strides:step_arr 225 name:@"strided_slice"]; 226 return Error::Ok; 227} 228 229Error 230MPSGraphBuilder::mpsSplitWithSizesOp(NodePtr nodePtr) { 231 auto graphNode = nodePtr->mpsnode_union_as_MPSSplitWithSizes(); 232 ET_LOG( 233 Debug, "%s: %d -> len(output)=%d", 234 __FUNCTION__, graphNode->input1_id(), graphNode->output_ids()->size() 235 ); 236 237 std::vector<MPSGraphTensor*> splitResults; 238 NSArray<MPSGraphTensor*>* mpsGraphResults; 239 240 mpsGraphResults = [_mpsGraph splitTensor:getMPSGraphTensor(graphNode->input1_id()) 241 splitSizes:getMPSShape(graphNode->split_sizes()) 242 axis:graphNode->dim() 243 name:@"split_size"]; 244 245 int crtIdx = 0; 246 for (auto outId : *graphNode->output_ids()) { 247 _idToMPSGraphTensor[outId] = mpsGraphResults[crtIdx++]; 248 } 249 250 return Error::Ok; 251} 252 253Error 254MPSGraphBuilder::mpsCastOp(NodePtr nodePtr) { 255 auto graphNode = nodePtr->mpsnode_union_as_MPSCast(); 256 ET_LOG( 257 Debug, "%s: %d -> %d", 258 __FUNCTION__, graphNode->input1_id(), graphNode->output_id() 259 ); 260 261 262 _idToMPSGraphTensor[graphNode->output_id()] = castMPSTensor( 263 _mpsGraph, getMPSGraphTensor(graphNode->input1_id()), getMPSDataType(graphNode->dtype())); 264 265 return Error::Ok; 266} 267 268} // namespace delegate 269} // namespace mps 270} // namespace backends 271} // namespace executorch 272