1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadData.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/Types.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/FunctionDescriptors.h>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
17*89c4ff92SAndroid Build Coastguard Worker #include "neon/workloads/NeonReduceWorkload.hpp"
18*89c4ff92SAndroid Build Coastguard Worker #endif
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
21*89c4ff92SAndroid Build Coastguard Worker #include "cl/workloads/ClReduceWorkload.hpp"
22*89c4ff92SAndroid Build Coastguard Worker #endif
23*89c4ff92SAndroid Build Coastguard Worker
24*89c4ff92SAndroid Build Coastguard Worker namespace armnn
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker
27*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::NormalizationLayerInfo
CreateAclNormalizationLayerInfoForL2Normalization(const armnn::TensorInfo & tensorInfo,armnn::DataLayout dataLayout)28*89c4ff92SAndroid Build Coastguard Worker CreateAclNormalizationLayerInfoForL2Normalization(const armnn::TensorInfo& tensorInfo,
29*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout dataLayout)
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker unsigned int depthDimension = dataLayout == armnn::DataLayout::NCHW ? 1 : 3;
32*89c4ff92SAndroid Build Coastguard Worker const unsigned int depth = tensorInfo.GetShape()[depthDimension];
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker // At the time of writing, {CL|Neon}L2Normalization performs the reduction only along dimension 0. This version of
35*89c4ff92SAndroid Build Coastguard Worker // L2 Normalization always performs the reduction along the depth axis, though. Thus, we repurpose
36*89c4ff92SAndroid Build Coastguard Worker // {CL|Neon}NormalizationLayers to act as depthwise L2 normalizations by carefully chosing the normalization
37*89c4ff92SAndroid Build Coastguard Worker // parameters.
38*89c4ff92SAndroid Build Coastguard Worker //
39*89c4ff92SAndroid Build Coastguard Worker // Please refer to both the reference implementation of the normalization layer and the implementation of
40*89c4ff92SAndroid Build Coastguard Worker // {CL|Neon}NormalizationLayer when checking the derivations for the parameter values below.
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker // Make sure normalization covers the entire depth range. ACL requires the normalization size to be odd.
43*89c4ff92SAndroid Build Coastguard Worker // CL: This does not result in extra kernel threads not doing any work: See usage of the RADIUS parameter in
44*89c4ff92SAndroid Build Coastguard Worker // ACL's normalization_layer_cross_map() CL function.
45*89c4ff92SAndroid Build Coastguard Worker const uint32_t normSize = depth * 2u + 1u;
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker // See ACL's NormalizationLayerInfo::scale_coeff() definition.
48*89c4ff92SAndroid Build Coastguard Worker // For the reference implementation, to make alpha_ become 1, we'd have to use alpha = normSize instead.
49*89c4ff92SAndroid Build Coastguard Worker const float alpha = 1.0f;
50*89c4ff92SAndroid Build Coastguard Worker
51*89c4ff92SAndroid Build Coastguard Worker // Don't offset the reduction.
52*89c4ff92SAndroid Build Coastguard Worker const float kappa = 0.0f;
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker // pow(reduction, -0.5) = 1 / sqrt(reduction)
55*89c4ff92SAndroid Build Coastguard Worker const float beta = 0.5f;
56*89c4ff92SAndroid Build Coastguard Worker
57*89c4ff92SAndroid Build Coastguard Worker return arm_compute::NormalizationLayerInfo(arm_compute::NormType::CROSS_MAP, normSize, alpha, beta, kappa, false);
58*89c4ff92SAndroid Build Coastguard Worker }
59*89c4ff92SAndroid Build Coastguard Worker
60*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ActivationLayerInfo::ActivationFunction
ConvertActivationFunctionToAclActivationFunction(ActivationFunction armnnFunction)61*89c4ff92SAndroid Build Coastguard Worker ConvertActivationFunctionToAclActivationFunction(ActivationFunction armnnFunction)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker using AclActivationFunction = arm_compute::ActivationLayerInfo::ActivationFunction;
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker switch (armnnFunction)
66*89c4ff92SAndroid Build Coastguard Worker {
67*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Linear: return AclActivationFunction::LINEAR;
68*89c4ff92SAndroid Build Coastguard Worker // Arm compute's 'logistic' function is non-parameterized, so it is exactly a sigmoid function.
69*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Sigmoid: return AclActivationFunction::LOGISTIC;
70*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::ReLu: return AclActivationFunction::RELU;
71*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::BoundedReLu: return AclActivationFunction::LU_BOUNDED_RELU;
72*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::SoftReLu: return AclActivationFunction::SOFT_RELU;
73*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::LeakyReLu: return AclActivationFunction::LEAKY_RELU;
74*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Abs: return AclActivationFunction::ABS;
75*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Sqrt: return AclActivationFunction::SQRT;
76*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Square: return AclActivationFunction::SQUARE;
77*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::TanH: return AclActivationFunction::TANH;
78*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::Elu: return AclActivationFunction::ELU;
79*89c4ff92SAndroid Build Coastguard Worker case ActivationFunction::HardSwish: return AclActivationFunction::HARD_SWISH;
80*89c4ff92SAndroid Build Coastguard Worker default: throw InvalidArgumentException("Unsupported activation function");
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker
84*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ActivationLayerInfo
ConvertActivationDescriptorToAclActivationLayerInfo(const ActivationDescriptor & actDesc)85*89c4ff92SAndroid Build Coastguard Worker ConvertActivationDescriptorToAclActivationLayerInfo(const ActivationDescriptor& actDesc)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker return arm_compute::ActivationLayerInfo(ConvertActivationFunctionToAclActivationFunction(actDesc.m_Function),
88*89c4ff92SAndroid Build Coastguard Worker actDesc.m_A, actDesc.m_B);
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ActivationLayerInfo
ConvertActivationDescriptorToAclActivationLayerInfo(const ActivationDescriptor * activationDescPtr)92*89c4ff92SAndroid Build Coastguard Worker ConvertActivationDescriptorToAclActivationLayerInfo(const ActivationDescriptor* activationDescPtr)
93*89c4ff92SAndroid Build Coastguard Worker {
94*89c4ff92SAndroid Build Coastguard Worker if (activationDescPtr != nullptr)
95*89c4ff92SAndroid Build Coastguard Worker {
96*89c4ff92SAndroid Build Coastguard Worker return ConvertActivationDescriptorToAclActivationLayerInfo(static_cast<ActivationDescriptor>(
97*89c4ff92SAndroid Build Coastguard Worker *activationDescPtr));
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker return arm_compute::ActivationLayerInfo();
100*89c4ff92SAndroid Build Coastguard Worker }
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ActivationLayerInfo
ConvertAdditionalInfoToAclActivationLayerInfo(const QueueDescriptor & queueDescriptor)103*89c4ff92SAndroid Build Coastguard Worker ConvertAdditionalInfoToAclActivationLayerInfo(const QueueDescriptor& queueDescriptor)
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor* activationDescPtr = queueDescriptor.GetAdditionalInformation<ActivationDescriptor>();
106*89c4ff92SAndroid Build Coastguard Worker
107*89c4ff92SAndroid Build Coastguard Worker if (activationDescPtr != nullptr)
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker return ConvertActivationDescriptorToAclActivationLayerInfo(static_cast<ActivationDescriptor>(
110*89c4ff92SAndroid Build Coastguard Worker *activationDescPtr));
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker return arm_compute::ActivationLayerInfo();
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker
115*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ActivationLayerInfo
ConvertLstmActivationFuncToAclLayerInfo(uint32_t activationFunction)116*89c4ff92SAndroid Build Coastguard Worker ConvertLstmActivationFuncToAclLayerInfo(uint32_t activationFunction)
117*89c4ff92SAndroid Build Coastguard Worker {
118*89c4ff92SAndroid Build Coastguard Worker // For preparing the object for the class ActivationLayerInfo, we need to consider 5 situations.
119*89c4ff92SAndroid Build Coastguard Worker switch (activationFunction)
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker case 0:
122*89c4ff92SAndroid Build Coastguard Worker return arm_compute::ActivationLayerInfo(); // no activation, do nothing
123*89c4ff92SAndroid Build Coastguard Worker case 1:
124*89c4ff92SAndroid Build Coastguard Worker return arm_compute::ActivationLayerInfo(arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
125*89c4ff92SAndroid Build Coastguard Worker case 3:
126*89c4ff92SAndroid Build Coastguard Worker return arm_compute::ActivationLayerInfo(
127*89c4ff92SAndroid Build Coastguard Worker arm_compute::ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, 6.0);
128*89c4ff92SAndroid Build Coastguard Worker case 4:
129*89c4ff92SAndroid Build Coastguard Worker return arm_compute::ActivationLayerInfo(
130*89c4ff92SAndroid Build Coastguard Worker arm_compute::ActivationLayerInfo::ActivationFunction::TANH, 1.0, 1.0);
131*89c4ff92SAndroid Build Coastguard Worker case 6:
132*89c4ff92SAndroid Build Coastguard Worker return arm_compute::ActivationLayerInfo(
133*89c4ff92SAndroid Build Coastguard Worker arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC);
134*89c4ff92SAndroid Build Coastguard Worker default:
135*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("Wrong Type of Activation Function!");
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker }
138*89c4ff92SAndroid Build Coastguard Worker
ConvertComparisonOperationToAcl(const ComparisonDescriptor & descriptor)139*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ComparisonOperation ConvertComparisonOperationToAcl(const ComparisonDescriptor& descriptor)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker switch (descriptor.m_Operation)
142*89c4ff92SAndroid Build Coastguard Worker {
143*89c4ff92SAndroid Build Coastguard Worker case ComparisonOperation::Greater: return arm_compute::ComparisonOperation::Greater;
144*89c4ff92SAndroid Build Coastguard Worker case ComparisonOperation::GreaterOrEqual: return arm_compute::ComparisonOperation::GreaterEqual;
145*89c4ff92SAndroid Build Coastguard Worker case ComparisonOperation::Less: return arm_compute::ComparisonOperation::Less;
146*89c4ff92SAndroid Build Coastguard Worker case ComparisonOperation::LessOrEqual: return arm_compute::ComparisonOperation::LessEqual;
147*89c4ff92SAndroid Build Coastguard Worker case ComparisonOperation::Equal: return arm_compute::ComparisonOperation::Equal;
148*89c4ff92SAndroid Build Coastguard Worker case ComparisonOperation::NotEqual: return arm_compute::ComparisonOperation::NotEqual;
149*89c4ff92SAndroid Build Coastguard Worker default: throw InvalidArgumentException("Unsupported comparison function");
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker
ConvertPoolingAlgorithmToAclPoolingType(PoolingAlgorithm poolingAlgorithm)153*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::PoolingType ConvertPoolingAlgorithmToAclPoolingType(PoolingAlgorithm poolingAlgorithm)
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker using arm_compute::PoolingType;
156*89c4ff92SAndroid Build Coastguard Worker
157*89c4ff92SAndroid Build Coastguard Worker switch (poolingAlgorithm)
158*89c4ff92SAndroid Build Coastguard Worker {
159*89c4ff92SAndroid Build Coastguard Worker case PoolingAlgorithm::Max: return PoolingType::MAX;
160*89c4ff92SAndroid Build Coastguard Worker case PoolingAlgorithm::Average: return PoolingType::AVG;
161*89c4ff92SAndroid Build Coastguard Worker case PoolingAlgorithm::L2: return PoolingType::L2;
162*89c4ff92SAndroid Build Coastguard Worker default: throw InvalidArgumentException("Unsupported pooling algorithm");
163*89c4ff92SAndroid Build Coastguard Worker }
164*89c4ff92SAndroid Build Coastguard Worker }
165*89c4ff92SAndroid Build Coastguard Worker
ConvertOutputShapeRoundingToAclDimensionRoundingType(OutputShapeRounding rounding)166*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::DimensionRoundingType ConvertOutputShapeRoundingToAclDimensionRoundingType(OutputShapeRounding
167*89c4ff92SAndroid Build Coastguard Worker rounding)
168*89c4ff92SAndroid Build Coastguard Worker {
169*89c4ff92SAndroid Build Coastguard Worker using arm_compute::DimensionRoundingType;
170*89c4ff92SAndroid Build Coastguard Worker
171*89c4ff92SAndroid Build Coastguard Worker switch (rounding)
172*89c4ff92SAndroid Build Coastguard Worker {
173*89c4ff92SAndroid Build Coastguard Worker case OutputShapeRounding::Ceiling: return DimensionRoundingType::CEIL;
174*89c4ff92SAndroid Build Coastguard Worker case OutputShapeRounding::Floor: return DimensionRoundingType::FLOOR;
175*89c4ff92SAndroid Build Coastguard Worker default: throw InvalidArgumentException("Unsupported Output Shape Rounding type");
176*89c4ff92SAndroid Build Coastguard Worker }
177*89c4ff92SAndroid Build Coastguard Worker }
178*89c4ff92SAndroid Build Coastguard Worker
179*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::NormType
ConvertNormalizationAlgorithmChannelToAclNormType(NormalizationAlgorithmChannel channelType)180*89c4ff92SAndroid Build Coastguard Worker ConvertNormalizationAlgorithmChannelToAclNormType(NormalizationAlgorithmChannel channelType)
181*89c4ff92SAndroid Build Coastguard Worker {
182*89c4ff92SAndroid Build Coastguard Worker using arm_compute::NormType;
183*89c4ff92SAndroid Build Coastguard Worker switch (channelType)
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker case NormalizationAlgorithmChannel::Across: return NormType::CROSS_MAP;
186*89c4ff92SAndroid Build Coastguard Worker case NormalizationAlgorithmChannel::Within: return NormType::IN_MAP_2D;
187*89c4ff92SAndroid Build Coastguard Worker default: throw InvalidArgumentException("Unsupported normalization algorithm channel type");
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker }
190*89c4ff92SAndroid Build Coastguard Worker
191*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::FullyConnectedLayerInfo
ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(const FullyConnectedDescriptor & fullyConnectedDesc,const ActivationDescriptor * activationDesc)192*89c4ff92SAndroid Build Coastguard Worker ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(const FullyConnectedDescriptor& fullyConnectedDesc,
193*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor* activationDesc)
194*89c4ff92SAndroid Build Coastguard Worker {
195*89c4ff92SAndroid Build Coastguard Worker arm_compute::FullyConnectedLayerInfo fc_info;
196*89c4ff92SAndroid Build Coastguard Worker fc_info.transpose_weights = fullyConnectedDesc.m_TransposeWeightMatrix;
197*89c4ff92SAndroid Build Coastguard Worker fc_info.activation_info = ConvertActivationDescriptorToAclActivationLayerInfo(activationDesc);
198*89c4ff92SAndroid Build Coastguard Worker return fc_info;
199*89c4ff92SAndroid Build Coastguard Worker }
200*89c4ff92SAndroid Build Coastguard Worker
201*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::FullyConnectedLayerInfo
ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(const FullyConnectedDescriptor & fullyConnectedDesc,arm_compute::ActivationLayerInfo activationLayerInfo)202*89c4ff92SAndroid Build Coastguard Worker ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(const FullyConnectedDescriptor& fullyConnectedDesc,
203*89c4ff92SAndroid Build Coastguard Worker arm_compute::ActivationLayerInfo activationLayerInfo)
204*89c4ff92SAndroid Build Coastguard Worker {
205*89c4ff92SAndroid Build Coastguard Worker arm_compute::FullyConnectedLayerInfo fc_info;
206*89c4ff92SAndroid Build Coastguard Worker fc_info.transpose_weights = fullyConnectedDesc.m_TransposeWeightMatrix;
207*89c4ff92SAndroid Build Coastguard Worker fc_info.activation_info = activationLayerInfo;
208*89c4ff92SAndroid Build Coastguard Worker return fc_info;
209*89c4ff92SAndroid Build Coastguard Worker }
210*89c4ff92SAndroid Build Coastguard Worker
ConvertResizeMethodToAclInterpolationPolicy(ResizeMethod resizeMethod)211*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::InterpolationPolicy ConvertResizeMethodToAclInterpolationPolicy(ResizeMethod resizeMethod)
212*89c4ff92SAndroid Build Coastguard Worker {
213*89c4ff92SAndroid Build Coastguard Worker switch (resizeMethod)
214*89c4ff92SAndroid Build Coastguard Worker {
215*89c4ff92SAndroid Build Coastguard Worker case ResizeMethod::Bilinear:
216*89c4ff92SAndroid Build Coastguard Worker return arm_compute::InterpolationPolicy::BILINEAR;
217*89c4ff92SAndroid Build Coastguard Worker case ResizeMethod::NearestNeighbor:
218*89c4ff92SAndroid Build Coastguard Worker return arm_compute::InterpolationPolicy::NEAREST_NEIGHBOR;
219*89c4ff92SAndroid Build Coastguard Worker default:
220*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Unsupported resize method");
221*89c4ff92SAndroid Build Coastguard Worker }
222*89c4ff92SAndroid Build Coastguard Worker }
223*89c4ff92SAndroid Build Coastguard Worker
224*89c4ff92SAndroid Build Coastguard Worker template<typename T>
ComputeSoftmaxAclAxis(const SoftmaxDescriptor & softmaxDesc,const armnn::TensorInfo & tensor)225*89c4ff92SAndroid Build Coastguard Worker inline T ComputeSoftmaxAclAxis(const SoftmaxDescriptor& softmaxDesc, const armnn::TensorInfo& tensor)
226*89c4ff92SAndroid Build Coastguard Worker {
227*89c4ff92SAndroid Build Coastguard Worker // Detect the Android default value of -1 and return the ACL default value of 0.
228*89c4ff92SAndroid Build Coastguard Worker if (softmaxDesc.m_Axis == -1)
229*89c4ff92SAndroid Build Coastguard Worker {
230*89c4ff92SAndroid Build Coastguard Worker return 0;
231*89c4ff92SAndroid Build Coastguard Worker }
232*89c4ff92SAndroid Build Coastguard Worker
233*89c4ff92SAndroid Build Coastguard Worker unsigned int dim = tensor.GetNumDimensions();
234*89c4ff92SAndroid Build Coastguard Worker
235*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(dim != 0);
236*89c4ff92SAndroid Build Coastguard Worker
237*89c4ff92SAndroid Build Coastguard Worker // Currently ArmNN support axis 1.
238*89c4ff92SAndroid Build Coastguard Worker auto aclAxis = (static_cast<T>(dim) - 1);
239*89c4ff92SAndroid Build Coastguard Worker aclAxis = aclAxis > 0 ? aclAxis -1 : aclAxis;
240*89c4ff92SAndroid Build Coastguard Worker
241*89c4ff92SAndroid Build Coastguard Worker return aclAxis;
242*89c4ff92SAndroid Build Coastguard Worker }
243*89c4ff92SAndroid Build Coastguard Worker
ComputeSplitAxis(const armnn::SplitterDescriptor & desc,const TensorShape & input)244*89c4ff92SAndroid Build Coastguard Worker inline std::set<unsigned int> ComputeSplitAxis(const armnn::SplitterDescriptor& desc, const TensorShape& input)
245*89c4ff92SAndroid Build Coastguard Worker {
246*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplit = desc.GetNumViews();
247*89c4ff92SAndroid Build Coastguard Worker unsigned int numDimensions = desc.GetNumDimensions();
248*89c4ff92SAndroid Build Coastguard Worker std::set<unsigned int> splitAxis;
249*89c4ff92SAndroid Build Coastguard Worker
250*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numSplit; ++i)
251*89c4ff92SAndroid Build Coastguard Worker {
252*89c4ff92SAndroid Build Coastguard Worker for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
253*89c4ff92SAndroid Build Coastguard Worker {
254*89c4ff92SAndroid Build Coastguard Worker if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
255*89c4ff92SAndroid Build Coastguard Worker {
256*89c4ff92SAndroid Build Coastguard Worker splitAxis.insert(dimIdx);
257*89c4ff92SAndroid Build Coastguard Worker }
258*89c4ff92SAndroid Build Coastguard Worker }
259*89c4ff92SAndroid Build Coastguard Worker }
260*89c4ff92SAndroid Build Coastguard Worker return splitAxis;
261*89c4ff92SAndroid Build Coastguard Worker }
262*89c4ff92SAndroid Build Coastguard Worker
263*89c4ff92SAndroid Build Coastguard Worker /// Function to convert ArmNN axis (left to right) to ACL axis (right to left) ranging from [-rank, rank)
ComputeAclAxis(const int & armnnAxis,const armnn::TensorInfo & tensor)264*89c4ff92SAndroid Build Coastguard Worker inline int ComputeAclAxis(const int& armnnAxis, const armnn::TensorInfo& tensor)
265*89c4ff92SAndroid Build Coastguard Worker {
266*89c4ff92SAndroid Build Coastguard Worker int rank = static_cast<int>(tensor.GetNumDimensions());
267*89c4ff92SAndroid Build Coastguard Worker
268*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(rank != 0);
269*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT((-1 * rank) <= armnnAxis);
270*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(armnnAxis < rank);
271*89c4ff92SAndroid Build Coastguard Worker
272*89c4ff92SAndroid Build Coastguard Worker int sign = (armnnAxis < 0) ? -1 : 1;
273*89c4ff92SAndroid Build Coastguard Worker int aclAxis = sign * rank - 1 - armnnAxis;
274*89c4ff92SAndroid Build Coastguard Worker
275*89c4ff92SAndroid Build Coastguard Worker return aclAxis;
276*89c4ff92SAndroid Build Coastguard Worker }
277*89c4ff92SAndroid Build Coastguard Worker
278*89c4ff92SAndroid Build Coastguard Worker /// Function to convert axis to its positive equivalent value.
279*89c4ff92SAndroid Build Coastguard Worker /// [-rank, rank) --> [0, rank)
ComputePositiveAxis(const int & axis,const armnn::TensorInfo & tensor)280*89c4ff92SAndroid Build Coastguard Worker inline unsigned int ComputePositiveAxis(const int& axis, const armnn::TensorInfo& tensor)
281*89c4ff92SAndroid Build Coastguard Worker {
282*89c4ff92SAndroid Build Coastguard Worker int rank = static_cast<int>(tensor.GetNumDimensions());
283*89c4ff92SAndroid Build Coastguard Worker
284*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(rank != 0);
285*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT((-1 * rank) <= axis);
286*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(axis < rank);
287*89c4ff92SAndroid Build Coastguard Worker
288*89c4ff92SAndroid Build Coastguard Worker int positiveAxis = (axis < 0) ? rank + axis : axis;
289*89c4ff92SAndroid Build Coastguard Worker return static_cast<unsigned int>(positiveAxis);
290*89c4ff92SAndroid Build Coastguard Worker }
291*89c4ff92SAndroid Build Coastguard Worker
292*89c4ff92SAndroid Build Coastguard Worker /// Utility function used to setup an arm_compute::Conv3dInfo object from convolution3d descriptor.
ComputeConv3DInfo(const armnn::Convolution3dDescriptor descriptor,bool isFastMathEnabled,const ActivationDescriptor * activationDescriptor)293*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::Conv3dInfo ComputeConv3DInfo(const armnn::Convolution3dDescriptor descriptor,
294*89c4ff92SAndroid Build Coastguard Worker bool isFastMathEnabled,
295*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor* activationDescriptor)
296*89c4ff92SAndroid Build Coastguard Worker {
297*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Size3D stride{descriptor.m_StrideX, descriptor.m_StrideY, descriptor.m_StrideZ};
298*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Padding3D padding{descriptor.m_PadLeft, descriptor.m_PadRight,
299*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop, descriptor.m_PadBottom,
300*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadFront, descriptor.m_PadBack};
301*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Size3D dilation{descriptor.m_DilationX, descriptor.m_DilationY, descriptor.m_DilationZ};
302*89c4ff92SAndroid Build Coastguard Worker
303*89c4ff92SAndroid Build Coastguard Worker const arm_compute::ActivationLayerInfo activationInfo =
304*89c4ff92SAndroid Build Coastguard Worker ConvertActivationDescriptorToAclActivationLayerInfo(activationDescriptor);
305*89c4ff92SAndroid Build Coastguard Worker const auto roundType = arm_compute::DimensionRoundingType::FLOOR;
306*89c4ff92SAndroid Build Coastguard Worker
307*89c4ff92SAndroid Build Coastguard Worker return arm_compute::Conv3dInfo{stride, padding, activationInfo, dilation, roundType, isFastMathEnabled};
308*89c4ff92SAndroid Build Coastguard Worker }
309*89c4ff92SAndroid Build Coastguard Worker
ComputeConv3DInfo(const armnn::Convolution3dQueueDescriptor queueDescriptor,bool isFastMathEnabled)310*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::Conv3dInfo ComputeConv3DInfo(const armnn::Convolution3dQueueDescriptor queueDescriptor,
311*89c4ff92SAndroid Build Coastguard Worker bool isFastMathEnabled)
312*89c4ff92SAndroid Build Coastguard Worker {
313*89c4ff92SAndroid Build Coastguard Worker auto descriptor = queueDescriptor.m_Parameters;
314*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Size3D stride{descriptor.m_StrideX, descriptor.m_StrideY, descriptor.m_StrideZ};
315*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Padding3D padding{descriptor.m_PadLeft, descriptor.m_PadRight,
316*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop, descriptor.m_PadBottom,
317*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadFront, descriptor.m_PadBack};
318*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Size3D dilation{descriptor.m_DilationX, descriptor.m_DilationY, descriptor.m_DilationZ};
319*89c4ff92SAndroid Build Coastguard Worker
320*89c4ff92SAndroid Build Coastguard Worker const arm_compute::ActivationLayerInfo activationInfo =
321*89c4ff92SAndroid Build Coastguard Worker ConvertAdditionalInfoToAclActivationLayerInfo(queueDescriptor);
322*89c4ff92SAndroid Build Coastguard Worker const auto roundType = arm_compute::DimensionRoundingType::FLOOR;
323*89c4ff92SAndroid Build Coastguard Worker
324*89c4ff92SAndroid Build Coastguard Worker return arm_compute::Conv3dInfo{stride, padding, activationInfo, dilation, roundType, isFastMathEnabled};
325*89c4ff92SAndroid Build Coastguard Worker }
326*89c4ff92SAndroid Build Coastguard Worker
ConvertPaddingModeToAcl(const PaddingMode & paddingMode)327*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::PaddingMode ConvertPaddingModeToAcl(const PaddingMode& paddingMode)
328*89c4ff92SAndroid Build Coastguard Worker {
329*89c4ff92SAndroid Build Coastguard Worker switch (paddingMode)
330*89c4ff92SAndroid Build Coastguard Worker {
331*89c4ff92SAndroid Build Coastguard Worker case PaddingMode::Constant: return arm_compute::PaddingMode::CONSTANT;
332*89c4ff92SAndroid Build Coastguard Worker case PaddingMode::Reflect: return arm_compute::PaddingMode::REFLECT;
333*89c4ff92SAndroid Build Coastguard Worker case PaddingMode::Symmetric: return arm_compute::PaddingMode::SYMMETRIC;
334*89c4ff92SAndroid Build Coastguard Worker default: throw InvalidArgumentException("Unsupported Padding Mode");
335*89c4ff92SAndroid Build Coastguard Worker }
336*89c4ff92SAndroid Build Coastguard Worker }
337*89c4ff92SAndroid Build Coastguard Worker
ConvertReductionOperationToAcl(const ReduceDescriptor & descriptor)338*89c4ff92SAndroid Build Coastguard Worker inline arm_compute::ReductionOperation ConvertReductionOperationToAcl(const ReduceDescriptor& descriptor)
339*89c4ff92SAndroid Build Coastguard Worker {
340*89c4ff92SAndroid Build Coastguard Worker switch (descriptor.m_ReduceOperation)
341*89c4ff92SAndroid Build Coastguard Worker {
342*89c4ff92SAndroid Build Coastguard Worker case ReduceOperation::Sum: return arm_compute::ReductionOperation::SUM;
343*89c4ff92SAndroid Build Coastguard Worker case ReduceOperation::Mean: return arm_compute::ReductionOperation::MEAN_SUM;
344*89c4ff92SAndroid Build Coastguard Worker case ReduceOperation::Max: return arm_compute::ReductionOperation::MAX;
345*89c4ff92SAndroid Build Coastguard Worker case ReduceOperation::Min: return arm_compute::ReductionOperation::MIN;
346*89c4ff92SAndroid Build Coastguard Worker case ReduceOperation::Prod: return arm_compute::ReductionOperation::PROD;
347*89c4ff92SAndroid Build Coastguard Worker default: throw InvalidArgumentException("Unsupported Reduction operation");
348*89c4ff92SAndroid Build Coastguard Worker }
349*89c4ff92SAndroid Build Coastguard Worker }
350*89c4ff92SAndroid Build Coastguard Worker
351*89c4ff92SAndroid Build Coastguard Worker /// Function to compute the output tensor shape based on the axes and if keepDims is set.
ComputeReductionTensorShape(const armnn::TensorInfo & input,const std::vector<uint32_t> & vAxis,const bool keepDims)352*89c4ff92SAndroid Build Coastguard Worker inline const TensorInfo ComputeReductionTensorShape(const armnn::TensorInfo& input,
353*89c4ff92SAndroid Build Coastguard Worker const std::vector<uint32_t>& vAxis,
354*89c4ff92SAndroid Build Coastguard Worker const bool keepDims)
355*89c4ff92SAndroid Build Coastguard Worker {
356*89c4ff92SAndroid Build Coastguard Worker auto reducedTensorInfo = input;
357*89c4ff92SAndroid Build Coastguard Worker unsigned int rank = reducedTensorInfo.GetNumDimensions();
358*89c4ff92SAndroid Build Coastguard Worker unsigned int outputRank = 0;
359*89c4ff92SAndroid Build Coastguard Worker // Calculate output dimension
360*89c4ff92SAndroid Build Coastguard Worker if (keepDims)
361*89c4ff92SAndroid Build Coastguard Worker {
362*89c4ff92SAndroid Build Coastguard Worker outputRank = rank;
363*89c4ff92SAndroid Build Coastguard Worker }
364*89c4ff92SAndroid Build Coastguard Worker else if (vAxis.empty())
365*89c4ff92SAndroid Build Coastguard Worker {
366*89c4ff92SAndroid Build Coastguard Worker outputRank = 1;
367*89c4ff92SAndroid Build Coastguard Worker }
368*89c4ff92SAndroid Build Coastguard Worker else if (vAxis.size() > reducedTensorInfo.GetNumDimensions())
369*89c4ff92SAndroid Build Coastguard Worker {
370*89c4ff92SAndroid Build Coastguard Worker throw LayerValidationException("ReduceLayer: Dimensions to reduce can not be bigger than input dimensions");
371*89c4ff92SAndroid Build Coastguard Worker }
372*89c4ff92SAndroid Build Coastguard Worker else
373*89c4ff92SAndroid Build Coastguard Worker {
374*89c4ff92SAndroid Build Coastguard Worker outputRank = reducedTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(vAxis.size());
375*89c4ff92SAndroid Build Coastguard Worker if (outputRank == 0)
376*89c4ff92SAndroid Build Coastguard Worker {
377*89c4ff92SAndroid Build Coastguard Worker outputRank = 1;
378*89c4ff92SAndroid Build Coastguard Worker }
379*89c4ff92SAndroid Build Coastguard Worker }
380*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> dimSizes(outputRank, 1);
381*89c4ff92SAndroid Build Coastguard Worker if (!vAxis.empty())
382*89c4ff92SAndroid Build Coastguard Worker {
383*89c4ff92SAndroid Build Coastguard Worker // Skip the dimension that has been reduced unless keepDims is true.
384*89c4ff92SAndroid Build Coastguard Worker unsigned int outputIndex = 0;
385*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < reducedTensorInfo.GetNumDimensions(); ++i)
386*89c4ff92SAndroid Build Coastguard Worker {
387*89c4ff92SAndroid Build Coastguard Worker if (std::find(vAxis.begin(), vAxis.end(), i) == vAxis.end())
388*89c4ff92SAndroid Build Coastguard Worker {
389*89c4ff92SAndroid Build Coastguard Worker dimSizes[outputIndex] = armnn::numeric_cast<unsigned int>(reducedTensorInfo.GetShape()[i]);
390*89c4ff92SAndroid Build Coastguard Worker ++outputIndex;
391*89c4ff92SAndroid Build Coastguard Worker }
392*89c4ff92SAndroid Build Coastguard Worker else if (keepDims)
393*89c4ff92SAndroid Build Coastguard Worker {
394*89c4ff92SAndroid Build Coastguard Worker dimSizes[outputIndex] = 1;
395*89c4ff92SAndroid Build Coastguard Worker ++outputIndex;
396*89c4ff92SAndroid Build Coastguard Worker }
397*89c4ff92SAndroid Build Coastguard Worker }
398*89c4ff92SAndroid Build Coastguard Worker }
399*89c4ff92SAndroid Build Coastguard Worker const TensorShape inferredShape = TensorShape(outputRank, dimSizes.data());
400*89c4ff92SAndroid Build Coastguard Worker reducedTensorInfo.SetShape(inferredShape);
401*89c4ff92SAndroid Build Coastguard Worker return reducedTensorInfo;
402*89c4ff92SAndroid Build Coastguard Worker }
403*89c4ff92SAndroid Build Coastguard Worker
404*89c4ff92SAndroid Build Coastguard Worker /// Macro function check if layer with multiple axes is supported on each backend
405*89c4ff92SAndroid Build Coastguard Worker #define IS_MULTI_AXES_REDUCE_SUPPORTED(func, input, desc, status) \
406*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = input; \
407*89c4ff92SAndroid Build Coastguard Worker unsigned int recalulatedAxis = 0; \
408*89c4ff92SAndroid Build Coastguard Worker std::vector<uint32_t> axes; \
409*89c4ff92SAndroid Build Coastguard Worker \
410*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i != desc.m_vAxis.size(); ++i) \
411*89c4ff92SAndroid Build Coastguard Worker { \
412*89c4ff92SAndroid Build Coastguard Worker axes.emplace_back(desc.m_vAxis[i]); \
413*89c4ff92SAndroid Build Coastguard Worker \
414*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& reducedTensorInfo = \
415*89c4ff92SAndroid Build Coastguard Worker ComputeReductionTensorShape(input, axes, desc.m_KeepDims); \
416*89c4ff92SAndroid Build Coastguard Worker \
417*89c4ff92SAndroid Build Coastguard Worker std::vector<uint32_t> singleAxis(1, desc.m_vAxis[i] - recalulatedAxis); \
418*89c4ff92SAndroid Build Coastguard Worker \
419*89c4ff92SAndroid Build Coastguard Worker armnn::ReduceDescriptor newReduceDescriptor = desc; \
420*89c4ff92SAndroid Build Coastguard Worker newReduceDescriptor.m_vAxis.assign(singleAxis.begin(), singleAxis.end()); \
421*89c4ff92SAndroid Build Coastguard Worker \
422*89c4ff92SAndroid Build Coastguard Worker status = func(inputTensorInfo, reducedTensorInfo, newReduceDescriptor); \
423*89c4ff92SAndroid Build Coastguard Worker if (!status) \
424*89c4ff92SAndroid Build Coastguard Worker { \
425*89c4ff92SAndroid Build Coastguard Worker break; \
426*89c4ff92SAndroid Build Coastguard Worker } \
427*89c4ff92SAndroid Build Coastguard Worker \
428*89c4ff92SAndroid Build Coastguard Worker if (!desc.m_KeepDims) \
429*89c4ff92SAndroid Build Coastguard Worker { \
430*89c4ff92SAndroid Build Coastguard Worker recalulatedAxis++; \
431*89c4ff92SAndroid Build Coastguard Worker } \
432*89c4ff92SAndroid Build Coastguard Worker \
433*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo = reducedTensorInfo; \
434*89c4ff92SAndroid Build Coastguard Worker }
435*89c4ff92SAndroid Build Coastguard Worker
436*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
437