1 2// 3// Copyright (c) 2023 Apple Inc. All rights reserved. 4// Provided subject to the LICENSE file in the top level directory. 5// 6 7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h> 8 9namespace executorch { 10namespace backends { 11namespace mps { 12namespace delegate { 13 14// Pad operations (1D/2D/3D forward) 15static 16MPSGraphTensor* padOutTemplate( 17 MPSGraph* mpsGraph, 18 MPSGraphTensor* input, 19 std::vector<size_t> padding, 20 MPSGraphPaddingMode mode, 21 float constantValue) { 22 23 const int padding_size = (int) padding.size(); 24 int padding_dim = padding_size / 2; // either 1D, 2D, or 3D 25 26 ET_CHECK_MSG(padding_size == 2 || padding_size == 4 || padding_size == 6, 27 "invalid padding argument of size %d", padding_size); 28 29 auto input_sizes = getMPSShapeVec(input.shape); 30 int64_t ndims = input_sizes.size(); 31 32 ET_CHECK_MSG( 33 ndims >= (int64_t)padding_dim, 34 "Length of pad should be no more than twice the number of " 35 "dimensions of the input. Pad length is %d while the input has %lld dimensions.", padding_size, ndims); 36 37 // number of input dims with ConstantPad could be less than 2 38 int dim_w = padding_dim; 39 int dim_h = padding_dim - 1; 40 int dim_d = padding_dim - 2; 41 42 if (mode != MPSGraphPaddingModeConstant && ndims > padding_dim) { 43 bool valid_dims = input_sizes[1] != 0 && input_sizes[padding_dim] != 0; 44 ET_CHECK_MSG((ndims == 1 + padding_dim && valid_dims) || 45 (ndims == 2 + padding_dim && valid_dims && input_sizes[1 + padding_dim] != 0), 46 "3D or 4D (batch mode) tensor expected for input, but got: %zu", input_sizes.size()); 47 } 48 49 if (ndims == padding_dim) { 50 dim_w--; 51 dim_h--; 52 dim_d--; 53 } else if (ndims > padding_dim + 1) { 54 const int dim_diff = (int)ndims - padding_dim - 1; 55 // this virtually inflates the padding with zeros if ndims > padding_dim + 2 56 padding_dim += dim_diff - 1; 57 dim_w += dim_diff; 58 dim_h += dim_diff; 59 dim_d += dim_diff; 60 } 61 62 int64_t pad_l = padding[0]; 63 int64_t pad_r = padding[1]; 64 int64_t pad_t = padding_size > 2 ? padding[2] : 0; 65 int64_t pad_b = padding_size > 2 ? padding[3] : 0; 66 int64_t pad_front = padding_size > 4 ? padding[4] : 0; 67 int64_t pad_back = padding_size > 4 ? padding[5] : 0; 68 69 int64_t input_w = input_sizes[dim_w]; 70 int64_t output_w = input_w + pad_l + pad_r; 71 int64_t input_h = padding_dim > 1 ? input_sizes[dim_h] : 0; 72 int64_t output_h = padding_dim > 1 ? input_h + pad_t + pad_b : 0; 73 int64_t input_d = padding_dim > 2 ? input_sizes[dim_d] : 0; 74 75 ET_CHECK_MSG( 76 output_w >= 1 || output_h >= padding_dim - 1, 77 "input (H: %lld, W: %lld) is too small. Calculated " 78 "output H: %lld, W: %lld", input_h, input_w, output_h, output_w); 79 80 // these checks are only relevant for reflection padding (code taken from ReflectionPad.cpp) 81 if (mode == MPSGraphPaddingModeReflect) { 82 ET_CHECK_MSG(pad_l < input_w && pad_r < input_w, 83 "Argument #4: Padding size should be less than the corresponding " 84 "input dimension, but got: padding (%lld, %lld) at dimension %d of input %lld", 85 pad_l, pad_r, dim_w, ndims); 86 87 if (padding_dim > 1) { 88 ET_CHECK_MSG(pad_t < input_h && pad_b < input_h, 89 "Argument #6: Padding size should be less than the corresponding " 90 "input dimension, but got: padding (%lld, %lld) at dimension %d of input %lld", 91 pad_t, pad_b, dim_h, ndims); 92 } 93 if (padding_dim > 2) { 94 ET_CHECK_MSG(pad_front < input_d && pad_back < input_d, 95 "Argument #8: Padding size should be less than the corresponding " 96 "input dimension, but got: padding (%lld, %lld) at dimension %lld of input %lld", 97 pad_front, input_d, pad_back, input_d); 98 } 99 } 100 101 std::vector<NSNumber*> leftPadVec(ndims, @(0)); 102 std::vector<NSNumber*> rightPadVec(ndims, @(0)); 103 104 for (int64_t pdim = 0; pdim < padding_size / 2; pdim++) { 105 const int64_t leftIdx = pdim * 2; 106 const int64_t rightIdx = pdim * 2 + 1; 107 const int64_t padIdx = ndims - pdim - 1; 108 109 leftPadVec [padIdx] = @(padding[leftIdx]); 110 rightPadVec[padIdx] = @(padding[rightIdx]); 111 } 112 MPSShape *leftPadding = [NSArray arrayWithObjects:leftPadVec.data() count:ndims]; 113 MPSShape *rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims]; 114 115 // TODO: check if Bool type works with Constant padding (asserts on pytorch) 116 MPSGraphTensor *padTensor = [mpsGraph padTensor: input 117 withPaddingMode: mode 118 leftPadding: leftPadding 119 rightPadding: rightPadding 120 constantValue: constantValue 121 name: nil]; 122 123 return padTensor; 124} 125 126Error 127MPSGraphBuilder::mpsConstantPadNDOp(NodePtr nodePtr) { 128 auto graphNode = nodePtr->mpsnode_union_as_MPSConstantPadND(); 129 ET_LOG( 130 Debug, "%s: %d -> %d", 131 __FUNCTION__, 132 graphNode->input1_id(), 133 graphNode->output_id() 134 ); 135 136 137 _idToMPSGraphTensor[graphNode->output_id()] = 138 padOutTemplate( 139 _mpsGraph, 140 getMPSGraphTensor(graphNode->input1_id()), 141 flatbufferDimsToVector(graphNode->pad()), 142 MPSGraphPaddingModeConstant, 143 graphNode->value() 144 ); 145 146 return Error::Ok; 147} 148 149} // namespace delegate 150} // namespace mps 151} // namespace backends 152} // namespace executorch 153