xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/NEReductionOperationKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2023 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/NEON/kernels/NEReductionOperationKernel.h"
25 
26 #include "arm_compute/core/Coordinates.h"
27 #include "arm_compute/core/Helpers.h"
28 #include "arm_compute/core/ITensor.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/Utils.h"
31 #include "arm_compute/core/Validate.h"
32 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
33 #include "src/core/CPP/Validate.h"
34 #include "src/core/NEON/INEKernel.h"
35 #include "src/core/NEON/NEMath.h"
36 #include "src/core/helpers/AutoConfiguration.h"
37 #include "src/core/helpers/WindowHelpers.h"
38 #include "support/SaturateCast.h"
39 
40 #include "src/core/NEON/wrapper/wrapper.h"
41 #include <arm_neon.h>
42 
43 namespace arm_compute
44 {
45 namespace
46 {
47 // Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized
48 template <typename T>
combine_and_store(int16x8_t t1,int16x8_t t2,Iterator & output,int offset=0)49 void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0)
50 {
51     if(std::is_same<T, uint8_t>::value)
52     {
53         auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2));
54         wrapper::vstore(output.ptr() + offset, res);
55     }
56     else
57     {
58         auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2));
59         wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res);
60     }
61 }
62 
63 template <typename T>
calculate_index(uint32_t idx,T a,T b,uint32x4x4_t c,ReductionOperation op,int axis)64 uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
65 {
66     uint32x4_t mask{ 0 };
67     if(op == ReductionOperation::ARG_IDX_MIN)
68     {
69         mask = wrapper::vcgt(b, a);
70     }
71     else
72     {
73         mask = wrapper::vclt(b, a);
74     }
75 
76     uint32x4_t vec_idx = { idx, idx + 1, idx + 2, idx + 3 };
77     if(axis != 0)
78     {
79         vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
80     }
81     uint32x4x4_t res = { { wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0 } };
82 
83     return res;
84 }
85 
86 template <typename T>
calculate_index_quantized(uint32_t idx,T a,T b,uint32x4x4_t c,ReductionOperation op,int axis)87 uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
88 {
89     uint32x4x4_t mask{ { 0 } };
90     uint8x16_t   mask_u8{ 0 };
91     if(op == ReductionOperation::ARG_IDX_MIN)
92     {
93         mask_u8 = wrapper::vcgt(b, a);
94     }
95     else
96     {
97         mask_u8 = wrapper::vclt(b, a);
98     }
99     auto wide_u16_1 = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
100     auto wide_u16_2 = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
101     mask.val[0]     = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
102     mask.val[1]     = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
103     mask.val[2]     = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
104     mask.val[3]     = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
105 
106     uint32x4x4_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
107             { idx + 4, idx + 5, idx + 6, idx + 7 },
108             { idx + 8, idx + 9, idx + 10, idx + 11 },
109             { idx + 12, idx + 13, idx + 14, idx + 15 }
110         }
111     };
112     if(axis != 0)
113     {
114         vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
115         vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
116         vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
117         vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
118     }
119     uint32x4x4_t res =
120     {
121         {
122             vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]),
123             vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]),
124             vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]),
125             vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])
126         }
127     };
128 
129     return res;
130 }
131 
132 // Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
133 template <typename T>
134 inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
135        typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
calculate_min(T in)136        calculate_min(T in)
137 {
138     auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
139     return wrapper::vpmin(pmin, pmin);
140 }
141 
142 // Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
143 template <typename T>
144 inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
145        typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
calculate_min(T in)146        calculate_min(T in)
147 {
148     auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
149     pmin      = wrapper::vpmin(pmin, pmin);
150     pmin      = wrapper::vpmin(pmin, pmin);
151     return wrapper::vpmin(pmin, pmin);
152 }
153 
154 // Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
155 template <typename T>
156 inline typename std::enable_if < std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value,
157        typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type >::type
calculate_max(T in)158        calculate_max(T in)
159 {
160     auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
161     return wrapper::vpmax(pmax, pmax);
162 }
163 
164 // Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
165 template <typename T>
166 inline typename std::enable_if < std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value,
167        typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type >::type
calculate_max(T in)168        calculate_max(T in)
169 {
170     auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
171     pmax      = wrapper::vpmax(pmax, pmax);
172     pmax      = wrapper::vpmax(pmax, pmax);
173     return wrapper::vpmax(pmax, pmax);
174 }
175 
176 template <typename T>
calculate_vector_index(uint32x4x4_t vec_res_idx,T vec_res_value,ReductionOperation op)177 uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
178 {
179     uint32x4_t res_idx_mask{ 0 };
180     uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
181 
182     if(op == ReductionOperation::ARG_IDX_MIN)
183     {
184         auto pmin    = calculate_min(vec_res_value);
185         auto mask    = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
186         res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
187     }
188     else
189     {
190         auto pmax    = calculate_max(vec_res_value);
191         auto mask    = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
192         res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
193     }
194 
195     res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones);
196     auto pmin    = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask));
197     pmin         = wrapper::vpmin(pmin, pmin);
198     uint32_t res = wrapper::vgetlane(pmin, 0);
199 
200     return (res - 0xFFFFFFFF);
201 }
202 
203 template <typename T>
calculate_vector_index_quantized(uint32x4x4_t vec_res_idx,T vec_res_value,ReductionOperation op)204 uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
205 {
206     uint32x4x4_t res_idx_mask{ { 0 } };
207     uint32x4_t   mask_ones = vdupq_n_u32(0xFFFFFFFF);
208     uint8x16_t   mask_u8{ 0 };
209     if(op == ReductionOperation::ARG_IDX_MIN)
210     {
211         auto pmin = calculate_min(vec_res_value);
212         mask_u8   = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
213     }
214     else
215     {
216         auto pmax = calculate_max(vec_res_value);
217         mask_u8   = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
218     }
219 
220     // Widen vectors
221     auto wide_u16_1     = wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8)));
222     auto wide_u16_2     = wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8)));
223     auto wide_u32_1     = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1)));
224     auto wide_u32_2     = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1)));
225     auto wide_u32_3     = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2)));
226     auto wide_u32_4     = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2)));
227     res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
228     res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
229     res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3);
230     res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4);
231     res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
232     res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
233     res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones);
234     res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones);
235 
236     uint32_t res  = 0xFFFFFFFF;
237     int      iter = 0;
238     do
239     {
240         auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
241         pmin      = wrapper::vpmin(pmin, pmin);
242         res       = std::min(wrapper::vgetlane(pmin, 0), res);
243         iter++;
244     }
245     while(iter < 4);
246 
247     return (res - 0xFFFFFFFF);
248 }
249 
250 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
251 template <>
calculate_index(uint32_t idx,float16x8_t a,float16x8_t b,uint32x4x4_t c,ReductionOperation op,int axis)252 uint32x4x4_t calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis)
253 {
254     uint32x4x2_t mask{ 0 };
255     uint16x8_t   mask_u16{ 0 };
256     if(op == ReductionOperation::ARG_IDX_MIN)
257     {
258         mask_u16 = wrapper::vcgt(b, a);
259     }
260     else
261     {
262         mask_u16 = wrapper::vclt(b, a);
263     }
264     mask.val[0]          = wrapper::vmovl(wrapper::vgetlow(mask_u16));
265     mask.val[1]          = wrapper::vmovl(wrapper::vgethigh(mask_u16));
266     uint32x4x2_t vec_idx = { { { idx + 0, idx + 1, idx + 2, idx + 3 },
267             { idx + 4, idx + 5, idx + 6, idx + 7 }
268         }
269     };
270     if(axis != 0)
271     {
272         vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
273         vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{});
274     }
275     uint32x4x4_t res = { wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]),
276                          wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]),
277                          0, 0
278                        };
279 
280     return res;
281 }
282 
283 // Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value.
calculate_min(float16x8_t in)284 inline float16x4_t calculate_min(float16x8_t in)
285 {
286     auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in));
287     pmin      = wrapper::vpmin(pmin, pmin);
288     return wrapper::vpmin(pmin, pmin);
289 }
290 // Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value.
calculate_max(float16x8_t in)291 inline float16x4_t calculate_max(float16x8_t in)
292 {
293     auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in));
294     pmax      = wrapper::vpmax(pmax, pmax);
295     return wrapper::vpmax(pmax, pmax);
296 }
297 
298 template <>
calculate_vector_index(uint32x4x4_t vec_res_idx,float16x8_t vec_res_value,ReductionOperation op)299 uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op)
300 {
301     uint32x4x2_t res_idx_mask{ 0 };
302     uint32x4_t   mask_ones = vdupq_n_u32(0xFFFFFFFF);
303     uint16x8_t   mask_u16;
304     if(op == ReductionOperation::ARG_IDX_MIN)
305     {
306         auto pmin = calculate_min(vec_res_value);
307         mask_u16  = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin));
308     }
309     else
310     {
311         auto pmax = calculate_max(vec_res_value);
312         mask_u16  = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
313     }
314 
315     // Widen vectors
316     auto wide_u32_1     = wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16)));
317     auto wide_u32_2     = wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16)));
318     res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1);
319     res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2);
320     res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones);
321     res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones);
322 
323     uint32_t res  = 0xFFFFFFFF;
324     uint32_t iter = 0;
325     do
326     {
327         auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter]));
328         pmin      = wrapper::vpmin(pmin, pmin);
329         res       = std::min(wrapper::vgetlane(pmin, 0), res);
330         iter++;
331     }
332     while(iter < 2);
333 
334     return (res - 0xFFFFFFFF);
335 }
336 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
337 
338 template <class F>
339 class Reducer
340 {
341 public:
reduceX(const Window & window,const ITensor * input,ITensor * output,F f,const ReductionOperation op)342     static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
343     {
344         // Set out window
345         Window out_window(window);
346         out_window.set(Window::DimX, Window::Dimension(0, 1, 1));
347 
348         f(window, out_window, input, output, op);
349     }
reduceY(const Window & window,const ITensor * input,ITensor * output,F f,const ReductionOperation op)350     static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
351     {
352         // Set in window
353         Window in_window(window);
354         Window out_window(window);
355 
356         in_window.set(Window::DimY, Window::Dimension(0, 1, 1));
357         out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1)));
358 
359         f(in_window, out_window, input, output, 1, op);
360     }
reduceZ(const Window & window,const ITensor * input,ITensor * output,F f,const ReductionOperation op)361     static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
362     {
363         // Set in window
364         Window in_window(window);
365         Window out_window(window);
366 
367         in_window.set(Window::DimZ, Window::Dimension(0, 1, 1));
368         out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2)));
369 
370         f(in_window, out_window, input, output, 2, op);
371     }
reduceW(const Window & window,const ITensor * input,ITensor * output,F f,const ReductionOperation op)372     static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op)
373     {
374         // Set in/out window
375         Window in_window(window);
376         Window out_window(window);
377 
378         in_window.set(3, Window::Dimension(0, 1, 1));
379         out_window.set(3, Window::Dimension(0, 1, 1));
380 
381         f(in_window, out_window, input, output, 3, op);
382     }
383 };
384 
385 template <typename T, int S>
386 struct RedOpX
387 {
388     /** SIMD vector tag type. */
389     using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
390 
operator ()arm_compute::__anon5929079c0111::RedOpX391     inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
392     {
393         const size_t input_dim_0    = in->info()->dimension(0);
394         const int    window_step_x  = 16 / sizeof(T);
395         const auto   window_start_x = static_cast<int>(in_window.x().start());
396         const auto   window_end_x   = static_cast<int>(in_window.x().end());
397 
398         Window in_win_no_pad = in_window;
399         in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
400 
401         Iterator input(in, in_win_no_pad);
402         Iterator output(out, out_window);
403 
404         execute_window_loop(
405             in_win_no_pad, [&](const Coordinates &)
406         {
407             const auto input_ptr = reinterpret_cast<const T *>(input.ptr());
408 
409             auto init_res_value = static_cast<T>(0.f);
410             switch(op)
411             {
412                 case ReductionOperation::ARG_IDX_MAX:
413                 case ReductionOperation::ARG_IDX_MIN:
414                 case ReductionOperation::MIN:
415                 case ReductionOperation::MAX:
416                 {
417                     init_res_value = static_cast<T>(*input_ptr);
418                     break;
419                 }
420                 case ReductionOperation::PROD:
421                 {
422                     init_res_value = static_cast<T>(1.f);
423                     break;
424                 }
425                 default:
426                     break;
427             }
428             auto         vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{});
429             uint32x4x4_t vec_res_idx{ { 0 } };
430 
431             // Compute window_step_x elements per iteration
432             int x = window_start_x;
433             for(; x <= (window_end_x - window_step_x); x += window_step_x)
434             {
435                 const auto vec_elements = wrapper::vloadq(input_ptr + x);
436                 switch(op)
437                 {
438                     case ReductionOperation::SUM_SQUARE:
439                         vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
440                         break;
441                     case ReductionOperation::MEAN_SUM:
442                     case ReductionOperation::SUM:
443                         vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
444                         break;
445                     case ReductionOperation::PROD:
446                         vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
447                         break;
448                     case ReductionOperation::ARG_IDX_MIN:
449                     {
450                         auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
451                         vec_res_idx             = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
452                         vec_res_value           = temp_vec_res_value;
453                         break;
454                     }
455                     case ReductionOperation::ARG_IDX_MAX:
456                     {
457                         auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
458                         vec_res_idx             = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
459                         vec_res_value           = temp_vec_res_value;
460                         break;
461                     }
462                     case ReductionOperation::MIN:
463                     {
464                         vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
465                         break;
466                     }
467                     case ReductionOperation::MAX:
468                     {
469                         vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
470                         break;
471                     }
472                     default:
473                         ARM_COMPUTE_ERROR("Not supported");
474                 }
475             }
476 
477             switch(op)
478             {
479                 case ReductionOperation::SUM:
480                 case ReductionOperation::MEAN_SUM:
481                 case ReductionOperation::SUM_SQUARE:
482                 {
483 #ifdef ARM_COMPUTE_DEBUG_ENABLED
484                     auto res = static_cast<T>(0.f);
485                     for(int i = 0; i < S; ++i)
486                     {
487                         res += wrapper::vgetlane(vec_res_value, i);
488                     }
489 #else  // ARM_COMPUTE_DEBUG_ENABLED
490                     auto carry_res = wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
491                     for(int i = 0; i < S / 4; ++i)
492                     {
493                         carry_res = wrapper::vpadd(carry_res, carry_res);
494                     }
495                     auto res = wrapper::vgetlane(carry_res, 0);
496 #endif // ARM_COMPUTE_DEBUG_ENABLED
497                     if(op == ReductionOperation::SUM_SQUARE)
498                     {
499                         // Compute left-over elements
500                         for(; x < window_end_x; ++x)
501                         {
502                             res += (*(input_ptr + x)) * (*(input_ptr + x));
503                         }
504                     }
505                     else
506                     {
507                         // Compute left-over elements
508                         for(; x < window_end_x; ++x)
509                         {
510                             res += *(input_ptr + x);
511                         }
512                     }
513 
514                     if(op == ReductionOperation::MEAN_SUM)
515                     {
516                         res /= input_dim_0;
517                     }
518 
519                     *(reinterpret_cast<T *>(output.ptr())) = res;
520                     break;
521                 }
522                 case ReductionOperation::PROD:
523                 {
524                     auto carry_res = wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
525                     T    res       = 1;
526                     for(int i = 0; i < S / 2; ++i)
527                     {
528                         res *= wrapper::vgetlane(carry_res, i);
529                     }
530 
531                     // Compute left-over elements
532                     for(; x < window_end_x; ++x)
533                     {
534                         res *= *(input_ptr + x);
535                     }
536 
537                     *(reinterpret_cast<T *>(output.ptr())) = res;
538                     break;
539                 }
540                 case ReductionOperation::ARG_IDX_MIN:
541                 {
542                     auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
543                     auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
544 
545                     // Compute left-over elements
546                     for(; x < window_end_x; ++x)
547                     {
548                         if(*(input_ptr + x) < res)
549                         {
550                             idx = x;
551                             res = *(input_ptr + x);
552                         }
553                     }
554                     *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
555                     break;
556                 }
557                 case ReductionOperation::ARG_IDX_MAX:
558                 {
559                     auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
560                     auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
561 
562                     // Compute left-over elements
563                     for(; x < window_end_x; ++x)
564                     {
565                         if(*(input_ptr + x) > res)
566                         {
567                             idx = x;
568                             res = *(input_ptr + x);
569                         }
570                     }
571                     *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
572                     break;
573                 }
574                 case ReductionOperation::MIN:
575                 {
576                     auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
577 
578                     // Compute left-over elements
579                     for(; x < window_end_x; ++x)
580                     {
581                         res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
582                     }
583                     *(reinterpret_cast<T *>(output.ptr())) = res;
584                     break;
585                 }
586                 case ReductionOperation::MAX:
587                 {
588                     auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
589 
590                     // Compute left-over elements
591                     for(; x < window_end_x; ++x)
592                     {
593                         res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
594                     }
595                     *(reinterpret_cast<T *>(output.ptr())) = res;
596                     break;
597                 }
598                 default:
599                     ARM_COMPUTE_ERROR("Not supported");
600             }
601         },
602         input, output);
603     }
604 };
605 
606 template <typename T>
607 struct RedOpX_quantized
608 {
operator ()arm_compute::__anon5929079c0111::RedOpX_quantized609     inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op)
610     {
611         using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
612 
613         const auto oq_info = out->info()->quantization_info().uniform();
614 
615         const TensorInfo              in_info = *(in->info());
616         const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
617 
618         const int  window_step_x  = 16 / sizeof(T);
619         const auto window_start_x = static_cast<int>(in_window.x().start());
620         const auto window_end_x   = static_cast<int>(in_window.x().end());
621 
622         Window in_win_no_pad = in_window;
623         in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1));
624 
625         Iterator input(in, in_win_no_pad);
626         Iterator output(out, out_window);
627 
628         const auto  in_offset = static_cast<float>(iq_info.offset);
629         const float in_scale  = iq_info.scale;
630 
631         const auto  out_offset = static_cast<float>(oq_info.offset);
632         const float out_scale  = oq_info.scale;
633 
634         const auto num_elements = static_cast<float>(in_info.dimension(0));
635 
636         const float A = in_scale / (out_scale * num_elements);
637         const float B = out_offset - (in_scale * in_offset) / (out_scale);
638 
639         execute_window_loop(
640             in_win_no_pad, [&](const Coordinates &)
641         {
642             const auto input_ptr = reinterpret_cast<T *>(input.ptr());
643 
644             auto vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
645             auto vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
646             auto vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
647             auto vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{});
648 
649             auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f));
650             auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f));
651             auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f));
652             auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f));
653 
654             typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = { 0 };
655 
656             if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX)
657             {
658                 vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{});
659             }
660 
661             uint32x4x4_t vec_res_idx{ { 0 } };
662             // Compute window_step_x elements per iteration
663             int x = window_start_x;
664             for(; x <= (window_end_x - window_step_x); x += window_step_x)
665             {
666                 const auto vec_elements = wrapper::vloadq(input_ptr + x);
667                 switch(op)
668                 {
669                     case ReductionOperation::SUM:
670                     case ReductionOperation::MEAN_SUM:
671                     {
672                         const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
673                         const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
674 
675                         const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
676                         const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
677                         const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
678                         const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
679 
680                         vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
681                         vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
682                         vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
683                         vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
684                         break;
685                     }
686                     case ReductionOperation::PROD:
687                     {
688                         const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset);
689                         const auto scale32x4f_4  = vdupq_n_f32(iq_info.scale);
690 
691                         const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
692                         const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
693 
694                         const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
695                         const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
696                         const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
697                         const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
698 
699                         auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
700                         auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
701                         auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
702                         auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
703 
704                         //de-quantize vec_elements
705                         temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4);
706                         temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4);
707                         temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4);
708                         temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4);
709 
710                         vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f);
711                         vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f);
712                         vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f);
713                         vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f);
714                         break;
715                     }
716                     case ReductionOperation::ARG_IDX_MIN:
717                     {
718                         auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
719                         vec_res_idx             = calculate_index_quantized<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
720                         vec_res_value           = temp_vec_res_value;
721                         break;
722                     }
723                     case ReductionOperation::ARG_IDX_MAX:
724                     {
725                         auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
726                         vec_res_idx             = calculate_index_quantized<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
727                         vec_res_value           = temp_vec_res_value;
728                         break;
729                     }
730                     case ReductionOperation::MIN:
731                     {
732                         vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
733                         break;
734                     }
735                     case ReductionOperation::MAX:
736                     {
737                         vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
738                         break;
739                     }
740                     default:
741                         ARM_COMPUTE_ERROR("Not supported");
742                 }
743             }
744 
745             switch(op)
746             {
747                 case ReductionOperation::ARG_IDX_MIN:
748                 {
749                     auto idx = calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
750                     auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
751 
752                     // Compute left-over elements
753                     for(; x < window_end_x; ++x)
754                     {
755                         if(*(input_ptr + x) < res)
756                         {
757                             idx = x;
758                             res = *(input_ptr + x);
759                         }
760                     }
761                     *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
762                     break;
763                 }
764                 case ReductionOperation::ARG_IDX_MAX:
765                 {
766                     auto idx = calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
767                     auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
768 
769                     // Compute left-over elements
770                     for(; x < window_end_x; ++x)
771                     {
772                         if(*(input_ptr + x) > res)
773                         {
774                             idx = x;
775                             res = *(input_ptr + x);
776                         }
777                     }
778                     *(reinterpret_cast<uint32_t *>(output.ptr())) = idx;
779                     break;
780                 }
781                 case ReductionOperation::MIN:
782                 {
783                     auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0));
784 
785                     // Compute left-over elements
786                     for(; x < window_end_x; ++x)
787                     {
788                         res = *(input_ptr + x) < res ? *(input_ptr + x) : res;
789                     }
790                     *(reinterpret_cast<T *>(output.ptr())) = res;
791                     break;
792                 }
793                 case ReductionOperation::MAX:
794                 {
795                     auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0));
796 
797                     // Compute left-over elements
798                     for(; x < window_end_x; ++x)
799                     {
800                         res = *(input_ptr + x) > res ? *(input_ptr + x) : res;
801                     }
802                     *(reinterpret_cast<T *>(output.ptr())) = res;
803                     break;
804                 }
805                 case ReductionOperation::PROD:
806                 {
807                     auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f);
808                     carry_res      = wrapper::vmul(carry_res, vec_res_value3_f);
809                     carry_res      = wrapper::vmul(carry_res, vec_res_value4_f);
810 
811                     float res = wrapper::vgetlane(carry_res, 0);
812                     res *= wrapper::vgetlane(carry_res, 1);
813                     res *= wrapper::vgetlane(carry_res, 2);
814                     res *= wrapper::vgetlane(carry_res, 3);
815 
816                     // Compute left-over elements
817                     for(; x < window_end_x; ++x)
818                     {
819                         //de-quantize input
820                         if(std::is_same<T, uint8_t>::value)
821                         {
822                             res *= dequantize_qasymm8(*(input_ptr + x), iq_info);
823                         }
824                         else
825                         {
826                             res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info);
827                         }
828                     }
829 
830                     //re-quantize result
831                     if(std::is_same<T, uint8_t>::value)
832                     {
833                         res = quantize_qasymm8(res, iq_info);
834                     }
835                     else
836                     {
837                         res = quantize_qasymm8_signed(res, iq_info);
838                     }
839 
840                     *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res);
841                     break;
842                 }
843                 case ReductionOperation::SUM:
844                 case ReductionOperation::MEAN_SUM:
845                 {
846                     auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2);
847                     carry_res      = wrapper::vadd(carry_res, vec_res_value3);
848                     carry_res      = wrapper::vadd(carry_res, vec_res_value4);
849 
850                     auto carry_paddition = wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res));
851                     carry_paddition      = wrapper::vpadd(carry_paddition, carry_paddition);
852                     auto res             = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0));
853 
854                     // Compute left-over elements
855                     for(; x < window_end_x; ++x)
856                     {
857                         res += *(input_ptr + x);
858                     }
859 
860                     if(op == ReductionOperation::MEAN_SUM)
861                     {
862                         const int32_t resFinal = A * (static_cast<float>(res)) + B;
863 
864                         *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(resFinal);
865                     }
866                     else
867                     {
868                         // Subtract accumulated offsets
869                         res -= (in_info.dimension(0) - 1) * iq_info.offset;
870                         *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res);
871                     }
872 
873                     break;
874                 }
875                 default:
876                     ARM_COMPUTE_ERROR("Not supported");
877             }
878         },
879         input, output);
880     }
881 };
882 
883 template <typename T, int S>
884 struct RedOpYZW
885 {
886     /** SIMD vector tag type. */
887     using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
888     using neon_vector  = typename wrapper::traits::neon_vector<T, S>::type;
889 
operator ()arm_compute::__anon5929079c0111::RedOpYZW890     inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int axis, const ReductionOperation op)
891     {
892         const TensorInfo in_info            = *(in->info());
893         const int        window_step_x      = 16 / sizeof(T);
894         const auto       window_start_x_tmp = static_cast<int>(in_window.x().start());
895         const auto       window_end_x_tmp   = static_cast<int>(in_window.x().end());
896         // As it split over x-axis, need to set the correct spiltted window start and end.
897         const auto window_start_x = static_cast<int>(0);
898         const auto window_end_x   = static_cast<int>(in_window.shape().x());
899 
900         Window in_win_no_pad = in_window;
901         in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
902         Window out_win_no_pad = out_window;
903         out_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
904 
905         Iterator input(in, in_win_no_pad);
906         Iterator output(out, out_win_no_pad);
907 
908         execute_window_loop(
909             in_win_no_pad, [&](const Coordinates &)
910         {
911             const auto input_ptr = reinterpret_cast<T *>(input.ptr());
912 
913             // Compute window_step_x elements per iteration
914             int x = window_start_x;
915             for(; x <= (window_end_x - window_step_x); x += window_step_x)
916             {
917                 neon_vector vec_res_value = { 0 };
918                 switch(op)
919                 {
920                     case ReductionOperation::ARG_IDX_MAX:
921                     case ReductionOperation::ARG_IDX_MIN:
922                     case ReductionOperation::MIN:
923                     case ReductionOperation::MAX:
924                     {
925                         vec_res_value = wrapper::vloadq(input_ptr + x);
926                         break;
927                     }
928                     case ReductionOperation::PROD:
929                     {
930                         vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{});
931                         break;
932                     }
933                     default:
934                     {
935                         vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
936                         break;
937                     }
938                 }
939                 uint32x4x4_t vec_res_idx{ { 0 } };
940 
941                 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
942                 {
943                     const T   *in_ptr       = reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
944                     const auto vec_elements = wrapper::vloadq(in_ptr);
945                     switch(op)
946                     {
947                         case ReductionOperation::SUM:
948                         case ReductionOperation::MEAN_SUM:
949                             vec_res_value = wrapper::vadd(vec_elements, vec_res_value);
950                             break;
951                         case ReductionOperation::SUM_SQUARE:
952                             vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value);
953                             break;
954                         case ReductionOperation::PROD:
955                             vec_res_value = wrapper::vmul(vec_elements, vec_res_value);
956                             break;
957                         case ReductionOperation::ARG_IDX_MIN:
958                         {
959                             auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
960                             vec_res_idx             = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
961                             vec_res_value           = temp_vec_res_value;
962                             break;
963                         }
964                         case ReductionOperation::ARG_IDX_MAX:
965                         {
966                             auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
967                             vec_res_idx             = calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
968                             vec_res_value           = temp_vec_res_value;
969                             break;
970                         }
971                         case ReductionOperation::MIN:
972                         {
973                             vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
974                             break;
975                         }
976                         case ReductionOperation::MAX:
977                         {
978                             vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
979                             break;
980                         }
981                         default:
982                             ARM_COMPUTE_ERROR("Not supported");
983                     }
984                 }
985 
986                 if(op == ReductionOperation::MEAN_SUM)
987                 {
988                     auto vec_width_inv = wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{}));
989                     vec_res_value      = wrapper::vmul(vec_res_value, vec_width_inv);
990                 }
991 
992                 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
993                 {
994                     wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]);
995 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
996                     if(std::is_same<T, float16_t>::value)
997                     {
998                         wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]);
999                     }
1000 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1001                 }
1002                 else
1003                 {
1004                     wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value);
1005                 }
1006             }
1007 
1008             // Compute left-over elements
1009             for(; x < window_end_x; ++x)
1010             {
1011                 auto res_value = 0.f;
1012                 switch(op)
1013                 {
1014                     case ReductionOperation::ARG_IDX_MAX:
1015                     case ReductionOperation::ARG_IDX_MIN:
1016                     case ReductionOperation::MIN:
1017                     case ReductionOperation::MAX:
1018                     {
1019                         res_value = *(input_ptr + x);
1020                         break;
1021                     }
1022                     case ReductionOperation::PROD:
1023                     {
1024                         res_value = static_cast<T>(1.f);
1025                         break;
1026                     }
1027                     default:
1028                     {
1029                         res_value = static_cast<T>(0.f);
1030                         break;
1031                     }
1032                 }
1033 
1034                 uint32_t res_idx = 0;
1035                 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1036                 {
1037                     const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim);
1038 
1039                     switch(op)
1040                     {
1041                         case ReductionOperation::SUM:
1042                         case ReductionOperation::MEAN_SUM:
1043                             res_value += *in_ptr;
1044                             break;
1045                         case ReductionOperation::SUM_SQUARE:
1046                             res_value += *in_ptr * *in_ptr;
1047                             break;
1048                         case ReductionOperation::PROD:
1049                             res_value *= *in_ptr;
1050                             break;
1051                         case ReductionOperation::ARG_IDX_MIN:
1052                         {
1053                             if(*in_ptr < res_value)
1054                             {
1055                                 res_value = *in_ptr;
1056                                 res_idx   = dim;
1057                             }
1058                             break;
1059                         }
1060                         case ReductionOperation::ARG_IDX_MAX:
1061                         {
1062                             if(*in_ptr > res_value)
1063                             {
1064                                 res_value = *in_ptr;
1065                                 res_idx   = dim;
1066                             }
1067                             break;
1068                         }
1069                         case ReductionOperation::MIN:
1070                         {
1071                             res_value = *in_ptr < res_value ? *in_ptr : res_value;
1072                             break;
1073                         }
1074                         case ReductionOperation::MAX:
1075                         {
1076                             res_value = *in_ptr > res_value ? *in_ptr : res_value;
1077                             break;
1078                         }
1079                         default:
1080                             ARM_COMPUTE_ERROR("Not supported");
1081                     }
1082                 }
1083 
1084                 if(op == ReductionOperation::MEAN_SUM)
1085                 {
1086                     res_value /= in_info.dimension(axis);
1087                 }
1088 
1089                 if(op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX)
1090                 {
1091                     *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx;
1092                 }
1093                 else
1094                 {
1095                     *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value;
1096                 }
1097             }
1098         },
1099         input, output);
1100     }
1101 };
1102 
1103 template <typename T, int S, int axis, ReductionOperation op>
1104 struct RedOpYZW_complex
1105 {
1106     /** SIMD vector tag type. */
1107     using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
1108     using neon_vector  = typename wrapper::traits::neon_vector<T, S>::type;
1109 
operator ()arm_compute::__anon5929079c0111::RedOpYZW_complex1110     inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation)
1111     {
1112         ARM_COMPUTE_ERROR_ON(axis != 2);
1113         ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM);
1114 
1115         const TensorInfo in_info            = *(in->info());
1116         const size_t     stride_z           = in_info.strides_in_bytes()[axis];
1117         const int        window_step_x      = 16 / sizeof(T);
1118         const auto       window_start_x_tmp = static_cast<int>(in_window.x().start());
1119         const auto       window_end_x_tmp   = static_cast<int>(in_window.x().end());
1120         // As it split over x-axis, need to set the correct spiltted window start and end.
1121         const auto window_start_x = static_cast<int>(0);
1122         const auto window_end_x   = static_cast<int>(in_window.shape().x());
1123 
1124         Window in_win_no_pad = in_window;
1125         in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
1126         Window out_win_no_pad = out_window;
1127         out_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
1128 
1129         Iterator input(in, in_win_no_pad);
1130         Iterator output(out, out_win_no_pad);
1131 
1132         execute_window_loop(
1133             in_win_no_pad, [&](const Coordinates &)
1134         {
1135             // Compute window_step_x elements per iteration
1136             int x = window_start_x;
1137             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1138             {
1139                 neon_vector vec_res_value_0 = { 0 };
1140                 neon_vector vec_res_value_1 = { 0 };
1141 
1142                 vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
1143                 vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
1144 
1145                 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1146                 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1147                 {
1148                     T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1149                     T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim);
1150 
1151                     const auto vec_elements_0 = wrapper::vloadq(in_ptr_0);
1152                     const auto vec_elements_1 = wrapper::vloadq(in_ptr_1);
1153 
1154                     vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0);
1155                     vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1);
1156                 }
1157 
1158                 wrapper::vstore(out_ptr, vec_res_value_0);
1159                 wrapper::vstore(out_ptr + 4, vec_res_value_1);
1160             }
1161 
1162             // Compute left-over elements
1163             for(; x < window_end_x; ++x)
1164             {
1165                 auto res_value_0 = 0.f;
1166                 auto res_value_1 = 0.f;
1167 
1168                 T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T));
1169                 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1170                 {
1171                     T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim);
1172                     res_value_0 += *in_ptr;
1173                     res_value_1 += *(in_ptr + 1);
1174                 }
1175                 *out_ptr       = res_value_0;
1176                 *(out_ptr + 1) = res_value_1;
1177             }
1178         },
1179         input, output);
1180     }
1181 };
1182 
1183 template <typename T>
1184 struct RedOpYZW_quantized
1185 {
operator ()arm_compute::__anon5929079c0111::RedOpYZW_quantized1186     inline void operator()(const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int axis, const ReductionOperation op)
1187     {
1188         const TensorInfo              in_info = *(in->info());
1189         const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform();
1190         using PromotedType                    = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type;
1191 
1192         const auto oq_info = out->info()->quantization_info().uniform();
1193 
1194         const int  window_step_x      = 16 / sizeof(T);
1195         const auto window_start_x_tmp = static_cast<int>(in_window.x().start());
1196         const auto window_end_x_tmp   = static_cast<int>(in_window.x().end());
1197         // As it split over x-axis, need to set the correct spiltted window start and end.
1198         const auto window_start_x = static_cast<int>(0);
1199         const auto window_end_x   = static_cast<int>(in_window.shape().x());
1200 
1201         Window in_win_no_pad = in_window;
1202         in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x()));
1203         Window out_win_no_pad = out_window;
1204         out_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x()));
1205 
1206         Iterator input(in, in_win_no_pad);
1207         Iterator output(out, out_win_no_pad);
1208 
1209         using vector_type   = typename wrapper::traits::neon_bitvector<PromotedType, wrapper::traits::BitWidth::W128>::type;
1210         using vector_type_f = typename wrapper::traits::neon_vector<float, 4>::type;
1211 
1212         vector_type vec_res_value1{};
1213         vector_type vec_res_value2{};
1214         vector_type vec_res_value3{};
1215         vector_type vec_res_value4{};
1216 
1217         vector_type_f vec_res_value1_f{};
1218         vector_type_f vec_res_value2_f{};
1219         vector_type_f vec_res_value3_f{};
1220         vector_type_f vec_res_value4_f{};
1221 
1222         const float in_offset = static_cast<float>(iq_info.offset);
1223         const float in_scale  = iq_info.scale;
1224 
1225         const float out_offset = static_cast<float>(oq_info.offset);
1226         const float out_scale  = oq_info.scale;
1227 
1228         const float num_elements = static_cast<float>(in_info.dimension(axis));
1229 
1230         const float A = in_scale / (out_scale * num_elements);
1231         const float B = out_offset - (in_scale * in_offset) / (out_scale);
1232 
1233         const auto vec_A = wrapper::vdup_n(static_cast<float>(A), wrapper::traits::vector_128_tag{});
1234         const auto vec_B = wrapper::vdup_n(static_cast<float>(B), wrapper::traits::vector_128_tag{});
1235 
1236         execute_window_loop(
1237             in_win_no_pad, [&](const Coordinates &)
1238         {
1239             const auto input_ptr = reinterpret_cast<T *>(input.ptr());
1240 
1241             // Compute window_step_x elements per iteration
1242             int x = window_start_x;
1243             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1244             {
1245                 uint32x4x4_t vec_res_idx{ { 0 } };
1246                 vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1247                 vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1248                 vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1249                 vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{});
1250 
1251                 vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1252                 vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1253                 vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1254                 vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{});
1255 
1256                 auto vec_res_value = wrapper::vloadq(input_ptr + x);
1257 
1258                 for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
1259                 {
1260                     const T   *in_ptr       = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim;
1261                     const auto vec_elements = wrapper::vloadq(in_ptr);
1262                     switch(op)
1263                     {
1264                         case ReductionOperation::SUM:
1265                         case ReductionOperation::MEAN_SUM:
1266                         {
1267                             const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1268                             const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1269 
1270                             const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1271                             const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1272                             const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1273                             const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1274 
1275                             vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1);
1276                             vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2);
1277                             vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3);
1278                             vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4);
1279                             break;
1280                         }
1281                         case ReductionOperation::PROD:
1282                         {
1283                             const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
1284                             const auto scale32x4f_4  = wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{});
1285 
1286                             const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements));
1287                             const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements));
1288 
1289                             const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1));
1290                             const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1));
1291                             const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2));
1292                             const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2));
1293 
1294                             auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1);
1295                             auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2);
1296                             auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3);
1297                             auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4);
1298 
1299                             //de-quantize vec_elements
1300                             temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4);
1301                             temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4);
1302                             temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4);
1303                             temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4);
1304 
1305                             vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f);
1306                             vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f);
1307                             vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f);
1308                             vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f);
1309                             break;
1310                         }
1311                         case ReductionOperation::ARG_IDX_MIN:
1312                         {
1313                             auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1314                             vec_res_idx             = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
1315                             vec_res_value           = temp_vec_res_value;
1316                             break;
1317                         }
1318                         case ReductionOperation::ARG_IDX_MAX:
1319                         {
1320                             auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1321                             vec_res_idx             = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis);
1322                             vec_res_value           = temp_vec_res_value;
1323                             break;
1324                         }
1325                         case ReductionOperation::MIN:
1326                         {
1327                             vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
1328                             break;
1329                         }
1330                         case ReductionOperation::MAX:
1331                         {
1332                             vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
1333                             break;
1334                         }
1335                         default:
1336                             ARM_COMPUTE_ERROR("Not supported");
1337                     }
1338                 }
1339 
1340                 switch(op)
1341                 {
1342                     case ReductionOperation::ARG_IDX_MIN:
1343                     case ReductionOperation::ARG_IDX_MAX:
1344                     {
1345                         wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]);
1346                         wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]);
1347                         wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]);
1348                         wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12, vec_res_idx.val[3]);
1349                         break;
1350                     }
1351                     case ReductionOperation::MIN:
1352                     case ReductionOperation::MAX:
1353                     {
1354                         wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value);
1355                         break;
1356                     }
1357                     case ReductionOperation::SUM:
1358                     {
1359                         // Subtract offsets
1360                         auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset);
1361 
1362                         auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1);
1363                         auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2);
1364                         auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3);
1365                         auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4);
1366 
1367                         vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets);
1368                         vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets);
1369                         vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets);
1370                         vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets);
1371 
1372                         const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2));
1373                         const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4));
1374 
1375                         combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x);
1376                         break;
1377                     }
1378                     case ReductionOperation::MEAN_SUM:
1379                     {
1380                         vec_res_value1_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value1), vec_A);
1381                         vec_res_value2_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value2), vec_A);
1382                         vec_res_value3_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value3), vec_A);
1383                         vec_res_value4_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value4), vec_A);
1384 
1385 #ifdef __aarch64__
1386                         vec_res_value1 = wrapper::vcvta<PromotedType>(vec_res_value1_f);
1387                         vec_res_value2 = wrapper::vcvta<PromotedType>(vec_res_value2_f);
1388                         vec_res_value3 = wrapper::vcvta<PromotedType>(vec_res_value3_f);
1389                         vec_res_value4 = wrapper::vcvta<PromotedType>(vec_res_value4_f);
1390 #else  // defined(__aarch64__)
1391                         vec_res_value1 = wrapper::vcvt<PromotedType>(vec_res_value1_f);
1392                         vec_res_value2 = wrapper::vcvt<PromotedType>(vec_res_value2_f);
1393                         vec_res_value3 = wrapper::vcvt<PromotedType>(vec_res_value3_f);
1394                         vec_res_value4 = wrapper::vcvt<PromotedType>(vec_res_value4_f);
1395 #endif // __aarch64__
1396 
1397                         const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1398                         const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1399                         auto       res         = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1400 
1401                         wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
1402                         break;
1403                     }
1404                     case ReductionOperation::PROD:
1405                     {
1406                         const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{});
1407                         const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale));
1408 
1409                         //re-quantize
1410                         vec_res_value1_f = wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4);
1411                         vec_res_value2_f = wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4);
1412                         vec_res_value3_f = wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4);
1413                         vec_res_value4_f = wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4);
1414 
1415                         vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f);
1416                         vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f);
1417                         vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f);
1418                         vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f);
1419 
1420                         const auto temp16x8t_1 = wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2));
1421                         const auto temp16x8t_2 = wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4));
1422                         auto       res         = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2));
1423 
1424                         wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res);
1425                         break;
1426                     }
1427                     default:
1428                         ARM_COMPUTE_ERROR("Not supported");
1429                 }
1430             }
1431 
1432             // Compute left-over elements
1433             for(; x < window_end_x; ++x)
1434             {
1435                 float   res_value   = 0.f;
1436                 int32_t res_value_q = 0;
1437 
1438                 switch(op)
1439                 {
1440                     case ReductionOperation::ARG_IDX_MAX:
1441                     case ReductionOperation::ARG_IDX_MIN:
1442                     case ReductionOperation::MIN:
1443                     case ReductionOperation::MAX:
1444                     {
1445                         res_value = *(input_ptr + x);
1446                         break;
1447                     }
1448                     case ReductionOperation::PROD:
1449                     {
1450                         res_value = static_cast<T>(1.0f);
1451                         break;
1452                     }
1453                     default:
1454                     {
1455                         res_value = static_cast<T>(0.0f);
1456                         break;
1457                     }
1458                 }
1459                 uint32_t res_idx = 0;
1460 
1461                 for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
1462                 {
1463                     const T *in_ptr = reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim);
1464                     switch(op)
1465                     {
1466                         case ReductionOperation::SUM:
1467                         {
1468                             res_value += *in_ptr;
1469                             break;
1470                         }
1471                         case ReductionOperation::MEAN_SUM:
1472                         {
1473                             res_value_q += *in_ptr;
1474                             break;
1475                         }
1476                         case ReductionOperation::SUM_SQUARE:
1477                         {
1478                             res_value += *in_ptr * *in_ptr;
1479                             break;
1480                         }
1481                         case ReductionOperation::PROD:
1482                         {
1483                             //de-quantize input
1484                             if(std::is_same<T, uint8_t>::value)
1485                             {
1486                                 res_value *= dequantize_qasymm8(*in_ptr, iq_info);
1487                             }
1488                             else
1489                             {
1490                                 res_value *= dequantize_qasymm8_signed(*in_ptr, iq_info);
1491                             }
1492                             break;
1493                         }
1494                         case ReductionOperation::ARG_IDX_MIN:
1495                         {
1496                             if(*in_ptr < res_value)
1497                             {
1498                                 res_value = *in_ptr;
1499                                 res_idx   = dim;
1500                             }
1501                             break;
1502                         }
1503                         case ReductionOperation::ARG_IDX_MAX:
1504                         {
1505                             if(*in_ptr > res_value)
1506                             {
1507                                 res_value = *in_ptr;
1508                                 res_idx   = dim;
1509                             }
1510                             break;
1511                         }
1512                         case ReductionOperation::MIN:
1513                         {
1514                             res_value = *in_ptr < res_value ? *in_ptr : res_value;
1515                             break;
1516                         }
1517                         case ReductionOperation::MAX:
1518                         {
1519                             res_value = *in_ptr > res_value ? *in_ptr : res_value;
1520                             break;
1521                         }
1522                         default:
1523                             ARM_COMPUTE_ERROR("Not supported");
1524                     }
1525                 }
1526 
1527                 switch(op)
1528                 {
1529                     case ReductionOperation::MEAN_SUM:
1530                     {
1531                         // Apply previously calculated coefficients (with rounding on aarch64)
1532 #ifdef  __aarch64__
1533                         const int32_t res                        = arm_compute::support::cpp11::round(A * (static_cast<float>(res_value_q)) + B);
1534 #else   // defined(__aarch64__)
1535                         const int32_t res                        = A * (static_cast<float>(res_value_q)) + B;
1536 #endif  // __aarch64__
1537                         *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res);
1538                         break;
1539                     }
1540                     case ReductionOperation::SUM:
1541                     {
1542                         // Subtract accumulated offsets
1543                         res_value -= (in_info.dimension(axis) - 1) * iq_info.offset;
1544                         *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value);
1545                         break;
1546                     }
1547                     case ReductionOperation::PROD:
1548                     {
1549                         //re-quantize result
1550                         T res = 0;
1551                         if(std::is_same<T, uint8_t>::value)
1552                         {
1553                             res = quantize_qasymm8(res_value, iq_info);
1554                         }
1555                         else
1556                         {
1557                             res = quantize_qasymm8_signed(res_value, iq_info);
1558                         }
1559                         *(reinterpret_cast<T *>(output.ptr() + x)) = res;
1560                         break;
1561                     }
1562                     case ReductionOperation::ARG_IDX_MIN:
1563                     case ReductionOperation::ARG_IDX_MAX:
1564                     {
1565                         *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx;
1566                         break;
1567                     }
1568                     default:
1569                         *(reinterpret_cast<T *>(output.ptr() + x)) = res_value;
1570                 }
1571             }
1572         },
1573         input, output);
1574     }
1575 };
1576 
reduce_op(const Window & window,const ITensor * input,ITensor * output,unsigned int axis,const ReductionOperation op)1577 void reduce_op(const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op)
1578 {
1579     const bool is_complex = (input->info()->num_channels() == 2);
1580 
1581     if(is_complex)
1582     {
1583         switch(axis)
1584         {
1585             case 2:
1586                 switch(input->info()->data_type())
1587                 {
1588                     case DataType::F32:
1589                         switch(op)
1590                         {
1591                             case ReductionOperation::SUM:
1592                                 return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ(window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op);
1593                             default:
1594                                 ARM_COMPUTE_ERROR("Not supported");
1595                         }
1596                     default:
1597                         ARM_COMPUTE_ERROR("Not supported");
1598                 }
1599             default:
1600                 ARM_COMPUTE_ERROR("Not supported");
1601         }
1602         return;
1603     }
1604 
1605     switch(axis)
1606     {
1607         case 0:
1608         {
1609             switch(input->info()->data_type())
1610             {
1611                 case DataType::QASYMM8:
1612                 {
1613                     return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, RedOpX_quantized<uint8_t>(), op);
1614                 }
1615                 case DataType::QASYMM8_SIGNED:
1616                 {
1617                     return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), op);
1618                 }
1619 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1620                 case DataType::F16:
1621                     return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op);
1622 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1623                 case DataType::F32:
1624                 {
1625                     return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
1626                 }
1627                 case DataType::S32:
1628                 {
1629                     return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
1630                 }
1631                 default:
1632                 {
1633                     ARM_COMPUTE_ERROR("Not supported");
1634                 }
1635             }
1636         }
1637         case 1:
1638             switch(input->info()->data_type())
1639             {
1640                 case DataType::QASYMM8:
1641                 {
1642                     return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1643                 }
1644                 case DataType::QASYMM8_SIGNED:
1645                 {
1646                     return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, RedOpYZW_quantized<int8_t>(), op);
1647                 }
1648 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1649                 case DataType::F16:
1650                     return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op);
1651 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1652                 case DataType::F32:
1653                     return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
1654                 case DataType::S32:
1655                     return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
1656                 default:
1657                     ARM_COMPUTE_ERROR("Not supported");
1658             }
1659         case 2:
1660             switch(input->info()->data_type())
1661             {
1662                 case DataType::QASYMM8:
1663                     return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1664                 case DataType::QASYMM8_SIGNED:
1665                     return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<int8_t>(), op);
1666 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1667                 case DataType::F16:
1668                     return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op);
1669 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1670                 case DataType::F32:
1671                     return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
1672                 case DataType::S32:
1673                     return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
1674                 default:
1675                     ARM_COMPUTE_ERROR("Not supported");
1676             }
1677         case 3:
1678             switch(input->info()->data_type())
1679             {
1680                 case DataType::QASYMM8:
1681                     return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, RedOpYZW_quantized<uint8_t>(), op);
1682                 case DataType::QASYMM8_SIGNED:
1683                     return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, RedOpYZW_quantized<int8_t>(), op);
1684 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1685                 case DataType::F16:
1686                     return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op);
1687 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1688                 case DataType::F32:
1689                     return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
1690                 case DataType::S32:
1691                     return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
1692                 default:
1693                     ARM_COMPUTE_ERROR("Not supported");
1694             }
1695         default:
1696             ARM_COMPUTE_ERROR("Unsupported reduction axis");
1697     }
1698 }
1699 
validate_arguments(const ITensorInfo * input,const ITensorInfo * output,unsigned int axis,ReductionOperation op)1700 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1701 {
1702     ARM_COMPUTE_UNUSED(op);
1703 
1704     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
1705     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
1706 
1707     if(input->num_channels() == 1)
1708     {
1709         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
1710     }
1711     else
1712     {
1713         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
1714         ARM_COMPUTE_RETURN_ERROR_ON(op != ReductionOperation::SUM);
1715         ARM_COMPUTE_RETURN_ERROR_ON(axis != 2);
1716     }
1717 
1718     ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
1719     ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
1720 
1721     if(output->total_size() != 0)
1722     {
1723         bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
1724         if(!is_arg_min_max)
1725         {
1726             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
1727             ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() != output->num_channels());
1728         }
1729         else
1730         {
1731             ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32, DataType::S32);
1732         }
1733 
1734         const TensorShape output_shape         = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis);
1735         const TensorInfo  tensor_info_reshaped = input->clone()->set_tensor_shape(output_shape);
1736         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_reshaped);
1737     }
1738 
1739     return Status{};
1740 }
1741 } // namespace
1742 
NEReductionOperationKernel()1743 NEReductionOperationKernel::NEReductionOperationKernel()
1744     : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE)
1745 {
1746 }
1747 
configure(const ITensor * input,ITensor * output,unsigned int axis,ReductionOperation op)1748 void NEReductionOperationKernel::configure(const ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
1749 {
1750     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
1751 
1752     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), axis, op));
1753 
1754     _input          = input;
1755     _output         = output;
1756     _op             = op;
1757     _reduction_axis = axis;
1758 
1759     // Configure kernel window
1760     Window win = calculate_max_window(*input->info(), Steps());
1761     INEKernel::configure(win);
1762 
1763     // Calculate output shape and set if empty
1764     const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->info()->tensor_shape(), axis);
1765     // Output auto initialization if not yet initialized
1766     const bool is_arg_min_max   = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
1767     DataType   output_data_type = is_arg_min_max ? DataType::S32 : input->info()->data_type();
1768     auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
1769 }
1770 
validate(const ITensorInfo * input,const ITensorInfo * output,unsigned int axis,ReductionOperation op)1771 Status NEReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
1772 {
1773     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op));
1774 
1775     return Status{};
1776 }
1777 
run(const Window & window,const ThreadInfo & info)1778 void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &info)
1779 {
1780     ARM_COMPUTE_UNUSED(info);
1781     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1782     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1783 
1784     reduce_op(window, _input, _output, _reduction_axis, _op);
1785 }
1786 } // namespace arm_compute
1787