xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/CpuPermuteKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/cpu/kernels/CpuPermuteKernel.h"
25 
26 #include "arm_compute/core/Error.h"
27 #include "arm_compute/core/Helpers.h"
28 #include "arm_compute/core/ITensor.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/Types.h"
31 #include "arm_compute/core/Validate.h"
32 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
33 #include "src/core/helpers/AutoConfiguration.h"
34 #include "src/core/helpers/WindowHelpers.h"
35 
36 namespace
37 {
38 #include "src/core/NEON/kernels/convolution/common/shims.hpp"
39 } // namespace
40 
41 namespace arm_compute
42 {
43 namespace cpu
44 {
45 namespace kernels
46 {
47 namespace
48 {
is_permutation_supported(const PermutationVector & v)49 inline bool is_permutation_supported(const PermutationVector &v)
50 {
51     static const std::array<PermutationVector, 2> permutations2 =
52     {
53         {
54             PermutationVector(0U, 1U),
55             PermutationVector(1U, 0U),
56         }
57     };
58     static const std::array<PermutationVector, 6> permutations3 =
59     {
60         {
61             PermutationVector(2U, 0U, 1U),
62             PermutationVector(1U, 2U, 0U),
63             PermutationVector(0U, 1U, 2U),
64             PermutationVector(0U, 2U, 1U),
65             PermutationVector(1U, 0U, 2U),
66             PermutationVector(2U, 1U, 0U),
67         }
68     };
69     static const std::array<PermutationVector, 24> permutations4 =
70     {
71         {
72             PermutationVector(0U, 1U, 2U, 3U),
73             PermutationVector(1U, 0U, 2U, 3U),
74             PermutationVector(2U, 0U, 1U, 3U),
75             PermutationVector(0U, 2U, 1U, 3U),
76             PermutationVector(1U, 2U, 0U, 3U),
77             PermutationVector(2U, 1U, 0U, 3U),
78             PermutationVector(2U, 1U, 3U, 0U),
79             PermutationVector(1U, 2U, 3U, 0U),
80             PermutationVector(3U, 2U, 1U, 0U),
81             PermutationVector(2U, 3U, 1U, 0U),
82             PermutationVector(1U, 3U, 2U, 0U),
83             PermutationVector(3U, 1U, 2U, 0U),
84             PermutationVector(3U, 0U, 2U, 1U),
85             PermutationVector(0U, 3U, 2U, 1U),
86             PermutationVector(2U, 3U, 0U, 1U),
87             PermutationVector(3U, 2U, 0U, 1U),
88             PermutationVector(0U, 2U, 3U, 1U),
89             PermutationVector(2U, 0U, 3U, 1U),
90             PermutationVector(1U, 0U, 3U, 2U),
91             PermutationVector(0U, 1U, 3U, 2U),
92             PermutationVector(3U, 1U, 0U, 2U),
93             PermutationVector(1U, 3U, 0U, 2U),
94             PermutationVector(0U, 3U, 1U, 2U),
95             PermutationVector(3U, 0U, 1U, 2U)
96         }
97     };
98 
99     return (permutations2.end() != std::find(permutations2.begin(), permutations2.end(), v)) || (permutations3.end() != std::find(permutations3.begin(), permutations3.end(), v))
100            || (permutations4.end() != std::find(permutations4.begin(), permutations4.end(), v));
101 }
102 
validate_arguments(const ITensorInfo * src,const ITensorInfo * dst,const PermutationVector & perm)103 Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const PermutationVector &perm)
104 {
105     ARM_COMPUTE_RETURN_ERROR_ON(src->data_type() == DataType::UNKNOWN);
106     ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_permutation_supported(perm), "PermutationVector not supported.");
107 
108     const TensorShape dst_shape = misc::shape_calculator::compute_permutation_output_shape(*src, perm);
109 
110     // Validate configured destination
111     if(dst->total_size() != 0)
112     {
113         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), dst_shape);
114         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(src, dst);
115         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
116     }
117 
118     return Status{};
119 }
120 
121 template <typename T>
run_permute(const Window & window,const ITensor * src,const ITensor * dst,const PermutationVector & perm)122 void run_permute(const Window &window, const ITensor *src, const ITensor *dst, const PermutationVector &perm)
123 {
124     const DataLayout src_layout = src->info()->data_layout();
125 
126     // Source window
127     Window window_src = window;
128 
129     // we only support these two configs in src/core/NEON/kernels/convolution/common/shims.hpp, for all others
130     // we have to fall back to C++
131     if((src_layout == DataLayout::NCHW && perm == PermutationVector{ 2U, 0U, 1U }) || (src_layout == DataLayout::NHWC && perm == PermutationVector{ 1U, 2U, 0U }))
132     {
133         window_src.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), window.x().end() - window.x().start()));
134         window_src.set(Window::DimY, Window::Dimension(window.y().start(), window.y().end(), window.y().end() - window.y().start()));
135         window_src.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), window.z().end() - window.z().start()));
136         window_src.set(3, Window::Dimension(window[3].start(), window[3].end(), window[3].end() - window[3].start()));
137     }
138 
139     // Destination window
140     Window                  window_dst(window);
141     const Window::Dimension zero_window = Window::Dimension(0, 0, 0);
142     for(size_t d = 0; d <= dst->info()->num_dimensions(); ++d)
143     {
144         window_dst.set(d, zero_window);
145     }
146 
147     // Create iterators
148     Iterator src_it(src, window_src);
149     Iterator dst_it(dst, window_dst);
150 
151     int in_row_stride     = 0;
152     int in_col_stride     = 0;
153     int in_channel_stride = 0;
154     int in_batch_stride   = 0;
155     int n_cols            = 0;
156     int n_rows            = 0;
157     int n_channels        = 0;
158     int n_batches         = 0;
159 
160     switch(src_layout)
161     {
162         case DataLayout::NCHW:
163         {
164             in_row_stride     = src->info()->strides_in_bytes().y() / sizeof(T);
165             in_channel_stride = src->info()->strides_in_bytes().z() / sizeof(T);
166             in_batch_stride   = src->info()->strides_in_bytes()[3] / sizeof(T);
167             n_cols            = src->info()->tensor_shape().x();
168             n_rows            = window_src.y().step();
169             n_channels        = src->info()->tensor_shape().z();
170             n_batches         = src->info()->tensor_shape()[3];
171             break;
172         }
173         case DataLayout::NHWC:
174         {
175             in_col_stride   = src->info()->strides_in_bytes().y() / sizeof(T);
176             in_row_stride   = src->info()->strides_in_bytes().z() / sizeof(T);
177             in_batch_stride = src->info()->strides_in_bytes()[3] / sizeof(T);
178             n_channels      = src->info()->tensor_shape().x();
179             n_cols          = window_src.y().step();
180             n_rows          = src->info()->tensor_shape().z();
181             n_batches       = src->info()->tensor_shape()[3];
182             break;
183         }
184         default:
185         {
186             ARM_COMPUTE_ERROR("Invalid source data layout.");
187             break;
188         }
189     }
190 
191     // CHW -> HWC
192     if(src_layout == DataLayout::NCHW && perm == PermutationVector{ 2U, 0U, 1U })
193     {
194         const int out_channel_stride = dst->info()->strides_in_bytes().x() / sizeof(T);
195         const int out_col_stride     = dst->info()->strides_in_bytes().y() / sizeof(T);
196         const int out_row_stride     = dst->info()->strides_in_bytes().z() / sizeof(T);
197         const int out_batch_stride   = dst->info()->strides_in_bytes()[3] / sizeof(T);
198         execute_window_loop(window_src, [&](const Coordinates & id)
199         {
200             const int idx = id[0] * out_col_stride + id[1] * out_row_stride + id[2] * out_channel_stride;
201             reorder::nchw_to_nhwc(reinterpret_cast<const T *>(src_it.ptr()), reinterpret_cast<T *>(dst_it.ptr()) + idx,
202                                   n_batches, n_channels, n_rows, n_cols,
203                                   in_batch_stride, in_channel_stride, in_row_stride,
204                                   out_batch_stride, out_row_stride, out_col_stride);
205         },
206         src_it, dst_it);
207     }
208     // HWC -> CHW
209     else if(src_layout == DataLayout::NHWC && perm == PermutationVector{ 1U, 2U, 0U })
210     {
211         const int out_col_stride     = dst->info()->strides_in_bytes().x() / sizeof(T);
212         const int out_row_stride     = dst->info()->strides_in_bytes().y() / sizeof(T);
213         const int out_channel_stride = dst->info()->strides_in_bytes().z() / sizeof(T);
214         const int out_batch_stride   = dst->info()->strides_in_bytes()[3] / sizeof(T);
215         execute_window_loop(window_src, [&](const Coordinates & id)
216         {
217             const int idx = id[0] * out_channel_stride + id[1] * out_col_stride + id[2] * out_row_stride;
218             reorder::nhwc_to_nchw(reinterpret_cast<const T *>(src_it.ptr()), reinterpret_cast<T *>(dst_it.ptr()) + idx,
219                                   n_batches, n_rows, n_cols, n_channels,
220                                   in_batch_stride, in_row_stride, in_col_stride,
221                                   out_batch_stride, out_channel_stride, out_row_stride);
222         },
223         src_it, dst_it);
224     }
225     else
226     {
227         // All other cases fall back to C++
228         // Permute strides
229         Strides strides      = dst->info()->strides_in_bytes();
230         Strides perm_strides = strides;
231         permute_strides(perm_strides, perm);
232         const int perm_stride_3 = src->info()->num_dimensions() >= 4 ? perm_strides[3] : 0;
233         execute_window_loop(window, [&](const Coordinates & id)
234         {
235             const int idx                                = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + id[3] * perm_stride_3;
236             *(reinterpret_cast<T *>(dst_it.ptr() + idx)) = *(reinterpret_cast<const T *>(src_it.ptr()));
237         },
238         src_it, dst_it);
239     }
240 }
241 } // namespace
242 
configure(const ITensorInfo * src,ITensorInfo * dst,const PermutationVector & perm)243 void CpuPermuteKernel::configure(const ITensorInfo *src, ITensorInfo *dst, const PermutationVector &perm)
244 {
245     ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
246     const TensorShape dst_shape = misc::shape_calculator::compute_permutation_output_shape(*src, perm);
247     // Destination auto inizialitation if not yet initialized
248     auto_init_if_empty(*dst, src->clone()->set_tensor_shape(dst_shape));
249 
250     // Perform validation step
251     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, dst, perm));
252 
253     _perm = perm;
254 
255     // Configure kernel window
256     Window win = calculate_max_window(*src, Steps());
257 
258     // This kernel doesn't need padding so update_window_and_padding() can be skipped
259 
260     ICpuKernel::configure(win);
261 }
262 
validate(const ITensorInfo * src,const ITensorInfo * dst,const PermutationVector & perm)263 Status CpuPermuteKernel::validate(const ITensorInfo *src, const ITensorInfo *dst, const PermutationVector &perm)
264 {
265     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, dst, perm));
266     return Status{};
267 }
268 
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)269 void CpuPermuteKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
270 {
271     ARM_COMPUTE_UNUSED(info);
272     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
273     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
274 
275     const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
276     auto       dst = tensors.get_tensor(TensorType::ACL_DST);
277 
278     switch(src->info()->element_size())
279     {
280         case 1:
281             run_permute<uint8_t>(window, src, dst, _perm);
282             break;
283         case 2:
284             run_permute<uint16_t>(window, src, dst, _perm);
285             break;
286         case 4:
287             run_permute<uint32_t>(window, src, dst, _perm);
288             break;
289         default:
290             ARM_COMPUTE_ERROR("Element size not supported");
291             break;
292     }
293 }
294 
name() const295 const char *CpuPermuteKernel::name() const
296 {
297     return "CpuPermuteKernel";
298 }
299 } // namespace kernels
300 } // namespace cpu
301 } // namespace arm_compute
302