1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/runtime/kernel/kernel_includes.h>
12
13 namespace torch {
14 namespace executor {
15
16 bool check_padding_args(
17 int64_t n,
18 const Tensor& in,
19 exec_aten::ArrayRef<int64_t> padding,
20 Tensor& out,
21 bool reflection = false);
22
23 void get_padding_out_target_size(
24 int64_t n,
25 const Tensor& in,
26 exec_aten::ArrayRef<int64_t> padding,
27 Tensor::SizesType* out_sizes,
28 size_t* out_ndim);
29
replication_ix(int64_t j,int64_t size,int64_t pad)30 inline int64_t replication_ix(int64_t j, int64_t size, int64_t pad) {
31 return j < pad ? 0 : j >= pad && j < size + pad ? j - pad : size - 1;
32 }
33
reflection_ix(int64_t j,int64_t size,int64_t pad)34 inline int64_t reflection_ix(int64_t j, int64_t size, int64_t pad) {
35 return j < pad ? pad - j
36 : j >= pad && j < size + pad ? j - pad
37 : 2 * size + pad - j - 2;
38 }
39
40 template <typename CTYPE, typename PaddingIx>
pad1d(const PaddingIx & padding_ix,const Tensor & in,Tensor & out,exec_aten::ArrayRef<int64_t> padding)41 void pad1d(
42 const PaddingIx& padding_ix,
43 const Tensor& in,
44 Tensor& out,
45 exec_aten::ArrayRef<int64_t> padding) {
46 const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
47 CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
48
49 const auto dim = in.dim() - 1;
50 const auto outer = getLeadingDims(out, dim);
51 const auto in_width = in.size(dim);
52 const auto out_width = out.size(dim);
53 const auto pad_left = padding[0];
54
55 for (size_t i = 0; i < outer; i++) {
56 size_t out_i_base = i * out_width;
57 size_t in_i_base = i * in_width;
58 for (size_t w = 0; w < out_width; w++) {
59 out_data[out_i_base + w] =
60 in_data[in_i_base + padding_ix(w, in_width, pad_left)];
61 }
62 }
63 }
64
65 template <typename CTYPE, typename PaddingIx>
pad2d(const PaddingIx & padding_ix,const Tensor & in,Tensor & out,exec_aten::ArrayRef<int64_t> padding)66 void pad2d(
67 const PaddingIx& padding_ix,
68 const Tensor& in,
69 Tensor& out,
70 exec_aten::ArrayRef<int64_t> padding) {
71 const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
72 CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
73
74 const auto dim = in.dim() - 2;
75 const auto outer = getLeadingDims(out, dim);
76 const auto in_height = in.size(dim);
77 const auto in_width = in.size(dim + 1);
78 const auto out_height = out.size(dim);
79 const auto out_width = out.size(dim + 1);
80 const auto pad_left = padding[0];
81 const auto pad_top = padding[2];
82
83 for (size_t i = 0; i < outer; i++) {
84 size_t out_i_base = i * out_height * out_width;
85 size_t in_i_base = i * in_height * in_width;
86 for (size_t h = 0; h < out_height; h++) {
87 size_t out_h_base = out_i_base + h * out_width;
88 size_t in_h_base =
89 in_i_base + padding_ix(h, in_height, pad_top) * in_width;
90 for (size_t w = 0; w < out_width; w++) {
91 out_data[out_h_base + w] =
92 in_data[in_h_base + padding_ix(w, in_width, pad_left)];
93 }
94 }
95 }
96 }
97
98 template <typename CTYPE, typename PaddingIx>
pad3d(const PaddingIx & padding_ix,const Tensor & in,Tensor & out,exec_aten::ArrayRef<int64_t> padding)99 void pad3d(
100 const PaddingIx& padding_ix,
101 const Tensor& in,
102 Tensor& out,
103 exec_aten::ArrayRef<int64_t> padding) {
104 const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
105 CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
106
107 const auto dim = in.dim() - 3;
108 const auto outer = getLeadingDims(out, dim);
109 const auto in_depth = in.size(dim);
110 const auto in_height = in.size(dim + 1);
111 const auto in_width = in.size(dim + 2);
112 const auto out_depth = out.size(dim);
113 const auto out_height = out.size(dim + 1);
114 const auto out_width = out.size(dim + 2);
115 const auto pad_left = padding[0];
116 const auto pad_top = padding[2];
117 const auto pad_front = padding[4];
118
119 for (size_t i = 0; i < outer; i++) {
120 size_t out_i_base = i * out_depth * out_height * out_width;
121 size_t in_i_base = i * in_depth * in_height * in_width;
122 for (size_t d = 0; d < out_depth; d++) {
123 size_t out_d_base = out_i_base + d * out_height * out_width;
124 size_t in_d_base =
125 in_i_base + padding_ix(d, in_depth, pad_front) * in_height * in_width;
126 for (size_t h = 0; h < out_height; h++) {
127 size_t out_h_base = out_d_base + h * out_width;
128 size_t in_h_base =
129 in_d_base + padding_ix(h, in_height, pad_top) * in_width;
130 for (size_t w = 0; w < out_width; w++) {
131 out_data[out_h_base + w] =
132 in_data[in_h_base + padding_ix(w, in_width, pad_left)];
133 }
134 }
135 }
136 }
137 }
138
139 } // namespace executor
140 } // namespace torch
141