xref: /aosp_15_r20/external/armnn/src/backends/aclCommon/ArmComputeUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadData.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/Types.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/FunctionDescriptors.h>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
17*89c4ff92SAndroid Build Coastguard Worker #include "neon/workloads/NeonReduceWorkload.hpp"
18*89c4ff92SAndroid Build Coastguard Worker #endif
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
21*89c4ff92SAndroid Build Coastguard Worker #include "cl/workloads/ClReduceWorkload.hpp"
22*89c4ff92SAndroid Build Coastguard Worker #endif
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker namespace armnn
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::NormalizationLayerInfo
CreateAclNormalizationLayerInfoForL2Normalization(const armnn::TensorInfo & tensorInfo,armnn::DataLayout dataLayout)28*89c4ff92SAndroid Build Coastguard Worker CreateAclNormalizationLayerInfoForL2Normalization(const armnn::TensorInfo& tensorInfo,
29*89c4ff92SAndroid Build Coastguard Worker                                                   armnn::DataLayout dataLayout)
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker     unsigned int depthDimension = dataLayout == armnn::DataLayout::NCHW ? 1 : 3;
32*89c4ff92SAndroid Build Coastguard Worker     const unsigned int depth = tensorInfo.GetShape()[depthDimension];
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker     // At the time of writing, {CL|Neon}L2Normalization performs the reduction only along dimension 0. This version of
35*89c4ff92SAndroid Build Coastguard Worker     // L2 Normalization always performs the reduction along the depth axis, though. Thus, we repurpose
36*89c4ff92SAndroid Build Coastguard Worker     // {CL|Neon}NormalizationLayers to act as depthwise L2 normalizations by carefully chosing the normalization
37*89c4ff92SAndroid Build Coastguard Worker     // parameters.
38*89c4ff92SAndroid Build Coastguard Worker     //
39*89c4ff92SAndroid Build Coastguard Worker     // Please refer to both the reference implementation of the normalization layer and the implementation of
40*89c4ff92SAndroid Build Coastguard Worker     // {CL|Neon}NormalizationLayer when checking the derivations for the parameter values below.
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker     // Make sure normalization covers the entire depth range. ACL requires the normalization size to be odd.
43*89c4ff92SAndroid Build Coastguard Worker     // CL: This does not result in extra kernel threads not doing any work: See usage of the RADIUS parameter in
44*89c4ff92SAndroid Build Coastguard Worker     // ACL's normalization_layer_cross_map() CL function.
45*89c4ff92SAndroid Build Coastguard Worker     const uint32_t normSize = depth * 2u + 1u;
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     // See ACL's NormalizationLayerInfo::scale_coeff() definition.
48*89c4ff92SAndroid Build Coastguard Worker     // For the reference implementation, to make alpha_ become 1, we'd have to use alpha = normSize instead.
49*89c4ff92SAndroid Build Coastguard Worker     const float alpha = 1.0f;
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     // Don't offset the reduction.
52*89c4ff92SAndroid Build Coastguard Worker     const float kappa = 0.0f;
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     // pow(reduction, -0.5) = 1 / sqrt(reduction)
55*89c4ff92SAndroid Build Coastguard Worker     const float beta = 0.5f;
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::NormalizationLayerInfo(arm_compute::NormType::CROSS_MAP, normSize, alpha, beta, kappa, false);
58*89c4ff92SAndroid Build Coastguard Worker }
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ActivationLayerInfo::ActivationFunction
ConvertActivationFunctionToAclActivationFunction(ActivationFunction armnnFunction)61*89c4ff92SAndroid Build Coastguard Worker ConvertActivationFunctionToAclActivationFunction(ActivationFunction armnnFunction)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker     using AclActivationFunction = arm_compute::ActivationLayerInfo::ActivationFunction;
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker     switch (armnnFunction)
66*89c4ff92SAndroid Build Coastguard Worker     {
67*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::Linear:        return AclActivationFunction::LINEAR;
68*89c4ff92SAndroid Build Coastguard Worker         // Arm compute's 'logistic' function is non-parameterized, so it is exactly a sigmoid function.
69*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::Sigmoid:       return AclActivationFunction::LOGISTIC;
70*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::ReLu:          return AclActivationFunction::RELU;
71*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::BoundedReLu:   return AclActivationFunction::LU_BOUNDED_RELU;
72*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::SoftReLu:      return AclActivationFunction::SOFT_RELU;
73*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::LeakyReLu:     return AclActivationFunction::LEAKY_RELU;
74*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::Abs:           return AclActivationFunction::ABS;
75*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::Sqrt:          return AclActivationFunction::SQRT;
76*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::Square:        return AclActivationFunction::SQUARE;
77*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::TanH:          return AclActivationFunction::TANH;
78*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::Elu:           return AclActivationFunction::ELU;
79*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::HardSwish:     return AclActivationFunction::HARD_SWISH;
80*89c4ff92SAndroid Build Coastguard Worker         default:                                throw InvalidArgumentException("Unsupported activation function");
81*89c4ff92SAndroid Build Coastguard Worker     }
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ActivationLayerInfo
ConvertActivationDescriptorToAclActivationLayerInfo(const ActivationDescriptor & actDesc)85*89c4ff92SAndroid Build Coastguard Worker ConvertActivationDescriptorToAclActivationLayerInfo(const ActivationDescriptor& actDesc)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::ActivationLayerInfo(ConvertActivationFunctionToAclActivationFunction(actDesc.m_Function),
88*89c4ff92SAndroid Build Coastguard Worker         actDesc.m_A, actDesc.m_B);
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ActivationLayerInfo
ConvertActivationDescriptorToAclActivationLayerInfo(const ActivationDescriptor * activationDescPtr)92*89c4ff92SAndroid Build Coastguard Worker ConvertActivationDescriptorToAclActivationLayerInfo(const ActivationDescriptor* activationDescPtr)
93*89c4ff92SAndroid Build Coastguard Worker {
94*89c4ff92SAndroid Build Coastguard Worker     if (activationDescPtr != nullptr)
95*89c4ff92SAndroid Build Coastguard Worker     {
96*89c4ff92SAndroid Build Coastguard Worker         return ConvertActivationDescriptorToAclActivationLayerInfo(static_cast<ActivationDescriptor>(
97*89c4ff92SAndroid Build Coastguard Worker                                                                            *activationDescPtr));
98*89c4ff92SAndroid Build Coastguard Worker     }
99*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::ActivationLayerInfo();
100*89c4ff92SAndroid Build Coastguard Worker }
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ActivationLayerInfo
ConvertAdditionalInfoToAclActivationLayerInfo(const QueueDescriptor & queueDescriptor)103*89c4ff92SAndroid Build Coastguard Worker ConvertAdditionalInfoToAclActivationLayerInfo(const QueueDescriptor& queueDescriptor)
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker     const ActivationDescriptor* activationDescPtr = queueDescriptor.GetAdditionalInformation<ActivationDescriptor>();
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker     if (activationDescPtr != nullptr)
108*89c4ff92SAndroid Build Coastguard Worker     {
109*89c4ff92SAndroid Build Coastguard Worker         return ConvertActivationDescriptorToAclActivationLayerInfo(static_cast<ActivationDescriptor>(
110*89c4ff92SAndroid Build Coastguard Worker                 *activationDescPtr));
111*89c4ff92SAndroid Build Coastguard Worker     }
112*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::ActivationLayerInfo();
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ActivationLayerInfo
ConvertLstmActivationFuncToAclLayerInfo(uint32_t activationFunction)116*89c4ff92SAndroid Build Coastguard Worker ConvertLstmActivationFuncToAclLayerInfo(uint32_t activationFunction)
117*89c4ff92SAndroid Build Coastguard Worker {
118*89c4ff92SAndroid Build Coastguard Worker     // For preparing the object for the class ActivationLayerInfo, we need to consider 5 situations.
119*89c4ff92SAndroid Build Coastguard Worker     switch (activationFunction)
120*89c4ff92SAndroid Build Coastguard Worker     {
121*89c4ff92SAndroid Build Coastguard Worker         case 0:
122*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::ActivationLayerInfo(); // no activation, do nothing
123*89c4ff92SAndroid Build Coastguard Worker         case 1:
124*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::ActivationLayerInfo(arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
125*89c4ff92SAndroid Build Coastguard Worker         case 3:
126*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::ActivationLayerInfo(
127*89c4ff92SAndroid Build Coastguard Worker                 arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0);
128*89c4ff92SAndroid Build Coastguard Worker         case 4:
129*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::ActivationLayerInfo(
130*89c4ff92SAndroid Build Coastguard Worker                 arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0);
131*89c4ff92SAndroid Build Coastguard Worker         case 6:
132*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::ActivationLayerInfo(
133*89c4ff92SAndroid Build Coastguard Worker                 arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC);
134*89c4ff92SAndroid Build Coastguard Worker         default:
135*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("Wrong Type of Activation Function!");
136*89c4ff92SAndroid Build Coastguard Worker     }
137*89c4ff92SAndroid Build Coastguard Worker }
138*89c4ff92SAndroid Build Coastguard Worker 
ConvertComparisonOperationToAcl(const ComparisonDescriptor & descriptor)139*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ComparisonOperation ConvertComparisonOperationToAcl(const ComparisonDescriptor& descriptor)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker     switch (descriptor.m_Operation)
142*89c4ff92SAndroid Build Coastguard Worker     {
143*89c4ff92SAndroid Build Coastguard Worker         case ComparisonOperation::Greater:         return arm_compute::ComparisonOperation::Greater;
144*89c4ff92SAndroid Build Coastguard Worker         case ComparisonOperation::GreaterOrEqual:  return arm_compute::ComparisonOperation::GreaterEqual;
145*89c4ff92SAndroid Build Coastguard Worker         case ComparisonOperation::Less:            return arm_compute::ComparisonOperation::Less;
146*89c4ff92SAndroid Build Coastguard Worker         case ComparisonOperation::LessOrEqual:     return arm_compute::ComparisonOperation::LessEqual;
147*89c4ff92SAndroid Build Coastguard Worker         case ComparisonOperation::Equal:           return arm_compute::ComparisonOperation::Equal;
148*89c4ff92SAndroid Build Coastguard Worker         case ComparisonOperation::NotEqual:        return arm_compute::ComparisonOperation::NotEqual;
149*89c4ff92SAndroid Build Coastguard Worker         default:                                   throw InvalidArgumentException("Unsupported comparison function");
150*89c4ff92SAndroid Build Coastguard Worker     }
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker 
ConvertPoolingAlgorithmToAclPoolingType(PoolingAlgorithm poolingAlgorithm)153*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::PoolingType ConvertPoolingAlgorithmToAclPoolingType(PoolingAlgorithm poolingAlgorithm)
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker     using arm_compute::PoolingType;
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker     switch (poolingAlgorithm)
158*89c4ff92SAndroid Build Coastguard Worker     {
159*89c4ff92SAndroid Build Coastguard Worker         case PoolingAlgorithm::Max:             return PoolingType::MAX;
160*89c4ff92SAndroid Build Coastguard Worker         case PoolingAlgorithm::Average:         return PoolingType::AVG;
161*89c4ff92SAndroid Build Coastguard Worker         case PoolingAlgorithm::L2:              return PoolingType::L2;
162*89c4ff92SAndroid Build Coastguard Worker         default:                                throw InvalidArgumentException("Unsupported pooling algorithm");
163*89c4ff92SAndroid Build Coastguard Worker     }
164*89c4ff92SAndroid Build Coastguard Worker }
165*89c4ff92SAndroid Build Coastguard Worker 
ConvertOutputShapeRoundingToAclDimensionRoundingType(OutputShapeRounding rounding)166*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::DimensionRoundingType ConvertOutputShapeRoundingToAclDimensionRoundingType(OutputShapeRounding
167*89c4ff92SAndroid Build Coastguard Worker                                                                                                               rounding)
168*89c4ff92SAndroid Build Coastguard Worker {
169*89c4ff92SAndroid Build Coastguard Worker     using arm_compute::DimensionRoundingType;
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker     switch (rounding)
172*89c4ff92SAndroid Build Coastguard Worker     {
173*89c4ff92SAndroid Build Coastguard Worker         case OutputShapeRounding::Ceiling:  return DimensionRoundingType::CEIL;
174*89c4ff92SAndroid Build Coastguard Worker         case OutputShapeRounding::Floor:    return DimensionRoundingType::FLOOR;
175*89c4ff92SAndroid Build Coastguard Worker         default:                            throw InvalidArgumentException("Unsupported Output Shape Rounding type");
176*89c4ff92SAndroid Build Coastguard Worker     }
177*89c4ff92SAndroid Build Coastguard Worker }
178*89c4ff92SAndroid Build Coastguard Worker 
179*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::NormType
ConvertNormalizationAlgorithmChannelToAclNormType(NormalizationAlgorithmChannel channelType)180*89c4ff92SAndroid Build Coastguard Worker ConvertNormalizationAlgorithmChannelToAclNormType(NormalizationAlgorithmChannel channelType)
181*89c4ff92SAndroid Build Coastguard Worker {
182*89c4ff92SAndroid Build Coastguard Worker     using arm_compute::NormType;
183*89c4ff92SAndroid Build Coastguard Worker     switch (channelType)
184*89c4ff92SAndroid Build Coastguard Worker     {
185*89c4ff92SAndroid Build Coastguard Worker         case NormalizationAlgorithmChannel::Across: return NormType::CROSS_MAP;
186*89c4ff92SAndroid Build Coastguard Worker         case NormalizationAlgorithmChannel::Within: return NormType::IN_MAP_2D;
187*89c4ff92SAndroid Build Coastguard Worker         default:    throw InvalidArgumentException("Unsupported normalization algorithm channel type");
188*89c4ff92SAndroid Build Coastguard Worker     }
189*89c4ff92SAndroid Build Coastguard Worker }
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::FullyConnectedLayerInfo
ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(const FullyConnectedDescriptor & fullyConnectedDesc,const ActivationDescriptor * activationDesc)192*89c4ff92SAndroid Build Coastguard Worker ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(const FullyConnectedDescriptor& fullyConnectedDesc,
193*89c4ff92SAndroid Build Coastguard Worker                                                             const ActivationDescriptor* activationDesc)
194*89c4ff92SAndroid Build Coastguard Worker {
195*89c4ff92SAndroid Build Coastguard Worker     arm_compute::FullyConnectedLayerInfo fc_info;
196*89c4ff92SAndroid Build Coastguard Worker     fc_info.transpose_weights = fullyConnectedDesc.m_TransposeWeightMatrix;
197*89c4ff92SAndroid Build Coastguard Worker     fc_info.activation_info = ConvertActivationDescriptorToAclActivationLayerInfo(activationDesc);
198*89c4ff92SAndroid Build Coastguard Worker     return fc_info;
199*89c4ff92SAndroid Build Coastguard Worker }
200*89c4ff92SAndroid Build Coastguard Worker 
201*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::FullyConnectedLayerInfo
ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(const FullyConnectedDescriptor & fullyConnectedDesc,arm_compute::ActivationLayerInfo activationLayerInfo)202*89c4ff92SAndroid Build Coastguard Worker ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(const FullyConnectedDescriptor& fullyConnectedDesc,
203*89c4ff92SAndroid Build Coastguard Worker         arm_compute::ActivationLayerInfo activationLayerInfo)
204*89c4ff92SAndroid Build Coastguard Worker {
205*89c4ff92SAndroid Build Coastguard Worker     arm_compute::FullyConnectedLayerInfo fc_info;
206*89c4ff92SAndroid Build Coastguard Worker     fc_info.transpose_weights = fullyConnectedDesc.m_TransposeWeightMatrix;
207*89c4ff92SAndroid Build Coastguard Worker     fc_info.activation_info = activationLayerInfo;
208*89c4ff92SAndroid Build Coastguard Worker     return fc_info;
209*89c4ff92SAndroid Build Coastguard Worker }
210*89c4ff92SAndroid Build Coastguard Worker 
ConvertResizeMethodToAclInterpolationPolicy(ResizeMethod resizeMethod)211*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::InterpolationPolicy ConvertResizeMethodToAclInterpolationPolicy(ResizeMethod resizeMethod)
212*89c4ff92SAndroid Build Coastguard Worker {
213*89c4ff92SAndroid Build Coastguard Worker     switch (resizeMethod)
214*89c4ff92SAndroid Build Coastguard Worker     {
215*89c4ff92SAndroid Build Coastguard Worker         case ResizeMethod::Bilinear:
216*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::InterpolationPolicy::BILINEAR;
217*89c4ff92SAndroid Build Coastguard Worker         case ResizeMethod::NearestNeighbor:
218*89c4ff92SAndroid Build Coastguard Worker             return arm_compute::InterpolationPolicy::NEAREST_NEIGHBOR;
219*89c4ff92SAndroid Build Coastguard Worker         default:
220*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("Unsupported resize method");
221*89c4ff92SAndroid Build Coastguard Worker     }
222*89c4ff92SAndroid Build Coastguard Worker }
223*89c4ff92SAndroid Build Coastguard Worker 
224*89c4ff92SAndroid Build Coastguard Worker template<typename T>
ComputeSoftmaxAclAxis(const SoftmaxDescriptor & softmaxDesc,const armnn::TensorInfo & tensor)225*89c4ff92SAndroid Build Coastguard Worker inline T ComputeSoftmaxAclAxis(const SoftmaxDescriptor& softmaxDesc, const armnn::TensorInfo& tensor)
226*89c4ff92SAndroid Build Coastguard Worker {
227*89c4ff92SAndroid Build Coastguard Worker     // Detect the Android default value of -1 and return the ACL default value of 0.
228*89c4ff92SAndroid Build Coastguard Worker     if (softmaxDesc.m_Axis == -1)
229*89c4ff92SAndroid Build Coastguard Worker     {
230*89c4ff92SAndroid Build Coastguard Worker         return 0;
231*89c4ff92SAndroid Build Coastguard Worker     }
232*89c4ff92SAndroid Build Coastguard Worker 
233*89c4ff92SAndroid Build Coastguard Worker     unsigned int dim = tensor.GetNumDimensions();
234*89c4ff92SAndroid Build Coastguard Worker 
235*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(dim != 0);
236*89c4ff92SAndroid Build Coastguard Worker 
237*89c4ff92SAndroid Build Coastguard Worker     // Currently ArmNN support axis 1.
238*89c4ff92SAndroid Build Coastguard Worker     auto aclAxis = (static_cast<T>(dim) - 1);
239*89c4ff92SAndroid Build Coastguard Worker     aclAxis = aclAxis > 0 ? aclAxis -1 : aclAxis;
240*89c4ff92SAndroid Build Coastguard Worker 
241*89c4ff92SAndroid Build Coastguard Worker     return aclAxis;
242*89c4ff92SAndroid Build Coastguard Worker }
243*89c4ff92SAndroid Build Coastguard Worker 
ComputeSplitAxis(const armnn::SplitterDescriptor & desc,const TensorShape & input)244*89c4ff92SAndroid Build Coastguard Worker inline std::set<unsigned int> ComputeSplitAxis(const armnn::SplitterDescriptor& desc, const TensorShape& input)
245*89c4ff92SAndroid Build Coastguard Worker {
246*89c4ff92SAndroid Build Coastguard Worker     unsigned int numSplit = desc.GetNumViews();
247*89c4ff92SAndroid Build Coastguard Worker     unsigned int numDimensions = desc.GetNumDimensions();
248*89c4ff92SAndroid Build Coastguard Worker     std::set<unsigned int> splitAxis;
249*89c4ff92SAndroid Build Coastguard Worker 
250*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < numSplit; ++i)
251*89c4ff92SAndroid Build Coastguard Worker     {
252*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
253*89c4ff92SAndroid Build Coastguard Worker         {
254*89c4ff92SAndroid Build Coastguard Worker             if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
255*89c4ff92SAndroid Build Coastguard Worker             {
256*89c4ff92SAndroid Build Coastguard Worker                 splitAxis.insert(dimIdx);
257*89c4ff92SAndroid Build Coastguard Worker             }
258*89c4ff92SAndroid Build Coastguard Worker         }
259*89c4ff92SAndroid Build Coastguard Worker     }
260*89c4ff92SAndroid Build Coastguard Worker     return splitAxis;
261*89c4ff92SAndroid Build Coastguard Worker }
262*89c4ff92SAndroid Build Coastguard Worker 
263*89c4ff92SAndroid Build Coastguard Worker /// Function to convert ArmNN axis (left to right) to ACL axis (right to left) ranging from [-rank, rank)
ComputeAclAxis(const int & armnnAxis,const armnn::TensorInfo & tensor)264*89c4ff92SAndroid Build Coastguard Worker inline int ComputeAclAxis(const int& armnnAxis, const armnn::TensorInfo& tensor)
265*89c4ff92SAndroid Build Coastguard Worker {
266*89c4ff92SAndroid Build Coastguard Worker     int rank = static_cast<int>(tensor.GetNumDimensions());
267*89c4ff92SAndroid Build Coastguard Worker 
268*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(rank != 0);
269*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT((-1 * rank) <= armnnAxis);
270*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(armnnAxis < rank);
271*89c4ff92SAndroid Build Coastguard Worker 
272*89c4ff92SAndroid Build Coastguard Worker     int sign = (armnnAxis < 0) ? -1 : 1;
273*89c4ff92SAndroid Build Coastguard Worker     int aclAxis = sign * rank - 1  - armnnAxis;
274*89c4ff92SAndroid Build Coastguard Worker 
275*89c4ff92SAndroid Build Coastguard Worker     return aclAxis;
276*89c4ff92SAndroid Build Coastguard Worker }
277*89c4ff92SAndroid Build Coastguard Worker 
278*89c4ff92SAndroid Build Coastguard Worker /// Function to convert axis to its positive equivalent value.
279*89c4ff92SAndroid Build Coastguard Worker /// [-rank, rank) --> [0, rank)
ComputePositiveAxis(const int & axis,const armnn::TensorInfo & tensor)280*89c4ff92SAndroid Build Coastguard Worker inline unsigned int ComputePositiveAxis(const int& axis, const armnn::TensorInfo& tensor)
281*89c4ff92SAndroid Build Coastguard Worker {
282*89c4ff92SAndroid Build Coastguard Worker     int rank = static_cast<int>(tensor.GetNumDimensions());
283*89c4ff92SAndroid Build Coastguard Worker 
284*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(rank != 0);
285*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT((-1 * rank) <= axis);
286*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(axis < rank);
287*89c4ff92SAndroid Build Coastguard Worker 
288*89c4ff92SAndroid Build Coastguard Worker     int positiveAxis = (axis < 0) ? rank + axis : axis;
289*89c4ff92SAndroid Build Coastguard Worker     return static_cast<unsigned int>(positiveAxis);
290*89c4ff92SAndroid Build Coastguard Worker }
291*89c4ff92SAndroid Build Coastguard Worker 
292*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::Conv3dInfo object from convolution3d descriptor.
ComputeConv3DInfo(const armnn::Convolution3dDescriptor descriptor,bool isFastMathEnabled,const ActivationDescriptor * activationDescriptor)293*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::Conv3dInfo ComputeConv3DInfo(const armnn::Convolution3dDescriptor descriptor,
294*89c4ff92SAndroid Build Coastguard Worker                                                  bool isFastMathEnabled,
295*89c4ff92SAndroid Build Coastguard Worker                                                  const ActivationDescriptor* activationDescriptor)
296*89c4ff92SAndroid Build Coastguard Worker {
297*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::Size3D    stride{descriptor.m_StrideX, descriptor.m_StrideY, descriptor.m_StrideZ};
298*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::Padding3D padding{descriptor.m_PadLeft, descriptor.m_PadRight,
299*89c4ff92SAndroid Build Coastguard Worker                                          descriptor.m_PadTop, descriptor.m_PadBottom,
300*89c4ff92SAndroid Build Coastguard Worker                                          descriptor.m_PadFront, descriptor.m_PadBack};
301*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::Size3D    dilation{descriptor.m_DilationX, descriptor.m_DilationY, descriptor.m_DilationZ};
302*89c4ff92SAndroid Build Coastguard Worker 
303*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::ActivationLayerInfo activationInfo =
304*89c4ff92SAndroid Build Coastguard Worker             ConvertActivationDescriptorToAclActivationLayerInfo(activationDescriptor);
305*89c4ff92SAndroid Build Coastguard Worker     const auto roundType = arm_compute::DimensionRoundingType::FLOOR;
306*89c4ff92SAndroid Build Coastguard Worker 
307*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::Conv3dInfo{stride, padding, activationInfo, dilation, roundType, isFastMathEnabled};
308*89c4ff92SAndroid Build Coastguard Worker }
309*89c4ff92SAndroid Build Coastguard Worker 
ComputeConv3DInfo(const armnn::Convolution3dQueueDescriptor queueDescriptor,bool isFastMathEnabled)310*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::Conv3dInfo ComputeConv3DInfo(const armnn::Convolution3dQueueDescriptor queueDescriptor,
311*89c4ff92SAndroid Build Coastguard Worker                                                  bool isFastMathEnabled)
312*89c4ff92SAndroid Build Coastguard Worker {
313*89c4ff92SAndroid Build Coastguard Worker     auto descriptor = queueDescriptor.m_Parameters;
314*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::Size3D    stride{descriptor.m_StrideX, descriptor.m_StrideY, descriptor.m_StrideZ};
315*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::Padding3D padding{descriptor.m_PadLeft, descriptor.m_PadRight,
316*89c4ff92SAndroid Build Coastguard Worker                                          descriptor.m_PadTop, descriptor.m_PadBottom,
317*89c4ff92SAndroid Build Coastguard Worker                                          descriptor.m_PadFront, descriptor.m_PadBack};
318*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::Size3D    dilation{descriptor.m_DilationX, descriptor.m_DilationY, descriptor.m_DilationZ};
319*89c4ff92SAndroid Build Coastguard Worker 
320*89c4ff92SAndroid Build Coastguard Worker     const arm_compute::ActivationLayerInfo activationInfo =
321*89c4ff92SAndroid Build Coastguard Worker             ConvertAdditionalInfoToAclActivationLayerInfo(queueDescriptor);
322*89c4ff92SAndroid Build Coastguard Worker     const auto roundType = arm_compute::DimensionRoundingType::FLOOR;
323*89c4ff92SAndroid Build Coastguard Worker 
324*89c4ff92SAndroid Build Coastguard Worker     return arm_compute::Conv3dInfo{stride, padding, activationInfo, dilation, roundType, isFastMathEnabled};
325*89c4ff92SAndroid Build Coastguard Worker }
326*89c4ff92SAndroid Build Coastguard Worker 
ConvertPaddingModeToAcl(const PaddingMode & paddingMode)327*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::PaddingMode ConvertPaddingModeToAcl(const PaddingMode& paddingMode)
328*89c4ff92SAndroid Build Coastguard Worker {
329*89c4ff92SAndroid Build Coastguard Worker     switch (paddingMode)
330*89c4ff92SAndroid Build Coastguard Worker     {
331*89c4ff92SAndroid Build Coastguard Worker         case PaddingMode::Constant:   return arm_compute::PaddingMode::CONSTANT;
332*89c4ff92SAndroid Build Coastguard Worker         case PaddingMode::Reflect:    return arm_compute::PaddingMode::REFLECT;
333*89c4ff92SAndroid Build Coastguard Worker         case PaddingMode::Symmetric:  return arm_compute::PaddingMode::SYMMETRIC;
334*89c4ff92SAndroid Build Coastguard Worker         default:                      throw InvalidArgumentException("Unsupported Padding Mode");
335*89c4ff92SAndroid Build Coastguard Worker     }
336*89c4ff92SAndroid Build Coastguard Worker }
337*89c4ff92SAndroid Build Coastguard Worker 
ConvertReductionOperationToAcl(const ReduceDescriptor & descriptor)338*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ReductionOperation ConvertReductionOperationToAcl(const ReduceDescriptor& descriptor)
339*89c4ff92SAndroid Build Coastguard Worker {
340*89c4ff92SAndroid Build Coastguard Worker     switch (descriptor.m_ReduceOperation)
341*89c4ff92SAndroid Build Coastguard Worker     {
342*89c4ff92SAndroid Build Coastguard Worker         case ReduceOperation::Sum:    return arm_compute::ReductionOperation::SUM;
343*89c4ff92SAndroid Build Coastguard Worker         case ReduceOperation::Mean:   return arm_compute::ReductionOperation::MEAN_SUM;
344*89c4ff92SAndroid Build Coastguard Worker         case ReduceOperation::Max:    return arm_compute::ReductionOperation::MAX;
345*89c4ff92SAndroid Build Coastguard Worker         case ReduceOperation::Min:    return arm_compute::ReductionOperation::MIN;
346*89c4ff92SAndroid Build Coastguard Worker         case ReduceOperation::Prod:   return arm_compute::ReductionOperation::PROD;
347*89c4ff92SAndroid Build Coastguard Worker         default:                      throw InvalidArgumentException("Unsupported Reduction operation");
348*89c4ff92SAndroid Build Coastguard Worker     }
349*89c4ff92SAndroid Build Coastguard Worker }
350*89c4ff92SAndroid Build Coastguard Worker 
351*89c4ff92SAndroid Build Coastguard Worker /// Function to compute the output tensor shape based on the axes and if keepDims is set.
ComputeReductionTensorShape(const armnn::TensorInfo & input,const std::vector<uint32_t> & vAxis,const bool keepDims)352*89c4ff92SAndroid Build Coastguard Worker inline const TensorInfo ComputeReductionTensorShape(const armnn::TensorInfo& input,
353*89c4ff92SAndroid Build Coastguard Worker                                                     const std::vector<uint32_t>& vAxis,
354*89c4ff92SAndroid Build Coastguard Worker                                                     const bool keepDims)
355*89c4ff92SAndroid Build Coastguard Worker {
356*89c4ff92SAndroid Build Coastguard Worker     auto reducedTensorInfo = input;
357*89c4ff92SAndroid Build Coastguard Worker     unsigned int rank = reducedTensorInfo.GetNumDimensions();
358*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputRank = 0;
359*89c4ff92SAndroid Build Coastguard Worker     // Calculate output dimension
360*89c4ff92SAndroid Build Coastguard Worker     if (keepDims)
361*89c4ff92SAndroid Build Coastguard Worker     {
362*89c4ff92SAndroid Build Coastguard Worker         outputRank = rank;
363*89c4ff92SAndroid Build Coastguard Worker     }
364*89c4ff92SAndroid Build Coastguard Worker     else if (vAxis.empty())
365*89c4ff92SAndroid Build Coastguard Worker     {
366*89c4ff92SAndroid Build Coastguard Worker         outputRank = 1;
367*89c4ff92SAndroid Build Coastguard Worker     }
368*89c4ff92SAndroid Build Coastguard Worker     else if (vAxis.size() > reducedTensorInfo.GetNumDimensions())
369*89c4ff92SAndroid Build Coastguard Worker     {
370*89c4ff92SAndroid Build Coastguard Worker         throw LayerValidationException("ReduceLayer: Dimensions to reduce can not be bigger than input dimensions");
371*89c4ff92SAndroid Build Coastguard Worker     }
372*89c4ff92SAndroid Build Coastguard Worker     else
373*89c4ff92SAndroid Build Coastguard Worker     {
374*89c4ff92SAndroid Build Coastguard Worker         outputRank = reducedTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(vAxis.size());
375*89c4ff92SAndroid Build Coastguard Worker         if (outputRank == 0)
376*89c4ff92SAndroid Build Coastguard Worker         {
377*89c4ff92SAndroid Build Coastguard Worker             outputRank = 1;
378*89c4ff92SAndroid Build Coastguard Worker         }
379*89c4ff92SAndroid Build Coastguard Worker     }
380*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> dimSizes(outputRank, 1);
381*89c4ff92SAndroid Build Coastguard Worker     if (!vAxis.empty())
382*89c4ff92SAndroid Build Coastguard Worker     {
383*89c4ff92SAndroid Build Coastguard Worker         // Skip the dimension that has been reduced unless keepDims is true.
384*89c4ff92SAndroid Build Coastguard Worker         unsigned int outputIndex = 0;
385*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < reducedTensorInfo.GetNumDimensions(); ++i)
386*89c4ff92SAndroid Build Coastguard Worker         {
387*89c4ff92SAndroid Build Coastguard Worker             if (std::find(vAxis.begin(), vAxis.end(), i) == vAxis.end())
388*89c4ff92SAndroid Build Coastguard Worker             {
389*89c4ff92SAndroid Build Coastguard Worker                 dimSizes[outputIndex] = armnn::numeric_cast<unsigned int>(reducedTensorInfo.GetShape()[i]);
390*89c4ff92SAndroid Build Coastguard Worker                 ++outputIndex;
391*89c4ff92SAndroid Build Coastguard Worker             }
392*89c4ff92SAndroid Build Coastguard Worker             else if (keepDims)
393*89c4ff92SAndroid Build Coastguard Worker             {
394*89c4ff92SAndroid Build Coastguard Worker                 dimSizes[outputIndex] = 1;
395*89c4ff92SAndroid Build Coastguard Worker                 ++outputIndex;
396*89c4ff92SAndroid Build Coastguard Worker             }
397*89c4ff92SAndroid Build Coastguard Worker         }
398*89c4ff92SAndroid Build Coastguard Worker     }
399*89c4ff92SAndroid Build Coastguard Worker     const TensorShape inferredShape = TensorShape(outputRank, dimSizes.data());
400*89c4ff92SAndroid Build Coastguard Worker     reducedTensorInfo.SetShape(inferredShape);
401*89c4ff92SAndroid Build Coastguard Worker     return reducedTensorInfo;
402*89c4ff92SAndroid Build Coastguard Worker }
403*89c4ff92SAndroid Build Coastguard Worker 
404*89c4ff92SAndroid Build Coastguard Worker /// Macro function check if layer with multiple axes is supported on each backend
405*89c4ff92SAndroid Build Coastguard Worker #define IS_MULTI_AXES_REDUCE_SUPPORTED(func, input, desc, status)                 \
406*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = input;                                    \
407*89c4ff92SAndroid Build Coastguard Worker     unsigned int recalulatedAxis = 0;                                             \
408*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint32_t> axes;                                                   \
409*89c4ff92SAndroid Build Coastguard Worker                                                                                   \
410*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i != desc.m_vAxis.size(); ++i)                       \
411*89c4ff92SAndroid Build Coastguard Worker     {                                                                             \
412*89c4ff92SAndroid Build Coastguard Worker         axes.emplace_back(desc.m_vAxis[i]);                                       \
413*89c4ff92SAndroid Build Coastguard Worker                                                                                   \
414*89c4ff92SAndroid Build Coastguard Worker         const armnn::TensorInfo& reducedTensorInfo =                              \
415*89c4ff92SAndroid Build Coastguard Worker             ComputeReductionTensorShape(input, axes, desc.m_KeepDims);            \
416*89c4ff92SAndroid Build Coastguard Worker                                                                                   \
417*89c4ff92SAndroid Build Coastguard Worker         std::vector<uint32_t> singleAxis(1, desc.m_vAxis[i] - recalulatedAxis);   \
418*89c4ff92SAndroid Build Coastguard Worker                                                                                   \
419*89c4ff92SAndroid Build Coastguard Worker         armnn::ReduceDescriptor newReduceDescriptor = desc;                       \
420*89c4ff92SAndroid Build Coastguard Worker         newReduceDescriptor.m_vAxis.assign(singleAxis.begin(), singleAxis.end()); \
421*89c4ff92SAndroid Build Coastguard Worker                                                                                   \
422*89c4ff92SAndroid Build Coastguard Worker         status = func(inputTensorInfo, reducedTensorInfo, newReduceDescriptor);   \
423*89c4ff92SAndroid Build Coastguard Worker         if (!status)                                                              \
424*89c4ff92SAndroid Build Coastguard Worker         {                                                                         \
425*89c4ff92SAndroid Build Coastguard Worker             break;                                                                \
426*89c4ff92SAndroid Build Coastguard Worker         }                                                                         \
427*89c4ff92SAndroid Build Coastguard Worker                                                                                   \
428*89c4ff92SAndroid Build Coastguard Worker         if (!desc.m_KeepDims)                                                     \
429*89c4ff92SAndroid Build Coastguard Worker         {                                                                         \
430*89c4ff92SAndroid Build Coastguard Worker             recalulatedAxis++;                                                    \
431*89c4ff92SAndroid Build Coastguard Worker         }                                                                         \
432*89c4ff92SAndroid Build Coastguard Worker                                                                                   \
433*89c4ff92SAndroid Build Coastguard Worker         inputTensorInfo = reducedTensorInfo;                                      \
434*89c4ff92SAndroid Build Coastguard Worker     }
435*89c4ff92SAndroid Build Coastguard Worker 
436*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
437