1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2017-2020 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #include "PixelWiseMultiplication.h"
25*c217d954SCole Faust
26*c217d954SCole Faust #include "tests/validation/Helpers.h"
27*c217d954SCole Faust
28*c217d954SCole Faust namespace arm_compute
29*c217d954SCole Faust {
30*c217d954SCole Faust namespace test
31*c217d954SCole Faust {
32*c217d954SCole Faust namespace validation
33*c217d954SCole Faust {
34*c217d954SCole Faust namespace reference
35*c217d954SCole Faust {
36*c217d954SCole Faust template <class T>
37*c217d954SCole Faust struct is_floating_point
38*c217d954SCole Faust : std::integral_constant < bool,
39*c217d954SCole Faust std::is_same<float, typename std::remove_cv<T>::type>::value || std::is_same<half_float::half, typename std::remove_cv<T>::type>::value
40*c217d954SCole Faust || std::is_same<double, typename std::remove_cv<T>::type>::value || std::is_same<long double, typename std::remove_cv<T>::type>::value >
41*c217d954SCole Faust {
42*c217d954SCole Faust };
43*c217d954SCole Faust
44*c217d954SCole Faust namespace
45*c217d954SCole Faust {
46*c217d954SCole Faust constexpr float scale1_constant = 1.f;
47*c217d954SCole Faust
48*c217d954SCole Faust /** Compute the result of `src1 * src2 * scale`. The result type always matches the type of @p src2.
49*c217d954SCole Faust *
50*c217d954SCole Faust * @param[in] src1 An input value. Data types supported: U8/S16/F16/F32.
51*c217d954SCole Faust * @param[in] src2 An input value. Data types supported: same as @p src1.
52*c217d954SCole Faust * @param[in] scale Scale to apply after multiplication.
53*c217d954SCole Faust * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
54*c217d954SCole Faust * @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate
55*c217d954SCole Faust * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
56*c217d954SCole Faust */
57*c217d954SCole Faust template <typename T1, typename T2, typename T3>
mul(const T1 src1,const T2 src2,float scale,ConvertPolicy convert_policy,RoundingPolicy rounding_policy)58*c217d954SCole Faust T3 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
59*c217d954SCole Faust {
60*c217d954SCole Faust using intermediate_type = typename common_promoted_signed_type<T1, T2, T3>::intermediate_type;
61*c217d954SCole Faust
62*c217d954SCole Faust const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale);
63*c217d954SCole Faust
64*c217d954SCole Faust if(is_floating_point<T3>::value)
65*c217d954SCole Faust {
66*c217d954SCole Faust const auto result = static_cast<T3>(val);
67*c217d954SCole Faust
68*c217d954SCole Faust return result;
69*c217d954SCole Faust }
70*c217d954SCole Faust else
71*c217d954SCole Faust {
72*c217d954SCole Faust double rounded_val = 0;
73*c217d954SCole Faust switch(rounding_policy)
74*c217d954SCole Faust {
75*c217d954SCole Faust case(RoundingPolicy::TO_ZERO):
76*c217d954SCole Faust rounded_val = support::cpp11::trunc(val);
77*c217d954SCole Faust break;
78*c217d954SCole Faust case(RoundingPolicy::TO_NEAREST_UP):
79*c217d954SCole Faust rounded_val = round_half_up(val);
80*c217d954SCole Faust break;
81*c217d954SCole Faust case(RoundingPolicy::TO_NEAREST_EVEN):
82*c217d954SCole Faust rounded_val = round_half_even(val);
83*c217d954SCole Faust break;
84*c217d954SCole Faust default:
85*c217d954SCole Faust ARM_COMPUTE_ERROR("Unsupported rounding policy");
86*c217d954SCole Faust }
87*c217d954SCole Faust
88*c217d954SCole Faust const auto result = static_cast<T3>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(rounded_val) : rounded_val);
89*c217d954SCole Faust
90*c217d954SCole Faust return result;
91*c217d954SCole Faust }
92*c217d954SCole Faust }
93*c217d954SCole Faust
94*c217d954SCole Faust template <>
mul(const int32_t src1,const int32_t src2,float scale,ConvertPolicy convert_policy,RoundingPolicy rounding_policy)95*c217d954SCole Faust int32_t mul(const int32_t src1, const int32_t src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
96*c217d954SCole Faust {
97*c217d954SCole Faust const int64_t intermediate_val = static_cast<int64_t>(src1) * static_cast<int64_t>(src2);
98*c217d954SCole Faust
99*c217d954SCole Faust if(std::abs(scale - scale1_constant) < 0.00001f)
100*c217d954SCole Faust {
101*c217d954SCole Faust // Use bit-accurate integer arithmetic for scale == 1
102*c217d954SCole Faust // Apply conversion
103*c217d954SCole Faust if(convert_policy == ConvertPolicy::SATURATE)
104*c217d954SCole Faust {
105*c217d954SCole Faust return saturate_cast<int32_t>(intermediate_val);
106*c217d954SCole Faust }
107*c217d954SCole Faust else
108*c217d954SCole Faust {
109*c217d954SCole Faust // Correct wrapping behaviour for int32_t
110*c217d954SCole Faust const auto i32_hi = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
111*c217d954SCole Faust const auto i32_lo = static_cast<int64_t>(std::numeric_limits<int32_t>::lowest());
112*c217d954SCole Faust const auto i32_wi = static_cast<int64_t>(1) << 32;
113*c217d954SCole Faust int64_t wrapped_rounded_val = intermediate_val - i32_wi * static_cast<int64_t>(support::cpp11::trunc(static_cast<double>(intermediate_val) / i32_wi));
114*c217d954SCole Faust if(wrapped_rounded_val <= i32_hi)
115*c217d954SCole Faust {
116*c217d954SCole Faust return static_cast<int32_t>(wrapped_rounded_val);
117*c217d954SCole Faust }
118*c217d954SCole Faust else
119*c217d954SCole Faust {
120*c217d954SCole Faust // Values beyond i32_hi wrap around to negatives
121*c217d954SCole Faust return static_cast<int32_t>((wrapped_rounded_val - i32_hi) + i32_lo - 1);
122*c217d954SCole Faust }
123*c217d954SCole Faust }
124*c217d954SCole Faust }
125*c217d954SCole Faust else
126*c217d954SCole Faust {
127*c217d954SCole Faust // Use double arithmetic for scale != 1; may not be bit-accurate
128*c217d954SCole Faust // Apply scaling
129*c217d954SCole Faust // scale == 1 / 2^scale_exponent
130*c217d954SCole Faust int scale_exponent = 0;
131*c217d954SCole Faust std::frexp(scale, &scale_exponent);
132*c217d954SCole Faust // Store the positive exponent. We know that we compute 1/2^n
133*c217d954SCole Faust // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5
134*c217d954SCole Faust scale_exponent = std::abs(scale_exponent - 1);
135*c217d954SCole Faust const double scale_inv = static_cast<int64_t>(1) << scale_exponent;
136*c217d954SCole Faust const double val = intermediate_val / scale_inv;
137*c217d954SCole Faust // Apply rounding
138*c217d954SCole Faust double rounded_val = 0;
139*c217d954SCole Faust switch(rounding_policy)
140*c217d954SCole Faust {
141*c217d954SCole Faust case(RoundingPolicy::TO_ZERO):
142*c217d954SCole Faust rounded_val = support::cpp11::trunc(val);
143*c217d954SCole Faust break;
144*c217d954SCole Faust case(RoundingPolicy::TO_NEAREST_UP):
145*c217d954SCole Faust rounded_val = round_half_up(val);
146*c217d954SCole Faust break;
147*c217d954SCole Faust case(RoundingPolicy::TO_NEAREST_EVEN):
148*c217d954SCole Faust rounded_val = round_half_even(val);
149*c217d954SCole Faust break;
150*c217d954SCole Faust default:
151*c217d954SCole Faust ARM_COMPUTE_ERROR("Unsupported rounding policy");
152*c217d954SCole Faust }
153*c217d954SCole Faust // Apply conversion
154*c217d954SCole Faust if(convert_policy == ConvertPolicy::SATURATE)
155*c217d954SCole Faust {
156*c217d954SCole Faust return saturate_cast<int32_t>(rounded_val);
157*c217d954SCole Faust }
158*c217d954SCole Faust else
159*c217d954SCole Faust {
160*c217d954SCole Faust // Correct wrapping behaviour for int32_t
161*c217d954SCole Faust const auto i32_hi = static_cast<double>(std::numeric_limits<int32_t>::max());
162*c217d954SCole Faust const auto i32_lo = static_cast<double>(std::numeric_limits<int32_t>::lowest());
163*c217d954SCole Faust const auto i32_wi = static_cast<double>(static_cast<int64_t>(1) << 32);
164*c217d954SCole Faust double wrapped_rounded_val = rounded_val - i32_wi * std::floor(rounded_val / i32_wi);
165*c217d954SCole Faust if(wrapped_rounded_val <= i32_hi)
166*c217d954SCole Faust {
167*c217d954SCole Faust return static_cast<int32_t>(wrapped_rounded_val);
168*c217d954SCole Faust }
169*c217d954SCole Faust else
170*c217d954SCole Faust {
171*c217d954SCole Faust // Values beyond i32_hi wrap around to negatives
172*c217d954SCole Faust return static_cast<int32_t>((wrapped_rounded_val - i32_hi) + i32_lo - 1);
173*c217d954SCole Faust }
174*c217d954SCole Faust }
175*c217d954SCole Faust }
176*c217d954SCole Faust }
177*c217d954SCole Faust
178*c217d954SCole Faust template <size_t dim>
179*c217d954SCole Faust struct BroadcastUnroll
180*c217d954SCole Faust {
181*c217d954SCole Faust template <typename T1, typename T2, typename T3>
unrollarm_compute::test::validation::reference::__anon60c5b57b0111::BroadcastUnroll182*c217d954SCole Faust static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst,
183*c217d954SCole Faust float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
184*c217d954SCole Faust Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
185*c217d954SCole Faust {
186*c217d954SCole Faust const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]);
187*c217d954SCole Faust const bool src2_is_broadcast = (src2.shape()[dim - 1] != dst.shape()[dim - 1]);
188*c217d954SCole Faust
189*c217d954SCole Faust id_src1.set(dim - 1, 0);
190*c217d954SCole Faust id_src2.set(dim - 1, 0);
191*c217d954SCole Faust id_dst.set(dim - 1, 0);
192*c217d954SCole Faust
193*c217d954SCole Faust for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
194*c217d954SCole Faust {
195*c217d954SCole Faust BroadcastUnroll < dim - 1 >::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
196*c217d954SCole Faust
197*c217d954SCole Faust id_src1[dim - 1] += !src1_is_broadcast;
198*c217d954SCole Faust id_src2[dim - 1] += !src2_is_broadcast;
199*c217d954SCole Faust }
200*c217d954SCole Faust }
201*c217d954SCole Faust };
202*c217d954SCole Faust
203*c217d954SCole Faust template <>
204*c217d954SCole Faust struct BroadcastUnroll<0>
205*c217d954SCole Faust {
206*c217d954SCole Faust template <typename T1, typename T2, typename T3>
unrollarm_compute::test::validation::reference::__anon60c5b57b0111::BroadcastUnroll207*c217d954SCole Faust static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst,
208*c217d954SCole Faust float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
209*c217d954SCole Faust Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
210*c217d954SCole Faust {
211*c217d954SCole Faust dst[coord2index(dst.shape(), id_dst)] = mul<T1, T2, T3>(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy);
212*c217d954SCole Faust }
213*c217d954SCole Faust };
214*c217d954SCole Faust } // namespace
215*c217d954SCole Faust
216*c217d954SCole Faust template <typename T1, typename T2, typename T3>
pixel_wise_multiplication(const SimpleTensor<T1> & src1,const SimpleTensor<T2> & src2,float scale,ConvertPolicy convert_policy,RoundingPolicy rounding_policy,DataType dt_out,const QuantizationInfo & qout)217*c217d954SCole Faust SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
218*c217d954SCole Faust DataType dt_out, const QuantizationInfo &qout)
219*c217d954SCole Faust {
220*c217d954SCole Faust ARM_COMPUTE_UNUSED(qout);
221*c217d954SCole Faust
222*c217d954SCole Faust SimpleTensor<T3> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out);
223*c217d954SCole Faust
224*c217d954SCole Faust if(scale < 0)
225*c217d954SCole Faust {
226*c217d954SCole Faust ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
227*c217d954SCole Faust }
228*c217d954SCole Faust
229*c217d954SCole Faust Coordinates id_src1{};
230*c217d954SCole Faust Coordinates id_src2{};
231*c217d954SCole Faust Coordinates id_dst{};
232*c217d954SCole Faust
233*c217d954SCole Faust BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
234*c217d954SCole Faust
235*c217d954SCole Faust return dst;
236*c217d954SCole Faust }
237*c217d954SCole Faust
238*c217d954SCole Faust template <>
pixel_wise_multiplication(const SimpleTensor<uint8_t> & src1,const SimpleTensor<uint8_t> & src2,float scale,ConvertPolicy convert_policy,RoundingPolicy rounding_policy,DataType dt_out,const QuantizationInfo & qout)239*c217d954SCole Faust SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
240*c217d954SCole Faust DataType dt_out, const QuantizationInfo &qout)
241*c217d954SCole Faust {
242*c217d954SCole Faust SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
243*c217d954SCole Faust
244*c217d954SCole Faust if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8)
245*c217d954SCole Faust {
246*c217d954SCole Faust SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
247*c217d954SCole Faust SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
248*c217d954SCole Faust SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
249*c217d954SCole Faust dst = convert_to_asymmetric<uint8_t>(dst_tmp, qout);
250*c217d954SCole Faust }
251*c217d954SCole Faust else
252*c217d954SCole Faust {
253*c217d954SCole Faust if(scale < 0)
254*c217d954SCole Faust {
255*c217d954SCole Faust ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
256*c217d954SCole Faust }
257*c217d954SCole Faust
258*c217d954SCole Faust Coordinates id_src1{};
259*c217d954SCole Faust Coordinates id_src2{};
260*c217d954SCole Faust Coordinates id_dst{};
261*c217d954SCole Faust BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
262*c217d954SCole Faust }
263*c217d954SCole Faust return dst;
264*c217d954SCole Faust }
265*c217d954SCole Faust
266*c217d954SCole Faust template <>
pixel_wise_multiplication(const SimpleTensor<uint8_t> & src1,const SimpleTensor<uint8_t> & src2,float scale,ConvertPolicy convert_policy,RoundingPolicy rounding_policy,DataType dt_out,const QuantizationInfo & qout)267*c217d954SCole Faust SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
268*c217d954SCole Faust DataType dt_out, const QuantizationInfo &qout)
269*c217d954SCole Faust {
270*c217d954SCole Faust SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
271*c217d954SCole Faust
272*c217d954SCole Faust if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8)
273*c217d954SCole Faust {
274*c217d954SCole Faust SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
275*c217d954SCole Faust SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
276*c217d954SCole Faust SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
277*c217d954SCole Faust dst = convert_to_symmetric<int16_t>(dst_tmp, qout);
278*c217d954SCole Faust }
279*c217d954SCole Faust else
280*c217d954SCole Faust {
281*c217d954SCole Faust if(scale < 0)
282*c217d954SCole Faust {
283*c217d954SCole Faust ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
284*c217d954SCole Faust }
285*c217d954SCole Faust
286*c217d954SCole Faust Coordinates id_src1{};
287*c217d954SCole Faust Coordinates id_src2{};
288*c217d954SCole Faust Coordinates id_dst{};
289*c217d954SCole Faust BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
290*c217d954SCole Faust }
291*c217d954SCole Faust return dst;
292*c217d954SCole Faust }
293*c217d954SCole Faust
294*c217d954SCole Faust template <>
pixel_wise_multiplication(const SimpleTensor<int8_t> & src1,const SimpleTensor<int8_t> & src2,float scale,ConvertPolicy convert_policy,RoundingPolicy rounding_policy,DataType dt_out,const QuantizationInfo & qout)295*c217d954SCole Faust SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
296*c217d954SCole Faust DataType dt_out, const QuantizationInfo &qout)
297*c217d954SCole Faust {
298*c217d954SCole Faust SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
299*c217d954SCole Faust
300*c217d954SCole Faust if(src1.data_type() == DataType::QASYMM8_SIGNED && src2.data_type() == DataType::QASYMM8_SIGNED)
301*c217d954SCole Faust {
302*c217d954SCole Faust SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
303*c217d954SCole Faust SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
304*c217d954SCole Faust SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
305*c217d954SCole Faust dst = convert_to_asymmetric<int8_t>(dst_tmp, qout);
306*c217d954SCole Faust }
307*c217d954SCole Faust else
308*c217d954SCole Faust {
309*c217d954SCole Faust if(scale < 0)
310*c217d954SCole Faust {
311*c217d954SCole Faust ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
312*c217d954SCole Faust }
313*c217d954SCole Faust
314*c217d954SCole Faust Coordinates id_src1{};
315*c217d954SCole Faust Coordinates id_src2{};
316*c217d954SCole Faust Coordinates id_dst{};
317*c217d954SCole Faust BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
318*c217d954SCole Faust }
319*c217d954SCole Faust return dst;
320*c217d954SCole Faust }
321*c217d954SCole Faust
322*c217d954SCole Faust template <>
pixel_wise_multiplication(const SimpleTensor<int16_t> & src1,const SimpleTensor<int16_t> & src2,float scale,ConvertPolicy convert_policy,RoundingPolicy rounding_policy,DataType dt_out,const QuantizationInfo & qout)323*c217d954SCole Faust SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
324*c217d954SCole Faust DataType dt_out, const QuantizationInfo &qout)
325*c217d954SCole Faust {
326*c217d954SCole Faust SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
327*c217d954SCole Faust
328*c217d954SCole Faust if(src1.data_type() == DataType::QSYMM16 && src2.data_type() == DataType::QSYMM16)
329*c217d954SCole Faust {
330*c217d954SCole Faust SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1);
331*c217d954SCole Faust SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2);
332*c217d954SCole Faust SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
333*c217d954SCole Faust dst = convert_to_symmetric<int16_t>(dst_tmp, qout);
334*c217d954SCole Faust }
335*c217d954SCole Faust else
336*c217d954SCole Faust {
337*c217d954SCole Faust if(scale < 0)
338*c217d954SCole Faust {
339*c217d954SCole Faust ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
340*c217d954SCole Faust }
341*c217d954SCole Faust
342*c217d954SCole Faust Coordinates id_src1{};
343*c217d954SCole Faust Coordinates id_src2{};
344*c217d954SCole Faust Coordinates id_dst{};
345*c217d954SCole Faust BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
346*c217d954SCole Faust }
347*c217d954SCole Faust return dst;
348*c217d954SCole Faust }
349*c217d954SCole Faust // *INDENT-OFF*
350*c217d954SCole Faust // clang-format off
351*c217d954SCole Faust template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
352*c217d954SCole Faust template SimpleTensor<int32_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
353*c217d954SCole Faust template SimpleTensor<int32_t> pixel_wise_multiplication(const SimpleTensor<int32_t> &src1, const SimpleTensor<int32_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
354*c217d954SCole Faust template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
355*c217d954SCole Faust template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
356*c217d954SCole Faust // clang-format on
357*c217d954SCole Faust // *INDENT-ON*
358*c217d954SCole Faust } // namespace reference
359*c217d954SCole Faust } // namespace validation
360*c217d954SCole Faust } // namespace test
361*c217d954SCole Faust } // namespace arm_compute
362