xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/CpuElementwiseKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2018-2022 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "src/cpu/kernels/CpuElementwiseKernel.h"
25*c217d954SCole Faust 
26*c217d954SCole Faust #include "arm_compute/core/Helpers.h"
27*c217d954SCole Faust #include "src/core/CPP/Validate.h"
28*c217d954SCole Faust #include "src/core/common/Registrars.h"
29*c217d954SCole Faust #include "src/core/helpers/AutoConfiguration.h"
30*c217d954SCole Faust #include "src/core/helpers/WindowHelpers.h"
31*c217d954SCole Faust #include "src/cpu/kernels/elementwise_binary/list.h"
32*c217d954SCole Faust 
33*c217d954SCole Faust #include <arm_neon.h>
34*c217d954SCole Faust 
35*c217d954SCole Faust #if defined(ENABLE_FP32_KERNELS)
36*c217d954SCole Faust namespace
37*c217d954SCole Faust {
38*c217d954SCole Faust     static constexpr size_t default_min_max_mws_N1_fp32_neon = 25308;
39*c217d954SCole Faust     static constexpr size_t default_min_max_mws_V1_fp32_neon = 34772;
40*c217d954SCole Faust     static constexpr size_t default_div_mws_N1_fp32_neon = 19043;
41*c217d954SCole Faust     static constexpr size_t default_div_mws_V1_fp32_neon = 25511;
42*c217d954SCole Faust }
43*c217d954SCole Faust #endif /* ENABLE_FP32_KERNELS */
44*c217d954SCole Faust 
45*c217d954SCole Faust namespace arm_compute
46*c217d954SCole Faust {
47*c217d954SCole Faust namespace cpu
48*c217d954SCole Faust {
49*c217d954SCole Faust namespace kernels
50*c217d954SCole Faust {
51*c217d954SCole Faust namespace
52*c217d954SCole Faust {
53*c217d954SCole Faust template <ArithmeticOperation                                                   op>
54*c217d954SCole Faust const std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> available_kernels_arithmetic =
55*c217d954SCole Faust {
56*c217d954SCole Faust     {
57*c217d954SCole Faust         "sve2_qu8_arithmetic",
58*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50302() 59*c217d954SCole Faust         {
60*c217d954SCole Faust             return data.dt == DataType::QASYMM8 && data.isa.sve2 && static_cast<ArithmeticOperation>(data.op) == op;
61*c217d954SCole Faust         },
62*c217d954SCole Faust         REGISTER_QASYMM8_SVE2(sve2_qasymm8_elementwise_binary<op>)
63*c217d954SCole Faust     },
64*c217d954SCole Faust     {
65*c217d954SCole Faust         "sve2_qs8_arithmetic",
66*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50402() 67*c217d954SCole Faust         {
68*c217d954SCole Faust             return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2 && static_cast<ArithmeticOperation>(data.op) == op;
69*c217d954SCole Faust         },
70*c217d954SCole Faust         REGISTER_QASYMM8_SIGNED_SVE2(sve2_qasymm8_signed_elementwise_binary<op>)
71*c217d954SCole Faust     },
72*c217d954SCole Faust     {
73*c217d954SCole Faust         "sve_fp32_arithmetic",
74*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50502() 75*c217d954SCole Faust         {
76*c217d954SCole Faust             return data.dt == DataType::F32 && data.isa.sve && static_cast<ArithmeticOperation>(data.op) == op;
77*c217d954SCole Faust         },
78*c217d954SCole Faust         REGISTER_FP32_SVE(sve_fp32_elementwise_binary<op>)
79*c217d954SCole Faust     },
80*c217d954SCole Faust     {
81*c217d954SCole Faust         "sve_s32_arithmetic",
82*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50602() 83*c217d954SCole Faust         {
84*c217d954SCole Faust             return data.dt == DataType::S32 && data.isa.sve && static_cast<ArithmeticOperation>(data.op) == op;
85*c217d954SCole Faust         },
86*c217d954SCole Faust         REGISTER_INTEGER_SVE(sve_s32_elementwise_binary<op>)
87*c217d954SCole Faust     },
88*c217d954SCole Faust     {
89*c217d954SCole Faust         "sve_s16_arithmetic",
90*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50702() 91*c217d954SCole Faust         {
92*c217d954SCole Faust             return data.dt == DataType::S16 && data.isa.sve && static_cast<ArithmeticOperation>(data.op) == op;
93*c217d954SCole Faust         },
94*c217d954SCole Faust         REGISTER_INTEGER_SVE(sve_s16_elementwise_binary<op>)
95*c217d954SCole Faust     },
96*c217d954SCole Faust     {
97*c217d954SCole Faust         "sve_fp16_arithmetic",
98*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50802() 99*c217d954SCole Faust         {
100*c217d954SCole Faust             return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16 && static_cast<ArithmeticOperation>(data.op) == op;
101*c217d954SCole Faust         },
102*c217d954SCole Faust         REGISTER_FP16_SVE(sve_fp16_elementwise_binary<op>)
103*c217d954SCole Faust     },
104*c217d954SCole Faust     {
105*c217d954SCole Faust         "neon_fp32_arithmetic",
106*c217d954SCole Faust 
107*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50902() 108*c217d954SCole Faust         {
109*c217d954SCole Faust             return data.dt == DataType::F32 && static_cast<ArithmeticOperation>(data.op) == op;
110*c217d954SCole Faust         },
111*c217d954SCole Faust         REGISTER_FP32_NEON(neon_fp32_elementwise_binary<op>)
112*c217d954SCole Faust     },
113*c217d954SCole Faust     {
114*c217d954SCole Faust         "neon_s32_arithmetic",
115*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50a02() 116*c217d954SCole Faust         {
117*c217d954SCole Faust             return data.dt == DataType::S32 && static_cast<ArithmeticOperation>(data.op) == op;
118*c217d954SCole Faust         },
119*c217d954SCole Faust         REGISTER_INTEGER_NEON(neon_s32_elementwise_binary<op>)
120*c217d954SCole Faust     },
121*c217d954SCole Faust     {
122*c217d954SCole Faust         "neon_fp16_arithmetic",
123*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50b02() 124*c217d954SCole Faust         {
125*c217d954SCole Faust             return data.dt == DataType::F16 && data.isa.fp16 && static_cast<ArithmeticOperation>(data.op) == op;
126*c217d954SCole Faust         },
127*c217d954SCole Faust         REGISTER_FP16_NEON(neon_fp16_elementwise_binary<op>)
128*c217d954SCole Faust     },
129*c217d954SCole Faust     {
130*c217d954SCole Faust         "neon_s16_arithmetic",
131*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50c02() 132*c217d954SCole Faust         {
133*c217d954SCole Faust             return data.dt == DataType::S16 && static_cast<ArithmeticOperation>(data.op) == op;
134*c217d954SCole Faust         },
135*c217d954SCole Faust         REGISTER_INTEGER_NEON(neon_s16_elementwise_binary<op>)
136*c217d954SCole Faust     },
137*c217d954SCole Faust     {
138*c217d954SCole Faust         "neon_qu8_arithmetic",
139*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50d02() 140*c217d954SCole Faust         {
141*c217d954SCole Faust             return data.dt == DataType::QASYMM8 && static_cast<ArithmeticOperation>(data.op) == op;
142*c217d954SCole Faust         },
143*c217d954SCole Faust         REGISTER_QASYMM8_NEON(neon_qasymm8_elementwise_binary<op>)
144*c217d954SCole Faust     },
145*c217d954SCole Faust     {
146*c217d954SCole Faust         "neon_qs8_arithmetic",
147*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50e02() 148*c217d954SCole Faust         {
149*c217d954SCole Faust             return data.dt == DataType::QASYMM8_SIGNED && static_cast<ArithmeticOperation>(data.op) == op;
150*c217d954SCole Faust         },
151*c217d954SCole Faust         REGISTER_QASYMM8_SIGNED_NEON(neon_qasymm8_signed_elementwise_binary<op>)
152*c217d954SCole Faust     },
153*c217d954SCole Faust };
154*c217d954SCole Faust template <ComparisonOperation                                                   op>
155*c217d954SCole Faust const std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> available_kernels_comperison =
156*c217d954SCole Faust {
157*c217d954SCole Faust     {
158*c217d954SCole Faust         "sve2_qu8_comparison",
159*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50f02() 160*c217d954SCole Faust         {
161*c217d954SCole Faust             return data.dt == DataType::QASYMM8 && data.isa.sve2 && static_cast<ComparisonOperation>(data.op) == op;
162*c217d954SCole Faust         },
163*c217d954SCole Faust         REGISTER_QASYMM8_SVE2(sve2_qasymm8_comparison_elementwise_binary<op>)
164*c217d954SCole Faust     },
165*c217d954SCole Faust     {
166*c217d954SCole Faust         "sve2_qs8_comparison",
167*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51002() 168*c217d954SCole Faust         {
169*c217d954SCole Faust             return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2 && static_cast<ComparisonOperation>(data.op) == op;
170*c217d954SCole Faust         },
171*c217d954SCole Faust         REGISTER_QASYMM8_SIGNED_SVE2(sve2_qasymm8_signed_comparison_elementwise_binary<op>)
172*c217d954SCole Faust     },
173*c217d954SCole Faust     {
174*c217d954SCole Faust         "sve_u8_comparison",
175*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51102() 176*c217d954SCole Faust         {
177*c217d954SCole Faust             return data.dt == DataType::U8 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
178*c217d954SCole Faust         },
179*c217d954SCole Faust         REGISTER_INTEGER_SVE(sve_u8_comparison_elementwise_binary<op>)
180*c217d954SCole Faust     },
181*c217d954SCole Faust     {
182*c217d954SCole Faust         "sve_fp32_comparison",
183*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51202() 184*c217d954SCole Faust         {
185*c217d954SCole Faust             return data.dt == DataType::F32 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
186*c217d954SCole Faust         },
187*c217d954SCole Faust         REGISTER_FP32_SVE(sve_fp32_comparison_elementwise_binary<op>)
188*c217d954SCole Faust     },
189*c217d954SCole Faust     {
190*c217d954SCole Faust         "sve_s16_comparison",
191*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51302() 192*c217d954SCole Faust         {
193*c217d954SCole Faust             return data.dt == DataType::S16 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
194*c217d954SCole Faust         },
195*c217d954SCole Faust         REGISTER_INTEGER_SVE(sve_s16_comparison_elementwise_binary<op>)
196*c217d954SCole Faust     },
197*c217d954SCole Faust     {
198*c217d954SCole Faust         "sve_s32_comparison",
199*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51402() 200*c217d954SCole Faust         {
201*c217d954SCole Faust             return data.dt == DataType::S32 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
202*c217d954SCole Faust         },
203*c217d954SCole Faust         REGISTER_INTEGER_SVE(sve_s32_comparison_elementwise_binary<op>)
204*c217d954SCole Faust     },
205*c217d954SCole Faust     {
206*c217d954SCole Faust         "sve_fp16_comparison",
207*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51502() 208*c217d954SCole Faust         {
209*c217d954SCole Faust             return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16 && static_cast<ComparisonOperation>(data.op) == op;
210*c217d954SCole Faust         },
211*c217d954SCole Faust         REGISTER_FP16_SVE(sve_fp16_comparison_elementwise_binary<op>)
212*c217d954SCole Faust     },
213*c217d954SCole Faust     {
214*c217d954SCole Faust         "neon_u8_comparison",
215*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51602() 216*c217d954SCole Faust         {
217*c217d954SCole Faust             return data.dt == DataType::U8 && static_cast<ComparisonOperation>(data.op) == op;
218*c217d954SCole Faust         },
219*c217d954SCole Faust         REGISTER_INTEGER_NEON(neon_u8_comparison_elementwise_binary<op>)
220*c217d954SCole Faust     },
221*c217d954SCole Faust     {
222*c217d954SCole Faust         "neon_fp32_comparison",
223*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51702() 224*c217d954SCole Faust         {
225*c217d954SCole Faust             return data.dt == DataType::F32 && static_cast<ComparisonOperation>(data.op) == op;
226*c217d954SCole Faust         },
227*c217d954SCole Faust         REGISTER_FP32_NEON(neon_fp32_comparison_elementwise_binary<op>)
228*c217d954SCole Faust     },
229*c217d954SCole Faust     {
230*c217d954SCole Faust         "neon_s16_comparison",
231*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51802() 232*c217d954SCole Faust         {
233*c217d954SCole Faust             return data.dt == DataType::S16 && static_cast<ComparisonOperation>(data.op) == op;
234*c217d954SCole Faust         },
235*c217d954SCole Faust         REGISTER_INTEGER_NEON(neon_s16_comparison_elementwise_binary<op>)
236*c217d954SCole Faust     },
237*c217d954SCole Faust     {
238*c217d954SCole Faust         "neon_s32_comparison",
239*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51902() 240*c217d954SCole Faust         {
241*c217d954SCole Faust             return data.dt == DataType::S32 && static_cast<ComparisonOperation>(data.op) == op;
242*c217d954SCole Faust         },
243*c217d954SCole Faust         REGISTER_INTEGER_NEON(neon_s32_comparison_elementwise_binary<op>)
244*c217d954SCole Faust     },
245*c217d954SCole Faust     {
246*c217d954SCole Faust         "neon_qu8_comparison",
247*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51a02() 248*c217d954SCole Faust         {
249*c217d954SCole Faust             return data.dt == DataType::QASYMM8 && static_cast<ComparisonOperation>(data.op) == op;
250*c217d954SCole Faust         },
251*c217d954SCole Faust         REGISTER_QASYMM8_NEON(neon_qasymm8_comparison_elementwise_binary<op>)
252*c217d954SCole Faust     },
253*c217d954SCole Faust     {
254*c217d954SCole Faust         "neon_qs8_comparison",
255*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51b02() 256*c217d954SCole Faust         {
257*c217d954SCole Faust             return data.dt == DataType::QASYMM8_SIGNED && static_cast<ComparisonOperation>(data.op) == op;
258*c217d954SCole Faust         },
259*c217d954SCole Faust         REGISTER_QASYMM8_SIGNED_NEON(neon_qasymm8_signed_comparison_elementwise_binary<op>)
260*c217d954SCole Faust     },
261*c217d954SCole Faust     {
262*c217d954SCole Faust         "neon_fp16_comparison",
263*c217d954SCole Faust         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51c02() 264*c217d954SCole Faust         {
265*c217d954SCole Faust             return data.dt == DataType::F16 && data.isa.fp16 && static_cast<ComparisonOperation>(data.op) == op;
266*c217d954SCole Faust         },
267*c217d954SCole Faust         REGISTER_FP16_NEON(neon_fp16_comparison_elementwise_binary<op>)
268*c217d954SCole Faust     },
269*c217d954SCole Faust };
270*c217d954SCole Faust } // namespace
271*c217d954SCole Faust 
get_available_kernels()272*c217d954SCole Faust const std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> &CpuArithmeticKernel::get_available_kernels()
273*c217d954SCole Faust {
274*c217d954SCole Faust     static std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> available_kernels;
275*c217d954SCole Faust     std::move(available_kernels_arithmetic<ArithmeticOperation::ADD>.begin(), available_kernels_arithmetic<ArithmeticOperation::ADD>.end(), std::back_inserter(available_kernels));
276*c217d954SCole Faust     std::move(available_kernels_arithmetic<ArithmeticOperation::SUB>.begin(), available_kernels_arithmetic<ArithmeticOperation::SUB>.end(), std::back_inserter(available_kernels));
277*c217d954SCole Faust     std::move(available_kernels_arithmetic<ArithmeticOperation::DIV>.begin(), available_kernels_arithmetic<ArithmeticOperation::DIV>.end(), std::back_inserter(available_kernels));
278*c217d954SCole Faust     std::move(available_kernels_arithmetic<ArithmeticOperation::MIN>.begin(), available_kernels_arithmetic<ArithmeticOperation::MIN>.end(), std::back_inserter(available_kernels));
279*c217d954SCole Faust     std::move(available_kernels_arithmetic<ArithmeticOperation::MAX>.begin(), available_kernels_arithmetic<ArithmeticOperation::MAX>.end(), std::back_inserter(available_kernels));
280*c217d954SCole Faust     std::move(available_kernels_arithmetic<ArithmeticOperation::SQUARED_DIFF>.begin(), available_kernels_arithmetic<ArithmeticOperation::SQUARED_DIFF>.end(), std::back_inserter(available_kernels));
281*c217d954SCole Faust     std::move(available_kernels_arithmetic<ArithmeticOperation::POWER>.begin(), available_kernels_arithmetic<ArithmeticOperation::POWER>.end(), std::back_inserter(available_kernels));
282*c217d954SCole Faust     std::move(available_kernels_arithmetic<ArithmeticOperation::PRELU>.begin(), available_kernels_arithmetic<ArithmeticOperation::PRELU>.end(), std::back_inserter(available_kernels));
283*c217d954SCole Faust 
284*c217d954SCole Faust     return available_kernels;
285*c217d954SCole Faust }
286*c217d954SCole Faust 
get_available_kernels()287*c217d954SCole Faust const std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> &CpuComparisonKernel::get_available_kernels()
288*c217d954SCole Faust {
289*c217d954SCole Faust     static std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> available_kernels;
290*c217d954SCole Faust     std::move(available_kernels_comperison<ComparisonOperation::Equal>.begin(), available_kernels_comperison<ComparisonOperation::Equal>.end(), std::back_inserter(available_kernels));
291*c217d954SCole Faust     std::move(available_kernels_comperison<ComparisonOperation::NotEqual>.begin(), available_kernels_comperison<ComparisonOperation::NotEqual>.end(), std::back_inserter(available_kernels));
292*c217d954SCole Faust     std::move(available_kernels_comperison<ComparisonOperation::Greater>.begin(), available_kernels_comperison<ComparisonOperation::Greater>.end(), std::back_inserter(available_kernels));
293*c217d954SCole Faust     std::move(available_kernels_comperison<ComparisonOperation::GreaterEqual>.begin(), available_kernels_comperison<ComparisonOperation::GreaterEqual>.end(), std::back_inserter(available_kernels));
294*c217d954SCole Faust     std::move(available_kernels_comperison<ComparisonOperation::Less>.begin(), available_kernels_comperison<ComparisonOperation::Less>.end(), std::back_inserter(available_kernels));
295*c217d954SCole Faust     std::move(available_kernels_comperison<ComparisonOperation::LessEqual>.begin(), available_kernels_comperison<ComparisonOperation::LessEqual>.end(), std::back_inserter(available_kernels));
296*c217d954SCole Faust 
297*c217d954SCole Faust     return available_kernels;
298*c217d954SCole Faust }
299*c217d954SCole Faust 
300*c217d954SCole Faust template <class Derived>
validate_arguments_common(const ITensorInfo & src0,const ITensorInfo & src1,const ITensorInfo & dst)301*c217d954SCole Faust Status CpuElementwiseKernel<Derived>::validate_arguments_common(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
302*c217d954SCole Faust {
303*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src0);
304*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1);
305*c217d954SCole Faust 
306*c217d954SCole Faust     const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape());
307*c217d954SCole Faust 
308*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
309*c217d954SCole Faust 
310*c217d954SCole Faust     // Validate in case of configured dst
311*c217d954SCole Faust     if(dst.total_size() > 0)
312*c217d954SCole Faust     {
313*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0),
314*c217d954SCole Faust                                         "Wrong shape for output");
315*c217d954SCole Faust     }
316*c217d954SCole Faust 
317*c217d954SCole Faust     return Status{};
318*c217d954SCole Faust }
319*c217d954SCole Faust 
configure_common(const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)320*c217d954SCole Faust void CpuArithmeticKernel::configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
321*c217d954SCole Faust {
322*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
323*c217d954SCole Faust 
324*c217d954SCole Faust     const auto *uk = CpuArithmeticKernel::get_implementation(ElementwiseDataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa(), static_cast<int>(_op) });
325*c217d954SCole Faust 
326*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
327*c217d954SCole Faust 
328*c217d954SCole Faust     _run_method = uk->ukernel;
329*c217d954SCole Faust     _name       = std::string("CpuArithmeticKernel").append("/").append(uk->name);
330*c217d954SCole Faust 
331*c217d954SCole Faust     // If any of shapes is dynamic, expect a configured window and dst at run-time.
332*c217d954SCole Faust     if(src0->is_dynamic() || src1->is_dynamic())
333*c217d954SCole Faust     {
334*c217d954SCole Faust         return;
335*c217d954SCole Faust     }
336*c217d954SCole Faust 
337*c217d954SCole Faust     auto shape_and_window = compute_output_shape_and_window(src0->tensor_shape(), src1->tensor_shape());
338*c217d954SCole Faust     auto_init_if_empty(*dst, shape_and_window.first, 1, src0->data_type());
339*c217d954SCole Faust     ICpuKernel::configure(shape_and_window.second);
340*c217d954SCole Faust }
341*c217d954SCole Faust 
configure_common(const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)342*c217d954SCole Faust void CpuComparisonKernel::configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
343*c217d954SCole Faust {
344*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
345*c217d954SCole Faust 
346*c217d954SCole Faust     const auto *uk = CpuComparisonKernel::get_implementation(ElementwiseDataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa(), static_cast<int>(_op) });
347*c217d954SCole Faust 
348*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
349*c217d954SCole Faust 
350*c217d954SCole Faust     _run_method = uk->ukernel;
351*c217d954SCole Faust     _name       = std::string("CpuComparisonKernel").append("/").append(uk->name);
352*c217d954SCole Faust 
353*c217d954SCole Faust     // If any of shapes is dynamic, expect a configured window and dst at run-time.
354*c217d954SCole Faust     if(src0->is_dynamic() || src1->is_dynamic())
355*c217d954SCole Faust     {
356*c217d954SCole Faust         return;
357*c217d954SCole Faust     }
358*c217d954SCole Faust 
359*c217d954SCole Faust     auto shape_and_window = compute_output_shape_and_window(src0->tensor_shape(), src1->tensor_shape());
360*c217d954SCole Faust     auto_init_if_empty(*dst, shape_and_window.first, 1, src0->data_type());
361*c217d954SCole Faust     ICpuKernel::configure(shape_and_window.second);
362*c217d954SCole Faust }
363*c217d954SCole Faust 
364*c217d954SCole Faust template <class Derived>
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)365*c217d954SCole Faust void CpuElementwiseKernel<Derived>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
366*c217d954SCole Faust {
367*c217d954SCole Faust     ARM_COMPUTE_UNUSED(info);
368*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
369*c217d954SCole Faust 
370*c217d954SCole Faust     auto src0 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
371*c217d954SCole Faust     auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
372*c217d954SCole Faust     auto dst  = tensors.get_tensor(TensorType::ACL_DST);
373*c217d954SCole Faust 
374*c217d954SCole Faust     _run_method(src0, src1, dst, window);
375*c217d954SCole Faust }
376*c217d954SCole Faust template void CpuElementwiseKernel<CpuArithmeticKernel>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info);
377*c217d954SCole Faust template void CpuElementwiseKernel<CpuComparisonKernel>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info);
378*c217d954SCole Faust 
379*c217d954SCole Faust template <class Derived>
name() const380*c217d954SCole Faust const char *CpuElementwiseKernel<Derived>::name() const
381*c217d954SCole Faust {
382*c217d954SCole Faust     return _name.c_str();
383*c217d954SCole Faust }
384*c217d954SCole Faust template const char *CpuElementwiseKernel<CpuArithmeticKernel>::name() const;
385*c217d954SCole Faust template const char *CpuElementwiseKernel<CpuComparisonKernel>::name() const;
386*c217d954SCole Faust 
387*c217d954SCole Faust /** Arithmetic operators (min, max, squared_diff) */
configure(ArithmeticOperation op,const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)388*c217d954SCole Faust void CpuArithmeticKernel::configure(ArithmeticOperation op, const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
389*c217d954SCole Faust {
390*c217d954SCole Faust     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
391*c217d954SCole Faust     _op = op;
392*c217d954SCole Faust     CpuArithmeticKernel::configure_common(src0, src1, dst);
393*c217d954SCole Faust }
394*c217d954SCole Faust 
validate_arguments(const ITensorInfo & src0,const ITensorInfo & src1,const ITensorInfo & dst)395*c217d954SCole Faust Status CpuArithmeticKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
396*c217d954SCole Faust {
397*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
398*c217d954SCole Faust     // Validate in case of configured dst
399*c217d954SCole Faust     if(dst.total_size() > 0)
400*c217d954SCole Faust     {
401*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &dst);
402*c217d954SCole Faust     }
403*c217d954SCole Faust     return validate_arguments_common(src0, src1, dst);
404*c217d954SCole Faust }
405*c217d954SCole Faust 
validate(ArithmeticOperation op,const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)406*c217d954SCole Faust Status CpuArithmeticKernel::validate(ArithmeticOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
407*c217d954SCole Faust {
408*c217d954SCole Faust     ARM_COMPUTE_UNUSED(op);
409*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
410*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
411*c217d954SCole Faust     return Status{};
412*c217d954SCole Faust }
413*c217d954SCole Faust 
get_mws(const CPUInfo & platform,size_t thread_count) const414*c217d954SCole Faust size_t CpuArithmeticKernel::get_mws(const CPUInfo &platform, size_t thread_count) const
415*c217d954SCole Faust {
416*c217d954SCole Faust     ARM_COMPUTE_UNUSED(thread_count);
417*c217d954SCole Faust 
418*c217d954SCole Faust #if defined(ENABLE_FP32_KERNELS)
419*c217d954SCole Faust     if(this->_run_method == &neon_fp32_elementwise_binary<ArithmeticOperation::MIN>
420*c217d954SCole Faust     || this->_run_method == &neon_fp32_elementwise_binary<ArithmeticOperation::MAX>)
421*c217d954SCole Faust     {
422*c217d954SCole Faust         size_t mws = ICPPKernel::default_mws;
423*c217d954SCole Faust         if(platform.get_cpu_model() == CPUModel::N1)
424*c217d954SCole Faust         {
425*c217d954SCole Faust             mws = default_min_max_mws_N1_fp32_neon;
426*c217d954SCole Faust         }
427*c217d954SCole Faust         else if(platform.get_cpu_model() == CPUModel::V1)
428*c217d954SCole Faust         {
429*c217d954SCole Faust             mws = default_min_max_mws_V1_fp32_neon;
430*c217d954SCole Faust         }
431*c217d954SCole Faust         else
432*c217d954SCole Faust         {
433*c217d954SCole Faust             return ICPPKernel::default_mws;
434*c217d954SCole Faust         }
435*c217d954SCole Faust 
436*c217d954SCole Faust         // tensor is 1D or was re-interpreted as 1D
437*c217d954SCole Faust         if(this->window().shape().num_dimensions() == 1)
438*c217d954SCole Faust         {
439*c217d954SCole Faust             return mws;
440*c217d954SCole Faust         }
441*c217d954SCole Faust         else
442*c217d954SCole Faust         {
443*c217d954SCole Faust             // scale mws down by the number of elements along all the dimensions (x, z, w, etc) except the one
444*c217d954SCole Faust             // that we parallelize along (the y dimension). This allows for parallelization when the Y_SIZE is small
445*c217d954SCole Faust             // but the other sizes are large, which boosts performance.
446*c217d954SCole Faust             mws = static_cast<size_t>(mws / (this->window().num_iterations_total() / this->window().num_iterations(1)));
447*c217d954SCole Faust             return std::max(static_cast<size_t>(1), mws);
448*c217d954SCole Faust         }
449*c217d954SCole Faust     }
450*c217d954SCole Faust #else /* ENABLE_FP32_KERNELS */
451*c217d954SCole Faust     ARM_COMPUTE_UNUSED(platform);
452*c217d954SCole Faust #endif /* ENABLE_FP32_KERNELS */
453*c217d954SCole Faust     return ICPPKernel::default_mws;
454*c217d954SCole Faust }
455*c217d954SCole Faust 
456*c217d954SCole Faust /** The division operator */
457*c217d954SCole Faust 
configure(const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)458*c217d954SCole Faust void CpuDivisionKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
459*c217d954SCole Faust {
460*c217d954SCole Faust     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
461*c217d954SCole Faust     _op = ArithmeticOperation::DIV;
462*c217d954SCole Faust     CpuArithmeticKernel::configure_common(src0, src1, dst);
463*c217d954SCole Faust }
464*c217d954SCole Faust 
get_mws(const CPUInfo & platform,size_t thread_count) const465*c217d954SCole Faust size_t CpuDivisionKernel::get_mws(const CPUInfo &platform, size_t thread_count) const
466*c217d954SCole Faust {
467*c217d954SCole Faust     ARM_COMPUTE_UNUSED(thread_count);
468*c217d954SCole Faust 
469*c217d954SCole Faust #if defined(ENABLE_FP32_KERNELS)
470*c217d954SCole Faust     if(this->_run_method == &neon_fp32_elementwise_binary<ArithmeticOperation::DIV>)
471*c217d954SCole Faust     {
472*c217d954SCole Faust         size_t mws = ICPPKernel::default_mws;
473*c217d954SCole Faust         if(platform.get_cpu_model() == CPUModel::N1)
474*c217d954SCole Faust         {
475*c217d954SCole Faust             mws = default_div_mws_N1_fp32_neon;
476*c217d954SCole Faust         }
477*c217d954SCole Faust         else if(platform.get_cpu_model() == CPUModel::V1)
478*c217d954SCole Faust         {
479*c217d954SCole Faust             mws = default_div_mws_V1_fp32_neon;
480*c217d954SCole Faust         }
481*c217d954SCole Faust         else
482*c217d954SCole Faust         {
483*c217d954SCole Faust             return ICPPKernel::default_mws;
484*c217d954SCole Faust         }
485*c217d954SCole Faust 
486*c217d954SCole Faust         // tensor is 1D or was re-interpreted as 1D
487*c217d954SCole Faust         if(this->window().shape().num_dimensions() == 1)
488*c217d954SCole Faust         {
489*c217d954SCole Faust             return mws;
490*c217d954SCole Faust         }
491*c217d954SCole Faust         else
492*c217d954SCole Faust         {
493*c217d954SCole Faust             // scale mws down by the number of elements along all the dimensions (x, z, w, etc) except the one
494*c217d954SCole Faust             // that we parallelize along (the y dimension). This allows for parallelization when the Y_SIZE is small
495*c217d954SCole Faust             // but the other sizes are large, which boosts performance.
496*c217d954SCole Faust             mws = static_cast<size_t>(mws / (this->window().num_iterations_total() / this->window().num_iterations(1)));
497*c217d954SCole Faust             return std::max(static_cast<size_t>(1), mws);
498*c217d954SCole Faust         }
499*c217d954SCole Faust     }
500*c217d954SCole Faust #else /* ENABLE_FP32_KERNELS */
501*c217d954SCole Faust     ARM_COMPUTE_UNUSED(platform);
502*c217d954SCole Faust #endif /* ENABLE_FP32_KERNELS */
503*c217d954SCole Faust     return ICPPKernel::default_mws;
504*c217d954SCole Faust }
505*c217d954SCole Faust 
validate_arguments(const ITensorInfo & src0,const ITensorInfo & src1,const ITensorInfo & dst)506*c217d954SCole Faust Status CpuDivisionKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
507*c217d954SCole Faust {
508*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::S32, DataType::F16, DataType::F32);
509*c217d954SCole Faust     return CpuArithmeticKernel::validate_arguments(src0, src1, dst);
510*c217d954SCole Faust }
511*c217d954SCole Faust 
validate(const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)512*c217d954SCole Faust Status CpuDivisionKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
513*c217d954SCole Faust {
514*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
515*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
516*c217d954SCole Faust     return Status{};
517*c217d954SCole Faust }
518*c217d954SCole Faust 
519*c217d954SCole Faust /** The power operator */
configure(const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)520*c217d954SCole Faust void CpuPowerKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
521*c217d954SCole Faust {
522*c217d954SCole Faust     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
523*c217d954SCole Faust     _op = ArithmeticOperation::POWER;
524*c217d954SCole Faust     CpuArithmeticKernel::configure_common(src0, src1, dst);
525*c217d954SCole Faust }
526*c217d954SCole Faust 
validate_arguments(const ITensorInfo & src0,const ITensorInfo & src1,const ITensorInfo & dst)527*c217d954SCole Faust Status CpuPowerKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
528*c217d954SCole Faust {
529*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::F16, DataType::F32);
530*c217d954SCole Faust     return CpuArithmeticKernel::validate_arguments(src0, src1, dst);
531*c217d954SCole Faust }
532*c217d954SCole Faust 
validate(const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)533*c217d954SCole Faust Status CpuPowerKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
534*c217d954SCole Faust {
535*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
536*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
537*c217d954SCole Faust     return Status{};
538*c217d954SCole Faust }
539*c217d954SCole Faust 
540*c217d954SCole Faust /** Comparison operators (equal, not equal, less than, greater than, less than or equal, greater than or equal) */
configure(ComparisonOperation op,const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)541*c217d954SCole Faust void CpuComparisonKernel::configure(ComparisonOperation op, const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
542*c217d954SCole Faust {
543*c217d954SCole Faust     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
544*c217d954SCole Faust     _op = op;
545*c217d954SCole Faust     CpuComparisonKernel::configure_common(src0, src1, dst);
546*c217d954SCole Faust }
547*c217d954SCole Faust 
validate_arguments(const ITensorInfo & src0,const ITensorInfo & src1,const ITensorInfo & dst)548*c217d954SCole Faust Status CpuComparisonKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
549*c217d954SCole Faust {
550*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
551*c217d954SCole Faust     // Validate in case of configured dst
552*c217d954SCole Faust     if(dst.total_size() > 0)
553*c217d954SCole Faust     {
554*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&dst, 1, DataType::U8);
555*c217d954SCole Faust     }
556*c217d954SCole Faust     return validate_arguments_common(src0, src1, dst);
557*c217d954SCole Faust }
558*c217d954SCole Faust 
validate(ComparisonOperation op,const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)559*c217d954SCole Faust Status CpuComparisonKernel::validate(ComparisonOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
560*c217d954SCole Faust {
561*c217d954SCole Faust     ARM_COMPUTE_UNUSED(op);
562*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
563*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
564*c217d954SCole Faust     return Status{};
565*c217d954SCole Faust }
566*c217d954SCole Faust } // namespace kernels
567*c217d954SCole Faust } // namespace cpu
568*c217d954SCole Faust } // namespace arm_compute
569