xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/CpuMulKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2016-2022 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "src/cpu/kernels/CpuMulKernel.h"
25*c217d954SCole Faust 
26*c217d954SCole Faust #include "arm_compute/core/ITensor.h"
27*c217d954SCole Faust #include "arm_compute/core/TensorInfo.h"
28*c217d954SCole Faust #include "src/core/CPP/Validate.h"
29*c217d954SCole Faust #include "src/core/NEON/NEAsymm.h"
30*c217d954SCole Faust #include "src/core/NEON/NESymm.h"
31*c217d954SCole Faust #include "src/core/NEON/wrapper/wrapper.h"
32*c217d954SCole Faust #include "src/core/helpers/AutoConfiguration.h"
33*c217d954SCole Faust #include "src/core/helpers/WindowHelpers.h"
34*c217d954SCole Faust 
35*c217d954SCole Faust #include <arm_neon.h>
36*c217d954SCole Faust 
37*c217d954SCole Faust namespace
38*c217d954SCole Faust {
39*c217d954SCole Faust #if defined(ENABLE_FP32_KERNELS)
40*c217d954SCole Faust     static constexpr size_t default_mws_N1_fp32_neon = 22447;
41*c217d954SCole Faust     static constexpr size_t default_mws_V1_fp32_neon = 38982;
42*c217d954SCole Faust #endif /* ENABLE_FP32_KERNELS */
43*c217d954SCole Faust     static constexpr size_t default_mws_other_platforms_1d_tensor = 10240;
44*c217d954SCole Faust }
45*c217d954SCole Faust namespace arm_compute
46*c217d954SCole Faust {
47*c217d954SCole Faust namespace cpu
48*c217d954SCole Faust {
49*c217d954SCole Faust namespace kernels
50*c217d954SCole Faust {
51*c217d954SCole Faust namespace
52*c217d954SCole Faust {
53*c217d954SCole Faust const float       scale255_constant      = 1.f / 255.f;
54*c217d954SCole Faust const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant);
55*c217d954SCole Faust const float32x4_t positive_round_f32q    = vdupq_n_f32(0.5f);
56*c217d954SCole Faust 
validate_arguments(const ITensorInfo * src1,const ITensorInfo * src2,const ITensorInfo * dst,float scale,ConvertPolicy overflow_policy,RoundingPolicy rounding_policy)57*c217d954SCole Faust inline Status validate_arguments(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
58*c217d954SCole Faust {
59*c217d954SCole Faust     ARM_COMPUTE_UNUSED(overflow_policy);
60*c217d954SCole Faust     ARM_COMPUTE_UNUSED(rounding_policy);
61*c217d954SCole Faust 
62*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src1);
63*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
64*c217d954SCole Faust                                                          DataType::F32);
65*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::S32, DataType::QSYMM16, DataType::F16,
66*c217d954SCole Faust                                                          DataType::F32);
67*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
68*c217d954SCole Faust                                                          DataType::S16, DataType::QSYMM16,
69*c217d954SCole Faust                                                          DataType::S32, DataType::F16, DataType::F32);
70*c217d954SCole Faust     if(is_data_type_quantized(src1->data_type()) || is_data_type_quantized(src2->data_type()))
71*c217d954SCole Faust     {
72*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src1, src2);
73*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MSG(overflow_policy == ConvertPolicy::WRAP, "ConvertPolicy cannot be WRAP if datatype is quantized");
74*c217d954SCole Faust     }
75*c217d954SCole Faust 
76*c217d954SCole Faust     if(dst->total_size() > 0)
77*c217d954SCole Faust     {
78*c217d954SCole Faust         const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
79*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst");
80*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
81*c217d954SCole Faust         // clang-format off
82*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MSG(
83*c217d954SCole Faust             !(src1->data_type() == src2->data_type() && src2->data_type() == dst->data_type()) &&
84*c217d954SCole Faust             !(src1->data_type() == DataType::U8 && src2->data_type() == DataType::U8 && dst->data_type() == DataType::S16) &&
85*c217d954SCole Faust             !(src1->data_type() == DataType::U8 && src2->data_type() == DataType::S16 && dst->data_type() == DataType::S16) &&
86*c217d954SCole Faust             !(src1->data_type() == DataType::S16 && src2->data_type() == DataType::U8 && dst->data_type() == DataType::S16) &&
87*c217d954SCole Faust             !(src1->data_type() == DataType::S16 && src2->data_type() == DataType::U8 && dst->data_type() == DataType::S16) &&
88*c217d954SCole Faust             !(src1->data_type() == DataType::QSYMM16 && src2->data_type() == DataType::QSYMM16 && dst->data_type() == DataType::S32)
89*c217d954SCole Faust             , "Invalid data type combination");
90*c217d954SCole Faust         // clang-format on
91*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->data_type() == DataType::S16 && dst->data_type() == DataType::S32 && scale != 1.f, "Unsupported scale for QSYMM16 inputs and S32 dst");
92*c217d954SCole Faust     }
93*c217d954SCole Faust 
94*c217d954SCole Faust     if(std::abs(scale - scale255_constant) < 0.00001f)
95*c217d954SCole Faust     {
96*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN);
97*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->data_type() == DataType::S32 && src2->data_type() == DataType::S32 && dst->data_type() == DataType::S32,
98*c217d954SCole Faust                                         "Scale == 1/255 is not supported if input and dst are of data type S32");
99*c217d954SCole Faust     }
100*c217d954SCole Faust     else
101*c217d954SCole Faust     {
102*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO);
103*c217d954SCole Faust 
104*c217d954SCole Faust         int         exponent            = 0;
105*c217d954SCole Faust         const float normalized_mantissa = std::frexp(scale, &exponent);
106*c217d954SCole Faust 
107*c217d954SCole Faust         // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15
108*c217d954SCole Faust         // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14
109*c217d954SCole Faust         // Moreover, it will be negative as we deal with 1/2^n
110*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MSG(!((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)), "Scale value not supported (Should be 1/(2^n) or 1/255");
111*c217d954SCole Faust     }
112*c217d954SCole Faust 
113*c217d954SCole Faust     return Status{};
114*c217d954SCole Faust }
115*c217d954SCole Faust 
116*c217d954SCole Faust /* Scales a given vector by 1/255.
117*c217d954SCole Faust  *
118*c217d954SCole Faust  * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats.
119*c217d954SCole Faust  *
120*c217d954SCole Faust  * @param in Input vector to scale.
121*c217d954SCole Faust  * @return   Scaled dst rounded to nearest (round half up).
122*c217d954SCole Faust  */
scale255_S32_S32(int32x4_t in)123*c217d954SCole Faust inline int32x4_t scale255_S32_S32(int32x4_t in)
124*c217d954SCole Faust {
125*c217d954SCole Faust     // Scale
126*c217d954SCole Faust     const float32x4_t tmp = vmulq_f32(vcvtq_f32_s32(in), scale255_constant_f32q);
127*c217d954SCole Faust     // Round to nearest (round half up)
128*c217d954SCole Faust     // Add +0.5 for all values
129*c217d954SCole Faust     // Afterwards vcvt rounds toward zero
130*c217d954SCole Faust     return vcvtq_s32_f32(vaddq_f32(tmp, positive_round_f32q));
131*c217d954SCole Faust }
132*c217d954SCole Faust 
scale255_U16_U16(uint16x8_t in)133*c217d954SCole Faust inline uint16x8_t scale255_U16_U16(uint16x8_t in)
134*c217d954SCole Faust {
135*c217d954SCole Faust     const int32x4_t tmp_s1 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(in))));
136*c217d954SCole Faust     const int32x4_t tmp_s2 = scale255_S32_S32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(in))));
137*c217d954SCole Faust     return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
138*c217d954SCole Faust }
139*c217d954SCole Faust 
140*c217d954SCole Faust template <typename T>
141*c217d954SCole Faust inline typename std::enable_if<std::is_same<T, int8_t>::value, int8x16_t>::type
vquantize(float32x4x4_t val,const UniformQuantizationInfo & info)142*c217d954SCole Faust vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
143*c217d954SCole Faust {
144*c217d954SCole Faust     return vquantize_signed(val, info);
145*c217d954SCole Faust }
146*c217d954SCole Faust 
147*c217d954SCole Faust template <typename T>
148*c217d954SCole Faust inline typename std::enable_if<std::is_same<T, uint8_t>::value, uint8x16_t>::type
vquantize(float32x4x4_t val,const UniformQuantizationInfo & info)149*c217d954SCole Faust vquantize(float32x4x4_t val, const UniformQuantizationInfo &info)
150*c217d954SCole Faust {
151*c217d954SCole Faust     return vquantize(val, info);
152*c217d954SCole Faust }
153*c217d954SCole Faust 
154*c217d954SCole Faust template <typename T>
mul_saturate_quantized_8(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,float scale)155*c217d954SCole Faust void mul_saturate_quantized_8(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, float scale)
156*c217d954SCole Faust {
157*c217d954SCole Faust     // Create input windows
158*c217d954SCole Faust     Window win        = window;
159*c217d954SCole Faust     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
160*c217d954SCole Faust     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
161*c217d954SCole Faust 
162*c217d954SCole Faust     // Clear X Dimension on execution window as we handle manually
163*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
164*c217d954SCole Faust 
165*c217d954SCole Faust     const int  window_step_x         = 16 / sizeof(T);
166*c217d954SCole Faust     const auto window_start_x        = static_cast<int>(window.x().start());
167*c217d954SCole Faust     const auto window_end_x          = static_cast<int>(window.x().end());
168*c217d954SCole Faust     const bool is_broadcast_across_x = src1->info()->tensor_shape().x() != src2->info()->tensor_shape().x();
169*c217d954SCole Faust 
170*c217d954SCole Faust     const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
171*c217d954SCole Faust     const UniformQuantizationInfo tmp_qua_info    = { output_qua_info.scale / scale, output_qua_info.offset };
172*c217d954SCole Faust 
173*c217d954SCole Faust     if(is_broadcast_across_x)
174*c217d954SCole Faust     {
175*c217d954SCole Faust         const bool                    is_broadcast_input_2 = input2_win.x().step() == 0;
176*c217d954SCole Faust         Window                        broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
177*c217d954SCole Faust         Window                        non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
178*c217d954SCole Faust         const ITensor                *broadcast_tensor     = is_broadcast_input_2 ? src2 : src1;
179*c217d954SCole Faust         const ITensor                *non_broadcast_tensor = !is_broadcast_input_2 ? src2 : src1;
180*c217d954SCole Faust         const UniformQuantizationInfo broadcast_qinfo      = broadcast_tensor->info()->quantization_info().uniform();
181*c217d954SCole Faust         const UniformQuantizationInfo non_broadcast_qinfo  = non_broadcast_tensor->info()->quantization_info().uniform();
182*c217d954SCole Faust 
183*c217d954SCole Faust         // Clear X Dimension on execution window as we handle manually
184*c217d954SCole Faust         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
185*c217d954SCole Faust 
186*c217d954SCole Faust         Iterator broadcast_input(broadcast_tensor, broadcast_win);
187*c217d954SCole Faust         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
188*c217d954SCole Faust         Iterator dst(out, win);
189*c217d954SCole Faust 
190*c217d954SCole Faust         using ExactTagType = typename wrapper::traits::neon_vector<T, window_step_x>::tag_type;
191*c217d954SCole Faust 
192*c217d954SCole Faust         execute_window_loop(
193*c217d954SCole Faust             win, [&](const Coordinates &)
194*c217d954SCole Faust         {
195*c217d954SCole Faust             const auto non_broadcast_input_ptr = reinterpret_cast<const T *>(non_broadcast_input.ptr());
196*c217d954SCole Faust             const auto output_ptr              = reinterpret_cast<T *>(dst.ptr());
197*c217d954SCole Faust 
198*c217d954SCole Faust             const auto broadcast_value     = *reinterpret_cast<const T *>(broadcast_input.ptr());
199*c217d954SCole Faust             const auto broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
200*c217d954SCole Faust 
201*c217d954SCole Faust             // Compute window_step_x elements per iteration
202*c217d954SCole Faust             int x = window_start_x;
203*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
204*c217d954SCole Faust             {
205*c217d954SCole Faust                 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
206*c217d954SCole Faust 
207*c217d954SCole Faust                 // Dequantize inputs
208*c217d954SCole Faust                 const float32x4x4_t in1_f32x4x4 = vdequantize(non_broadcast_v, non_broadcast_qinfo);
209*c217d954SCole Faust                 const float32x4x4_t in2_f32x4x4 = vdequantize(broadcast_value_vec, broadcast_qinfo);
210*c217d954SCole Faust 
211*c217d954SCole Faust                 const float32x4x4_t out_f32x4x4 =
212*c217d954SCole Faust                 {
213*c217d954SCole Faust                     vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
214*c217d954SCole Faust                     vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
215*c217d954SCole Faust                     vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
216*c217d954SCole Faust                     vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
217*c217d954SCole Faust                 };
218*c217d954SCole Faust 
219*c217d954SCole Faust                 // Quantize dst
220*c217d954SCole Faust                 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
221*c217d954SCole Faust                 wrapper::vstore(output_ptr + x, result);
222*c217d954SCole Faust             }
223*c217d954SCole Faust 
224*c217d954SCole Faust             // Compute left-over elements
225*c217d954SCole Faust             for(; x < window_end_x; ++x)
226*c217d954SCole Faust             {
227*c217d954SCole Faust                 // Dequantize inputs
228*c217d954SCole Faust                 const T     src1    = *(non_broadcast_input_ptr + x);
229*c217d954SCole Faust                 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(src1, non_broadcast_qinfo);
230*c217d954SCole Faust                 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(broadcast_value, broadcast_qinfo);
231*c217d954SCole Faust                 const float tmp_f   = tmp_in1 * tmp_in2;
232*c217d954SCole Faust 
233*c217d954SCole Faust                 // Quantize dst
234*c217d954SCole Faust                 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
235*c217d954SCole Faust                 *(output_ptr + x)  = tmp_qua;
236*c217d954SCole Faust             }
237*c217d954SCole Faust         },
238*c217d954SCole Faust         broadcast_input, non_broadcast_input, dst);
239*c217d954SCole Faust     }
240*c217d954SCole Faust     else
241*c217d954SCole Faust     {
242*c217d954SCole Faust         const UniformQuantizationInfo input1_qua_info = src1->info()->quantization_info().uniform();
243*c217d954SCole Faust         const UniformQuantizationInfo input2_qua_info = src2->info()->quantization_info().uniform();
244*c217d954SCole Faust 
245*c217d954SCole Faust         // Clear X Dimension on execution window as we handle manually
246*c217d954SCole Faust         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
247*c217d954SCole Faust         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
248*c217d954SCole Faust 
249*c217d954SCole Faust         Iterator input1(src1, input1_win);
250*c217d954SCole Faust         Iterator input2(src2, input2_win);
251*c217d954SCole Faust         Iterator dst(out, win);
252*c217d954SCole Faust 
253*c217d954SCole Faust         execute_window_loop(
254*c217d954SCole Faust             win, [&](const Coordinates &)
255*c217d954SCole Faust         {
256*c217d954SCole Faust             const auto input1_ptr = reinterpret_cast<const T *>(input1.ptr());
257*c217d954SCole Faust             const auto input2_ptr = reinterpret_cast<const T *>(input2.ptr());
258*c217d954SCole Faust             const auto output_ptr = reinterpret_cast<T *>(dst.ptr());
259*c217d954SCole Faust 
260*c217d954SCole Faust             // Compute window_step_x elements per iteration
261*c217d954SCole Faust             int x = window_start_x;
262*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
263*c217d954SCole Faust             {
264*c217d954SCole Faust                 const auto input1_q = wrapper::vloadq(input1_ptr + x);
265*c217d954SCole Faust                 const auto input2_q = wrapper::vloadq(input2_ptr + x);
266*c217d954SCole Faust 
267*c217d954SCole Faust                 // Dequantize inputs
268*c217d954SCole Faust                 const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
269*c217d954SCole Faust                 const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
270*c217d954SCole Faust 
271*c217d954SCole Faust                 const float32x4x4_t out_f32x4x4 =
272*c217d954SCole Faust                 {
273*c217d954SCole Faust                     vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
274*c217d954SCole Faust                     vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
275*c217d954SCole Faust                     vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
276*c217d954SCole Faust                     vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
277*c217d954SCole Faust                 };
278*c217d954SCole Faust 
279*c217d954SCole Faust                 // Quantize dst
280*c217d954SCole Faust                 const auto result = vquantize<T>(out_f32x4x4, tmp_qua_info);
281*c217d954SCole Faust                 wrapper::vstore(output_ptr + x, result);
282*c217d954SCole Faust             }
283*c217d954SCole Faust 
284*c217d954SCole Faust             // Compute left-over elements
285*c217d954SCole Faust             for(; x < window_end_x; ++x)
286*c217d954SCole Faust             {
287*c217d954SCole Faust                 // Dequantize inputs
288*c217d954SCole Faust                 const T     src1    = *(input1_ptr + x);
289*c217d954SCole Faust                 const T     src2    = *(input2_ptr + x);
290*c217d954SCole Faust                 const float tmp_in1 = Qasymm8QuantizationHelper<T>::dequantize(src1, input1_qua_info);
291*c217d954SCole Faust                 const float tmp_in2 = Qasymm8QuantizationHelper<T>::dequantize(src2, input2_qua_info);
292*c217d954SCole Faust                 const float tmp_f   = tmp_in1 * tmp_in2;
293*c217d954SCole Faust 
294*c217d954SCole Faust                 // Quantize dst
295*c217d954SCole Faust                 const auto tmp_qua = Qasymm8QuantizationHelper<T>::quantize(tmp_f, tmp_qua_info);
296*c217d954SCole Faust                 *(output_ptr + x)  = tmp_qua;
297*c217d954SCole Faust             }
298*c217d954SCole Faust         },
299*c217d954SCole Faust         input1, input2, dst);
300*c217d954SCole Faust     }
301*c217d954SCole Faust }
302*c217d954SCole Faust 
mul_q8_neon_fixedpoint_possible(const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst,float scale)303*c217d954SCole Faust bool mul_q8_neon_fixedpoint_possible(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst, float scale)
304*c217d954SCole Faust {
305*c217d954SCole Faust     const auto iq0 = src0->quantization_info().uniform();
306*c217d954SCole Faust     const auto iq1 = src1->quantization_info().uniform();
307*c217d954SCole Faust     const auto oq  = dst->quantization_info().uniform();
308*c217d954SCole Faust 
309*c217d954SCole Faust     const auto multiplier = ((iq0.scale * iq1.scale) / oq.scale) * scale;
310*c217d954SCole Faust 
311*c217d954SCole Faust     if(multiplier < -8191.f || multiplier > 8191.f)
312*c217d954SCole Faust     {
313*c217d954SCole Faust         //The multiplier cannot be stored as a 14.18 signed fixed-point number
314*c217d954SCole Faust         return false;
315*c217d954SCole Faust     }
316*c217d954SCole Faust 
317*c217d954SCole Faust     const auto offset_out = float(oq.offset);
318*c217d954SCole Faust 
319*c217d954SCole Faust     const auto max_result = multiplier * (256) * (256) + offset_out;
320*c217d954SCole Faust 
321*c217d954SCole Faust     if(max_result > 8191.f)
322*c217d954SCole Faust     {
323*c217d954SCole Faust         //It might not be possible to store the result as a 14.18 signed fixed-point number.
324*c217d954SCole Faust         return false;
325*c217d954SCole Faust     }
326*c217d954SCole Faust 
327*c217d954SCole Faust     return true;
328*c217d954SCole Faust }
329*c217d954SCole Faust 
330*c217d954SCole Faust template <typename ScalarType>
mul_q8_neon_fixedpoint(const ITensor * src0,const ITensor * src1,ITensor * dst,const Window & window,float scale)331*c217d954SCole Faust void mul_q8_neon_fixedpoint(const ITensor *src0, const ITensor *src1, ITensor *dst, const Window &window, float scale)
332*c217d954SCole Faust {
333*c217d954SCole Faust     const auto in0_info = src0->info();
334*c217d954SCole Faust     const auto in1_info = src1->info();
335*c217d954SCole Faust 
336*c217d954SCole Faust     const auto &in0_shape = in0_info->tensor_shape();
337*c217d954SCole Faust     const auto &in1_shape = in1_info->tensor_shape();
338*c217d954SCole Faust 
339*c217d954SCole Faust     // Create input windows.
340*c217d954SCole Faust     Window in0_win = window.broadcast_if_dimension_le_one(in0_shape);
341*c217d954SCole Faust     Window in1_win = window.broadcast_if_dimension_le_one(in1_shape);
342*c217d954SCole Faust 
343*c217d954SCole Faust     // Clear the x dimension on the execution window as we process the whole row each iteration.
344*c217d954SCole Faust     Window win = window;
345*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
346*c217d954SCole Faust 
347*c217d954SCole Faust     constexpr int window_step_x         = 16;
348*c217d954SCole Faust     const auto    window_start_x        = window.x().start();
349*c217d954SCole Faust     const auto    window_end_x          = window.x().end();
350*c217d954SCole Faust     const auto    is_broadcast_across_x = in0_shape.x() != in1_shape.x();
351*c217d954SCole Faust 
352*c217d954SCole Faust     const auto iq0_info = in0_info->quantization_info().uniform();
353*c217d954SCole Faust     const auto iq1_info = in1_info->quantization_info().uniform();
354*c217d954SCole Faust     const auto oq_info  = dst->info()->quantization_info().uniform();
355*c217d954SCole Faust 
356*c217d954SCole Faust     const auto in0_offset = iq0_info.offset;
357*c217d954SCole Faust     const auto in1_offset = iq1_info.offset;
358*c217d954SCole Faust     const auto out_offset = oq_info.offset;
359*c217d954SCole Faust     const auto multiplier = ((iq0_info.scale * iq1_info.scale) / oq_info.scale) * scale;
360*c217d954SCole Faust 
361*c217d954SCole Faust     constexpr int32_t two_pwr18i = 262144;
362*c217d954SCole Faust     constexpr float   two_pwr18f = 262144.f;
363*c217d954SCole Faust 
364*c217d954SCole Faust     const auto in0_offset_16p0  = static_cast<int16_t>(in0_offset);
365*c217d954SCole Faust     const auto in1_offset_16p0  = static_cast<int16_t>(in1_offset);
366*c217d954SCole Faust     const auto out_offset_14p18 = static_cast<int32_t>(out_offset * two_pwr18i);
367*c217d954SCole Faust     const auto multiplier_14p18 = static_cast<int32_t>(multiplier * two_pwr18f);
368*c217d954SCole Faust 
369*c217d954SCole Faust     if(is_broadcast_across_x)
370*c217d954SCole Faust     {
371*c217d954SCole Faust         // Prefix: a = non-broadcast, b = broadcast.
372*c217d954SCole Faust 
373*c217d954SCole Faust         const auto is_broadcast_input_1 = in1_win.x().step() == 0;
374*c217d954SCole Faust         auto       a_win                = is_broadcast_input_1 ? in0_win : in1_win;
375*c217d954SCole Faust         auto       b_win                = is_broadcast_input_1 ? in1_win : in0_win;
376*c217d954SCole Faust         const auto a_tensor             = is_broadcast_input_1 ? src0 : src1;
377*c217d954SCole Faust         const auto b_tensor             = is_broadcast_input_1 ? src1 : src0;
378*c217d954SCole Faust 
379*c217d954SCole Faust         const auto a_offset_16p0 = is_broadcast_input_1 ? in0_offset_16p0 : in1_offset_16p0;
380*c217d954SCole Faust         const auto b_offset_16p0 = is_broadcast_input_1 ? in1_offset : in0_offset;
381*c217d954SCole Faust #ifndef __aarch64__
382*c217d954SCole Faust         const auto a_offset = is_broadcast_input_1 ? in0_offset : in1_offset;
383*c217d954SCole Faust         const auto b_offset = is_broadcast_input_1 ? in1_offset : in0_offset;
384*c217d954SCole Faust #endif //__aarch64__
385*c217d954SCole Faust         const auto a_voffset_16p0 = wrapper::vdup_n(a_offset_16p0, wrapper::traits::vector_64_tag());
386*c217d954SCole Faust 
387*c217d954SCole Faust         // Clear the x dimension on the execution window as we process the whole row each iteration.
388*c217d954SCole Faust         a_win.set(Window::DimX, Window::Dimension(0, 1, 1));
389*c217d954SCole Faust 
390*c217d954SCole Faust         Iterator a_input_it(a_tensor, a_win);
391*c217d954SCole Faust         Iterator b_input_it(b_tensor, b_win);
392*c217d954SCole Faust         Iterator out_it(dst, win);
393*c217d954SCole Faust 
394*c217d954SCole Faust         execute_window_loop(
395*c217d954SCole Faust             win, [&](const Coordinates &)
396*c217d954SCole Faust         {
397*c217d954SCole Faust             const auto a_ptr   = reinterpret_cast<const ScalarType *>(a_input_it.ptr());
398*c217d954SCole Faust             const auto b_ptr   = reinterpret_cast<const ScalarType *>(b_input_it.ptr());
399*c217d954SCole Faust             const auto out_ptr = reinterpret_cast<ScalarType *>(out_it.ptr());
400*c217d954SCole Faust 
401*c217d954SCole Faust             const auto b_val            = *b_ptr;
402*c217d954SCole Faust             const auto b_offseted_32p0  = static_cast<int32_t>(b_val - b_offset_16p0);
403*c217d954SCole Faust             const auto b_voffseted_32p0 = wrapper::vdup_n(b_offseted_32p0, wrapper::traits::vector_128_tag());
404*c217d954SCole Faust 
405*c217d954SCole Faust             const auto vmultiplier_14p18 = wrapper::vdup_n(multiplier_14p18, wrapper::traits::vector_128_tag());
406*c217d954SCole Faust             const auto voffsetout_14p18  = wrapper::vdup_n(out_offset_14p18, wrapper::traits::vector_128_tag());
407*c217d954SCole Faust 
408*c217d954SCole Faust             int x = window_start_x;
409*c217d954SCole Faust 
410*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
411*c217d954SCole Faust             {
412*c217d954SCole Faust                 // Load the inputs.
413*c217d954SCole Faust                 const auto a_vin_8p0 = wrapper::vloadq(a_ptr + x);
414*c217d954SCole Faust 
415*c217d954SCole Faust                 // Widen the non-broadcast elements to signed 16-bit regardless of the input signedness.
416*c217d954SCole Faust                 const auto a_vin_16p0_0 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(a_vin_8p0)));
417*c217d954SCole Faust                 const auto a_vin_16p0_1 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(a_vin_8p0)));
418*c217d954SCole Faust 
419*c217d954SCole Faust                 const auto voffseted_32p0_00 = wrapper::vsubl(wrapper::vgetlow(a_vin_16p0_0), a_voffset_16p0);
420*c217d954SCole Faust                 const auto voffseted_32p0_01 = wrapper::vsubl(wrapper::vgethigh(a_vin_16p0_0), a_voffset_16p0);
421*c217d954SCole Faust                 const auto voffseted_32p0_10 = wrapper::vsubl(wrapper::vgetlow(a_vin_16p0_1), a_voffset_16p0);
422*c217d954SCole Faust                 const auto voffseted_32p0_11 = wrapper::vsubl(wrapper::vgethigh(a_vin_16p0_1), a_voffset_16p0);
423*c217d954SCole Faust 
424*c217d954SCole Faust                 const auto vinnermul_32p0_00 = wrapper::vmul(voffseted_32p0_00, b_voffseted_32p0);
425*c217d954SCole Faust                 const auto vinnermul_32p0_01 = wrapper::vmul(voffseted_32p0_01, b_voffseted_32p0);
426*c217d954SCole Faust                 const auto vinnermul_32p0_10 = wrapper::vmul(voffseted_32p0_10, b_voffseted_32p0);
427*c217d954SCole Faust                 const auto vinnermul_32p0_11 = wrapper::vmul(voffseted_32p0_11, b_voffseted_32p0);
428*c217d954SCole Faust 
429*c217d954SCole Faust                 const auto vout_14p18_00 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_00, vmultiplier_14p18);
430*c217d954SCole Faust                 const auto vout_14p18_01 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_01, vmultiplier_14p18);
431*c217d954SCole Faust                 const auto vout_14p18_10 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_10, vmultiplier_14p18);
432*c217d954SCole Faust                 const auto vout_14p18_11 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_11, vmultiplier_14p18);
433*c217d954SCole Faust 
434*c217d954SCole Faust                 // These shift rights are to revert the multiplication by twopwr18. Hard limit of a maximum shift by 8 requires multiple shift instructions to achieve this.
435*c217d954SCole Faust                 const auto vout_15p1_00 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_00));
436*c217d954SCole Faust                 const auto vout_15p1_01 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_01));
437*c217d954SCole Faust                 const auto vout_15p1_10 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_10));
438*c217d954SCole Faust                 const auto vout_15p1_11 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_11));
439*c217d954SCole Faust 
440*c217d954SCole Faust                 const auto vout_15p1_0 = wrapper::vcombine(
441*c217d954SCole Faust                                              vout_15p1_00,
442*c217d954SCole Faust                                              vout_15p1_01);
443*c217d954SCole Faust 
444*c217d954SCole Faust                 const auto vout_15p1_1 = wrapper::vcombine(
445*c217d954SCole Faust                                              vout_15p1_10,
446*c217d954SCole Faust                                              vout_15p1_11);
447*c217d954SCole Faust                 const auto out_ptr = reinterpret_cast<ScalarType *>(out_it.ptr());
448*c217d954SCole Faust 
449*c217d954SCole Faust                 const auto vout_8p0 = wrapper::vcombine(
450*c217d954SCole Faust                                           wrapper::vqrshrn<2>(vout_15p1_0),
451*c217d954SCole Faust                                           wrapper::vqrshrn<2>(vout_15p1_1));
452*c217d954SCole Faust                 wrapper::vstore(out_ptr + x, vout_8p0);
453*c217d954SCole Faust             }
454*c217d954SCole Faust 
455*c217d954SCole Faust             //Process the left-over elements.
456*c217d954SCole Faust             for(; x < window_end_x; ++x)
457*c217d954SCole Faust             {
458*c217d954SCole Faust #ifdef __aarch64__
459*c217d954SCole Faust                 out_ptr[x] = wrapper::vqrshrn<2>(wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>((multiplier_14p18 * (int32_t(a_ptr[x]) - a_offset_16p0) * (int32_t(
460*c217d954SCole Faust                                                                                                              b_val) - b_offset_16p0)) + out_offset_14p18)));
461*c217d954SCole Faust #else  //__aarch64__
462*c217d954SCole Faust                 out_ptr[x] = utility::clamp<int32_t, ScalarType>(support::cpp11::lround(multiplier * ((float(a_ptr[x]) - a_offset) * (float(b_val) - b_offset)) + float(out_offset)));
463*c217d954SCole Faust #endif //__aarch64__
464*c217d954SCole Faust             }
465*c217d954SCole Faust         },
466*c217d954SCole Faust         a_input_it, b_input_it, out_it);
467*c217d954SCole Faust     }
468*c217d954SCole Faust     else
469*c217d954SCole Faust     {
470*c217d954SCole Faust         const auto voffset0_16p0     = wrapper::vdup_n(in0_offset_16p0, wrapper::traits::vector_64_tag());
471*c217d954SCole Faust         const auto voffset1_16p0     = wrapper::vdup_n(in1_offset_16p0, wrapper::traits::vector_64_tag());
472*c217d954SCole Faust         const auto voffsetout_14p18  = wrapper::vdup_n(out_offset_14p18, wrapper::traits::vector_128_tag());
473*c217d954SCole Faust         const auto vmultiplier_14p18 = wrapper::vdup_n(multiplier_14p18, wrapper::traits::vector_128_tag());
474*c217d954SCole Faust 
475*c217d954SCole Faust         // Clear the x dimension on the execution window as we process the whole row each iteration.
476*c217d954SCole Faust         in0_win.set(Window::DimX, Window::Dimension(0, 1, 1));
477*c217d954SCole Faust         in1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
478*c217d954SCole Faust 
479*c217d954SCole Faust         Iterator in0_it(src0, in0_win);
480*c217d954SCole Faust         Iterator in1_it(src1, in1_win);
481*c217d954SCole Faust         Iterator out_it(dst, win);
482*c217d954SCole Faust 
483*c217d954SCole Faust         execute_window_loop(
484*c217d954SCole Faust             win, [&](const Coordinates &)
485*c217d954SCole Faust         {
486*c217d954SCole Faust             const auto in0_ptr = reinterpret_cast<const ScalarType *>(in0_it.ptr());
487*c217d954SCole Faust             const auto in1_ptr = reinterpret_cast<const ScalarType *>(in1_it.ptr());
488*c217d954SCole Faust             const auto out_ptr = reinterpret_cast<ScalarType *>(out_it.ptr());
489*c217d954SCole Faust 
490*c217d954SCole Faust             int x = window_start_x;
491*c217d954SCole Faust 
492*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
493*c217d954SCole Faust             {
494*c217d954SCole Faust                 // Load the inputs.
495*c217d954SCole Faust                 const auto vin0_8p0 = wrapper::vloadq(in0_ptr + x);
496*c217d954SCole Faust                 const auto vin1_8p0 = wrapper::vloadq(in1_ptr + x);
497*c217d954SCole Faust 
498*c217d954SCole Faust                 // Widen the input elements to signed 16-bit regardless of the input signedness.
499*c217d954SCole Faust                 const auto vin0_16p0_0 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(vin0_8p0)));
500*c217d954SCole Faust                 const auto vin0_16p0_1 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(vin0_8p0)));
501*c217d954SCole Faust                 const auto vin1_16p0_0 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgetlow(vin1_8p0)));
502*c217d954SCole Faust                 const auto vin1_16p0_1 = wrapper::vreinterpret(wrapper::vmovl(wrapper::vgethigh(vin1_8p0)));
503*c217d954SCole Faust 
504*c217d954SCole Faust                 const auto voffseted0_32p0_00 = wrapper::vsubl(wrapper::vgetlow(vin0_16p0_0), voffset0_16p0);
505*c217d954SCole Faust                 const auto voffseted0_32p0_01 = wrapper::vsubl(wrapper::vgethigh(vin0_16p0_0), voffset0_16p0);
506*c217d954SCole Faust                 const auto voffseted0_32p0_10 = wrapper::vsubl(wrapper::vgetlow(vin0_16p0_1), voffset0_16p0);
507*c217d954SCole Faust                 const auto voffseted0_32p0_11 = wrapper::vsubl(wrapper::vgethigh(vin0_16p0_1), voffset0_16p0);
508*c217d954SCole Faust 
509*c217d954SCole Faust                 const auto voffseted1_32p0_00 = wrapper::vsubl(wrapper::vgetlow(vin1_16p0_0), voffset1_16p0);
510*c217d954SCole Faust                 const auto voffseted1_32p0_01 = wrapper::vsubl(wrapper::vgethigh(vin1_16p0_0), voffset1_16p0);
511*c217d954SCole Faust                 const auto voffseted1_32p0_10 = wrapper::vsubl(wrapper::vgetlow(vin1_16p0_1), voffset1_16p0);
512*c217d954SCole Faust                 const auto voffseted1_32p0_11 = wrapper::vsubl(wrapper::vgethigh(vin1_16p0_1), voffset1_16p0);
513*c217d954SCole Faust 
514*c217d954SCole Faust                 const auto vinnermul_32p0_00 = wrapper::vmul(voffseted0_32p0_00, voffseted1_32p0_00);
515*c217d954SCole Faust                 const auto vinnermul_32p0_01 = wrapper::vmul(voffseted0_32p0_01, voffseted1_32p0_01);
516*c217d954SCole Faust                 const auto vinnermul_32p0_10 = wrapper::vmul(voffseted0_32p0_10, voffseted1_32p0_10);
517*c217d954SCole Faust                 const auto vinnermul_32p0_11 = wrapper::vmul(voffseted0_32p0_11, voffseted1_32p0_11);
518*c217d954SCole Faust 
519*c217d954SCole Faust                 const auto vout_14p18_00 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_00, vmultiplier_14p18);
520*c217d954SCole Faust                 const auto vout_14p18_01 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_01, vmultiplier_14p18);
521*c217d954SCole Faust                 const auto vout_14p18_10 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_10, vmultiplier_14p18);
522*c217d954SCole Faust                 const auto vout_14p18_11 = wrapper::vmla(voffsetout_14p18, vinnermul_32p0_11, vmultiplier_14p18);
523*c217d954SCole Faust 
524*c217d954SCole Faust                 // These shift rights are to revert the multiplication by twopwr18. Hard limit of a maximum shift by 8 requires multiple shift instructions to achieve this.
525*c217d954SCole Faust                 const auto vout_14p2_00 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_00));
526*c217d954SCole Faust                 const auto vout_14p2_01 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_01));
527*c217d954SCole Faust                 const auto vout_14p2_10 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_10));
528*c217d954SCole Faust                 const auto vout_14p2_11 = wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>(vout_14p18_11));
529*c217d954SCole Faust 
530*c217d954SCole Faust                 const auto vout_14p2_0 = wrapper::vcombine(
531*c217d954SCole Faust                                              vout_14p2_00,
532*c217d954SCole Faust                                              vout_14p2_01);
533*c217d954SCole Faust 
534*c217d954SCole Faust                 const auto vout_14p2_1 = wrapper::vcombine(
535*c217d954SCole Faust                                              vout_14p2_10,
536*c217d954SCole Faust                                              vout_14p2_11);
537*c217d954SCole Faust 
538*c217d954SCole Faust                 const auto vout_8p0 = wrapper::vcombine(
539*c217d954SCole Faust                                           wrapper::vqrshrn<2>(vout_14p2_0),
540*c217d954SCole Faust                                           wrapper::vqrshrn<2>(vout_14p2_1));
541*c217d954SCole Faust                 wrapper::vstore(out_ptr + x, vout_8p0);
542*c217d954SCole Faust             }
543*c217d954SCole Faust 
544*c217d954SCole Faust             //Process the left-over elements.
545*c217d954SCole Faust             for(; x < window_end_x; ++x)
546*c217d954SCole Faust             {
547*c217d954SCole Faust #ifdef __aarch64__
548*c217d954SCole Faust                 out_ptr[x] = wrapper::vqrshrn<2>(wrapper::vqrshrn_ex<8, ScalarType>(wrapper::vshrq_n<8>((multiplier_14p18 * (int32_t(in0_ptr[x]) - in0_offset_16p0) * (int32_t(
549*c217d954SCole Faust                                                                                                              in1_ptr[x]) - in1_offset_16p0)) + out_offset_14p18)));
550*c217d954SCole Faust #else  //__aarch64__
551*c217d954SCole Faust                 out_ptr[x] = utility::clamp<int32_t, ScalarType>(support::cpp11::lround(multiplier * ((float(in0_ptr[x]) - in0_offset) * (float(in1_ptr[x]) - in1_offset)) + float(out_offset)));
552*c217d954SCole Faust #endif //__aarch64__
553*c217d954SCole Faust             }
554*c217d954SCole Faust         },
555*c217d954SCole Faust         in0_it, in1_it, out_it);
556*c217d954SCole Faust     }
557*c217d954SCole Faust }
558*c217d954SCole Faust 
mul_saturate_QSYMM16_QSYMM16_QSYMM16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,float scale)559*c217d954SCole Faust void mul_saturate_QSYMM16_QSYMM16_QSYMM16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, float scale)
560*c217d954SCole Faust {
561*c217d954SCole Faust     const UniformQuantizationInfo input1_qua_info = src1->info()->quantization_info().uniform();
562*c217d954SCole Faust     const UniformQuantizationInfo input2_qua_info = src2->info()->quantization_info().uniform();
563*c217d954SCole Faust     const UniformQuantizationInfo output_qua_info = out->info()->quantization_info().uniform();
564*c217d954SCole Faust 
565*c217d954SCole Faust     // Create input windows
566*c217d954SCole Faust     Window win        = window;
567*c217d954SCole Faust     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
568*c217d954SCole Faust     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
569*c217d954SCole Faust 
570*c217d954SCole Faust     // Clear X Dimension on execution window as we handle manually
571*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
572*c217d954SCole Faust     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
573*c217d954SCole Faust     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
574*c217d954SCole Faust 
575*c217d954SCole Faust     Iterator input1(src1, input1_win);
576*c217d954SCole Faust     Iterator input2(src2, input2_win);
577*c217d954SCole Faust     Iterator dst(out, win);
578*c217d954SCole Faust 
579*c217d954SCole Faust     const int  window_step_x  = 16;
580*c217d954SCole Faust     const auto window_start_x = static_cast<int>(window.x().start());
581*c217d954SCole Faust     const auto window_end_x   = static_cast<int>(window.x().end());
582*c217d954SCole Faust 
583*c217d954SCole Faust     const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
584*c217d954SCole Faust 
585*c217d954SCole Faust     execute_window_loop(
586*c217d954SCole Faust         win, [&](const Coordinates &)
587*c217d954SCole Faust     {
588*c217d954SCole Faust         const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
589*c217d954SCole Faust         const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
590*c217d954SCole Faust         const auto output_ptr = reinterpret_cast<qsymm16_t *>(dst.ptr());
591*c217d954SCole Faust 
592*c217d954SCole Faust         // Compute window_step_x elements per iteration
593*c217d954SCole Faust         int x = window_start_x;
594*c217d954SCole Faust         for(; x <= (window_end_x - window_step_x); x += window_step_x)
595*c217d954SCole Faust         {
596*c217d954SCole Faust             const qsymm16x8x2_t input1_q =
597*c217d954SCole Faust             {
598*c217d954SCole Faust                 {
599*c217d954SCole Faust                     vld1q_s16(input1_ptr + x),
600*c217d954SCole Faust                     vld1q_s16(input1_ptr + x + 8),
601*c217d954SCole Faust                 }
602*c217d954SCole Faust             };
603*c217d954SCole Faust             const qsymm16x8x2_t input2_q =
604*c217d954SCole Faust             {
605*c217d954SCole Faust                 {
606*c217d954SCole Faust                     vld1q_s16(input2_ptr + x),
607*c217d954SCole Faust                     vld1q_s16(input2_ptr + x + 8),
608*c217d954SCole Faust                 }
609*c217d954SCole Faust             };
610*c217d954SCole Faust 
611*c217d954SCole Faust             // Dequantize inputs
612*c217d954SCole Faust             const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
613*c217d954SCole Faust             const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
614*c217d954SCole Faust 
615*c217d954SCole Faust             const float32x4x4_t out_f32x4x4 =
616*c217d954SCole Faust             {
617*c217d954SCole Faust                 vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
618*c217d954SCole Faust                 vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
619*c217d954SCole Faust                 vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
620*c217d954SCole Faust                 vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]),
621*c217d954SCole Faust             };
622*c217d954SCole Faust 
623*c217d954SCole Faust             const qsymm16x8x2_t result = vquantize_qsymm16(out_f32x4x4, tmp_qua_info);
624*c217d954SCole Faust             vst1q_s16(output_ptr + x, result.val[0]);
625*c217d954SCole Faust             vst1q_s16(output_ptr + x + 8, result.val[1]);
626*c217d954SCole Faust         }
627*c217d954SCole Faust 
628*c217d954SCole Faust         // Compute left-over elements
629*c217d954SCole Faust         for(; x < window_end_x; ++x)
630*c217d954SCole Faust         {
631*c217d954SCole Faust             // Dequantize inputs
632*c217d954SCole Faust             float tmp_in1 = static_cast<float>(*(input1_ptr + x)) * input1_qua_info.scale;
633*c217d954SCole Faust             float tmp_in2 = static_cast<float>(*(input2_ptr + x)) * input2_qua_info.scale;
634*c217d954SCole Faust             float tmp_f   = tmp_in1 * tmp_in2;
635*c217d954SCole Faust 
636*c217d954SCole Faust             // Quantize dst, lrintf() has same rounding mode as vcombine_s16
637*c217d954SCole Faust             int32_t   tmp     = lrintf(tmp_f / tmp_qua_info.scale);
638*c217d954SCole Faust             qsymm16_t tmp_qua = static_cast<qsymm16_t>(tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
639*c217d954SCole Faust             *(output_ptr + x) = tmp_qua;
640*c217d954SCole Faust         }
641*c217d954SCole Faust     },
642*c217d954SCole Faust     input1, input2, dst);
643*c217d954SCole Faust }
644*c217d954SCole Faust 
mul_QSYMM16_QSYMM16_S32(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int scale)645*c217d954SCole Faust void mul_QSYMM16_QSYMM16_S32(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int scale)
646*c217d954SCole Faust {
647*c217d954SCole Faust     ARM_COMPUTE_UNUSED(scale);
648*c217d954SCole Faust 
649*c217d954SCole Faust     // Create input windows
650*c217d954SCole Faust     Window win        = window;
651*c217d954SCole Faust     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
652*c217d954SCole Faust     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
653*c217d954SCole Faust 
654*c217d954SCole Faust     // Clear X Dimension on execution window as we handle manually
655*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
656*c217d954SCole Faust     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
657*c217d954SCole Faust     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
658*c217d954SCole Faust 
659*c217d954SCole Faust     Iterator input1(src1, input1_win);
660*c217d954SCole Faust     Iterator input2(src2, input2_win);
661*c217d954SCole Faust     Iterator dst(out, win);
662*c217d954SCole Faust 
663*c217d954SCole Faust     const int  window_step_x  = 16;
664*c217d954SCole Faust     const auto window_start_x = static_cast<int>(window.x().start());
665*c217d954SCole Faust     const auto window_end_x   = static_cast<int>(window.x().end());
666*c217d954SCole Faust 
667*c217d954SCole Faust     execute_window_loop(
668*c217d954SCole Faust         win, [&](const Coordinates &)
669*c217d954SCole Faust     {
670*c217d954SCole Faust         const auto input1_ptr = reinterpret_cast<const qsymm16_t *>(input1.ptr());
671*c217d954SCole Faust         const auto input2_ptr = reinterpret_cast<const qsymm16_t *>(input2.ptr());
672*c217d954SCole Faust         const auto output_ptr = reinterpret_cast<int32_t *>(dst.ptr());
673*c217d954SCole Faust 
674*c217d954SCole Faust         // Compute window_step_x elements per iteration
675*c217d954SCole Faust         int x = window_start_x;
676*c217d954SCole Faust         for(; x <= (window_end_x - window_step_x); x += window_step_x)
677*c217d954SCole Faust         {
678*c217d954SCole Faust             const qsymm16x8x2_t input1_q =
679*c217d954SCole Faust             {
680*c217d954SCole Faust                 {
681*c217d954SCole Faust                     vld1q_s16(input1_ptr + x),
682*c217d954SCole Faust                     vld1q_s16(input1_ptr + x + 8),
683*c217d954SCole Faust                 }
684*c217d954SCole Faust             };
685*c217d954SCole Faust             const qsymm16x8x2_t input2_q =
686*c217d954SCole Faust             {
687*c217d954SCole Faust                 {
688*c217d954SCole Faust                     vld1q_s16(input2_ptr + x),
689*c217d954SCole Faust                     vld1q_s16(input2_ptr + x + 8),
690*c217d954SCole Faust                 }
691*c217d954SCole Faust             };
692*c217d954SCole Faust 
693*c217d954SCole Faust             const int32x4x4_t in1_s32 =
694*c217d954SCole Faust             {
695*c217d954SCole Faust                 {
696*c217d954SCole Faust                     vmovl_s16(vget_low_s16(input1_q.val[0])),
697*c217d954SCole Faust                     vmovl_s16(vget_high_s16(input1_q.val[0])),
698*c217d954SCole Faust                     vmovl_s16(vget_low_s16(input1_q.val[1])),
699*c217d954SCole Faust                     vmovl_s16(vget_high_s16(input1_q.val[1])),
700*c217d954SCole Faust                 }
701*c217d954SCole Faust             };
702*c217d954SCole Faust             const int32x4x4_t in2_s32 =
703*c217d954SCole Faust             {
704*c217d954SCole Faust                 {
705*c217d954SCole Faust                     vmovl_s16(vget_low_s16(input2_q.val[0])),
706*c217d954SCole Faust                     vmovl_s16(vget_high_s16(input2_q.val[0])),
707*c217d954SCole Faust                     vmovl_s16(vget_low_s16(input2_q.val[1])),
708*c217d954SCole Faust                     vmovl_s16(vget_high_s16(input2_q.val[1])),
709*c217d954SCole Faust                 }
710*c217d954SCole Faust             };
711*c217d954SCole Faust 
712*c217d954SCole Faust             const int32x4x4_t result =
713*c217d954SCole Faust             {
714*c217d954SCole Faust                 {
715*c217d954SCole Faust                     vmulq_s32(in1_s32.val[0], in2_s32.val[0]),
716*c217d954SCole Faust                     vmulq_s32(in1_s32.val[1], in2_s32.val[1]),
717*c217d954SCole Faust                     vmulq_s32(in1_s32.val[2], in2_s32.val[2]),
718*c217d954SCole Faust                     vmulq_s32(in1_s32.val[3], in2_s32.val[3]),
719*c217d954SCole Faust                 }
720*c217d954SCole Faust             };
721*c217d954SCole Faust 
722*c217d954SCole Faust             vst1q_s32(output_ptr + x, result.val[0]);
723*c217d954SCole Faust             vst1q_s32(output_ptr + x + 4, result.val[1]);
724*c217d954SCole Faust             vst1q_s32(output_ptr + x + 8, result.val[2]);
725*c217d954SCole Faust             vst1q_s32(output_ptr + x + 12, result.val[3]);
726*c217d954SCole Faust         }
727*c217d954SCole Faust 
728*c217d954SCole Faust         // Compute left-over elements
729*c217d954SCole Faust         for(; x < window_end_x; ++x)
730*c217d954SCole Faust         {
731*c217d954SCole Faust             int32_t tmp       = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
732*c217d954SCole Faust             *(output_ptr + x) = tmp;
733*c217d954SCole Faust         }
734*c217d954SCole Faust     },
735*c217d954SCole Faust     input1, input2, dst);
736*c217d954SCole Faust }
737*c217d954SCole Faust 
738*c217d954SCole Faust template <bool is_scale255, bool is_sat>
mul_U8_U8_U8(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)739*c217d954SCole Faust void mul_U8_U8_U8(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
740*c217d954SCole Faust {
741*c217d954SCole Faust     // Create input windows
742*c217d954SCole Faust     Window win        = window;
743*c217d954SCole Faust     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
744*c217d954SCole Faust     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
745*c217d954SCole Faust 
746*c217d954SCole Faust     // Clear X Dimension on execution window as we handle manually
747*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
748*c217d954SCole Faust     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
749*c217d954SCole Faust     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
750*c217d954SCole Faust 
751*c217d954SCole Faust     Iterator input1(src1, input1_win);
752*c217d954SCole Faust     Iterator input2(src2, input2_win);
753*c217d954SCole Faust     Iterator dst(out, win);
754*c217d954SCole Faust 
755*c217d954SCole Faust     const int  window_step_x  = 16 / sizeof(uint8_t);
756*c217d954SCole Faust     const auto window_start_x = static_cast<int>(window.x().start());
757*c217d954SCole Faust     const auto window_end_x   = static_cast<int>(window.x().end());
758*c217d954SCole Faust 
759*c217d954SCole Faust     execute_window_loop(
760*c217d954SCole Faust         win, [&](const Coordinates &)
761*c217d954SCole Faust     {
762*c217d954SCole Faust         const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
763*c217d954SCole Faust         const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
764*c217d954SCole Faust         const auto output_ptr = reinterpret_cast<uint8_t *>(dst.ptr());
765*c217d954SCole Faust 
766*c217d954SCole Faust         // Compute window_step_x elements per iteration
767*c217d954SCole Faust         int x = window_start_x;
768*c217d954SCole Faust         for(; x <= (window_end_x - window_step_x); x += window_step_x)
769*c217d954SCole Faust         {
770*c217d954SCole Faust             const uint8x16_t ta1 = wrapper::vloadq(input1_ptr + x);
771*c217d954SCole Faust             const uint8x16_t ta2 = wrapper::vloadq(input2_ptr + x);
772*c217d954SCole Faust 
773*c217d954SCole Faust             uint16x8_t       tmp1_high = vmovl_u8(vget_high_u8(ta1));
774*c217d954SCole Faust             const uint16x8_t tmp2_high = vmovl_u8(vget_high_u8(ta2));
775*c217d954SCole Faust             uint16x8_t       tmp1_low  = vmovl_u8(vget_low_u8(ta1));
776*c217d954SCole Faust             const uint16x8_t tmp2_low  = vmovl_u8(vget_low_u8(ta2));
777*c217d954SCole Faust 
778*c217d954SCole Faust             tmp1_high = vmulq_u16(tmp1_high, tmp2_high);
779*c217d954SCole Faust             tmp1_low  = vmulq_u16(tmp1_low, tmp2_low);
780*c217d954SCole Faust 
781*c217d954SCole Faust             if(is_scale255)
782*c217d954SCole Faust             {
783*c217d954SCole Faust                 tmp1_high = scale255_U16_U16(tmp1_high);
784*c217d954SCole Faust                 tmp1_low  = scale255_U16_U16(tmp1_low);
785*c217d954SCole Faust             }
786*c217d954SCole Faust             else
787*c217d954SCole Faust             {
788*c217d954SCole Faust                 const int16x8_t vn = vdupq_n_s16(-n);
789*c217d954SCole Faust 
790*c217d954SCole Faust                 if(is_sat)
791*c217d954SCole Faust                 {
792*c217d954SCole Faust                     tmp1_high = vqshlq_u16(tmp1_high, vn);
793*c217d954SCole Faust                     tmp1_low  = vqshlq_u16(tmp1_low, vn);
794*c217d954SCole Faust                 }
795*c217d954SCole Faust                 else
796*c217d954SCole Faust                 {
797*c217d954SCole Faust                     tmp1_high = vshlq_u16(tmp1_high, vn);
798*c217d954SCole Faust                     tmp1_low  = vshlq_u16(tmp1_low, vn);
799*c217d954SCole Faust                 }
800*c217d954SCole Faust             }
801*c217d954SCole Faust             if(is_sat)
802*c217d954SCole Faust             {
803*c217d954SCole Faust                 vst1q_u8(output_ptr + x, vcombine_u8(vqmovn_u16(tmp1_low), vqmovn_u16(tmp1_high)));
804*c217d954SCole Faust             }
805*c217d954SCole Faust             else
806*c217d954SCole Faust             {
807*c217d954SCole Faust                 vst1q_u8(output_ptr + x, vcombine_u8(vmovn_u16(tmp1_low), vmovn_u16(tmp1_high)));
808*c217d954SCole Faust             }
809*c217d954SCole Faust         }
810*c217d954SCole Faust 
811*c217d954SCole Faust         // Compute left-over elements
812*c217d954SCole Faust         for(; x < window_end_x; ++x)
813*c217d954SCole Faust         {
814*c217d954SCole Faust             uint16_t tmp = static_cast<uint16_t>(*(input1_ptr + x)) * static_cast<uint16_t>(*(input2_ptr + x));
815*c217d954SCole Faust 
816*c217d954SCole Faust             if(is_scale255)
817*c217d954SCole Faust             {
818*c217d954SCole Faust                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
819*c217d954SCole Faust                 tmp         = static_cast<uint16_t>(tmp_f + 0.5f);
820*c217d954SCole Faust             }
821*c217d954SCole Faust             else
822*c217d954SCole Faust             {
823*c217d954SCole Faust                 tmp >>= n;
824*c217d954SCole Faust             }
825*c217d954SCole Faust             if(is_sat && tmp > 255)
826*c217d954SCole Faust             {
827*c217d954SCole Faust                 tmp = 255;
828*c217d954SCole Faust             }
829*c217d954SCole Faust             *(output_ptr + x) = static_cast<uint8_t>(tmp);
830*c217d954SCole Faust         }
831*c217d954SCole Faust     },
832*c217d954SCole Faust     input1, input2, dst);
833*c217d954SCole Faust }
834*c217d954SCole Faust 
835*c217d954SCole Faust template <bool is_scale255, bool is_sat>
mul_S16_S16_S16_n_loop(const int16x8_t & src1,const int16x8_t & src2,int n)836*c217d954SCole Faust inline int16x8_t mul_S16_S16_S16_n_loop(const int16x8_t &src1, const int16x8_t &src2, int n)
837*c217d954SCole Faust {
838*c217d954SCole Faust     int32x4_t       tmp1_high = vmovl_s16(vget_high_s16(src1));
839*c217d954SCole Faust     const int32x4_t tmp2_high = vmovl_s16(vget_high_s16(src2));
840*c217d954SCole Faust     int32x4_t       tmp1_low  = vmovl_s16(vget_low_s16(src1));
841*c217d954SCole Faust     const int32x4_t tmp2_low  = vmovl_s16(vget_low_s16(src2));
842*c217d954SCole Faust 
843*c217d954SCole Faust     tmp1_high = vmulq_s32(tmp1_high, tmp2_high);
844*c217d954SCole Faust     tmp1_low  = vmulq_s32(tmp1_low, tmp2_low);
845*c217d954SCole Faust 
846*c217d954SCole Faust     if(is_scale255)
847*c217d954SCole Faust     {
848*c217d954SCole Faust         tmp1_high = scale255_S32_S32(tmp1_high);
849*c217d954SCole Faust         tmp1_low  = scale255_S32_S32(tmp1_low);
850*c217d954SCole Faust     }
851*c217d954SCole Faust     else
852*c217d954SCole Faust     {
853*c217d954SCole Faust         // Right shift amount
854*c217d954SCole Faust         const int32x4_t vn = vdupq_n_s32(-n);
855*c217d954SCole Faust         // Left shift amount
856*c217d954SCole Faust         const int32x4_t vnl = vdupq_n_s32(n);
857*c217d954SCole Faust         // Calculate conversion bit
858*c217d954SCole Faust         const uint32x4_t tmp1_high_u  = vreinterpretq_u32_s32(tmp1_high);
859*c217d954SCole Faust         const uint32x4_t tmp1_low_u   = vreinterpretq_u32_s32(tmp1_low);
860*c217d954SCole Faust         const uint32x4_t sign_high    = vshrq_n_u32(tmp1_high_u, 31);
861*c217d954SCole Faust         const uint32x4_t sign_low     = vshrq_n_u32(tmp1_low_u, 31);
862*c217d954SCole Faust         const int32x4_t  sign_high_s  = vreinterpretq_s32_u32(sign_high);
863*c217d954SCole Faust         const int32x4_t  sign_low_s   = vreinterpretq_s32_u32(sign_low);
864*c217d954SCole Faust         const int32x4_t  convert_high = vsubq_s32(vshlq_s32(sign_high_s, vnl), sign_high_s);
865*c217d954SCole Faust         const int32x4_t  convert_low  = vsubq_s32(vshlq_s32(sign_low_s, vnl), sign_low_s);
866*c217d954SCole Faust         if(is_sat)
867*c217d954SCole Faust         {
868*c217d954SCole Faust             tmp1_high = vqshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
869*c217d954SCole Faust             tmp1_low  = vqshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
870*c217d954SCole Faust         }
871*c217d954SCole Faust         else
872*c217d954SCole Faust         {
873*c217d954SCole Faust             tmp1_high = vshlq_s32(vaddq_s32(tmp1_high, convert_high), vn);
874*c217d954SCole Faust             tmp1_low  = vshlq_s32(vaddq_s32(tmp1_low, convert_low), vn);
875*c217d954SCole Faust         }
876*c217d954SCole Faust     }
877*c217d954SCole Faust 
878*c217d954SCole Faust     if(is_sat)
879*c217d954SCole Faust     {
880*c217d954SCole Faust         return vcombine_s16(vqmovn_s32(tmp1_low), vqmovn_s32(tmp1_high));
881*c217d954SCole Faust     }
882*c217d954SCole Faust     else
883*c217d954SCole Faust     {
884*c217d954SCole Faust         return vcombine_s16(vmovn_s32(tmp1_low), vmovn_s32(tmp1_high));
885*c217d954SCole Faust     }
886*c217d954SCole Faust }
887*c217d954SCole Faust 
888*c217d954SCole Faust template <bool is_scale255, bool is_sat>
mul_S16_S16_S16_n_k(const int16x8x2_t & src1,const int16x8x2_t & src2,int n)889*c217d954SCole Faust inline int16x8x2_t mul_S16_S16_S16_n_k(const int16x8x2_t &src1, const int16x8x2_t &src2, int n)
890*c217d954SCole Faust {
891*c217d954SCole Faust     const int16x8x2_t result =
892*c217d954SCole Faust     {
893*c217d954SCole Faust         {
894*c217d954SCole Faust             // First 8 elements
895*c217d954SCole Faust             mul_S16_S16_S16_n_loop<is_scale255, is_sat>(src1.val[0], src2.val[0], n),
896*c217d954SCole Faust             // Second 8 elements
897*c217d954SCole Faust             mul_S16_S16_S16_n_loop<is_scale255, is_sat>(src1.val[1], src2.val[1], n)
898*c217d954SCole Faust         }
899*c217d954SCole Faust     };
900*c217d954SCole Faust 
901*c217d954SCole Faust     return result;
902*c217d954SCole Faust }
903*c217d954SCole Faust 
904*c217d954SCole Faust template <bool is_scale255, bool is_sat>
mul_S16_S16_S16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)905*c217d954SCole Faust void mul_S16_S16_S16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
906*c217d954SCole Faust {
907*c217d954SCole Faust     // Create input windows
908*c217d954SCole Faust     Window win        = window;
909*c217d954SCole Faust     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
910*c217d954SCole Faust     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
911*c217d954SCole Faust 
912*c217d954SCole Faust     // Clear X Dimension on execution window as we handle manually
913*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
914*c217d954SCole Faust     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
915*c217d954SCole Faust     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
916*c217d954SCole Faust 
917*c217d954SCole Faust     Iterator input1(src1, input1_win);
918*c217d954SCole Faust     Iterator input2(src2, input2_win);
919*c217d954SCole Faust     Iterator dst(out, win);
920*c217d954SCole Faust 
921*c217d954SCole Faust     const int  window_step_x  = 16;
922*c217d954SCole Faust     const auto window_start_x = static_cast<int>(window.x().start());
923*c217d954SCole Faust     const auto window_end_x   = static_cast<int>(window.x().end());
924*c217d954SCole Faust 
925*c217d954SCole Faust     execute_window_loop(
926*c217d954SCole Faust         win, [&](const Coordinates &)
927*c217d954SCole Faust     {
928*c217d954SCole Faust         const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
929*c217d954SCole Faust         const auto input2_ptr = reinterpret_cast<const int16_t *>(input2.ptr());
930*c217d954SCole Faust         const auto output_ptr = reinterpret_cast<int16_t *>(dst.ptr());
931*c217d954SCole Faust 
932*c217d954SCole Faust         // Compute window_step_x elements per iteration
933*c217d954SCole Faust         int x = window_start_x;
934*c217d954SCole Faust         for(; x <= (window_end_x - window_step_x); x += window_step_x)
935*c217d954SCole Faust         {
936*c217d954SCole Faust             const int16x8x2_t ta1 =
937*c217d954SCole Faust             {
938*c217d954SCole Faust                 {
939*c217d954SCole Faust                     vld1q_s16(input1_ptr + x),
940*c217d954SCole Faust                     vld1q_s16(input1_ptr + x + 8),
941*c217d954SCole Faust                 }
942*c217d954SCole Faust             };
943*c217d954SCole Faust             const int16x8x2_t ta2 =
944*c217d954SCole Faust             {
945*c217d954SCole Faust                 {
946*c217d954SCole Faust                     vld1q_s16(input2_ptr + x),
947*c217d954SCole Faust                     vld1q_s16(input2_ptr + x + 8),
948*c217d954SCole Faust                 }
949*c217d954SCole Faust             };
950*c217d954SCole Faust             const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
951*c217d954SCole Faust 
952*c217d954SCole Faust             vst1q_s16(output_ptr + x, result.val[0]);
953*c217d954SCole Faust             vst1q_s16(output_ptr + x + 8, result.val[1]);
954*c217d954SCole Faust         }
955*c217d954SCole Faust 
956*c217d954SCole Faust         // Compute left-over elements
957*c217d954SCole Faust         for(; x < window_end_x; ++x)
958*c217d954SCole Faust         {
959*c217d954SCole Faust             int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
960*c217d954SCole Faust 
961*c217d954SCole Faust             if(is_scale255)
962*c217d954SCole Faust             {
963*c217d954SCole Faust                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
964*c217d954SCole Faust 
965*c217d954SCole Faust                 tmp = static_cast<int32_t>(tmp_f + 0.5f);
966*c217d954SCole Faust             }
967*c217d954SCole Faust             else
968*c217d954SCole Faust             {
969*c217d954SCole Faust                 if(tmp >= 0)
970*c217d954SCole Faust                 {
971*c217d954SCole Faust                     tmp >>= n;
972*c217d954SCole Faust                 }
973*c217d954SCole Faust                 else
974*c217d954SCole Faust                 {
975*c217d954SCole Faust                     uint32_t mask = (1u << n) - 1;
976*c217d954SCole Faust                     tmp           = (tmp + static_cast<int32_t>(mask)) >> n;
977*c217d954SCole Faust                 }
978*c217d954SCole Faust             }
979*c217d954SCole Faust             if(is_sat)
980*c217d954SCole Faust             {
981*c217d954SCole Faust                 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
982*c217d954SCole Faust             }
983*c217d954SCole Faust             *(output_ptr + x) = static_cast<int16_t>(tmp);
984*c217d954SCole Faust         }
985*c217d954SCole Faust     },
986*c217d954SCole Faust     input1, input2, dst);
987*c217d954SCole Faust }
988*c217d954SCole Faust 
989*c217d954SCole Faust template <bool is_sat>
mul_S32_S32_S32_n_loop(const int32x4_t & src1,const int32x4_t & src2,int n)990*c217d954SCole Faust inline int32x4_t mul_S32_S32_S32_n_loop(const int32x4_t &src1, const int32x4_t &src2, int n)
991*c217d954SCole Faust {
992*c217d954SCole Faust     const int32x2_t input1_1 = vget_low_s32(src1);
993*c217d954SCole Faust     const int32x2_t input2_1 = vget_low_s32(src2);
994*c217d954SCole Faust     const int32x2_t input1_2 = vget_high_s32(src1);
995*c217d954SCole Faust     const int32x2_t input2_2 = vget_high_s32(src2);
996*c217d954SCole Faust 
997*c217d954SCole Faust     int64x2_t tmp_1 = vmull_s32(input1_1, input2_1);
998*c217d954SCole Faust     int64x2_t tmp_2 = vmull_s32(input1_2, input2_2);
999*c217d954SCole Faust 
1000*c217d954SCole Faust     // Apply scaling, conversion and rounding (round to zero)
1001*c217d954SCole Faust     // Right shift amount
1002*c217d954SCole Faust     const int64x2_t vn = vdupq_n_s64(-n);
1003*c217d954SCole Faust     // Left shift amount
1004*c217d954SCole Faust     const int64x2_t vnl = vdupq_n_s64(n);
1005*c217d954SCole Faust     // Calculate conversion bit
1006*c217d954SCole Faust     const uint64x2_t tmp_1_u   = vreinterpretq_u64_s64(tmp_1);
1007*c217d954SCole Faust     const uint64x2_t sign_1    = vshrq_n_u64(tmp_1_u, 63);
1008*c217d954SCole Faust     const int64x2_t  sign_1_s  = vreinterpretq_s64_u64(sign_1);
1009*c217d954SCole Faust     const int64x2_t  convert_1 = vsubq_s64(vshlq_s64(sign_1_s, vnl), sign_1_s);
1010*c217d954SCole Faust 
1011*c217d954SCole Faust     const uint64x2_t tmp_2_u   = vreinterpretq_u64_s64(tmp_2);
1012*c217d954SCole Faust     const uint64x2_t sign_2    = vshrq_n_u64(tmp_2_u, 63);
1013*c217d954SCole Faust     const int64x2_t  sign_2_s  = vreinterpretq_s64_u64(sign_2);
1014*c217d954SCole Faust     const int64x2_t  convert_2 = vsubq_s64(vshlq_s64(sign_2_s, vnl), sign_2_s);
1015*c217d954SCole Faust     if(is_sat)
1016*c217d954SCole Faust     {
1017*c217d954SCole Faust         tmp_1 = vqshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
1018*c217d954SCole Faust         tmp_2 = vqshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
1019*c217d954SCole Faust         return vcombine_s32(vqmovn_s64(tmp_1), vqmovn_s64(tmp_2));
1020*c217d954SCole Faust     }
1021*c217d954SCole Faust     else
1022*c217d954SCole Faust     {
1023*c217d954SCole Faust         tmp_1 = vshlq_s64(vaddq_s64(tmp_1, convert_1), vn);
1024*c217d954SCole Faust         tmp_2 = vshlq_s64(vaddq_s64(tmp_2, convert_2), vn);
1025*c217d954SCole Faust         return vcombine_s32(vmovn_s64(tmp_1), vmovn_s64(tmp_2));
1026*c217d954SCole Faust     }
1027*c217d954SCole Faust }
1028*c217d954SCole Faust 
1029*c217d954SCole Faust template <bool is_sat>
mul_S32_S32_S32_n_k(const int32x4x2_t & src1,const int32x4x2_t & src2,int n)1030*c217d954SCole Faust inline int32x4x2_t mul_S32_S32_S32_n_k(const int32x4x2_t &src1, const int32x4x2_t &src2, int n)
1031*c217d954SCole Faust {
1032*c217d954SCole Faust     const int32x4x2_t result =
1033*c217d954SCole Faust     {
1034*c217d954SCole Faust         {
1035*c217d954SCole Faust             // First 4 elements
1036*c217d954SCole Faust             mul_S32_S32_S32_n_loop<is_sat>(src1.val[0], src2.val[0], n),
1037*c217d954SCole Faust             // Second 4 elements
1038*c217d954SCole Faust             mul_S32_S32_S32_n_loop<is_sat>(src1.val[1], src2.val[1], n)
1039*c217d954SCole Faust         }
1040*c217d954SCole Faust     };
1041*c217d954SCole Faust 
1042*c217d954SCole Faust     return result;
1043*c217d954SCole Faust }
1044*c217d954SCole Faust 
1045*c217d954SCole Faust template <bool is_sat>
mul_S32_S32_S32(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)1046*c217d954SCole Faust void mul_S32_S32_S32(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
1047*c217d954SCole Faust {
1048*c217d954SCole Faust     // Create input windows
1049*c217d954SCole Faust     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1050*c217d954SCole Faust     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1051*c217d954SCole Faust 
1052*c217d954SCole Faust     // Clear X Dimension on execution window as we handle manually
1053*c217d954SCole Faust     Window win = window;
1054*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1055*c217d954SCole Faust 
1056*c217d954SCole Faust     const int  window_step_x         = 8;
1057*c217d954SCole Faust     const auto window_start_x        = static_cast<int>(window.x().start());
1058*c217d954SCole Faust     const auto window_end_x          = static_cast<int>(window.x().end());
1059*c217d954SCole Faust     const bool is_broadcast_across_x = src1->info()->tensor_shape().x() != src2->info()->tensor_shape().x();
1060*c217d954SCole Faust 
1061*c217d954SCole Faust     if(is_broadcast_across_x)
1062*c217d954SCole Faust     {
1063*c217d954SCole Faust         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
1064*c217d954SCole Faust         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
1065*c217d954SCole Faust         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
1066*c217d954SCole Faust         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? src2 : src1;
1067*c217d954SCole Faust         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? src2 : src1;
1068*c217d954SCole Faust 
1069*c217d954SCole Faust         // Clear X Dimension on execution window as we handle manually
1070*c217d954SCole Faust         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1071*c217d954SCole Faust 
1072*c217d954SCole Faust         Iterator broadcast_input(broadcast_tensor, broadcast_win);
1073*c217d954SCole Faust         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1074*c217d954SCole Faust         Iterator dst(out, win);
1075*c217d954SCole Faust 
1076*c217d954SCole Faust         execute_window_loop(
1077*c217d954SCole Faust             win, [&](const Coordinates &)
1078*c217d954SCole Faust         {
1079*c217d954SCole Faust             const auto non_broadcast_input_ptr = reinterpret_cast<const int32_t *>(non_broadcast_input.ptr());
1080*c217d954SCole Faust             const auto output_ptr              = reinterpret_cast<int32_t *>(dst.ptr());
1081*c217d954SCole Faust 
1082*c217d954SCole Faust             const int32_t broadcast_value     = *reinterpret_cast<const int32_t *>(broadcast_input.ptr());
1083*c217d954SCole Faust             const auto    broadcast_value_vec = vdupq_n_s32(broadcast_value);
1084*c217d954SCole Faust 
1085*c217d954SCole Faust             // Compute window_step_x elements per iteration
1086*c217d954SCole Faust             int x = window_start_x;
1087*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1088*c217d954SCole Faust             {
1089*c217d954SCole Faust                 const int32x4x2_t broadcast_v =
1090*c217d954SCole Faust                 {
1091*c217d954SCole Faust                     {
1092*c217d954SCole Faust                         broadcast_value_vec,
1093*c217d954SCole Faust                         broadcast_value_vec,
1094*c217d954SCole Faust                     }
1095*c217d954SCole Faust                 };
1096*c217d954SCole Faust                 const int32x4x2_t non_broadcast_v =
1097*c217d954SCole Faust                 {
1098*c217d954SCole Faust                     {
1099*c217d954SCole Faust                         vld1q_s32(non_broadcast_input_ptr + x),
1100*c217d954SCole Faust                         vld1q_s32(non_broadcast_input_ptr + x + 4),
1101*c217d954SCole Faust                     }
1102*c217d954SCole Faust                 };
1103*c217d954SCole Faust                 const int32x4x2_t result = mul_S32_S32_S32_n_k<is_sat>(broadcast_v, non_broadcast_v, n);
1104*c217d954SCole Faust 
1105*c217d954SCole Faust                 vst1q_s32(output_ptr + x, result.val[0]);
1106*c217d954SCole Faust                 vst1q_s32(output_ptr + x + 4, result.val[1]);
1107*c217d954SCole Faust             }
1108*c217d954SCole Faust 
1109*c217d954SCole Faust             // Compute left-over elements
1110*c217d954SCole Faust             for(; x < window_end_x; ++x)
1111*c217d954SCole Faust             {
1112*c217d954SCole Faust                 int64_t tmp = static_cast<int64_t>(broadcast_value) * static_cast<int64_t>(*(non_broadcast_input_ptr + x));
1113*c217d954SCole Faust 
1114*c217d954SCole Faust                 if(tmp >= 0)
1115*c217d954SCole Faust                 {
1116*c217d954SCole Faust                     tmp >>= n;
1117*c217d954SCole Faust                 }
1118*c217d954SCole Faust                 else
1119*c217d954SCole Faust                 {
1120*c217d954SCole Faust                     uint64_t mask = ((uint64_t)1u << n) - 1;
1121*c217d954SCole Faust                     tmp           = (tmp + static_cast<int64_t>(mask)) >> n;
1122*c217d954SCole Faust                 }
1123*c217d954SCole Faust                 if(is_sat)
1124*c217d954SCole Faust                 {
1125*c217d954SCole Faust                     tmp = utility::clamp<int64_t, int32_t>(tmp);
1126*c217d954SCole Faust                 }
1127*c217d954SCole Faust                 *(output_ptr + x) = static_cast<int32_t>(tmp);
1128*c217d954SCole Faust             }
1129*c217d954SCole Faust         },
1130*c217d954SCole Faust         broadcast_input, non_broadcast_input, dst);
1131*c217d954SCole Faust     }
1132*c217d954SCole Faust     else
1133*c217d954SCole Faust     {
1134*c217d954SCole Faust         // Clear X Dimension on execution window as we handle manually
1135*c217d954SCole Faust         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1136*c217d954SCole Faust         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1137*c217d954SCole Faust 
1138*c217d954SCole Faust         Iterator input1(src1, input1_win);
1139*c217d954SCole Faust         Iterator input2(src2, input2_win);
1140*c217d954SCole Faust         Iterator dst(out, win);
1141*c217d954SCole Faust 
1142*c217d954SCole Faust         execute_window_loop(
1143*c217d954SCole Faust             win, [&](const Coordinates &)
1144*c217d954SCole Faust         {
1145*c217d954SCole Faust             const auto input1_ptr = reinterpret_cast<const int32_t *>(input1.ptr());
1146*c217d954SCole Faust             const auto input2_ptr = reinterpret_cast<const int32_t *>(input2.ptr());
1147*c217d954SCole Faust             const auto output_ptr = reinterpret_cast<int32_t *>(dst.ptr());
1148*c217d954SCole Faust 
1149*c217d954SCole Faust             // Compute window_step_x elements per iteration
1150*c217d954SCole Faust             int x = window_start_x;
1151*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1152*c217d954SCole Faust             {
1153*c217d954SCole Faust                 const int32x4x2_t ta1 =
1154*c217d954SCole Faust                 {
1155*c217d954SCole Faust                     {
1156*c217d954SCole Faust                         vld1q_s32(input1_ptr + x),
1157*c217d954SCole Faust                         vld1q_s32(input1_ptr + x + 4),
1158*c217d954SCole Faust                     }
1159*c217d954SCole Faust                 };
1160*c217d954SCole Faust                 const int32x4x2_t ta2 =
1161*c217d954SCole Faust                 {
1162*c217d954SCole Faust                     {
1163*c217d954SCole Faust                         vld1q_s32(input2_ptr + x),
1164*c217d954SCole Faust                         vld1q_s32(input2_ptr + x + 4),
1165*c217d954SCole Faust                     }
1166*c217d954SCole Faust                 };
1167*c217d954SCole Faust                 const int32x4x2_t result = mul_S32_S32_S32_n_k<is_sat>(ta1, ta2, n);
1168*c217d954SCole Faust 
1169*c217d954SCole Faust                 vst1q_s32(output_ptr + x, result.val[0]);
1170*c217d954SCole Faust                 vst1q_s32(output_ptr + x + 4, result.val[1]);
1171*c217d954SCole Faust             }
1172*c217d954SCole Faust 
1173*c217d954SCole Faust             // Compute left-over elements
1174*c217d954SCole Faust             for(; x < window_end_x; ++x)
1175*c217d954SCole Faust             {
1176*c217d954SCole Faust                 int64_t tmp = static_cast<int64_t>(*(input1_ptr + x)) * static_cast<int64_t>(*(input2_ptr + x));
1177*c217d954SCole Faust 
1178*c217d954SCole Faust                 if(tmp >= 0)
1179*c217d954SCole Faust                 {
1180*c217d954SCole Faust                     tmp >>= n;
1181*c217d954SCole Faust                 }
1182*c217d954SCole Faust                 else
1183*c217d954SCole Faust                 {
1184*c217d954SCole Faust                     uint64_t mask = ((uint64_t)1u << n) - 1;
1185*c217d954SCole Faust                     tmp           = (tmp + static_cast<int64_t>(mask)) >> n;
1186*c217d954SCole Faust                 }
1187*c217d954SCole Faust                 if(is_sat)
1188*c217d954SCole Faust                 {
1189*c217d954SCole Faust                     tmp = utility::clamp<int64_t, int32_t>(tmp);
1190*c217d954SCole Faust                 }
1191*c217d954SCole Faust                 *(output_ptr + x) = static_cast<int32_t>(tmp);
1192*c217d954SCole Faust             }
1193*c217d954SCole Faust         },
1194*c217d954SCole Faust         input1, input2, dst);
1195*c217d954SCole Faust     }
1196*c217d954SCole Faust }
1197*c217d954SCole Faust 
mul_F32_F32_F32(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,float scale)1198*c217d954SCole Faust void mul_F32_F32_F32(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, float scale)
1199*c217d954SCole Faust {
1200*c217d954SCole Faust     // Create input windows
1201*c217d954SCole Faust     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1202*c217d954SCole Faust     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1203*c217d954SCole Faust 
1204*c217d954SCole Faust     // Clear X Dimension on execution window as we handle manually
1205*c217d954SCole Faust     Window win = window;
1206*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1207*c217d954SCole Faust 
1208*c217d954SCole Faust     constexpr int window_step_x         = 16 / sizeof(float);
1209*c217d954SCole Faust     const auto    window_start_x        = static_cast<int>(window.x().start());
1210*c217d954SCole Faust     const auto    window_end_x          = static_cast<int>(window.x().end());
1211*c217d954SCole Faust     const bool    is_broadcast_across_x = src1->info()->tensor_shape().x() != src2->info()->tensor_shape().x();
1212*c217d954SCole Faust 
1213*c217d954SCole Faust     using ExactTagType = typename wrapper::traits::neon_vector<float, window_step_x>::tag_type;
1214*c217d954SCole Faust 
1215*c217d954SCole Faust     if(is_broadcast_across_x)
1216*c217d954SCole Faust     {
1217*c217d954SCole Faust         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
1218*c217d954SCole Faust         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
1219*c217d954SCole Faust         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
1220*c217d954SCole Faust         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? src2 : src1;
1221*c217d954SCole Faust         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? src2 : src1;
1222*c217d954SCole Faust 
1223*c217d954SCole Faust         // Clear X Dimension on execution window as we handle manually
1224*c217d954SCole Faust         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1225*c217d954SCole Faust 
1226*c217d954SCole Faust         Iterator broadcast_input(broadcast_tensor, broadcast_win);
1227*c217d954SCole Faust         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1228*c217d954SCole Faust         Iterator dst(out, win);
1229*c217d954SCole Faust 
1230*c217d954SCole Faust         execute_window_loop(
1231*c217d954SCole Faust             win, [&](const Coordinates &)
1232*c217d954SCole Faust         {
1233*c217d954SCole Faust             const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
1234*c217d954SCole Faust             const auto output_ptr              = reinterpret_cast<float *>(dst.ptr());
1235*c217d954SCole Faust 
1236*c217d954SCole Faust             const float broadcast_value     = *reinterpret_cast<const float *>(broadcast_input.ptr());
1237*c217d954SCole Faust             const auto  broadcast_value_vec = wrapper::vdup_n(broadcast_value, ExactTagType{});
1238*c217d954SCole Faust             const auto  scale_vec           = wrapper::vdup_n(scale, ExactTagType{});
1239*c217d954SCole Faust 
1240*c217d954SCole Faust             // Compute window_step_x elements per iteration
1241*c217d954SCole Faust             int x = window_start_x;
1242*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1243*c217d954SCole Faust             {
1244*c217d954SCole Faust                 const auto non_broadcast_v = wrapper::vloadq(non_broadcast_input_ptr + x);
1245*c217d954SCole Faust                 auto       res             = wrapper::vmul(wrapper::vmul(broadcast_value_vec, non_broadcast_v), scale_vec);
1246*c217d954SCole Faust                 wrapper::vstore(output_ptr + x, res);
1247*c217d954SCole Faust             }
1248*c217d954SCole Faust 
1249*c217d954SCole Faust             // Compute left-over elements
1250*c217d954SCole Faust             for(; x < window_end_x; ++x)
1251*c217d954SCole Faust             {
1252*c217d954SCole Faust                 const auto non_broadcast_v = *(non_broadcast_input_ptr + x);
1253*c217d954SCole Faust                 *(output_ptr + x)          = broadcast_value * non_broadcast_v * scale;
1254*c217d954SCole Faust             }
1255*c217d954SCole Faust         },
1256*c217d954SCole Faust         broadcast_input, non_broadcast_input, dst);
1257*c217d954SCole Faust     }
1258*c217d954SCole Faust     else
1259*c217d954SCole Faust     {
1260*c217d954SCole Faust         // Clear X Dimension on execution window as we handle manually
1261*c217d954SCole Faust         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1262*c217d954SCole Faust         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1263*c217d954SCole Faust 
1264*c217d954SCole Faust         Iterator input1(src1, input1_win);
1265*c217d954SCole Faust         Iterator input2(src2, input2_win);
1266*c217d954SCole Faust         Iterator dst(out, win);
1267*c217d954SCole Faust 
1268*c217d954SCole Faust         execute_window_loop(
1269*c217d954SCole Faust             win, [&](const Coordinates &)
1270*c217d954SCole Faust         {
1271*c217d954SCole Faust             const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
1272*c217d954SCole Faust             const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
1273*c217d954SCole Faust             const auto output_ptr = reinterpret_cast<float *>(dst.ptr());
1274*c217d954SCole Faust 
1275*c217d954SCole Faust             // Compute window_step_x elements per iteration
1276*c217d954SCole Faust             int x = window_start_x;
1277*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1278*c217d954SCole Faust             {
1279*c217d954SCole Faust                 const auto ta1       = wrapper::vloadq(input1_ptr + x);
1280*c217d954SCole Faust                 const auto ta2       = wrapper::vloadq(input2_ptr + x);
1281*c217d954SCole Faust                 const auto scale_vec = wrapper::vdup_n(scale, ExactTagType{});
1282*c217d954SCole Faust                 const auto res       = wrapper::vmul(wrapper::vmul(ta1, ta2), scale_vec);
1283*c217d954SCole Faust                 wrapper::vstore(output_ptr + x, res);
1284*c217d954SCole Faust             }
1285*c217d954SCole Faust 
1286*c217d954SCole Faust             // Compute left-over elements
1287*c217d954SCole Faust             for(; x < window_end_x; ++x)
1288*c217d954SCole Faust             {
1289*c217d954SCole Faust                 const auto ta1    = *(input1_ptr + x);
1290*c217d954SCole Faust                 const auto ta2    = *(input2_ptr + x);
1291*c217d954SCole Faust                 *(output_ptr + x) = ta1 * ta2 * scale;
1292*c217d954SCole Faust             }
1293*c217d954SCole Faust         },
1294*c217d954SCole Faust         input1, input2, dst);
1295*c217d954SCole Faust     }
1296*c217d954SCole Faust }
1297*c217d954SCole Faust 
c_mul_F32_F32_F32_n(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window)1298*c217d954SCole Faust void c_mul_F32_F32_F32_n(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window)
1299*c217d954SCole Faust {
1300*c217d954SCole Faust     // Create input windows
1301*c217d954SCole Faust     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1302*c217d954SCole Faust     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1303*c217d954SCole Faust 
1304*c217d954SCole Faust     // Clear X Dimension on execution window as we handle manually
1305*c217d954SCole Faust     Window win = window;
1306*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1307*c217d954SCole Faust 
1308*c217d954SCole Faust     constexpr int window_step_x         = 8 / sizeof(float);
1309*c217d954SCole Faust     const auto    window_start_x        = static_cast<int>(window.x().start());
1310*c217d954SCole Faust     const auto    window_end_x          = static_cast<int>(window.x().end());
1311*c217d954SCole Faust     const bool    is_broadcast_across_x = src1->info()->tensor_shape().x() != src2->info()->tensor_shape().x();
1312*c217d954SCole Faust 
1313*c217d954SCole Faust     using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
1314*c217d954SCole Faust 
1315*c217d954SCole Faust     if(is_broadcast_across_x)
1316*c217d954SCole Faust     {
1317*c217d954SCole Faust         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
1318*c217d954SCole Faust         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
1319*c217d954SCole Faust         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
1320*c217d954SCole Faust         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? src2 : src1;
1321*c217d954SCole Faust         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? src2 : src1;
1322*c217d954SCole Faust 
1323*c217d954SCole Faust         // Clear X Dimension on execution window as we handle manually
1324*c217d954SCole Faust         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1325*c217d954SCole Faust 
1326*c217d954SCole Faust         Iterator broadcast_input(broadcast_tensor, broadcast_win);
1327*c217d954SCole Faust         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1328*c217d954SCole Faust         Iterator dst(out, win);
1329*c217d954SCole Faust 
1330*c217d954SCole Faust         execute_window_loop(
1331*c217d954SCole Faust             win, [&](const Coordinates &)
1332*c217d954SCole Faust         {
1333*c217d954SCole Faust             const auto non_broadcast_input_ptr = reinterpret_cast<const float *>(non_broadcast_input.ptr());
1334*c217d954SCole Faust             const auto output_ptr              = reinterpret_cast<float *>(dst.ptr());
1335*c217d954SCole Faust 
1336*c217d954SCole Faust             const float broadcast_value = *reinterpret_cast<const float *>(broadcast_input.ptr());
1337*c217d954SCole Faust 
1338*c217d954SCole Faust             // Compute window_step_x elements per iteration
1339*c217d954SCole Faust             int x = window_start_x;
1340*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1341*c217d954SCole Faust             {
1342*c217d954SCole Faust                 const auto  a = wrapper::vloadq(non_broadcast_input_ptr + 2 * x);
1343*c217d954SCole Faust                 float32x4_t b = vdupq_n_f32(broadcast_value);
1344*c217d954SCole Faust 
1345*c217d954SCole Faust                 const float32x4_t mask  = { -1.0f, 1.0f, -1.0f, 1.0f };
1346*c217d954SCole Faust                 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
1347*c217d954SCole Faust                 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
1348*c217d954SCole Faust                 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
1349*c217d954SCole Faust                 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
1350*c217d954SCole Faust 
1351*c217d954SCole Faust                 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
1352*c217d954SCole Faust                 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
1353*c217d954SCole Faust 
1354*c217d954SCole Faust                 float32x4_t res = wrapper::vmul(tmp0, b);
1355*c217d954SCole Faust                 b               = wrapper::vmul(b, mask);
1356*c217d954SCole Faust 
1357*c217d954SCole Faust                 res = wrapper::vmla(res, tmp1, b);
1358*c217d954SCole Faust                 wrapper::vstore(output_ptr + 2 * x, res);
1359*c217d954SCole Faust             }
1360*c217d954SCole Faust 
1361*c217d954SCole Faust             // Compute left-over elements
1362*c217d954SCole Faust             for(; x < window_end_x; ++x)
1363*c217d954SCole Faust             {
1364*c217d954SCole Faust                 const auto non_broadcast_value0 = *(non_broadcast_input_ptr + 2 * x);
1365*c217d954SCole Faust                 const auto non_broadcast_value1 = *(non_broadcast_input_ptr + 2 * x + 1);
1366*c217d954SCole Faust                 auto       res1                 = broadcast_value * (non_broadcast_value0 - non_broadcast_value1);
1367*c217d954SCole Faust                 auto       res2                 = broadcast_value * (non_broadcast_value1 + non_broadcast_value0);
1368*c217d954SCole Faust                 *(output_ptr + 2 * x)           = res1;
1369*c217d954SCole Faust                 *(output_ptr + 2 * x + 1)       = res2;
1370*c217d954SCole Faust             }
1371*c217d954SCole Faust         },
1372*c217d954SCole Faust         broadcast_input, non_broadcast_input, dst);
1373*c217d954SCole Faust     }
1374*c217d954SCole Faust     else
1375*c217d954SCole Faust     {
1376*c217d954SCole Faust         // Clear X Dimension on execution window as we handle manually
1377*c217d954SCole Faust         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1378*c217d954SCole Faust         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1379*c217d954SCole Faust 
1380*c217d954SCole Faust         Iterator input1(src1, input1_win);
1381*c217d954SCole Faust         Iterator input2(src2, input2_win);
1382*c217d954SCole Faust         Iterator dst(out, win);
1383*c217d954SCole Faust 
1384*c217d954SCole Faust         execute_window_loop(
1385*c217d954SCole Faust             win, [&](const Coordinates &)
1386*c217d954SCole Faust         {
1387*c217d954SCole Faust             const auto input1_ptr = reinterpret_cast<const float *>(input1.ptr());
1388*c217d954SCole Faust             const auto input2_ptr = reinterpret_cast<const float *>(input2.ptr());
1389*c217d954SCole Faust             const auto output_ptr = reinterpret_cast<float *>(dst.ptr());
1390*c217d954SCole Faust 
1391*c217d954SCole Faust             // Compute window_step_x elements per iteration
1392*c217d954SCole Faust             int x = window_start_x;
1393*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1394*c217d954SCole Faust             {
1395*c217d954SCole Faust                 const float32x4_t a = wrapper::vloadq(input1_ptr + 2 * x);
1396*c217d954SCole Faust                 float32x4_t       b = wrapper::vloadq(input2_ptr + 2 * x);
1397*c217d954SCole Faust 
1398*c217d954SCole Faust                 const float32x4_t mask  = { -1.0f, 1.0f, -1.0f, 1.0f };
1399*c217d954SCole Faust                 const float32x2_t tmp00 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
1400*c217d954SCole Faust                 const float32x2_t tmp01 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
1401*c217d954SCole Faust                 const float32x2_t tmp10 = wrapper::vdup_n(wrapper::vgetlane(a, 2), ExactTagType{});
1402*c217d954SCole Faust                 const float32x2_t tmp11 = wrapper::vdup_n(wrapper::vgetlane(a, 3), ExactTagType{});
1403*c217d954SCole Faust 
1404*c217d954SCole Faust                 const float32x4_t tmp0 = wrapper::vcombine(tmp00, tmp10);
1405*c217d954SCole Faust                 const float32x4_t tmp1 = wrapper::vcombine(tmp01, tmp11);
1406*c217d954SCole Faust 
1407*c217d954SCole Faust                 float32x4_t res = wrapper::vmul(tmp0, b);
1408*c217d954SCole Faust 
1409*c217d954SCole Faust                 b = wrapper::vrev64(b);
1410*c217d954SCole Faust                 b = wrapper::vmul(b, mask);
1411*c217d954SCole Faust 
1412*c217d954SCole Faust                 res = wrapper::vmla(res, tmp1, b);
1413*c217d954SCole Faust                 wrapper::vstore(output_ptr + 2 * x, res);
1414*c217d954SCole Faust             }
1415*c217d954SCole Faust 
1416*c217d954SCole Faust             // Compute left-over elements
1417*c217d954SCole Faust             for(; x < window_end_x; ++x)
1418*c217d954SCole Faust             {
1419*c217d954SCole Faust                 const auto a0             = *(input1_ptr + 2 * x);
1420*c217d954SCole Faust                 const auto a1             = *(input1_ptr + 2 * x + 1);
1421*c217d954SCole Faust                 const auto b0             = *(input2_ptr + 2 * x);
1422*c217d954SCole Faust                 const auto b1             = *(input2_ptr + 2 * x + 1);
1423*c217d954SCole Faust                 auto       res1           = a0 * b0 - a1 * b1;
1424*c217d954SCole Faust                 auto       res2           = a0 * b1 + a1 * b0;
1425*c217d954SCole Faust                 *(output_ptr + 2 * x)     = res1;
1426*c217d954SCole Faust                 *(output_ptr + 2 * x + 1) = res2;
1427*c217d954SCole Faust             }
1428*c217d954SCole Faust         },
1429*c217d954SCole Faust         input1, input2, dst);
1430*c217d954SCole Faust     }
1431*c217d954SCole Faust }
1432*c217d954SCole Faust 
1433*c217d954SCole Faust #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
mul_F16_F16_F16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,float scale)1434*c217d954SCole Faust void mul_F16_F16_F16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, float scale)
1435*c217d954SCole Faust {
1436*c217d954SCole Faust     // Create input windows
1437*c217d954SCole Faust     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1438*c217d954SCole Faust     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1439*c217d954SCole Faust 
1440*c217d954SCole Faust     // Clear X Dimension on execution window as we handle manually
1441*c217d954SCole Faust     Window win = window;
1442*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1443*c217d954SCole Faust     constexpr int window_step_x         = 16;
1444*c217d954SCole Faust     const auto    window_start_x        = static_cast<int>(window.x().start());
1445*c217d954SCole Faust     const auto    window_end_x          = static_cast<int>(window.x().end());
1446*c217d954SCole Faust     const bool    is_broadcast_across_x = src1->info()->tensor_shape().x() != src2->info()->tensor_shape().x();
1447*c217d954SCole Faust     if(is_broadcast_across_x)
1448*c217d954SCole Faust     {
1449*c217d954SCole Faust         const bool     is_broadcast_input_2 = input2_win.x().step() == 0;
1450*c217d954SCole Faust         Window         broadcast_win        = is_broadcast_input_2 ? input2_win : input1_win;
1451*c217d954SCole Faust         Window         non_broadcast_win    = !is_broadcast_input_2 ? input2_win : input1_win;
1452*c217d954SCole Faust         const ITensor *broadcast_tensor     = is_broadcast_input_2 ? src2 : src1;
1453*c217d954SCole Faust         const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? src2 : src1;
1454*c217d954SCole Faust         // Clear X Dimension on execution window as we handle manually
1455*c217d954SCole Faust         non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1456*c217d954SCole Faust         Iterator broadcast_input(broadcast_tensor, broadcast_win);
1457*c217d954SCole Faust         Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win);
1458*c217d954SCole Faust         Iterator dst(out, win);
1459*c217d954SCole Faust         execute_window_loop(
1460*c217d954SCole Faust             win, [&](const Coordinates &)
1461*c217d954SCole Faust         {
1462*c217d954SCole Faust             const auto          non_broadcast_input_ptr = reinterpret_cast<const float16_t *>(non_broadcast_input.ptr());
1463*c217d954SCole Faust             const auto          output_ptr              = reinterpret_cast<float16_t *>(dst.ptr());
1464*c217d954SCole Faust             const auto          broadcast_value         = *reinterpret_cast<const float16_t *>(broadcast_input.ptr());
1465*c217d954SCole Faust             const float16x8x2_t broadcast_value_vec     =
1466*c217d954SCole Faust             {
1467*c217d954SCole Faust                 {
1468*c217d954SCole Faust                     vdupq_n_f16(broadcast_value),
1469*c217d954SCole Faust                     vdupq_n_f16(broadcast_value),
1470*c217d954SCole Faust                 }
1471*c217d954SCole Faust             };
1472*c217d954SCole Faust             const auto scale_vec = vdupq_n_f16(scale);
1473*c217d954SCole Faust             // Compute window_step_x elements per iteration
1474*c217d954SCole Faust             int x = window_start_x;
1475*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1476*c217d954SCole Faust             {
1477*c217d954SCole Faust                 const float16x8x2_t non_broadcast_v =
1478*c217d954SCole Faust                 {
1479*c217d954SCole Faust                     {
1480*c217d954SCole Faust                         vld1q_f16(non_broadcast_input_ptr + x),
1481*c217d954SCole Faust                         vld1q_f16(non_broadcast_input_ptr + x + 8),
1482*c217d954SCole Faust                     }
1483*c217d954SCole Faust                 };
1484*c217d954SCole Faust                 const float16x8x2_t result =
1485*c217d954SCole Faust                 {
1486*c217d954SCole Faust                     {
1487*c217d954SCole Faust                         vmulq_f16(vmulq_f16(broadcast_value_vec.val[0], non_broadcast_v.val[0]), scale_vec),
1488*c217d954SCole Faust                         vmulq_f16(vmulq_f16(broadcast_value_vec.val[1], non_broadcast_v.val[1]), scale_vec),
1489*c217d954SCole Faust                     }
1490*c217d954SCole Faust                 };
1491*c217d954SCole Faust                 vst1q_f16(output_ptr + x, result.val[0]);
1492*c217d954SCole Faust                 vst1q_f16(output_ptr + x + 8, result.val[1]);
1493*c217d954SCole Faust             }
1494*c217d954SCole Faust             // Compute left-over elements
1495*c217d954SCole Faust             for(; x < window_end_x; ++x)
1496*c217d954SCole Faust             {
1497*c217d954SCole Faust                 const auto non_broadcast_v = *(non_broadcast_input_ptr + x);
1498*c217d954SCole Faust                 *(output_ptr + x)          = broadcast_value * non_broadcast_v * scale;
1499*c217d954SCole Faust             }
1500*c217d954SCole Faust         },
1501*c217d954SCole Faust         broadcast_input, non_broadcast_input, dst);
1502*c217d954SCole Faust     }
1503*c217d954SCole Faust     else
1504*c217d954SCole Faust     {
1505*c217d954SCole Faust         input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1506*c217d954SCole Faust         input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1507*c217d954SCole Faust         Iterator input1(src1, input1_win);
1508*c217d954SCole Faust         Iterator input2(src2, input2_win);
1509*c217d954SCole Faust         Iterator dst(out, win);
1510*c217d954SCole Faust         execute_window_loop(
1511*c217d954SCole Faust             win, [&](const Coordinates &)
1512*c217d954SCole Faust         {
1513*c217d954SCole Faust             const auto input1_ptr = reinterpret_cast<const float16_t *>(input1.ptr());
1514*c217d954SCole Faust             const auto input2_ptr = reinterpret_cast<const float16_t *>(input2.ptr());
1515*c217d954SCole Faust             const auto output_ptr = reinterpret_cast<float16_t *>(dst.ptr());
1516*c217d954SCole Faust             // Compute window_step_x elements per iteration
1517*c217d954SCole Faust             int x = window_start_x;
1518*c217d954SCole Faust             for(; x <= (window_end_x - window_step_x); x += window_step_x)
1519*c217d954SCole Faust             {
1520*c217d954SCole Faust                 const float16x8x2_t ta1 =
1521*c217d954SCole Faust                 {
1522*c217d954SCole Faust                     {
1523*c217d954SCole Faust                         vld1q_f16(input1_ptr + x),
1524*c217d954SCole Faust                         vld1q_f16(input1_ptr + x + 8),
1525*c217d954SCole Faust                     }
1526*c217d954SCole Faust                 };
1527*c217d954SCole Faust                 const float16x8x2_t ta2 =
1528*c217d954SCole Faust                 {
1529*c217d954SCole Faust                     {
1530*c217d954SCole Faust                         vld1q_f16(input2_ptr + x),
1531*c217d954SCole Faust                         vld1q_f16(input2_ptr + x + 8),
1532*c217d954SCole Faust                     }
1533*c217d954SCole Faust                 };
1534*c217d954SCole Faust                 const float16x8_t   scale_vec = vdupq_n_f16(scale);
1535*c217d954SCole Faust                 const float16x8x2_t result    =
1536*c217d954SCole Faust                 {
1537*c217d954SCole Faust                     {
1538*c217d954SCole Faust                         vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
1539*c217d954SCole Faust                         vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
1540*c217d954SCole Faust                     }
1541*c217d954SCole Faust                 };
1542*c217d954SCole Faust                 vst1q_f16(output_ptr + x, result.val[0]);
1543*c217d954SCole Faust                 vst1q_f16(output_ptr + x + 8, result.val[1]);
1544*c217d954SCole Faust             }
1545*c217d954SCole Faust             // Compute left-over elements
1546*c217d954SCole Faust             for(; x < window_end_x; ++x)
1547*c217d954SCole Faust             {
1548*c217d954SCole Faust                 const auto ta1    = *(input1_ptr + x);
1549*c217d954SCole Faust                 const auto ta2    = *(input2_ptr + x);
1550*c217d954SCole Faust                 *(output_ptr + x) = ta1 * ta2 * scale;
1551*c217d954SCole Faust             }
1552*c217d954SCole Faust         },
1553*c217d954SCole Faust         input1, input2, dst);
1554*c217d954SCole Faust     }
1555*c217d954SCole Faust }
1556*c217d954SCole Faust #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1557*c217d954SCole Faust 
1558*c217d954SCole Faust template <bool is_scale255, bool is_sat>
mul_U8_U8_S16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)1559*c217d954SCole Faust void mul_U8_U8_S16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
1560*c217d954SCole Faust {
1561*c217d954SCole Faust     // Create input windows
1562*c217d954SCole Faust     Window win        = window;
1563*c217d954SCole Faust     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1564*c217d954SCole Faust     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1565*c217d954SCole Faust 
1566*c217d954SCole Faust     // Clear X Dimension on execution window as we handle manually
1567*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1568*c217d954SCole Faust     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1569*c217d954SCole Faust     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1570*c217d954SCole Faust 
1571*c217d954SCole Faust     Iterator input1(src1, input1_win);
1572*c217d954SCole Faust     Iterator input2(src2, input2_win);
1573*c217d954SCole Faust     Iterator dst(out, win);
1574*c217d954SCole Faust 
1575*c217d954SCole Faust     const int  window_step_x  = 16 / sizeof(uint8_t);
1576*c217d954SCole Faust     const auto window_start_x = static_cast<int>(window.x().start());
1577*c217d954SCole Faust     const auto window_end_x   = static_cast<int>(window.x().end());
1578*c217d954SCole Faust 
1579*c217d954SCole Faust     execute_window_loop(
1580*c217d954SCole Faust         win, [&](const Coordinates &)
1581*c217d954SCole Faust     {
1582*c217d954SCole Faust         const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
1583*c217d954SCole Faust         const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1584*c217d954SCole Faust         const auto output_ptr = reinterpret_cast<int16_t *>(dst.ptr());
1585*c217d954SCole Faust 
1586*c217d954SCole Faust         // Compute window_step_x elements per iteration
1587*c217d954SCole Faust         int x = window_start_x;
1588*c217d954SCole Faust         for(; x <= (window_end_x - window_step_x); x += window_step_x)
1589*c217d954SCole Faust         {
1590*c217d954SCole Faust             const uint8x16_t bv = wrapper::vloadq(input2_ptr + x);
1591*c217d954SCole Faust             const uint8x16_t av = wrapper::vloadq(input1_ptr + x);
1592*c217d954SCole Faust 
1593*c217d954SCole Faust             uint16x8_t tmp_low  = vmovl_u8(vget_low_u8(av));
1594*c217d954SCole Faust             uint16x8_t tmp_high = vmovl_u8(vget_high_u8(av));
1595*c217d954SCole Faust             tmp_low             = vmulq_u16(tmp_low, vmovl_u8(vget_low_u8(bv)));
1596*c217d954SCole Faust             tmp_high            = vmulq_u16(tmp_high, vmovl_u8(vget_high_u8(bv)));
1597*c217d954SCole Faust 
1598*c217d954SCole Faust             if(is_scale255)
1599*c217d954SCole Faust             {
1600*c217d954SCole Faust                 tmp_low  = scale255_U16_U16(tmp_low);
1601*c217d954SCole Faust                 tmp_high = scale255_U16_U16(tmp_high);
1602*c217d954SCole Faust             }
1603*c217d954SCole Faust             else
1604*c217d954SCole Faust             {
1605*c217d954SCole Faust                 const int16x8_t vn = vdupq_n_s16(-n);
1606*c217d954SCole Faust 
1607*c217d954SCole Faust                 if(is_sat)
1608*c217d954SCole Faust                 {
1609*c217d954SCole Faust                     tmp_low  = vqshlq_u16(tmp_low, vn);
1610*c217d954SCole Faust                     tmp_high = vqshlq_u16(tmp_high, vn);
1611*c217d954SCole Faust                 }
1612*c217d954SCole Faust                 else
1613*c217d954SCole Faust                 {
1614*c217d954SCole Faust                     tmp_low  = vshlq_u16(tmp_low, vn);
1615*c217d954SCole Faust                     tmp_high = vshlq_u16(tmp_high, vn);
1616*c217d954SCole Faust                 }
1617*c217d954SCole Faust             }
1618*c217d954SCole Faust 
1619*c217d954SCole Faust             if(is_sat)
1620*c217d954SCole Faust             {
1621*c217d954SCole Faust                 static const uint16x8_t max = vdupq_n_u16(SHRT_MAX);
1622*c217d954SCole Faust 
1623*c217d954SCole Faust                 tmp_low  = vminq_u16(tmp_low, max);
1624*c217d954SCole Faust                 tmp_high = vminq_u16(tmp_high, max);
1625*c217d954SCole Faust             }
1626*c217d954SCole Faust 
1627*c217d954SCole Faust             vst1q_s16(output_ptr + x, vreinterpretq_s16_u16(tmp_low));
1628*c217d954SCole Faust             vst1q_s16(output_ptr + x + 8, vreinterpretq_s16_u16(tmp_high));
1629*c217d954SCole Faust         }
1630*c217d954SCole Faust 
1631*c217d954SCole Faust         // Compute left-over elements
1632*c217d954SCole Faust         for(; x < window_end_x; ++x)
1633*c217d954SCole Faust         {
1634*c217d954SCole Faust             int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1635*c217d954SCole Faust 
1636*c217d954SCole Faust             if(is_scale255)
1637*c217d954SCole Faust             {
1638*c217d954SCole Faust                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1639*c217d954SCole Faust                 tmp         = static_cast<int32_t>(tmp_f + 0.5f);
1640*c217d954SCole Faust             }
1641*c217d954SCole Faust             else
1642*c217d954SCole Faust             {
1643*c217d954SCole Faust                 tmp >>= n;
1644*c217d954SCole Faust             }
1645*c217d954SCole Faust 
1646*c217d954SCole Faust             if(is_sat)
1647*c217d954SCole Faust             {
1648*c217d954SCole Faust                 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : tmp;
1649*c217d954SCole Faust             }
1650*c217d954SCole Faust 
1651*c217d954SCole Faust             *(output_ptr + x) = static_cast<int16_t>(tmp);
1652*c217d954SCole Faust         }
1653*c217d954SCole Faust     },
1654*c217d954SCole Faust     input1, input2, dst);
1655*c217d954SCole Faust }
1656*c217d954SCole Faust 
1657*c217d954SCole Faust template <bool is_scale255, bool is_sat>
mul_S16_U8_S16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)1658*c217d954SCole Faust void mul_S16_U8_S16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
1659*c217d954SCole Faust {
1660*c217d954SCole Faust     // Create input windows
1661*c217d954SCole Faust     Window win        = window;
1662*c217d954SCole Faust     Window input1_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
1663*c217d954SCole Faust     Window input2_win = window.broadcast_if_dimension_le_one(src2->info()->tensor_shape());
1664*c217d954SCole Faust 
1665*c217d954SCole Faust     // Clear X Dimension on execution window as we handle manually
1666*c217d954SCole Faust     win.set(Window::DimX, Window::Dimension(0, 1, 1));
1667*c217d954SCole Faust     input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1668*c217d954SCole Faust     input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
1669*c217d954SCole Faust 
1670*c217d954SCole Faust     Iterator input1(src1, input1_win);
1671*c217d954SCole Faust     Iterator input2(src2, input2_win);
1672*c217d954SCole Faust     Iterator dst(out, win);
1673*c217d954SCole Faust 
1674*c217d954SCole Faust     const int  window_step_x  = 16;
1675*c217d954SCole Faust     const auto window_start_x = static_cast<int>(window.x().start());
1676*c217d954SCole Faust     const auto window_end_x   = static_cast<int>(window.x().end());
1677*c217d954SCole Faust 
1678*c217d954SCole Faust     execute_window_loop(
1679*c217d954SCole Faust         win, [&](const Coordinates &)
1680*c217d954SCole Faust     {
1681*c217d954SCole Faust         const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
1682*c217d954SCole Faust         const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
1683*c217d954SCole Faust         const auto output_ptr = reinterpret_cast<int16_t *>(dst.ptr());
1684*c217d954SCole Faust 
1685*c217d954SCole Faust         // Compute window_step_x elements per iteration
1686*c217d954SCole Faust         int x = window_start_x;
1687*c217d954SCole Faust         for(; x <= (window_end_x - window_step_x); x += window_step_x)
1688*c217d954SCole Faust         {
1689*c217d954SCole Faust             const int16x8x2_t ta1 =
1690*c217d954SCole Faust             {
1691*c217d954SCole Faust                 {
1692*c217d954SCole Faust                     vld1q_s16(input1_ptr + x),
1693*c217d954SCole Faust                     vld1q_s16(input1_ptr + x + 8),
1694*c217d954SCole Faust                 }
1695*c217d954SCole Faust             };
1696*c217d954SCole Faust             const uint8x8x2_t ta2u =
1697*c217d954SCole Faust             {
1698*c217d954SCole Faust                 {
1699*c217d954SCole Faust                     vld1_u8(input2_ptr + x),
1700*c217d954SCole Faust                     vld1_u8(input2_ptr + x + 8),
1701*c217d954SCole Faust                 }
1702*c217d954SCole Faust             };
1703*c217d954SCole Faust             const int16x8x2_t ta2 =
1704*c217d954SCole Faust             {
1705*c217d954SCole Faust                 {
1706*c217d954SCole Faust                     vreinterpretq_s16_u16(vmovl_u8(ta2u.val[0])),
1707*c217d954SCole Faust                     vreinterpretq_s16_u16(vmovl_u8(ta2u.val[1]))
1708*c217d954SCole Faust                 }
1709*c217d954SCole Faust             };
1710*c217d954SCole Faust 
1711*c217d954SCole Faust             const int16x8x2_t result = mul_S16_S16_S16_n_k<is_scale255, is_sat>(ta1, ta2, n);
1712*c217d954SCole Faust 
1713*c217d954SCole Faust             vst1q_s16(output_ptr + x, result.val[0]);
1714*c217d954SCole Faust             vst1q_s16(output_ptr + x + 8, result.val[1]);
1715*c217d954SCole Faust         }
1716*c217d954SCole Faust 
1717*c217d954SCole Faust         // Compute left-over elements
1718*c217d954SCole Faust         for(; x < window_end_x; ++x)
1719*c217d954SCole Faust         {
1720*c217d954SCole Faust             int32_t tmp = static_cast<int32_t>(*(input1_ptr + x)) * static_cast<int32_t>(*(input2_ptr + x));
1721*c217d954SCole Faust 
1722*c217d954SCole Faust             if(is_scale255)
1723*c217d954SCole Faust             {
1724*c217d954SCole Faust                 float tmp_f = static_cast<float>(tmp) * scale255_constant;
1725*c217d954SCole Faust 
1726*c217d954SCole Faust                 tmp = static_cast<int32_t>(tmp_f + 0.5f);
1727*c217d954SCole Faust             }
1728*c217d954SCole Faust             else
1729*c217d954SCole Faust             {
1730*c217d954SCole Faust                 if(tmp >= 0)
1731*c217d954SCole Faust                 {
1732*c217d954SCole Faust                     tmp >>= n;
1733*c217d954SCole Faust                 }
1734*c217d954SCole Faust                 else
1735*c217d954SCole Faust                 {
1736*c217d954SCole Faust                     uint32_t mask = (1u << n) - 1;
1737*c217d954SCole Faust                     tmp           = (tmp + static_cast<int32_t>(mask)) >> n;
1738*c217d954SCole Faust                 }
1739*c217d954SCole Faust             }
1740*c217d954SCole Faust             if(is_sat)
1741*c217d954SCole Faust             {
1742*c217d954SCole Faust                 tmp = (tmp > SHRT_MAX) ? SHRT_MAX : ((tmp < SHRT_MIN) ? SHRT_MIN : tmp);
1743*c217d954SCole Faust             }
1744*c217d954SCole Faust             *(output_ptr + x) = static_cast<int16_t>(tmp);
1745*c217d954SCole Faust         }
1746*c217d954SCole Faust     },
1747*c217d954SCole Faust     input1, input2, dst);
1748*c217d954SCole Faust }
1749*c217d954SCole Faust 
1750*c217d954SCole Faust template <bool is_scale255, bool is_sat>
mul_U8_S16_S16(const ITensor * src1,const ITensor * src2,ITensor * out,const Window & window,int n)1751*c217d954SCole Faust void mul_U8_S16_S16(const ITensor *src1, const ITensor *src2, ITensor *out, const Window &window, int n)
1752*c217d954SCole Faust {
1753*c217d954SCole Faust     // Simply swap the two input buffers
1754*c217d954SCole Faust     mul_S16_U8_S16<is_scale255, is_sat>(src2, src1, out, window, n);
1755*c217d954SCole Faust }
1756*c217d954SCole Faust } // namespace
1757*c217d954SCole Faust 
configure(ITensorInfo * src1,ITensorInfo * src2,ITensorInfo * dst,float scale,ConvertPolicy overflow_policy,RoundingPolicy rounding_policy)1758*c217d954SCole Faust void CpuMulKernel::configure(ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
1759*c217d954SCole Faust {
1760*c217d954SCole Faust     ARM_COMPUTE_UNUSED(rounding_policy);
1761*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_NULLPTR(src1, src2, dst);
1762*c217d954SCole Faust 
1763*c217d954SCole Faust     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src1, src2, dst, scale, overflow_policy, rounding_policy));
1764*c217d954SCole Faust 
1765*c217d954SCole Faust     const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
1766*c217d954SCole Faust 
1767*c217d954SCole Faust     // Auto initialize dst if not initialized
1768*c217d954SCole Faust     set_shape_if_empty(*dst, out_shape);
1769*c217d954SCole Faust 
1770*c217d954SCole Faust     _scale          = scale;
1771*c217d954SCole Faust     _scale_exponent = 0;
1772*c217d954SCole Faust     _func_quantized = nullptr;
1773*c217d954SCole Faust     _func_int       = nullptr;
1774*c217d954SCole Faust     _func_float     = nullptr;
1775*c217d954SCole Faust 
1776*c217d954SCole Faust     bool is_scale_255 = false;
1777*c217d954SCole Faust     // Check and validate scaling factor
1778*c217d954SCole Faust     if(std::abs(scale - scale255_constant) < 0.00001f)
1779*c217d954SCole Faust     {
1780*c217d954SCole Faust         is_scale_255 = true;
1781*c217d954SCole Faust     }
1782*c217d954SCole Faust     else
1783*c217d954SCole Faust     {
1784*c217d954SCole Faust         int exponent = 0;
1785*c217d954SCole Faust 
1786*c217d954SCole Faust         std::frexp(scale, &exponent);
1787*c217d954SCole Faust 
1788*c217d954SCole Faust         // Store the positive exponent. We know that we compute 1/2^n
1789*c217d954SCole Faust         // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
1790*c217d954SCole Faust         _scale_exponent = std::abs(exponent - 1);
1791*c217d954SCole Faust     }
1792*c217d954SCole Faust 
1793*c217d954SCole Faust     const DataType dt_input1 = src1->data_type();
1794*c217d954SCole Faust     const DataType dt_input2 = src2->data_type();
1795*c217d954SCole Faust     const DataType dt_output = dst->data_type();
1796*c217d954SCole Faust     const bool     is_sat    = (overflow_policy == ConvertPolicy::SATURATE);
1797*c217d954SCole Faust 
1798*c217d954SCole Faust     switch(dt_input1)
1799*c217d954SCole Faust     {
1800*c217d954SCole Faust         case DataType::QASYMM8:
1801*c217d954SCole Faust             if(dt_input2 == DataType::QASYMM8 && dt_output == DataType::QASYMM8)
1802*c217d954SCole Faust             {
1803*c217d954SCole Faust                 if(mul_q8_neon_fixedpoint_possible(src1, src2, dst, scale))
1804*c217d954SCole Faust                 {
1805*c217d954SCole Faust                     _func_quantized = &mul_q8_neon_fixedpoint<uint8_t>;
1806*c217d954SCole Faust                 }
1807*c217d954SCole Faust                 else
1808*c217d954SCole Faust                 {
1809*c217d954SCole Faust                     _func_quantized = &mul_saturate_quantized_8<uint8_t>;
1810*c217d954SCole Faust                 }
1811*c217d954SCole Faust             }
1812*c217d954SCole Faust             break;
1813*c217d954SCole Faust         case DataType::QASYMM8_SIGNED:
1814*c217d954SCole Faust             if(dt_input2 == DataType::QASYMM8_SIGNED)
1815*c217d954SCole Faust             {
1816*c217d954SCole Faust                 if(mul_q8_neon_fixedpoint_possible(src1, src2, dst, scale))
1817*c217d954SCole Faust                 {
1818*c217d954SCole Faust                     _func_quantized = &mul_q8_neon_fixedpoint<int8_t>;
1819*c217d954SCole Faust                 }
1820*c217d954SCole Faust                 else
1821*c217d954SCole Faust                 {
1822*c217d954SCole Faust                     _func_quantized = &mul_saturate_quantized_8<int8_t>;
1823*c217d954SCole Faust                 }
1824*c217d954SCole Faust             }
1825*c217d954SCole Faust             break;
1826*c217d954SCole Faust         case DataType::QSYMM16:
1827*c217d954SCole Faust             if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::QSYMM16)
1828*c217d954SCole Faust             {
1829*c217d954SCole Faust                 _func_quantized = &mul_saturate_QSYMM16_QSYMM16_QSYMM16;
1830*c217d954SCole Faust             }
1831*c217d954SCole Faust             else if(dt_input2 == DataType::QSYMM16 && dt_output == DataType::S32)
1832*c217d954SCole Faust             {
1833*c217d954SCole Faust                 _func_int = &mul_QSYMM16_QSYMM16_S32;
1834*c217d954SCole Faust             }
1835*c217d954SCole Faust             break;
1836*c217d954SCole Faust         case DataType::S16:
1837*c217d954SCole Faust             if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1838*c217d954SCole Faust             {
1839*c217d954SCole Faust                 if(is_scale_255)
1840*c217d954SCole Faust                 {
1841*c217d954SCole Faust                     _func_int = is_sat ? &mul_S16_U8_S16<true, true> : &mul_S16_U8_S16<true, false>;
1842*c217d954SCole Faust                 }
1843*c217d954SCole Faust                 else
1844*c217d954SCole Faust                 {
1845*c217d954SCole Faust                     _func_int = is_sat ? &mul_S16_U8_S16<false, true> : &mul_S16_U8_S16<false, false>;
1846*c217d954SCole Faust                 }
1847*c217d954SCole Faust             }
1848*c217d954SCole Faust             if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1849*c217d954SCole Faust             {
1850*c217d954SCole Faust                 if(is_scale_255)
1851*c217d954SCole Faust                 {
1852*c217d954SCole Faust                     _func_int = is_sat ? &mul_S16_S16_S16<true, true> : &mul_S16_S16_S16<true, false>;
1853*c217d954SCole Faust                 }
1854*c217d954SCole Faust                 else
1855*c217d954SCole Faust                 {
1856*c217d954SCole Faust                     _func_int = is_sat ? &mul_S16_S16_S16<false, true> : &mul_S16_S16_S16<false, false>;
1857*c217d954SCole Faust                 }
1858*c217d954SCole Faust             }
1859*c217d954SCole Faust             break;
1860*c217d954SCole Faust         case DataType::S32:
1861*c217d954SCole Faust             if(DataType::S32 == dt_input2 && DataType::S32 == dt_output)
1862*c217d954SCole Faust             {
1863*c217d954SCole Faust                 _func_int = is_sat ? &mul_S32_S32_S32<true> : &mul_S32_S32_S32<false>;
1864*c217d954SCole Faust             }
1865*c217d954SCole Faust             break;
1866*c217d954SCole Faust         case DataType::U8:
1867*c217d954SCole Faust             if(DataType::U8 == dt_input2 && DataType::U8 == dt_output)
1868*c217d954SCole Faust             {
1869*c217d954SCole Faust                 if(is_scale_255)
1870*c217d954SCole Faust                 {
1871*c217d954SCole Faust                     _func_int = is_sat ? &mul_U8_U8_U8<true, true> : &mul_U8_U8_U8<true, false>;
1872*c217d954SCole Faust                 }
1873*c217d954SCole Faust                 else
1874*c217d954SCole Faust                 {
1875*c217d954SCole Faust                     _func_int = is_sat ? &mul_U8_U8_U8<false, true> : &mul_U8_U8_U8<false, false>;
1876*c217d954SCole Faust                 }
1877*c217d954SCole Faust             }
1878*c217d954SCole Faust             else if(DataType::U8 == dt_input2 && DataType::S16 == dt_output)
1879*c217d954SCole Faust             {
1880*c217d954SCole Faust                 if(is_scale_255)
1881*c217d954SCole Faust                 {
1882*c217d954SCole Faust                     _func_int = is_sat ? &mul_U8_U8_S16<true, true> : &mul_U8_U8_S16<true, false>;
1883*c217d954SCole Faust                 }
1884*c217d954SCole Faust                 else
1885*c217d954SCole Faust                 {
1886*c217d954SCole Faust                     _func_int = is_sat ? &mul_U8_U8_S16<false, true> : &mul_U8_U8_S16<false, false>;
1887*c217d954SCole Faust                 }
1888*c217d954SCole Faust             }
1889*c217d954SCole Faust             else if(DataType::S16 == dt_input2 && DataType::S16 == dt_output)
1890*c217d954SCole Faust             {
1891*c217d954SCole Faust                 if(is_scale_255)
1892*c217d954SCole Faust                 {
1893*c217d954SCole Faust                     _func_int = is_sat ? &mul_U8_S16_S16<true, true> : &mul_U8_S16_S16<true, false>;
1894*c217d954SCole Faust                 }
1895*c217d954SCole Faust                 else
1896*c217d954SCole Faust                 {
1897*c217d954SCole Faust                     _func_int = is_sat ? &mul_U8_S16_S16<false, true> : &mul_U8_S16_S16<false, false>;
1898*c217d954SCole Faust                 }
1899*c217d954SCole Faust             }
1900*c217d954SCole Faust             break;
1901*c217d954SCole Faust #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1902*c217d954SCole Faust         case DataType::F16:
1903*c217d954SCole Faust             _func_float = &mul_F16_F16_F16;
1904*c217d954SCole Faust             break;
1905*c217d954SCole Faust #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1906*c217d954SCole Faust         case DataType::F32:
1907*c217d954SCole Faust             _func_float = &mul_F32_F32_F32;
1908*c217d954SCole Faust             break;
1909*c217d954SCole Faust         default:
1910*c217d954SCole Faust             ARM_COMPUTE_ERROR("You called with the wrong img formats");
1911*c217d954SCole Faust     }
1912*c217d954SCole Faust 
1913*c217d954SCole Faust     // Configure kernel window
1914*c217d954SCole Faust     Window win;
1915*c217d954SCole Faust     std::tie(win, _split_dimension) = calculate_squashed_or_max_window(*src1, *src2);
1916*c217d954SCole Faust 
1917*c217d954SCole Faust     ICpuKernel::configure(win);
1918*c217d954SCole Faust }
1919*c217d954SCole Faust 
get_mws(const CPUInfo & platform,size_t thread_count) const1920*c217d954SCole Faust size_t CpuMulKernel::get_mws(const CPUInfo &platform, size_t thread_count) const
1921*c217d954SCole Faust {
1922*c217d954SCole Faust     ARM_COMPUTE_UNUSED(thread_count);
1923*c217d954SCole Faust 
1924*c217d954SCole Faust #if defined(ENABLE_FP32_KERNELS)
1925*c217d954SCole Faust     if(this->_func_float == &mul_F32_F32_F32)
1926*c217d954SCole Faust     {
1927*c217d954SCole Faust         size_t mws = ICPPKernel::default_mws;
1928*c217d954SCole Faust         if(platform.get_cpu_model() == CPUModel::N1)
1929*c217d954SCole Faust         {
1930*c217d954SCole Faust             mws = default_mws_N1_fp32_neon;
1931*c217d954SCole Faust         }
1932*c217d954SCole Faust         else if(platform.get_cpu_model() == CPUModel::V1)
1933*c217d954SCole Faust         {
1934*c217d954SCole Faust             mws = default_mws_V1_fp32_neon;
1935*c217d954SCole Faust         }
1936*c217d954SCole Faust         else
1937*c217d954SCole Faust         {
1938*c217d954SCole Faust             if(_split_dimension == Window::DimX)
1939*c217d954SCole Faust             {
1940*c217d954SCole Faust                 // Don't split the work load too small if the tensor has been reinterpreted as 1D.
1941*c217d954SCole Faust                 // This number is loosely chosen as threading overhead in each platform varies wildly.
1942*c217d954SCole Faust                 return default_mws_other_platforms_1d_tensor;
1943*c217d954SCole Faust             }
1944*c217d954SCole Faust             return default_mws;
1945*c217d954SCole Faust         }
1946*c217d954SCole Faust 
1947*c217d954SCole Faust         // tensor is 1D or was re-interpreted as 1D
1948*c217d954SCole Faust         if(this->window().shape().num_dimensions() == 1)
1949*c217d954SCole Faust         {
1950*c217d954SCole Faust             return mws;
1951*c217d954SCole Faust         }
1952*c217d954SCole Faust         else
1953*c217d954SCole Faust         {
1954*c217d954SCole Faust             // scale mws down by the number of elements along all the dimensions (x, z, w, etc) except the one
1955*c217d954SCole Faust             // that we parallelize along (the y dimension). This allows for parallelization when the Y_SIZE is small
1956*c217d954SCole Faust             // but the other sizes are large, which boosts performance.
1957*c217d954SCole Faust             mws = static_cast<size_t>(mws / (this->window().num_iterations_total() / this->window().num_iterations(1)));
1958*c217d954SCole Faust             return std::max(static_cast<size_t>(1), mws);
1959*c217d954SCole Faust         }
1960*c217d954SCole Faust     }
1961*c217d954SCole Faust #else /* ENABLE_FP32_KERNELS */
1962*c217d954SCole Faust     ARM_COMPUTE_UNUSED(platform);
1963*c217d954SCole Faust #endif /* ENABLE_FP32_KERNELS */
1964*c217d954SCole Faust     if(_split_dimension == Window::DimX)
1965*c217d954SCole Faust     {
1966*c217d954SCole Faust         // Don't split the work load too small if the tensor has been reinterpreted as 1D.
1967*c217d954SCole Faust         // This number is loosely chosen as threading overhead in each platform varies wildly.
1968*c217d954SCole Faust         return default_mws_other_platforms_1d_tensor;
1969*c217d954SCole Faust     }
1970*c217d954SCole Faust     return default_mws;
1971*c217d954SCole Faust }
1972*c217d954SCole Faust 
validate(const ITensorInfo * src1,const ITensorInfo * src2,const ITensorInfo * dst,float scale,ConvertPolicy overflow_policy,RoundingPolicy rounding_policy)1973*c217d954SCole Faust Status CpuMulKernel::validate(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float scale, ConvertPolicy overflow_policy,
1974*c217d954SCole Faust                               RoundingPolicy rounding_policy)
1975*c217d954SCole Faust {
1976*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_NULLPTR(src1, src2, dst);
1977*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src1, src2, dst, scale, overflow_policy, rounding_policy));
1978*c217d954SCole Faust 
1979*c217d954SCole Faust     return Status{};
1980*c217d954SCole Faust }
1981*c217d954SCole Faust 
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)1982*c217d954SCole Faust void CpuMulKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
1983*c217d954SCole Faust {
1984*c217d954SCole Faust     ARM_COMPUTE_UNUSED(info);
1985*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1986*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
1987*c217d954SCole Faust 
1988*c217d954SCole Faust     auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1989*c217d954SCole Faust     auto src2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1990*c217d954SCole Faust     auto dst  = tensors.get_tensor(TensorType::ACL_DST);
1991*c217d954SCole Faust 
1992*c217d954SCole Faust     if(_func_quantized != nullptr)
1993*c217d954SCole Faust     {
1994*c217d954SCole Faust         (*_func_quantized)(src1, src2, dst, window, _scale);
1995*c217d954SCole Faust     }
1996*c217d954SCole Faust     else if(_func_int != nullptr)
1997*c217d954SCole Faust     {
1998*c217d954SCole Faust         (*_func_int)(src1, src2, dst, window, _scale_exponent);
1999*c217d954SCole Faust     }
2000*c217d954SCole Faust     else
2001*c217d954SCole Faust     {
2002*c217d954SCole Faust         ARM_COMPUTE_ERROR_ON(_func_float == nullptr);
2003*c217d954SCole Faust         (*_func_float)(src1, src2, dst, window, _scale);
2004*c217d954SCole Faust     }
2005*c217d954SCole Faust }
2006*c217d954SCole Faust 
name() const2007*c217d954SCole Faust const char *CpuMulKernel::name() const
2008*c217d954SCole Faust {
2009*c217d954SCole Faust     return "CpuMulKernel";
2010*c217d954SCole Faust }
2011*c217d954SCole Faust 
2012*c217d954SCole Faust namespace
2013*c217d954SCole Faust {
validate_arguments_complex(const ITensorInfo * src1,const ITensorInfo * src2,const ITensorInfo * dst)2014*c217d954SCole Faust Status validate_arguments_complex(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst)
2015*c217d954SCole Faust {
2016*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1, 2, DataType::F32);
2017*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src2, 2, DataType::F32);
2018*c217d954SCole Faust 
2019*c217d954SCole Faust     const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
2020*c217d954SCole Faust 
2021*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
2022*c217d954SCole Faust 
2023*c217d954SCole Faust     // Validate in case of configured dst
2024*c217d954SCole Faust     if(dst->total_size() > 0)
2025*c217d954SCole Faust     {
2026*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 2, DataType::F32);
2027*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst->tensor_shape(), 0), "Wrong shape for dst");
2028*c217d954SCole Faust     }
2029*c217d954SCole Faust 
2030*c217d954SCole Faust     return Status{};
2031*c217d954SCole Faust }
2032*c217d954SCole Faust } // namespace
2033*c217d954SCole Faust 
configure(ITensorInfo * src1,ITensorInfo * src2,ITensorInfo * dst)2034*c217d954SCole Faust void CpuComplexMulKernel::configure(ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst)
2035*c217d954SCole Faust {
2036*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_NULLPTR(src1, src2, dst);
2037*c217d954SCole Faust     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_complex(src1, src2, dst));
2038*c217d954SCole Faust 
2039*c217d954SCole Faust     const TensorShape &out_shape = TensorShape::broadcast_shape(src1->tensor_shape(), src2->tensor_shape());
2040*c217d954SCole Faust 
2041*c217d954SCole Faust     // Auto initialize dst if not initialized
2042*c217d954SCole Faust     const TensorInfo out_info(out_shape, src1->num_channels(), src1->data_type());
2043*c217d954SCole Faust     auto_init_if_empty(*dst, out_info);
2044*c217d954SCole Faust 
2045*c217d954SCole Faust     // Configure kernel window
2046*c217d954SCole Faust     Window win = calculate_max_window(out_shape);
2047*c217d954SCole Faust 
2048*c217d954SCole Faust     ICpuKernel::configure(win);
2049*c217d954SCole Faust }
2050*c217d954SCole Faust 
validate(const ITensorInfo * src1,const ITensorInfo * src2,const ITensorInfo * dst)2051*c217d954SCole Faust Status CpuComplexMulKernel::validate(const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst)
2052*c217d954SCole Faust {
2053*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_NULLPTR(src1, src2, dst);
2054*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_complex(src1, src2, dst));
2055*c217d954SCole Faust 
2056*c217d954SCole Faust     return Status{};
2057*c217d954SCole Faust }
2058*c217d954SCole Faust 
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)2059*c217d954SCole Faust void CpuComplexMulKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
2060*c217d954SCole Faust {
2061*c217d954SCole Faust     ARM_COMPUTE_UNUSED(info);
2062*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
2063*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
2064*c217d954SCole Faust 
2065*c217d954SCole Faust     auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
2066*c217d954SCole Faust     auto src2 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
2067*c217d954SCole Faust     auto dst  = tensors.get_tensor(TensorType::ACL_DST);
2068*c217d954SCole Faust 
2069*c217d954SCole Faust     c_mul_F32_F32_F32_n(src1, src2, dst, window);
2070*c217d954SCole Faust }
2071*c217d954SCole Faust 
name() const2072*c217d954SCole Faust const char *CpuComplexMulKernel::name() const
2073*c217d954SCole Faust {
2074*c217d954SCole Faust     return "CpuComplexMulKernel";
2075*c217d954SCole Faust }
2076*c217d954SCole Faust } // namespace kernels
2077*c217d954SCole Faust } // namespace cpu
2078*c217d954SCole Faust } // namespace arm_compute
2079