1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2021 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #include "src/cpu/kernels/CpuTransposeKernel.h"
25*c217d954SCole Faust
26*c217d954SCole Faust #include "arm_compute/core/Error.h"
27*c217d954SCole Faust #include "arm_compute/core/Helpers.h"
28*c217d954SCole Faust #include "arm_compute/core/ITensor.h"
29*c217d954SCole Faust #include "arm_compute/core/TensorInfo.h"
30*c217d954SCole Faust #include "arm_compute/core/Types.h"
31*c217d954SCole Faust #include "arm_compute/core/Validate.h"
32*c217d954SCole Faust #include "arm_compute/core/utils/misc/ShapeCalculator.h"
33*c217d954SCole Faust #include "src/core/helpers/AutoConfiguration.h"
34*c217d954SCole Faust #include "src/core/helpers/WindowHelpers.h"
35*c217d954SCole Faust
36*c217d954SCole Faust #include <arm_neon.h>
37*c217d954SCole Faust
38*c217d954SCole Faust namespace arm_compute
39*c217d954SCole Faust {
40*c217d954SCole Faust namespace cpu
41*c217d954SCole Faust {
42*c217d954SCole Faust namespace kernels
43*c217d954SCole Faust {
44*c217d954SCole Faust namespace
45*c217d954SCole Faust {
num_elems_processed(size_t element_size)46*c217d954SCole Faust unsigned int num_elems_processed(size_t element_size)
47*c217d954SCole Faust {
48*c217d954SCole Faust switch(element_size)
49*c217d954SCole Faust {
50*c217d954SCole Faust case 1:
51*c217d954SCole Faust return 8;
52*c217d954SCole Faust case 2:
53*c217d954SCole Faust case 4:
54*c217d954SCole Faust return 4;
55*c217d954SCole Faust default:
56*c217d954SCole Faust break;
57*c217d954SCole Faust }
58*c217d954SCole Faust
59*c217d954SCole Faust ARM_COMPUTE_ERROR("Element size not supported");
60*c217d954SCole Faust }
61*c217d954SCole Faust
transpose_8bit_elements(const ITensor * in,ITensor * out,const Window & window)62*c217d954SCole Faust void transpose_8bit_elements(const ITensor *in, ITensor *out, const Window &window)
63*c217d954SCole Faust {
64*c217d954SCole Faust const int window_step_x = 8;
65*c217d954SCole Faust const int window_step_y = 8;
66*c217d954SCole Faust const int window_start_x = window.x().start();
67*c217d954SCole Faust const int window_end_x = window.x().end();
68*c217d954SCole Faust const int window_start_y = window.y().start();
69*c217d954SCole Faust const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1)));
70*c217d954SCole Faust const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
71*c217d954SCole Faust const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
72*c217d954SCole Faust const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
73*c217d954SCole Faust
74*c217d954SCole Faust // Check if we need a left-over loop for the y dimension
75*c217d954SCole Faust bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
76*c217d954SCole Faust
77*c217d954SCole Faust Window window_in(window);
78*c217d954SCole Faust window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
79*c217d954SCole Faust if(left_over_loop_y)
80*c217d954SCole Faust {
81*c217d954SCole Faust // Check if window_end_y_multiple_of is greater than window_start_y
82*c217d954SCole Faust if(window_end_y_multiple_of > window_start_y)
83*c217d954SCole Faust {
84*c217d954SCole Faust window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
85*c217d954SCole Faust }
86*c217d954SCole Faust else
87*c217d954SCole Faust {
88*c217d954SCole Faust window_in.set(Window::DimY, Window::Dimension(0, 0, 1));
89*c217d954SCole Faust }
90*c217d954SCole Faust }
91*c217d954SCole Faust
92*c217d954SCole Faust Window window_out(window);
93*c217d954SCole Faust window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
94*c217d954SCole Faust window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
95*c217d954SCole Faust
96*c217d954SCole Faust Iterator output(out, window_out);
97*c217d954SCole Faust
98*c217d954SCole Faust // Run the SIMD path if and only if the input is not a row-vector
99*c217d954SCole Faust if(in->info()->dimension(1) != 1)
100*c217d954SCole Faust {
101*c217d954SCole Faust Iterator input(in, window_in);
102*c217d954SCole Faust execute_window_loop(window_in, [&](const Coordinates & id)
103*c217d954SCole Faust {
104*c217d954SCole Faust // Compute 8x8 elements per iteration
105*c217d954SCole Faust int x = window_start_x;
106*c217d954SCole Faust for(; x <= (window_end_x - window_step_x); x += window_step_x)
107*c217d954SCole Faust {
108*c217d954SCole Faust const uint8x8_t row0 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 0 * input_stride_in_bytes));
109*c217d954SCole Faust const uint8x8_t row1 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 1 * input_stride_in_bytes));
110*c217d954SCole Faust const uint8x8_t row2 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 2 * input_stride_in_bytes));
111*c217d954SCole Faust const uint8x8_t row3 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 3 * input_stride_in_bytes));
112*c217d954SCole Faust const uint8x8_t row4 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 4 * input_stride_in_bytes));
113*c217d954SCole Faust const uint8x8_t row5 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 5 * input_stride_in_bytes));
114*c217d954SCole Faust const uint8x8_t row6 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 6 * input_stride_in_bytes));
115*c217d954SCole Faust const uint8x8_t row7 = vld1_u8(reinterpret_cast<const uint8_t *>(input.ptr() + x + 7 * input_stride_in_bytes));
116*c217d954SCole Faust
117*c217d954SCole Faust // Transpose 2x2
118*c217d954SCole Faust const uint8x8x2_t k0_u8 = vtrn_u8(row0, row1);
119*c217d954SCole Faust const uint8x8x2_t k1_u8 = vtrn_u8(row2, row3);
120*c217d954SCole Faust const uint8x8x2_t k2_u8 = vtrn_u8(row4, row5);
121*c217d954SCole Faust const uint8x8x2_t k3_u8 = vtrn_u8(row6, row7);
122*c217d954SCole Faust
123*c217d954SCole Faust // Transpose 4x4
124*c217d954SCole Faust const uint16x4x2_t k0_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[0]), vreinterpret_u16_u8(k1_u8.val[0]));
125*c217d954SCole Faust const uint16x4x2_t k1_u16 = vtrn_u16(vreinterpret_u16_u8(k0_u8.val[1]), vreinterpret_u16_u8(k1_u8.val[1]));
126*c217d954SCole Faust const uint16x4x2_t k2_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[0]), vreinterpret_u16_u8(k3_u8.val[0]));
127*c217d954SCole Faust const uint16x4x2_t k3_u16 = vtrn_u16(vreinterpret_u16_u8(k2_u8.val[1]), vreinterpret_u16_u8(k3_u8.val[1]));
128*c217d954SCole Faust
129*c217d954SCole Faust // Transpose 8x8
130*c217d954SCole Faust const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k2_u16.val[0]));
131*c217d954SCole Faust const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k2_u16.val[1]));
132*c217d954SCole Faust const uint32x2x2_t k2_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[0]), vreinterpret_u32_u16(k3_u16.val[0]));
133*c217d954SCole Faust const uint32x2x2_t k3_u32 = vtrn_u32(vreinterpret_u32_u16(k1_u16.val[1]), vreinterpret_u32_u16(k3_u16.val[1]));
134*c217d954SCole Faust
135*c217d954SCole Faust // Compute destination address
136*c217d954SCole Faust const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes;
137*c217d954SCole Faust
138*c217d954SCole Faust vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[0])));
139*c217d954SCole Faust vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[0])));
140*c217d954SCole Faust vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[0])));
141*c217d954SCole Faust vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[0])));
142*c217d954SCole Faust vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 4 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k0_u32.val[1])));
143*c217d954SCole Faust vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 5 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k2_u32.val[1])));
144*c217d954SCole Faust vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 6 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k1_u32.val[1])));
145*c217d954SCole Faust vst1_u8(reinterpret_cast<uint8_t *>(output.ptr() + dst_offset_in_bytes + 7 * output_stride_in_bytes), vreinterpret_u8_u16(vreinterpret_u16_u32(k3_u32.val[1])));
146*c217d954SCole Faust }
147*c217d954SCole Faust
148*c217d954SCole Faust // Compute left-over elements along the x dimension (1x8)
149*c217d954SCole Faust for(; x < window_end_x; ++x)
150*c217d954SCole Faust {
151*c217d954SCole Faust const uint8_t val0 = *(input.ptr() + x + 0 * input_stride_in_bytes);
152*c217d954SCole Faust const uint8_t val1 = *(input.ptr() + x + 1 * input_stride_in_bytes);
153*c217d954SCole Faust const uint8_t val2 = *(input.ptr() + x + 2 * input_stride_in_bytes);
154*c217d954SCole Faust const uint8_t val3 = *(input.ptr() + x + 3 * input_stride_in_bytes);
155*c217d954SCole Faust const uint8_t val4 = *(input.ptr() + x + 4 * input_stride_in_bytes);
156*c217d954SCole Faust const uint8_t val5 = *(input.ptr() + x + 5 * input_stride_in_bytes);
157*c217d954SCole Faust const uint8_t val6 = *(input.ptr() + x + 6 * input_stride_in_bytes);
158*c217d954SCole Faust const uint8_t val7 = *(input.ptr() + x + 7 * input_stride_in_bytes);
159*c217d954SCole Faust
160*c217d954SCole Faust uint8x8_t result = vdup_n_u8(0);
161*c217d954SCole Faust result = vset_lane_u8(val0, result, 0);
162*c217d954SCole Faust result = vset_lane_u8(val1, result, 1);
163*c217d954SCole Faust result = vset_lane_u8(val2, result, 2);
164*c217d954SCole Faust result = vset_lane_u8(val3, result, 3);
165*c217d954SCole Faust result = vset_lane_u8(val4, result, 4);
166*c217d954SCole Faust result = vset_lane_u8(val5, result, 5);
167*c217d954SCole Faust result = vset_lane_u8(val6, result, 6);
168*c217d954SCole Faust result = vset_lane_u8(val7, result, 7);
169*c217d954SCole Faust
170*c217d954SCole Faust // Compute destination address
171*c217d954SCole Faust const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + x * output_stride_in_bytes;
172*c217d954SCole Faust
173*c217d954SCole Faust vst1_u8(output.ptr() + dst_offset_in_bytes, result);
174*c217d954SCole Faust }
175*c217d954SCole Faust },
176*c217d954SCole Faust input, output);
177*c217d954SCole Faust }
178*c217d954SCole Faust
179*c217d954SCole Faust if(left_over_loop_y)
180*c217d954SCole Faust {
181*c217d954SCole Faust window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
182*c217d954SCole Faust window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
183*c217d954SCole Faust
184*c217d954SCole Faust Iterator input(in, window_in);
185*c217d954SCole Faust Iterator output(out, window_out);
186*c217d954SCole Faust
187*c217d954SCole Faust // Compute left-over elements along the y dimension (1x1)
188*c217d954SCole Faust execute_window_loop(window_in, [&](const Coordinates & id)
189*c217d954SCole Faust {
190*c217d954SCole Faust const uint8_t val0 = *input.ptr();
191*c217d954SCole Faust
192*c217d954SCole Faust // Compute destination address
193*c217d954SCole Faust const size_t dst_offset_in_bytes = id.y() * sizeof(uint8_t) + id.x() * output_stride_in_bytes;
194*c217d954SCole Faust
195*c217d954SCole Faust *(output.ptr() + dst_offset_in_bytes) = val0;
196*c217d954SCole Faust },
197*c217d954SCole Faust input, output);
198*c217d954SCole Faust }
199*c217d954SCole Faust }
200*c217d954SCole Faust
transpose_16bit_elements(const ITensor * in,ITensor * out,const Window & window)201*c217d954SCole Faust void transpose_16bit_elements(const ITensor *in, ITensor *out, const Window &window)
202*c217d954SCole Faust {
203*c217d954SCole Faust const int window_step_x = 4;
204*c217d954SCole Faust const int window_step_y = 4;
205*c217d954SCole Faust const int window_start_x = window.x().start();
206*c217d954SCole Faust const int window_end_x = window.x().end();
207*c217d954SCole Faust const int window_start_y = window.y().start();
208*c217d954SCole Faust const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1)));
209*c217d954SCole Faust const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
210*c217d954SCole Faust const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
211*c217d954SCole Faust const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
212*c217d954SCole Faust
213*c217d954SCole Faust // Check if we need a left-over loop for the y dimension
214*c217d954SCole Faust bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
215*c217d954SCole Faust
216*c217d954SCole Faust Window window_in(window);
217*c217d954SCole Faust window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
218*c217d954SCole Faust if(left_over_loop_y)
219*c217d954SCole Faust {
220*c217d954SCole Faust // Check if window_end_y_multiple_of is greater than window_start_y
221*c217d954SCole Faust if(window_end_y_multiple_of > window_start_y)
222*c217d954SCole Faust {
223*c217d954SCole Faust window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
224*c217d954SCole Faust }
225*c217d954SCole Faust else
226*c217d954SCole Faust {
227*c217d954SCole Faust window_in.set(Window::DimY, Window::Dimension(0, 0, 1));
228*c217d954SCole Faust }
229*c217d954SCole Faust }
230*c217d954SCole Faust
231*c217d954SCole Faust Window window_out(window);
232*c217d954SCole Faust window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
233*c217d954SCole Faust window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
234*c217d954SCole Faust
235*c217d954SCole Faust Iterator output(out, window_out);
236*c217d954SCole Faust
237*c217d954SCole Faust // Run the SIMD path if and only if the input is not a row-vector
238*c217d954SCole Faust if(in->info()->dimension(1) != 1)
239*c217d954SCole Faust {
240*c217d954SCole Faust Iterator input(in, window_in);
241*c217d954SCole Faust execute_window_loop(window_in, [&](const Coordinates & id)
242*c217d954SCole Faust {
243*c217d954SCole Faust // Compute 4x4 elements per iteration
244*c217d954SCole Faust int x = window_start_x;
245*c217d954SCole Faust for(; x <= (window_end_x - window_step_x); x += window_step_x)
246*c217d954SCole Faust {
247*c217d954SCole Faust const uint16x4_t row0 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
248*c217d954SCole Faust const uint16x4_t row1 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
249*c217d954SCole Faust const uint16x4_t row2 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
250*c217d954SCole Faust const uint16x4_t row3 = vld1_u16(reinterpret_cast<const uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
251*c217d954SCole Faust
252*c217d954SCole Faust // Transpose 2x2
253*c217d954SCole Faust const uint16x4x2_t k0_u16 = vtrn_u16(row0, row1);
254*c217d954SCole Faust const uint16x4x2_t k1_u16 = vtrn_u16(row2, row3);
255*c217d954SCole Faust
256*c217d954SCole Faust // Transpose 4x4
257*c217d954SCole Faust const uint32x2x2_t k0_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[0]), vreinterpret_u32_u16(k1_u16.val[0]));
258*c217d954SCole Faust const uint32x2x2_t k1_u32 = vtrn_u32(vreinterpret_u32_u16(k0_u16.val[1]), vreinterpret_u32_u16(k1_u16.val[1]));
259*c217d954SCole Faust
260*c217d954SCole Faust // Compute destination address
261*c217d954SCole Faust const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes;
262*c217d954SCole Faust
263*c217d954SCole Faust vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[0]));
264*c217d954SCole Faust vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[0]));
265*c217d954SCole Faust vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vreinterpret_u16_u32(k0_u32.val[1]));
266*c217d954SCole Faust vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vreinterpret_u16_u32(k1_u32.val[1]));
267*c217d954SCole Faust }
268*c217d954SCole Faust
269*c217d954SCole Faust // Compute left-over elements (1x4)
270*c217d954SCole Faust for(; x < window_end_x; ++x)
271*c217d954SCole Faust {
272*c217d954SCole Faust const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
273*c217d954SCole Faust const uint16_t val1 = *(reinterpret_cast<uint16_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
274*c217d954SCole Faust const uint16_t val2 = *(reinterpret_cast<uint16_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
275*c217d954SCole Faust const uint16_t val3 = *(reinterpret_cast<uint16_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
276*c217d954SCole Faust
277*c217d954SCole Faust uint16x4_t result = vdup_n_u16(0);
278*c217d954SCole Faust result = vset_lane_u16(val0, result, 0);
279*c217d954SCole Faust result = vset_lane_u16(val1, result, 1);
280*c217d954SCole Faust result = vset_lane_u16(val2, result, 2);
281*c217d954SCole Faust result = vset_lane_u16(val3, result, 3);
282*c217d954SCole Faust
283*c217d954SCole Faust // Compute destination address
284*c217d954SCole Faust const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + x * output_stride_in_bytes;
285*c217d954SCole Faust
286*c217d954SCole Faust vst1_u16(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes), result);
287*c217d954SCole Faust }
288*c217d954SCole Faust },
289*c217d954SCole Faust input, output);
290*c217d954SCole Faust }
291*c217d954SCole Faust
292*c217d954SCole Faust if(left_over_loop_y)
293*c217d954SCole Faust {
294*c217d954SCole Faust window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
295*c217d954SCole Faust window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
296*c217d954SCole Faust
297*c217d954SCole Faust Iterator input(in, window_in);
298*c217d954SCole Faust Iterator output(out, window_out);
299*c217d954SCole Faust
300*c217d954SCole Faust // Compute left-over elements along the y dimension (1x1)
301*c217d954SCole Faust execute_window_loop(window_in, [&](const Coordinates & id)
302*c217d954SCole Faust {
303*c217d954SCole Faust const uint16_t val0 = *(reinterpret_cast<uint16_t *>(input.ptr()));
304*c217d954SCole Faust
305*c217d954SCole Faust // Compute destination address
306*c217d954SCole Faust const size_t dst_offset_in_bytes = id.y() * sizeof(uint16_t) + id.x() * output_stride_in_bytes;
307*c217d954SCole Faust
308*c217d954SCole Faust *(reinterpret_cast<uint16_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
309*c217d954SCole Faust },
310*c217d954SCole Faust input, output);
311*c217d954SCole Faust }
312*c217d954SCole Faust }
313*c217d954SCole Faust
transpose_32bit_elements(const ITensor * in,ITensor * out,const Window & window)314*c217d954SCole Faust void transpose_32bit_elements(const ITensor *in, ITensor *out, const Window &window)
315*c217d954SCole Faust {
316*c217d954SCole Faust const int window_step_x = 4;
317*c217d954SCole Faust const int window_step_y = 4;
318*c217d954SCole Faust const int window_start_x = window.x().start();
319*c217d954SCole Faust const int window_end_x = window.x().end();
320*c217d954SCole Faust const int window_start_y = window.y().start();
321*c217d954SCole Faust const int window_end_y = std::min(window.y().end(), static_cast<int>(in->info()->dimension(1)));
322*c217d954SCole Faust const int window_end_y_multiple_of = ((window_end_y - window_start_y) / window_step_y) * window_step_y;
323*c217d954SCole Faust const size_t input_stride_in_bytes = in->info()->strides_in_bytes()[1];
324*c217d954SCole Faust const size_t output_stride_in_bytes = out->info()->strides_in_bytes()[1];
325*c217d954SCole Faust
326*c217d954SCole Faust // Check if we need a left-over loop for the y dimension
327*c217d954SCole Faust bool left_over_loop_y = (((window_end_y - window_start_y) % window_step_y) != 0);
328*c217d954SCole Faust
329*c217d954SCole Faust Window window_in(window);
330*c217d954SCole Faust window_in.set(Window::DimX, Window::Dimension(0, 1, 1));
331*c217d954SCole Faust if(left_over_loop_y)
332*c217d954SCole Faust {
333*c217d954SCole Faust // Check if window_end_y_multiple_of is greater than window_start_y
334*c217d954SCole Faust if(window_end_y_multiple_of > window_start_y)
335*c217d954SCole Faust {
336*c217d954SCole Faust window_in.set(Window::DimY, Window::Dimension(window_start_y, window_end_y_multiple_of, window_step_y));
337*c217d954SCole Faust }
338*c217d954SCole Faust else
339*c217d954SCole Faust {
340*c217d954SCole Faust window_in.set(Window::DimY, Window::Dimension(0, 0, 1));
341*c217d954SCole Faust }
342*c217d954SCole Faust }
343*c217d954SCole Faust
344*c217d954SCole Faust Window window_out(window);
345*c217d954SCole Faust window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
346*c217d954SCole Faust window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
347*c217d954SCole Faust
348*c217d954SCole Faust Iterator output(out, window_out);
349*c217d954SCole Faust
350*c217d954SCole Faust // Run the SIMD path if and only if the input is not a row-vector
351*c217d954SCole Faust if(in->info()->dimension(1) != 1)
352*c217d954SCole Faust {
353*c217d954SCole Faust Iterator input(in, window_in);
354*c217d954SCole Faust execute_window_loop(window_in, [&](const Coordinates & id)
355*c217d954SCole Faust {
356*c217d954SCole Faust // Compute 4x4 elements per iteration
357*c217d954SCole Faust int x = window_start_x;
358*c217d954SCole Faust for(; x <= (window_end_x - window_step_x); x += window_step_x)
359*c217d954SCole Faust {
360*c217d954SCole Faust const uint32x4_t row0 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
361*c217d954SCole Faust const uint32x4_t row1 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
362*c217d954SCole Faust const uint32x4_t row2 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
363*c217d954SCole Faust const uint32x4_t row3 = vld1q_u32(reinterpret_cast<const uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
364*c217d954SCole Faust
365*c217d954SCole Faust // Transpose 2x2
366*c217d954SCole Faust const uint32x2x2_t k0_u32 = vtrn_u32(vget_low_u32(row0), vget_low_u32(row1));
367*c217d954SCole Faust const uint32x2x2_t k1_u32 = vtrn_u32(vget_high_u32(row2), vget_high_u32(row3));
368*c217d954SCole Faust const uint32x2x2_t k2_u32 = vtrn_u32(vget_high_u32(row0), vget_high_u32(row1));
369*c217d954SCole Faust const uint32x2x2_t k3_u32 = vtrn_u32(vget_low_u32(row2), vget_low_u32(row3));
370*c217d954SCole Faust
371*c217d954SCole Faust // Compute destination address
372*c217d954SCole Faust const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
373*c217d954SCole Faust
374*c217d954SCole Faust // Swap block 01 with block 10 and store
375*c217d954SCole Faust vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 0 * output_stride_in_bytes), vcombine_u32(k0_u32.val[0], k3_u32.val[0]));
376*c217d954SCole Faust vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 1 * output_stride_in_bytes), vcombine_u32(k0_u32.val[1], k3_u32.val[1]));
377*c217d954SCole Faust vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 2 * output_stride_in_bytes), vcombine_u32(k2_u32.val[0], k1_u32.val[0]));
378*c217d954SCole Faust vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes + 3 * output_stride_in_bytes), vcombine_u32(k2_u32.val[1], k1_u32.val[1]));
379*c217d954SCole Faust }
380*c217d954SCole Faust
381*c217d954SCole Faust // Compute left-over elements (1x4)
382*c217d954SCole Faust for(; x < window_end_x; ++x)
383*c217d954SCole Faust {
384*c217d954SCole Faust const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr() + 0 * input_stride_in_bytes) + x);
385*c217d954SCole Faust const uint32_t val1 = *(reinterpret_cast<uint32_t *>(input.ptr() + 1 * input_stride_in_bytes) + x);
386*c217d954SCole Faust const uint32_t val2 = *(reinterpret_cast<uint32_t *>(input.ptr() + 2 * input_stride_in_bytes) + x);
387*c217d954SCole Faust const uint32_t val3 = *(reinterpret_cast<uint32_t *>(input.ptr() + 3 * input_stride_in_bytes) + x);
388*c217d954SCole Faust
389*c217d954SCole Faust uint32x4_t result = vdupq_n_u32(0);
390*c217d954SCole Faust result = vsetq_lane_u32(val0, result, 0);
391*c217d954SCole Faust result = vsetq_lane_u32(val1, result, 1);
392*c217d954SCole Faust result = vsetq_lane_u32(val2, result, 2);
393*c217d954SCole Faust result = vsetq_lane_u32(val3, result, 3);
394*c217d954SCole Faust
395*c217d954SCole Faust // Compute destination address
396*c217d954SCole Faust const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + x * output_stride_in_bytes;
397*c217d954SCole Faust
398*c217d954SCole Faust vst1q_u32(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes), result);
399*c217d954SCole Faust }
400*c217d954SCole Faust },
401*c217d954SCole Faust input, output);
402*c217d954SCole Faust }
403*c217d954SCole Faust
404*c217d954SCole Faust if(left_over_loop_y)
405*c217d954SCole Faust {
406*c217d954SCole Faust window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), 1));
407*c217d954SCole Faust window_in.set(Window::DimY, Window::Dimension(window_end_y_multiple_of, window_end_y, 1));
408*c217d954SCole Faust
409*c217d954SCole Faust Iterator input(in, window_in);
410*c217d954SCole Faust Iterator output(out, window_out);
411*c217d954SCole Faust
412*c217d954SCole Faust // Compute left-over elements along the y dimension (1x1)
413*c217d954SCole Faust execute_window_loop(window_in, [&](const Coordinates & id)
414*c217d954SCole Faust {
415*c217d954SCole Faust const uint32_t val0 = *(reinterpret_cast<uint32_t *>(input.ptr()));
416*c217d954SCole Faust
417*c217d954SCole Faust // Compute destination address
418*c217d954SCole Faust const size_t dst_offset_in_bytes = id.y() * sizeof(uint32_t) + id.x() * output_stride_in_bytes;
419*c217d954SCole Faust
420*c217d954SCole Faust *(reinterpret_cast<uint32_t *>(output.ptr() + dst_offset_in_bytes)) = val0;
421*c217d954SCole Faust },
422*c217d954SCole Faust input, output);
423*c217d954SCole Faust }
424*c217d954SCole Faust }
425*c217d954SCole Faust } // namespace
426*c217d954SCole Faust
configure(const ITensorInfo * src,ITensorInfo * dst)427*c217d954SCole Faust void CpuTransposeKernel::configure(const ITensorInfo *src, ITensorInfo *dst)
428*c217d954SCole Faust {
429*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
430*c217d954SCole Faust
431*c217d954SCole Faust // Destination auto inizialitation if not yet initialized
432*c217d954SCole Faust const TensorShape dst_shape = misc::shape_calculator::compute_transposed_shape(*src);
433*c217d954SCole Faust auto_init_if_empty(*dst, src->clone()->set_tensor_shape(dst_shape));
434*c217d954SCole Faust
435*c217d954SCole Faust // Perform validation step
436*c217d954SCole Faust ARM_COMPUTE_ERROR_THROW_ON(validate(src, dst));
437*c217d954SCole Faust
438*c217d954SCole Faust // Note: This kernel performs 16 elements per iteration.
439*c217d954SCole Faust // However, since we use a left-over for loop on both dimensions (X and Y), we cannot have any read or write out of memory
440*c217d954SCole Faust // For this reason num_elems_processed_per_iteration_x is set to 1
441*c217d954SCole Faust const unsigned int num_elems_processed_per_iteration_x = 1;
442*c217d954SCole Faust const unsigned int num_elems_processed_per_iteration_y = num_elems_processed(src->element_size());
443*c217d954SCole Faust
444*c217d954SCole Faust // Configure kernel window
445*c217d954SCole Faust Window win = calculate_max_window(*src, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
446*c217d954SCole Faust
447*c217d954SCole Faust // The CpuTranspose doesn't need padding so update_window_and_padding() can be skipped
448*c217d954SCole Faust Coordinates coord;
449*c217d954SCole Faust coord.set_num_dimensions(dst->num_dimensions());
450*c217d954SCole Faust dst->set_valid_region(ValidRegion(coord, dst->tensor_shape()));
451*c217d954SCole Faust
452*c217d954SCole Faust ICpuKernel::configure(win);
453*c217d954SCole Faust }
454*c217d954SCole Faust
validate(const ITensorInfo * src,const ITensorInfo * dst)455*c217d954SCole Faust Status CpuTransposeKernel::validate(const ITensorInfo *src, const ITensorInfo *dst)
456*c217d954SCole Faust {
457*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src);
458*c217d954SCole Faust //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use CPU FP16 instructions.
459*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON(src->data_type() == DataType::UNKNOWN);
460*c217d954SCole Faust
461*c217d954SCole Faust // Error if input is not 8 bit, 16bit or 32bit
462*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->element_size() != 1 && src->element_size() != 2 && src->element_size() != 4,
463*c217d954SCole Faust "Element size not supported");
464*c217d954SCole Faust
465*c217d954SCole Faust // Validate configured destination
466*c217d954SCole Faust if(dst->total_size() != 0)
467*c217d954SCole Faust {
468*c217d954SCole Faust const TensorShape dst_shape = misc::shape_calculator::compute_transposed_shape(*src);
469*c217d954SCole Faust
470*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), dst_shape);
471*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(src, dst);
472*c217d954SCole Faust ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
473*c217d954SCole Faust }
474*c217d954SCole Faust
475*c217d954SCole Faust return Status{};
476*c217d954SCole Faust }
477*c217d954SCole Faust
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)478*c217d954SCole Faust void CpuTransposeKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
479*c217d954SCole Faust {
480*c217d954SCole Faust ARM_COMPUTE_UNUSED(info);
481*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
482*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
483*c217d954SCole Faust
484*c217d954SCole Faust const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
485*c217d954SCole Faust auto dst = tensors.get_tensor(TensorType::ACL_DST);
486*c217d954SCole Faust
487*c217d954SCole Faust switch(src->info()->element_size())
488*c217d954SCole Faust {
489*c217d954SCole Faust case 1:
490*c217d954SCole Faust transpose_8bit_elements(src, dst, window);
491*c217d954SCole Faust break;
492*c217d954SCole Faust case 2:
493*c217d954SCole Faust transpose_16bit_elements(src, dst, window);
494*c217d954SCole Faust break;
495*c217d954SCole Faust case 4:
496*c217d954SCole Faust transpose_32bit_elements(src, dst, window);
497*c217d954SCole Faust break;
498*c217d954SCole Faust default:
499*c217d954SCole Faust ARM_COMPUTE_ERROR("Element size not supported");
500*c217d954SCole Faust break;
501*c217d954SCole Faust }
502*c217d954SCole Faust }
503*c217d954SCole Faust
name() const504*c217d954SCole Faust const char *CpuTransposeKernel::name() const
505*c217d954SCole Faust {
506*c217d954SCole Faust return "CpuTransposeKernel";
507*c217d954SCole Faust }
508*c217d954SCole Faust } // namespace kernels
509*c217d954SCole Faust } // namespace cpu
510*c217d954SCole Faust } // namespace arm_compute
511