xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/CpuAddMulAddKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2023 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "src/cpu/kernels/CpuAddMulAddKernel.h"
25*c217d954SCole Faust 
26*c217d954SCole Faust #include "arm_compute/core/ITensor.h"
27*c217d954SCole Faust #include "arm_compute/core/TensorInfo.h"
28*c217d954SCole Faust #include "arm_compute/core/Validate.h"
29*c217d954SCole Faust 
30*c217d954SCole Faust #include "src/core/CPP/Validate.h"
31*c217d954SCole Faust #include "src/core/common/Registrars.h"
32*c217d954SCole Faust #include "src/core/helpers/AutoConfiguration.h"
33*c217d954SCole Faust #include "src/core/helpers/WindowHelpers.h"
34*c217d954SCole Faust #include "src/cpu/kernels/addmuladd/list.h"
35*c217d954SCole Faust 
36*c217d954SCole Faust namespace arm_compute
37*c217d954SCole Faust {
38*c217d954SCole Faust namespace cpu
39*c217d954SCole Faust {
40*c217d954SCole Faust namespace kernels
41*c217d954SCole Faust {
42*c217d954SCole Faust namespace
43*c217d954SCole Faust {
44*c217d954SCole Faust static const std::vector<CpuAddMulAddKernel::AddMulAddKernel> available_kernels =
45*c217d954SCole Faust {
46*c217d954SCole Faust #ifdef __aarch64__
47*c217d954SCole Faust     {
48*c217d954SCole Faust         "neon_fp32_add_mul_add",
__anonb68491330202() 49*c217d954SCole Faust         [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); },
50*c217d954SCole Faust         REGISTER_FP32_NEON(arm_compute::cpu::add_mul_add_fp32_neon)
51*c217d954SCole Faust     },
52*c217d954SCole Faust     {
53*c217d954SCole Faust         "neon_fp16_add_mul_add",
__anonb68491330302() 54*c217d954SCole Faust         [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F16); },
55*c217d954SCole Faust         REGISTER_FP16_NEON(arm_compute::cpu::add_mul_add_fp16_neon)
56*c217d954SCole Faust     },
57*c217d954SCole Faust     {
58*c217d954SCole Faust         "neon_qasymm8_add_mul_add",
__anonb68491330402() 59*c217d954SCole Faust         [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); },
60*c217d954SCole Faust         REGISTER_QASYMM8_NEON(arm_compute::cpu::add_mul_add_u8_neon)
61*c217d954SCole Faust     },
62*c217d954SCole Faust     {
63*c217d954SCole Faust         "neon_qasymm8_signed_add_mul_add",
__anonb68491330502() 64*c217d954SCole Faust         [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
65*c217d954SCole Faust         REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::add_mul_add_s8_neon)
66*c217d954SCole Faust     }
67*c217d954SCole Faust #endif // __aarch64__
68*c217d954SCole Faust };
69*c217d954SCole Faust 
validate_arguments(const ITensorInfo * input1,const ITensorInfo * input2,const ITensorInfo * bn_mul,const ITensorInfo * bn_add,const ITensorInfo * add_output,const ITensorInfo * final_output,ConvertPolicy policy,const ActivationLayerInfo & act_info)70*c217d954SCole Faust Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2,
71*c217d954SCole Faust                           const ITensorInfo *bn_mul, const ITensorInfo *bn_add,
72*c217d954SCole Faust                           const ITensorInfo *add_output, const ITensorInfo *final_output,
73*c217d954SCole Faust                           ConvertPolicy policy, const ActivationLayerInfo &act_info)
74*c217d954SCole Faust {
75*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, bn_mul, bn_add, final_output);
76*c217d954SCole Faust 
77*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_MSG(policy != ConvertPolicy::SATURATE, "Only Saturate Policy is supported");
78*c217d954SCole Faust 
79*c217d954SCole Faust     using ActFunction          = ActivationLayerInfo::ActivationFunction;
80*c217d954SCole Faust     const ActFunction act_func = act_info.activation();
81*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_MSG(
82*c217d954SCole Faust         (act_func != ActFunction::BOUNDED_RELU && act_func != ActFunction::RELU && act_func != ActFunction::LU_BOUNDED_RELU && act_func != ActFunction::IDENTITY),
83*c217d954SCole Faust         "Only RELU Family activations, or no activation, is supported");
84*c217d954SCole Faust 
85*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
86*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
87*c217d954SCole Faust                                                          DataType::F16, DataType::F32);
88*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
89*c217d954SCole Faust 
90*c217d954SCole Faust     if(is_data_type_quantized(input1->data_type()))
91*c217d954SCole Faust     {
92*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bn_mul, 1, DataType::F32);
93*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bn_add, 1, DataType::F32);
94*c217d954SCole Faust     }
95*c217d954SCole Faust     else
96*c217d954SCole Faust     {
97*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, bn_mul);
98*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, bn_add);
99*c217d954SCole Faust     }
100*c217d954SCole Faust 
101*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); // No broadcasting
102*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(bn_mul, bn_add);
103*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_MSG(bn_mul->num_dimensions() != 1, "BatchNorm coefficients should be 1D array");
104*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_MSG(bn_mul->tensor_shape()[0] != input1->tensor_shape()[0], "First dimensions of inputs and batchNorm coefs should match");
105*c217d954SCole Faust 
106*c217d954SCole Faust     // Validate in case we have add layer's output (intermediate) initialized
107*c217d954SCole Faust     if(add_output != nullptr && add_output->total_size() > 0)
108*c217d954SCole Faust     {
109*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, add_output);
110*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, add_output);
111*c217d954SCole Faust     }
112*c217d954SCole Faust 
113*c217d954SCole Faust     // Validate in case final output has been initialized
114*c217d954SCole Faust     if(final_output->total_size() > 0)
115*c217d954SCole Faust     {
116*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, final_output);
117*c217d954SCole Faust         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, final_output);
118*c217d954SCole Faust     }
119*c217d954SCole Faust 
120*c217d954SCole Faust     const auto uk = CpuAddMulAddKernel::get_implementation<DataTypeISASelectorData>(DataTypeISASelectorData{ input1->data_type(), CPUInfo::get().get_isa() });
121*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
122*c217d954SCole Faust 
123*c217d954SCole Faust     return Status{};
124*c217d954SCole Faust }
125*c217d954SCole Faust } // namespace
126*c217d954SCole Faust 
configure(const ITensorInfo * input1,const ITensorInfo * input2,const ITensorInfo * bn_mul,const ITensorInfo * bn_add,ITensorInfo * add_output,ITensorInfo * final_output,ConvertPolicy policy,const ActivationLayerInfo & act_info)127*c217d954SCole Faust void CpuAddMulAddKernel::configure(const ITensorInfo *input1, const ITensorInfo *input2,
128*c217d954SCole Faust                                    const ITensorInfo *bn_mul, const ITensorInfo *bn_add,
129*c217d954SCole Faust                                    ITensorInfo *add_output, ITensorInfo *final_output,
130*c217d954SCole Faust                                    ConvertPolicy policy, const ActivationLayerInfo &act_info)
131*c217d954SCole Faust {
132*c217d954SCole Faust     ARM_COMPUTE_UNUSED(bn_mul, bn_add, input2);
133*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, bn_add, bn_mul, final_output);
134*c217d954SCole Faust     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1, input2, bn_mul, bn_add, add_output, final_output, policy, act_info));
135*c217d954SCole Faust 
136*c217d954SCole Faust     const auto uk = CpuAddMulAddKernel::get_implementation<DataTypeISASelectorData>(DataTypeISASelectorData{ input1->data_type(), CPUInfo::get().get_isa() });
137*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
138*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(uk->ukernel == nullptr);
139*c217d954SCole Faust 
140*c217d954SCole Faust     _policy     = policy;
141*c217d954SCole Faust     _act_info   = act_info;
142*c217d954SCole Faust     _run_method = uk->ukernel;
143*c217d954SCole Faust     _name       = std::string("CpuAddMulAddKernel/").append(uk->name);
144*c217d954SCole Faust 
145*c217d954SCole Faust     // Auto initialize outputs if not initialized
146*c217d954SCole Faust     set_shape_if_empty(*final_output, input1->tensor_shape());
147*c217d954SCole Faust     set_data_type_if_unknown(*final_output, input1->data_type());
148*c217d954SCole Faust 
149*c217d954SCole Faust     if(add_output != nullptr)
150*c217d954SCole Faust     {
151*c217d954SCole Faust         set_shape_if_empty(*add_output, input1->tensor_shape());
152*c217d954SCole Faust         set_data_type_if_unknown(*add_output, input1->data_type());
153*c217d954SCole Faust     }
154*c217d954SCole Faust 
155*c217d954SCole Faust     // Configure kernel window
156*c217d954SCole Faust     Window win;
157*c217d954SCole Faust     win = calculate_max_window(*final_output, Steps());
158*c217d954SCole Faust     ICpuKernel::configure(win);
159*c217d954SCole Faust }
160*c217d954SCole Faust 
validate(const ITensorInfo * input1,const ITensorInfo * input2,const ITensorInfo * bn_mul,const ITensorInfo * bn_add,const ITensorInfo * add_output,const ITensorInfo * final_output,ConvertPolicy policy,const ActivationLayerInfo & act_info)161*c217d954SCole Faust Status CpuAddMulAddKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2,
162*c217d954SCole Faust                                     const ITensorInfo *bn_mul, const ITensorInfo *bn_add,
163*c217d954SCole Faust                                     const ITensorInfo *add_output, const ITensorInfo *final_output,
164*c217d954SCole Faust                                     ConvertPolicy policy, const ActivationLayerInfo &act_info)
165*c217d954SCole Faust {
166*c217d954SCole Faust     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, bn_mul, bn_add, final_output);
167*c217d954SCole Faust 
168*c217d954SCole Faust     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, bn_mul, bn_add, add_output, final_output, policy, act_info));
169*c217d954SCole Faust 
170*c217d954SCole Faust     return Status{};
171*c217d954SCole Faust }
172*c217d954SCole Faust 
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)173*c217d954SCole Faust void CpuAddMulAddKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
174*c217d954SCole Faust {
175*c217d954SCole Faust     ARM_COMPUTE_UNUSED(info);
176*c217d954SCole Faust 
177*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
178*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
179*c217d954SCole Faust 
180*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(tensors.empty());
181*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(_run_method == nullptr);
182*c217d954SCole Faust 
183*c217d954SCole Faust     const ITensor *input1       = tensors.get_const_tensor(TensorType::ACL_SRC_0);
184*c217d954SCole Faust     const ITensor *input2       = tensors.get_const_tensor(TensorType::ACL_SRC_1);
185*c217d954SCole Faust     const ITensor *bn_mul       = tensors.get_const_tensor(TensorType::ACL_SRC_2);
186*c217d954SCole Faust     const ITensor *bn_add       = tensors.get_const_tensor(TensorType::ACL_SRC_3);
187*c217d954SCole Faust     ITensor       *add_output   = tensors.get_tensor(TensorType::ACL_DST_0);
188*c217d954SCole Faust     ITensor       *final_output = tensors.get_tensor(TensorType::ACL_DST_1);
189*c217d954SCole Faust 
190*c217d954SCole Faust     _run_method(input1, input2, bn_mul, bn_add, add_output, final_output, _policy, _act_info, window);
191*c217d954SCole Faust }
192*c217d954SCole Faust 
name() const193*c217d954SCole Faust const char *CpuAddMulAddKernel::name() const
194*c217d954SCole Faust {
195*c217d954SCole Faust     return _name.c_str();
196*c217d954SCole Faust }
197*c217d954SCole Faust 
get_available_kernels()198*c217d954SCole Faust const std::vector<CpuAddMulAddKernel::AddMulAddKernel> &CpuAddMulAddKernel::get_available_kernels()
199*c217d954SCole Faust {
200*c217d954SCole Faust     return available_kernels;
201*c217d954SCole Faust }
202*c217d954SCole Faust } // namespace kernels
203*c217d954SCole Faust } // namespace cpu
204*c217d954SCole Faust } // namespace arm_compute
205