xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/NEON/kernels/NEFFTDigitReverseKernel.h"
25 
26 #include "arm_compute/core/ITensor.h"
27 #include "arm_compute/core/TensorInfo.h"
28 #include "arm_compute/core/Types.h"
29 #include "arm_compute/core/Validate.h"
30 #include "arm_compute/core/Window.h"
31 #include "src/core/helpers/AutoConfiguration.h"
32 #include "src/core/helpers/WindowHelpers.h"
33 
34 #include <set>
35 
36 namespace arm_compute
37 {
38 namespace
39 {
validate_arguments(const ITensorInfo * input,const ITensorInfo * output,const ITensorInfo * idx,const FFTDigitReverseKernelInfo & config)40 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *idx, const FFTDigitReverseKernelInfo &config)
41 {
42     ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32);
43     ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() > 2);
44     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(idx, 1, DataType::U32);
45     ARM_COMPUTE_RETURN_ERROR_ON(std::set<unsigned int>({ 0, 1 }).count(config.axis) == 0);
46     ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[config.axis] != idx->tensor_shape().x());
47 
48     // Checks performed when output is configured
49     if((output != nullptr) && (output->total_size() != 0))
50     {
51         ARM_COMPUTE_RETURN_ERROR_ON(output->num_channels() != 2);
52         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
53         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
54     }
55 
56     return Status{};
57 }
58 
validate_and_configure_window(ITensorInfo * input,ITensorInfo * output,ITensorInfo * idx,const FFTDigitReverseKernelInfo & config)59 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *idx, const FFTDigitReverseKernelInfo &config)
60 {
61     ARM_COMPUTE_UNUSED(idx, config);
62 
63     auto_init_if_empty(*output, input->clone()->set_num_channels(2));
64 
65     Window win = calculate_max_window(*input, Steps());
66 
67     return std::make_pair(Status{}, win);
68 }
69 } // namespace
70 
NEFFTDigitReverseKernel()71 NEFFTDigitReverseKernel::NEFFTDigitReverseKernel()
72     : _func(nullptr), _input(nullptr), _output(nullptr), _idx(nullptr)
73 {
74 }
75 
configure(const ITensor * input,ITensor * output,const ITensor * idx,const FFTDigitReverseKernelInfo & config)76 void NEFFTDigitReverseKernel::configure(const ITensor *input, ITensor *output, const ITensor *idx, const FFTDigitReverseKernelInfo &config)
77 {
78     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, idx);
79     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), idx->info(), config));
80 
81     _input  = input;
82     _output = output;
83     _idx    = idx;
84 
85     const size_t axis             = config.axis;
86     const bool   is_conj          = config.conjugate;
87     const bool   is_input_complex = (input->info()->num_channels() == 2);
88 
89     // Configure kernel window
90     auto win_config = validate_and_configure_window(input->info(), output->info(), idx->info(), config);
91     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
92     INEKernel::configure(win_config.second);
93 
94     if(axis == 0)
95     {
96         if(is_input_complex)
97         {
98             if(is_conj)
99             {
100                 _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_0<true, true>;
101             }
102             else
103             {
104                 _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_0<true, false>;
105             }
106         }
107         else
108         {
109             _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_0<false, false>;
110         }
111     }
112     else if(axis == 1)
113     {
114         if(is_input_complex)
115         {
116             if(is_conj)
117             {
118                 _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_1<true, true>;
119             }
120             else
121             {
122                 _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_1<true, false>;
123             }
124         }
125         else
126         {
127             _func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_1<false, false>;
128         }
129     }
130     else
131     {
132         ARM_COMPUTE_ERROR("Not supported");
133     }
134 }
135 
validate(const ITensorInfo * input,const ITensorInfo * output,const ITensorInfo * idx,const FFTDigitReverseKernelInfo & config)136 Status NEFFTDigitReverseKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *idx, const FFTDigitReverseKernelInfo &config)
137 {
138     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, idx, config));
139     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), idx->clone().get(), config).first);
140     return Status{};
141 }
142 
143 template <bool is_input_complex, bool is_conj>
digit_reverse_kernel_axis_0(const Window & window)144 void NEFFTDigitReverseKernel::digit_reverse_kernel_axis_0(const Window &window)
145 {
146     const size_t N = _input->info()->dimension(0);
147 
148     // Copy the look-up buffer to a local array
149     std::vector<unsigned int> buffer_idx(N);
150     std::copy_n(reinterpret_cast<unsigned int *>(_idx->buffer()), N, buffer_idx.data());
151 
152     // Input/output iterators
153     Window slice = window;
154     slice.set(0, Window::DimX);
155     Iterator in(_input, slice);
156     Iterator out(_output, slice);
157 
158     // Row buffers
159     std::vector<float> buffer_row_out(2 * N);
160     std::vector<float> buffer_row_in(2 * N);
161 
162     execute_window_loop(slice, [&](const Coordinates &)
163     {
164         if(is_input_complex)
165         {
166             // Load
167             memcpy(buffer_row_in.data(), reinterpret_cast<float *>(in.ptr()), 2 * N * sizeof(float));
168 
169             // Shuffle
170             for(size_t x = 0; x < 2 * N; x += 2)
171             {
172                 size_t idx            = buffer_idx[x / 2];
173                 buffer_row_out[x]     = buffer_row_in[2 * idx];
174                 buffer_row_out[x + 1] = (is_conj ? -buffer_row_in[2 * idx + 1] : buffer_row_in[2 * idx + 1]);
175             }
176         }
177         else
178         {
179             // Load
180             memcpy(buffer_row_in.data(), reinterpret_cast<float *>(in.ptr()), N * sizeof(float));
181 
182             // Shuffle
183             for(size_t x = 0; x < N; ++x)
184             {
185                 size_t idx            = buffer_idx[x];
186                 buffer_row_out[2 * x] = buffer_row_in[idx];
187             }
188         }
189 
190         // Copy back
191         memcpy(reinterpret_cast<float *>(out.ptr()), buffer_row_out.data(), 2 * N * sizeof(float));
192     },
193     in, out);
194 }
195 
196 template <bool is_input_complex, bool is_conj>
digit_reverse_kernel_axis_1(const Window & window)197 void NEFFTDigitReverseKernel::digit_reverse_kernel_axis_1(const Window &window)
198 {
199     const size_t Nx = _input->info()->dimension(0);
200     const size_t Ny = _input->info()->dimension(1);
201 
202     // Copy the look-up buffer to a local array
203     std::vector<unsigned int> buffer_idx(Ny);
204     std::copy_n(reinterpret_cast<unsigned int *>(_idx->buffer()), Ny, buffer_idx.data());
205 
206     // Output iterator
207     Window slice = window;
208     slice.set(0, Window::DimX);
209     Iterator out(_output, slice);
210 
211     // Row buffer
212     std::vector<float> buffer_row(Nx);
213 
214     // Strides
215     const size_t stride_z = _input->info()->strides_in_bytes()[2];
216     const size_t stride_w = _input->info()->strides_in_bytes()[3];
217 
218     execute_window_loop(slice, [&](const Coordinates & id)
219     {
220         auto        *out_ptr    = reinterpret_cast<float *>(out.ptr());
221         auto        *in_ptr     = reinterpret_cast<float *>(_input->buffer() + id.z() * stride_z + id[3] * stride_w);
222         const size_t y_shuffled = buffer_idx[id.y()];
223 
224         if(is_input_complex)
225         {
226             // Shuffle the entire row into the output
227             memcpy(out_ptr, in_ptr + 2 * Nx * y_shuffled, 2 * Nx * sizeof(float));
228 
229             // Conjugate if necessary
230             if(is_conj)
231             {
232                 for(size_t x = 0; x < 2 * Nx; x += 2)
233                 {
234                     out_ptr[x + 1] = -out_ptr[x + 1];
235                 }
236             }
237         }
238         else
239         {
240             // Shuffle the entire row into the buffer
241             memcpy(buffer_row.data(), in_ptr + Nx * y_shuffled, Nx * sizeof(float));
242 
243             // Copy the buffer to the output, with a zero imaginary part
244             for(size_t x = 0; x < 2 * Nx; x += 2)
245             {
246                 out_ptr[x] = buffer_row[x / 2];
247             }
248         }
249     },
250     out);
251 }
252 
run(const Window & window,const ThreadInfo & info)253 void NEFFTDigitReverseKernel::run(const Window &window, const ThreadInfo &info)
254 {
255     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
256     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
257     ARM_COMPUTE_UNUSED(info);
258     (this->*_func)(window);
259 }
260 
261 } // namespace arm_compute
262