xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/CpuElementwiseKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/cpu/kernels/CpuElementwiseKernel.h"
25 
26 #include "arm_compute/core/Helpers.h"
27 #include "src/core/CPP/Validate.h"
28 #include "src/core/common/Registrars.h"
29 #include "src/core/helpers/AutoConfiguration.h"
30 #include "src/core/helpers/WindowHelpers.h"
31 #include "src/cpu/kernels/elementwise_binary/list.h"
32 
33 #include <arm_neon.h>
34 
35 #if defined(ENABLE_FP32_KERNELS)
36 namespace
37 {
38     static constexpr size_t default_min_max_mws_N1_fp32_neon = 25308;
39     static constexpr size_t default_min_max_mws_V1_fp32_neon = 34772;
40     static constexpr size_t default_div_mws_N1_fp32_neon = 19043;
41     static constexpr size_t default_div_mws_V1_fp32_neon = 25511;
42 }
43 #endif /* ENABLE_FP32_KERNELS */
44 
45 namespace arm_compute
46 {
47 namespace cpu
48 {
49 namespace kernels
50 {
51 namespace
52 {
53 template <ArithmeticOperation                                                   op>
54 const std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> available_kernels_arithmetic =
55 {
56     {
57         "sve2_qu8_arithmetic",
58         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50302() 59         {
60             return data.dt == DataType::QASYMM8 && data.isa.sve2 && static_cast<ArithmeticOperation>(data.op) == op;
61         },
62         REGISTER_QASYMM8_SVE2(sve2_qasymm8_elementwise_binary<op>)
63     },
64     {
65         "sve2_qs8_arithmetic",
66         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50402() 67         {
68             return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2 && static_cast<ArithmeticOperation>(data.op) == op;
69         },
70         REGISTER_QASYMM8_SIGNED_SVE2(sve2_qasymm8_signed_elementwise_binary<op>)
71     },
72     {
73         "sve_fp32_arithmetic",
74         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50502() 75         {
76             return data.dt == DataType::F32 && data.isa.sve && static_cast<ArithmeticOperation>(data.op) == op;
77         },
78         REGISTER_FP32_SVE(sve_fp32_elementwise_binary<op>)
79     },
80     {
81         "sve_s32_arithmetic",
82         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50602() 83         {
84             return data.dt == DataType::S32 && data.isa.sve && static_cast<ArithmeticOperation>(data.op) == op;
85         },
86         REGISTER_INTEGER_SVE(sve_s32_elementwise_binary<op>)
87     },
88     {
89         "sve_s16_arithmetic",
90         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50702() 91         {
92             return data.dt == DataType::S16 && data.isa.sve && static_cast<ArithmeticOperation>(data.op) == op;
93         },
94         REGISTER_INTEGER_SVE(sve_s16_elementwise_binary<op>)
95     },
96     {
97         "sve_fp16_arithmetic",
98         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50802() 99         {
100             return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16 && static_cast<ArithmeticOperation>(data.op) == op;
101         },
102         REGISTER_FP16_SVE(sve_fp16_elementwise_binary<op>)
103     },
104     {
105         "neon_fp32_arithmetic",
106 
107         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50902() 108         {
109             return data.dt == DataType::F32 && static_cast<ArithmeticOperation>(data.op) == op;
110         },
111         REGISTER_FP32_NEON(neon_fp32_elementwise_binary<op>)
112     },
113     {
114         "neon_s32_arithmetic",
115         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50a02() 116         {
117             return data.dt == DataType::S32 && static_cast<ArithmeticOperation>(data.op) == op;
118         },
119         REGISTER_INTEGER_NEON(neon_s32_elementwise_binary<op>)
120     },
121     {
122         "neon_fp16_arithmetic",
123         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50b02() 124         {
125             return data.dt == DataType::F16 && data.isa.fp16 && static_cast<ArithmeticOperation>(data.op) == op;
126         },
127         REGISTER_FP16_NEON(neon_fp16_elementwise_binary<op>)
128     },
129     {
130         "neon_s16_arithmetic",
131         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50c02() 132         {
133             return data.dt == DataType::S16 && static_cast<ArithmeticOperation>(data.op) == op;
134         },
135         REGISTER_INTEGER_NEON(neon_s16_elementwise_binary<op>)
136     },
137     {
138         "neon_qu8_arithmetic",
139         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50d02() 140         {
141             return data.dt == DataType::QASYMM8 && static_cast<ArithmeticOperation>(data.op) == op;
142         },
143         REGISTER_QASYMM8_NEON(neon_qasymm8_elementwise_binary<op>)
144     },
145     {
146         "neon_qs8_arithmetic",
147         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50e02() 148         {
149             return data.dt == DataType::QASYMM8_SIGNED && static_cast<ArithmeticOperation>(data.op) == op;
150         },
151         REGISTER_QASYMM8_SIGNED_NEON(neon_qasymm8_signed_elementwise_binary<op>)
152     },
153 };
154 template <ComparisonOperation                                                   op>
155 const std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> available_kernels_comperison =
156 {
157     {
158         "sve2_qu8_comparison",
159         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b50f02() 160         {
161             return data.dt == DataType::QASYMM8 && data.isa.sve2 && static_cast<ComparisonOperation>(data.op) == op;
162         },
163         REGISTER_QASYMM8_SVE2(sve2_qasymm8_comparison_elementwise_binary<op>)
164     },
165     {
166         "sve2_qs8_comparison",
167         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51002() 168         {
169             return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2 && static_cast<ComparisonOperation>(data.op) == op;
170         },
171         REGISTER_QASYMM8_SIGNED_SVE2(sve2_qasymm8_signed_comparison_elementwise_binary<op>)
172     },
173     {
174         "sve_u8_comparison",
175         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51102() 176         {
177             return data.dt == DataType::U8 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
178         },
179         REGISTER_INTEGER_SVE(sve_u8_comparison_elementwise_binary<op>)
180     },
181     {
182         "sve_fp32_comparison",
183         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51202() 184         {
185             return data.dt == DataType::F32 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
186         },
187         REGISTER_FP32_SVE(sve_fp32_comparison_elementwise_binary<op>)
188     },
189     {
190         "sve_s16_comparison",
191         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51302() 192         {
193             return data.dt == DataType::S16 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
194         },
195         REGISTER_INTEGER_SVE(sve_s16_comparison_elementwise_binary<op>)
196     },
197     {
198         "sve_s32_comparison",
199         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51402() 200         {
201             return data.dt == DataType::S32 && data.isa.sve && static_cast<ComparisonOperation>(data.op) == op;
202         },
203         REGISTER_INTEGER_SVE(sve_s32_comparison_elementwise_binary<op>)
204     },
205     {
206         "sve_fp16_comparison",
207         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51502() 208         {
209             return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16 && static_cast<ComparisonOperation>(data.op) == op;
210         },
211         REGISTER_FP16_SVE(sve_fp16_comparison_elementwise_binary<op>)
212     },
213     {
214         "neon_u8_comparison",
215         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51602() 216         {
217             return data.dt == DataType::U8 && static_cast<ComparisonOperation>(data.op) == op;
218         },
219         REGISTER_INTEGER_NEON(neon_u8_comparison_elementwise_binary<op>)
220     },
221     {
222         "neon_fp32_comparison",
223         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51702() 224         {
225             return data.dt == DataType::F32 && static_cast<ComparisonOperation>(data.op) == op;
226         },
227         REGISTER_FP32_NEON(neon_fp32_comparison_elementwise_binary<op>)
228     },
229     {
230         "neon_s16_comparison",
231         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51802() 232         {
233             return data.dt == DataType::S16 && static_cast<ComparisonOperation>(data.op) == op;
234         },
235         REGISTER_INTEGER_NEON(neon_s16_comparison_elementwise_binary<op>)
236     },
237     {
238         "neon_s32_comparison",
239         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51902() 240         {
241             return data.dt == DataType::S32 && static_cast<ComparisonOperation>(data.op) == op;
242         },
243         REGISTER_INTEGER_NEON(neon_s32_comparison_elementwise_binary<op>)
244     },
245     {
246         "neon_qu8_comparison",
247         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51a02() 248         {
249             return data.dt == DataType::QASYMM8 && static_cast<ComparisonOperation>(data.op) == op;
250         },
251         REGISTER_QASYMM8_NEON(neon_qasymm8_comparison_elementwise_binary<op>)
252     },
253     {
254         "neon_qs8_comparison",
255         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51b02() 256         {
257             return data.dt == DataType::QASYMM8_SIGNED && static_cast<ComparisonOperation>(data.op) == op;
258         },
259         REGISTER_QASYMM8_SIGNED_NEON(neon_qasymm8_signed_comparison_elementwise_binary<op>)
260     },
261     {
262         "neon_fp16_comparison",
263         [](const ElementwiseDataTypeISASelectorData & data)
__anonf1c139b51c02() 264         {
265             return data.dt == DataType::F16 && data.isa.fp16 && static_cast<ComparisonOperation>(data.op) == op;
266         },
267         REGISTER_FP16_NEON(neon_fp16_comparison_elementwise_binary<op>)
268     },
269 };
270 } // namespace
271 
get_available_kernels()272 const std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> &CpuArithmeticKernel::get_available_kernels()
273 {
274     static std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> available_kernels;
275     std::move(available_kernels_arithmetic<ArithmeticOperation::ADD>.begin(), available_kernels_arithmetic<ArithmeticOperation::ADD>.end(), std::back_inserter(available_kernels));
276     std::move(available_kernels_arithmetic<ArithmeticOperation::SUB>.begin(), available_kernels_arithmetic<ArithmeticOperation::SUB>.end(), std::back_inserter(available_kernels));
277     std::move(available_kernels_arithmetic<ArithmeticOperation::DIV>.begin(), available_kernels_arithmetic<ArithmeticOperation::DIV>.end(), std::back_inserter(available_kernels));
278     std::move(available_kernels_arithmetic<ArithmeticOperation::MIN>.begin(), available_kernels_arithmetic<ArithmeticOperation::MIN>.end(), std::back_inserter(available_kernels));
279     std::move(available_kernels_arithmetic<ArithmeticOperation::MAX>.begin(), available_kernels_arithmetic<ArithmeticOperation::MAX>.end(), std::back_inserter(available_kernels));
280     std::move(available_kernels_arithmetic<ArithmeticOperation::SQUARED_DIFF>.begin(), available_kernels_arithmetic<ArithmeticOperation::SQUARED_DIFF>.end(), std::back_inserter(available_kernels));
281     std::move(available_kernels_arithmetic<ArithmeticOperation::POWER>.begin(), available_kernels_arithmetic<ArithmeticOperation::POWER>.end(), std::back_inserter(available_kernels));
282     std::move(available_kernels_arithmetic<ArithmeticOperation::PRELU>.begin(), available_kernels_arithmetic<ArithmeticOperation::PRELU>.end(), std::back_inserter(available_kernels));
283 
284     return available_kernels;
285 }
286 
get_available_kernels()287 const std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> &CpuComparisonKernel::get_available_kernels()
288 {
289     static std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> available_kernels;
290     std::move(available_kernels_comperison<ComparisonOperation::Equal>.begin(), available_kernels_comperison<ComparisonOperation::Equal>.end(), std::back_inserter(available_kernels));
291     std::move(available_kernels_comperison<ComparisonOperation::NotEqual>.begin(), available_kernels_comperison<ComparisonOperation::NotEqual>.end(), std::back_inserter(available_kernels));
292     std::move(available_kernels_comperison<ComparisonOperation::Greater>.begin(), available_kernels_comperison<ComparisonOperation::Greater>.end(), std::back_inserter(available_kernels));
293     std::move(available_kernels_comperison<ComparisonOperation::GreaterEqual>.begin(), available_kernels_comperison<ComparisonOperation::GreaterEqual>.end(), std::back_inserter(available_kernels));
294     std::move(available_kernels_comperison<ComparisonOperation::Less>.begin(), available_kernels_comperison<ComparisonOperation::Less>.end(), std::back_inserter(available_kernels));
295     std::move(available_kernels_comperison<ComparisonOperation::LessEqual>.begin(), available_kernels_comperison<ComparisonOperation::LessEqual>.end(), std::back_inserter(available_kernels));
296 
297     return available_kernels;
298 }
299 
300 template <class Derived>
validate_arguments_common(const ITensorInfo & src0,const ITensorInfo & src1,const ITensorInfo & dst)301 Status CpuElementwiseKernel<Derived>::validate_arguments_common(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
302 {
303     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src0);
304     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1);
305 
306     const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape());
307 
308     ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
309 
310     // Validate in case of configured dst
311     if(dst.total_size() > 0)
312     {
313         ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0),
314                                         "Wrong shape for output");
315     }
316 
317     return Status{};
318 }
319 
configure_common(const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)320 void CpuArithmeticKernel::configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
321 {
322     ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
323 
324     const auto *uk = CpuArithmeticKernel::get_implementation(ElementwiseDataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa(), static_cast<int>(_op) });
325 
326     ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
327 
328     _run_method = uk->ukernel;
329     _name       = std::string("CpuArithmeticKernel").append("/").append(uk->name);
330 
331     // If any of shapes is dynamic, expect a configured window and dst at run-time.
332     if(src0->is_dynamic() || src1->is_dynamic())
333     {
334         return;
335     }
336 
337     auto shape_and_window = compute_output_shape_and_window(src0->tensor_shape(), src1->tensor_shape());
338     auto_init_if_empty(*dst, shape_and_window.first, 1, src0->data_type());
339     ICpuKernel::configure(shape_and_window.second);
340 }
341 
configure_common(const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)342 void CpuComparisonKernel::configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
343 {
344     ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
345 
346     const auto *uk = CpuComparisonKernel::get_implementation(ElementwiseDataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa(), static_cast<int>(_op) });
347 
348     ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
349 
350     _run_method = uk->ukernel;
351     _name       = std::string("CpuComparisonKernel").append("/").append(uk->name);
352 
353     // If any of shapes is dynamic, expect a configured window and dst at run-time.
354     if(src0->is_dynamic() || src1->is_dynamic())
355     {
356         return;
357     }
358 
359     auto shape_and_window = compute_output_shape_and_window(src0->tensor_shape(), src1->tensor_shape());
360     auto_init_if_empty(*dst, shape_and_window.first, 1, src0->data_type());
361     ICpuKernel::configure(shape_and_window.second);
362 }
363 
364 template <class Derived>
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)365 void CpuElementwiseKernel<Derived>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
366 {
367     ARM_COMPUTE_UNUSED(info);
368     ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
369 
370     auto src0 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
371     auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
372     auto dst  = tensors.get_tensor(TensorType::ACL_DST);
373 
374     _run_method(src0, src1, dst, window);
375 }
376 template void CpuElementwiseKernel<CpuArithmeticKernel>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info);
377 template void CpuElementwiseKernel<CpuComparisonKernel>::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info);
378 
379 template <class Derived>
name() const380 const char *CpuElementwiseKernel<Derived>::name() const
381 {
382     return _name.c_str();
383 }
384 template const char *CpuElementwiseKernel<CpuArithmeticKernel>::name() const;
385 template const char *CpuElementwiseKernel<CpuComparisonKernel>::name() const;
386 
387 /** Arithmetic operators (min, max, squared_diff) */
configure(ArithmeticOperation op,const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)388 void CpuArithmeticKernel::configure(ArithmeticOperation op, const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
389 {
390     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
391     _op = op;
392     CpuArithmeticKernel::configure_common(src0, src1, dst);
393 }
394 
validate_arguments(const ITensorInfo & src0,const ITensorInfo & src1,const ITensorInfo & dst)395 Status CpuArithmeticKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
396 {
397     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
398     // Validate in case of configured dst
399     if(dst.total_size() > 0)
400     {
401         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &dst);
402     }
403     return validate_arguments_common(src0, src1, dst);
404 }
405 
validate(ArithmeticOperation op,const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)406 Status CpuArithmeticKernel::validate(ArithmeticOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
407 {
408     ARM_COMPUTE_UNUSED(op);
409     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
410     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
411     return Status{};
412 }
413 
get_mws(const CPUInfo & platform,size_t thread_count) const414 size_t CpuArithmeticKernel::get_mws(const CPUInfo &platform, size_t thread_count) const
415 {
416     ARM_COMPUTE_UNUSED(thread_count);
417 
418 #if defined(ENABLE_FP32_KERNELS)
419     if(this->_run_method == &neon_fp32_elementwise_binary<ArithmeticOperation::MIN>
420     || this->_run_method == &neon_fp32_elementwise_binary<ArithmeticOperation::MAX>)
421     {
422         size_t mws = ICPPKernel::default_mws;
423         if(platform.get_cpu_model() == CPUModel::N1)
424         {
425             mws = default_min_max_mws_N1_fp32_neon;
426         }
427         else if(platform.get_cpu_model() == CPUModel::V1)
428         {
429             mws = default_min_max_mws_V1_fp32_neon;
430         }
431         else
432         {
433             return ICPPKernel::default_mws;
434         }
435 
436         // tensor is 1D or was re-interpreted as 1D
437         if(this->window().shape().num_dimensions() == 1)
438         {
439             return mws;
440         }
441         else
442         {
443             // scale mws down by the number of elements along all the dimensions (x, z, w, etc) except the one
444             // that we parallelize along (the y dimension). This allows for parallelization when the Y_SIZE is small
445             // but the other sizes are large, which boosts performance.
446             mws = static_cast<size_t>(mws / (this->window().num_iterations_total() / this->window().num_iterations(1)));
447             return std::max(static_cast<size_t>(1), mws);
448         }
449     }
450 #else /* ENABLE_FP32_KERNELS */
451     ARM_COMPUTE_UNUSED(platform);
452 #endif /* ENABLE_FP32_KERNELS */
453     return ICPPKernel::default_mws;
454 }
455 
456 /** The division operator */
457 
configure(const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)458 void CpuDivisionKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
459 {
460     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
461     _op = ArithmeticOperation::DIV;
462     CpuArithmeticKernel::configure_common(src0, src1, dst);
463 }
464 
get_mws(const CPUInfo & platform,size_t thread_count) const465 size_t CpuDivisionKernel::get_mws(const CPUInfo &platform, size_t thread_count) const
466 {
467     ARM_COMPUTE_UNUSED(thread_count);
468 
469 #if defined(ENABLE_FP32_KERNELS)
470     if(this->_run_method == &neon_fp32_elementwise_binary<ArithmeticOperation::DIV>)
471     {
472         size_t mws = ICPPKernel::default_mws;
473         if(platform.get_cpu_model() == CPUModel::N1)
474         {
475             mws = default_div_mws_N1_fp32_neon;
476         }
477         else if(platform.get_cpu_model() == CPUModel::V1)
478         {
479             mws = default_div_mws_V1_fp32_neon;
480         }
481         else
482         {
483             return ICPPKernel::default_mws;
484         }
485 
486         // tensor is 1D or was re-interpreted as 1D
487         if(this->window().shape().num_dimensions() == 1)
488         {
489             return mws;
490         }
491         else
492         {
493             // scale mws down by the number of elements along all the dimensions (x, z, w, etc) except the one
494             // that we parallelize along (the y dimension). This allows for parallelization when the Y_SIZE is small
495             // but the other sizes are large, which boosts performance.
496             mws = static_cast<size_t>(mws / (this->window().num_iterations_total() / this->window().num_iterations(1)));
497             return std::max(static_cast<size_t>(1), mws);
498         }
499     }
500 #else /* ENABLE_FP32_KERNELS */
501     ARM_COMPUTE_UNUSED(platform);
502 #endif /* ENABLE_FP32_KERNELS */
503     return ICPPKernel::default_mws;
504 }
505 
validate_arguments(const ITensorInfo & src0,const ITensorInfo & src1,const ITensorInfo & dst)506 Status CpuDivisionKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
507 {
508     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::S32, DataType::F16, DataType::F32);
509     return CpuArithmeticKernel::validate_arguments(src0, src1, dst);
510 }
511 
validate(const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)512 Status CpuDivisionKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
513 {
514     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
515     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
516     return Status{};
517 }
518 
519 /** The power operator */
configure(const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)520 void CpuPowerKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
521 {
522     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
523     _op = ArithmeticOperation::POWER;
524     CpuArithmeticKernel::configure_common(src0, src1, dst);
525 }
526 
validate_arguments(const ITensorInfo & src0,const ITensorInfo & src1,const ITensorInfo & dst)527 Status CpuPowerKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
528 {
529     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::F16, DataType::F32);
530     return CpuArithmeticKernel::validate_arguments(src0, src1, dst);
531 }
532 
validate(const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)533 Status CpuPowerKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
534 {
535     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
536     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
537     return Status{};
538 }
539 
540 /** Comparison operators (equal, not equal, less than, greater than, less than or equal, greater than or equal) */
configure(ComparisonOperation op,const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)541 void CpuComparisonKernel::configure(ComparisonOperation op, const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
542 {
543     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
544     _op = op;
545     CpuComparisonKernel::configure_common(src0, src1, dst);
546 }
547 
validate_arguments(const ITensorInfo & src0,const ITensorInfo & src1,const ITensorInfo & dst)548 Status CpuComparisonKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
549 {
550     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
551     // Validate in case of configured dst
552     if(dst.total_size() > 0)
553     {
554         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&dst, 1, DataType::U8);
555     }
556     return validate_arguments_common(src0, src1, dst);
557 }
558 
validate(ComparisonOperation op,const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)559 Status CpuComparisonKernel::validate(ComparisonOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
560 {
561     ARM_COMPUTE_UNUSED(op);
562     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
563     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
564     return Status{};
565 }
566 } // namespace kernels
567 } // namespace cpu
568 } // namespace arm_compute
569