xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/padding_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 
16 bool check_padding_args(
17     int64_t n,
18     const Tensor& in,
19     exec_aten::ArrayRef<int64_t> padding,
20     Tensor& out,
21     bool reflection = false);
22 
23 void get_padding_out_target_size(
24     int64_t n,
25     const Tensor& in,
26     exec_aten::ArrayRef<int64_t> padding,
27     Tensor::SizesType* out_sizes,
28     size_t* out_ndim);
29 
replication_ix(int64_t j,int64_t size,int64_t pad)30 inline int64_t replication_ix(int64_t j, int64_t size, int64_t pad) {
31   return j < pad ? 0 : j >= pad && j < size + pad ? j - pad : size - 1;
32 }
33 
reflection_ix(int64_t j,int64_t size,int64_t pad)34 inline int64_t reflection_ix(int64_t j, int64_t size, int64_t pad) {
35   return j < pad                   ? pad - j
36       : j >= pad && j < size + pad ? j - pad
37                                    : 2 * size + pad - j - 2;
38 }
39 
40 template <typename CTYPE, typename PaddingIx>
pad1d(const PaddingIx & padding_ix,const Tensor & in,Tensor & out,exec_aten::ArrayRef<int64_t> padding)41 void pad1d(
42     const PaddingIx& padding_ix,
43     const Tensor& in,
44     Tensor& out,
45     exec_aten::ArrayRef<int64_t> padding) {
46   const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
47   CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
48 
49   const auto dim = in.dim() - 1;
50   const auto outer = getLeadingDims(out, dim);
51   const auto in_width = in.size(dim);
52   const auto out_width = out.size(dim);
53   const auto pad_left = padding[0];
54 
55   for (size_t i = 0; i < outer; i++) {
56     size_t out_i_base = i * out_width;
57     size_t in_i_base = i * in_width;
58     for (size_t w = 0; w < out_width; w++) {
59       out_data[out_i_base + w] =
60           in_data[in_i_base + padding_ix(w, in_width, pad_left)];
61     }
62   }
63 }
64 
65 template <typename CTYPE, typename PaddingIx>
pad2d(const PaddingIx & padding_ix,const Tensor & in,Tensor & out,exec_aten::ArrayRef<int64_t> padding)66 void pad2d(
67     const PaddingIx& padding_ix,
68     const Tensor& in,
69     Tensor& out,
70     exec_aten::ArrayRef<int64_t> padding) {
71   const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
72   CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
73 
74   const auto dim = in.dim() - 2;
75   const auto outer = getLeadingDims(out, dim);
76   const auto in_height = in.size(dim);
77   const auto in_width = in.size(dim + 1);
78   const auto out_height = out.size(dim);
79   const auto out_width = out.size(dim + 1);
80   const auto pad_left = padding[0];
81   const auto pad_top = padding[2];
82 
83   for (size_t i = 0; i < outer; i++) {
84     size_t out_i_base = i * out_height * out_width;
85     size_t in_i_base = i * in_height * in_width;
86     for (size_t h = 0; h < out_height; h++) {
87       size_t out_h_base = out_i_base + h * out_width;
88       size_t in_h_base =
89           in_i_base + padding_ix(h, in_height, pad_top) * in_width;
90       for (size_t w = 0; w < out_width; w++) {
91         out_data[out_h_base + w] =
92             in_data[in_h_base + padding_ix(w, in_width, pad_left)];
93       }
94     }
95   }
96 }
97 
98 template <typename CTYPE, typename PaddingIx>
pad3d(const PaddingIx & padding_ix,const Tensor & in,Tensor & out,exec_aten::ArrayRef<int64_t> padding)99 void pad3d(
100     const PaddingIx& padding_ix,
101     const Tensor& in,
102     Tensor& out,
103     exec_aten::ArrayRef<int64_t> padding) {
104   const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
105   CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
106 
107   const auto dim = in.dim() - 3;
108   const auto outer = getLeadingDims(out, dim);
109   const auto in_depth = in.size(dim);
110   const auto in_height = in.size(dim + 1);
111   const auto in_width = in.size(dim + 2);
112   const auto out_depth = out.size(dim);
113   const auto out_height = out.size(dim + 1);
114   const auto out_width = out.size(dim + 2);
115   const auto pad_left = padding[0];
116   const auto pad_top = padding[2];
117   const auto pad_front = padding[4];
118 
119   for (size_t i = 0; i < outer; i++) {
120     size_t out_i_base = i * out_depth * out_height * out_width;
121     size_t in_i_base = i * in_depth * in_height * in_width;
122     for (size_t d = 0; d < out_depth; d++) {
123       size_t out_d_base = out_i_base + d * out_height * out_width;
124       size_t in_d_base =
125           in_i_base + padding_ix(d, in_depth, pad_front) * in_height * in_width;
126       for (size_t h = 0; h < out_height; h++) {
127         size_t out_h_base = out_d_base + h * out_width;
128         size_t in_h_base =
129             in_d_base + padding_ix(h, in_height, pad_top) * in_width;
130         for (size_t w = 0; w < out_width; w++) {
131           out_data[out_h_base + w] =
132               in_data[in_h_base + padding_ix(w, in_width, pad_left)];
133         }
134       }
135     }
136   }
137 }
138 
139 } // namespace executor
140 } // namespace torch
141