xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/PadOps.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1
2//
3//  Copyright (c) 2023 Apple Inc. All rights reserved.
4//  Provided subject to the LICENSE file in the top level directory.
5//
6
7#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
8
9namespace executorch {
10namespace backends {
11namespace mps {
12namespace delegate {
13
14// Pad operations (1D/2D/3D forward)
15static
16MPSGraphTensor* padOutTemplate(
17  MPSGraph* mpsGraph,
18  MPSGraphTensor* input,
19  std::vector<size_t> padding,
20  MPSGraphPaddingMode mode,
21  float constantValue) {
22
23  const int padding_size = (int) padding.size();
24  int padding_dim = padding_size / 2; // either 1D, 2D, or 3D
25
26  ET_CHECK_MSG(padding_size == 2 || padding_size == 4 || padding_size == 6,
27              "invalid padding argument of size %d", padding_size);
28
29  auto input_sizes = getMPSShapeVec(input.shape);
30  int64_t ndims = input_sizes.size();
31
32  ET_CHECK_MSG(
33    ndims >= (int64_t)padding_dim,
34    "Length of pad should be no more than twice the number of "
35    "dimensions of the input. Pad length is %d while the input has %lld dimensions.", padding_size, ndims);
36
37  // number of input dims with ConstantPad could be less than 2
38  int dim_w = padding_dim;
39  int dim_h = padding_dim - 1;
40  int dim_d = padding_dim - 2;
41
42  if (mode != MPSGraphPaddingModeConstant && ndims > padding_dim) {
43    bool valid_dims = input_sizes[1] != 0 && input_sizes[padding_dim] != 0;
44    ET_CHECK_MSG((ndims == 1 + padding_dim && valid_dims) ||
45                (ndims == 2 + padding_dim && valid_dims && input_sizes[1 + padding_dim] != 0),
46                "3D or 4D (batch mode) tensor expected for input, but got: %zu", input_sizes.size());
47  }
48
49  if (ndims == padding_dim) {
50    dim_w--;
51    dim_h--;
52    dim_d--;
53  } else if (ndims > padding_dim + 1) {
54    const int dim_diff = (int)ndims - padding_dim - 1;
55    // this virtually inflates the padding with zeros if ndims > padding_dim + 2
56    padding_dim += dim_diff - 1;
57    dim_w += dim_diff;
58    dim_h += dim_diff;
59    dim_d += dim_diff;
60  }
61
62  int64_t pad_l = padding[0];
63  int64_t pad_r = padding[1];
64  int64_t pad_t = padding_size > 2 ? padding[2] : 0;
65  int64_t pad_b = padding_size > 2 ? padding[3] : 0;
66  int64_t pad_front = padding_size > 4 ? padding[4] : 0;
67  int64_t pad_back  = padding_size > 4 ? padding[5] : 0;
68
69  int64_t input_w = input_sizes[dim_w];
70  int64_t output_w  = input_w + pad_l + pad_r;
71  int64_t input_h = padding_dim > 1 ? input_sizes[dim_h] : 0;
72  int64_t output_h = padding_dim > 1 ? input_h + pad_t + pad_b : 0;
73  int64_t input_d = padding_dim > 2 ? input_sizes[dim_d] : 0;
74
75  ET_CHECK_MSG(
76    output_w >= 1 || output_h >= padding_dim - 1,
77    "input (H: %lld, W: %lld) is too small. Calculated "
78    "output H: %lld, W: %lld",  input_h, input_w, output_h, output_w);
79
80  // these checks are only relevant for reflection padding (code taken from ReflectionPad.cpp)
81  if (mode == MPSGraphPaddingModeReflect) {
82    ET_CHECK_MSG(pad_l < input_w && pad_r < input_w,
83      "Argument #4: Padding size should be less than the corresponding "
84      "input dimension, but got: padding (%lld, %lld) at dimension %d of input %lld",
85      pad_l, pad_r, dim_w, ndims);
86
87    if (padding_dim > 1) {
88      ET_CHECK_MSG(pad_t < input_h && pad_b < input_h,
89        "Argument #6: Padding size should be less than the corresponding "
90        "input dimension, but got: padding (%lld, %lld) at dimension %d of input %lld",
91        pad_t, pad_b, dim_h, ndims);
92    }
93    if (padding_dim > 2) {
94      ET_CHECK_MSG(pad_front < input_d && pad_back < input_d,
95        "Argument #8: Padding size should be less than the corresponding "
96        "input dimension, but got: padding (%lld, %lld) at dimension %lld of input %lld",
97        pad_front, input_d, pad_back, input_d);
98    }
99  }
100
101  std::vector<NSNumber*> leftPadVec(ndims, @(0));
102  std::vector<NSNumber*> rightPadVec(ndims, @(0));
103
104  for (int64_t pdim = 0; pdim < padding_size / 2; pdim++) {
105    const int64_t leftIdx  = pdim * 2;
106    const int64_t rightIdx = pdim * 2 + 1;
107    const int64_t padIdx = ndims - pdim - 1;
108
109    leftPadVec [padIdx] = @(padding[leftIdx]);
110    rightPadVec[padIdx] = @(padding[rightIdx]);
111  }
112  MPSShape *leftPadding  = [NSArray arrayWithObjects:leftPadVec.data() count:ndims];
113  MPSShape *rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims];
114
115  // TODO: check if Bool type works with Constant padding (asserts on pytorch)
116  MPSGraphTensor *padTensor = [mpsGraph padTensor: input
117                                  withPaddingMode: mode
118                                      leftPadding: leftPadding
119                                     rightPadding: rightPadding
120                                    constantValue: constantValue
121                                             name: nil];
122
123  return padTensor;
124}
125
126Error
127MPSGraphBuilder::mpsConstantPadNDOp(NodePtr nodePtr) {
128  auto graphNode = nodePtr->mpsnode_union_as_MPSConstantPadND();
129  ET_LOG(
130    Debug, "%s: %d -> %d",
131    __FUNCTION__,
132    graphNode->input1_id(),
133    graphNode->output_id()
134  );
135
136
137  _idToMPSGraphTensor[graphNode->output_id()] =
138    padOutTemplate(
139      _mpsGraph,
140      getMPSGraphTensor(graphNode->input1_id()),
141      flatbufferDimsToVector(graphNode->pad()),
142      MPSGraphPaddingModeConstant,
143      graphNode->value()
144    );
145
146  return Error::Ok;
147}
148
149} // namespace delegate
150} // namespace mps
151} // namespace backends
152} // namespace executorch
153