xref: /aosp_15_r20/external/armnn/src/backends/neon/NeonLayerSupport.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "NeonLayerSupport.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "NeonBackendId.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "NeonBackendModelContext.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Exceptions.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <InternalTypes.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <LayerSupportCommon.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
21*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeUtils.hpp>
22*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorUtils.hpp>
23*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonAbsWorkload.hpp"
24*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonAdditionWorkload.hpp"
25*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonActivationWorkload.hpp"
26*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonArgMinMaxWorkload.hpp"
27*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonBatchMatMulWorkload.hpp"
28*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonBatchNormalizationWorkload.hpp"
29*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonBatchToSpaceNdWorkload.hpp"
30*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonCastWorkload.hpp"
31*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonChannelShuffleWorkload.hpp"
32*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonComparisonWorkload.hpp"
33*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonConcatWorkload.hpp"
34*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonConstantWorkload.hpp"
35*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonConvolution2dWorkload.hpp"
36*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonConvolution3dWorkload.hpp"
37*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonDepthToSpaceWorkload.hpp"
38*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonDepthwiseConvolutionWorkload.hpp"
39*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonDequantizeWorkload.hpp"
40*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonExpWorkload.hpp"
41*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonInstanceNormalizationWorkload.hpp"
42*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonL2NormalizationFloatWorkload.hpp"
43*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonLogWorkload.hpp"
44*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonLogSoftmaxWorkload.hpp"
45*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonLogicalAndWorkload.hpp"
46*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonLogicalNotWorkload.hpp"
47*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonLogicalOrWorkload.hpp"
48*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonLstmFloatWorkload.hpp"
49*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonMaximumWorkload.hpp"
50*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonMeanWorkload.hpp"
51*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonMinimumWorkload.hpp"
52*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonMultiplicationWorkload.hpp"
53*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonDivisionWorkload.hpp"
54*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonNegWorkload.hpp"
55*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonNormalizationFloatWorkload.hpp"
56*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonFullyConnectedWorkload.hpp"
57*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonGatherWorkload.hpp"
58*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonGatherNdWorkload.hpp"
59*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonPadWorkload.hpp"
60*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonPermuteWorkload.hpp"
61*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonPooling2dWorkload.hpp"
62*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonPooling3dWorkload.hpp"
63*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonPreluWorkload.hpp"
64*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonQLstmWorkload.hpp"
65*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonQuantizeWorkload.hpp"
66*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonQuantizedLstmWorkload.hpp"
67*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonReduceWorkload.hpp"
68*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonReshapeWorkload.hpp"
69*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonResizeWorkload.hpp"
70*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonRsqrtWorkload.hpp"
71*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonSinWorkload.hpp"
72*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonSliceWorkload.hpp"
73*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonSoftmaxWorkload.hpp"
74*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonSpaceToBatchNdWorkload.hpp"
75*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonSpaceToDepthWorkload.hpp"
76*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonSplitterWorkload.hpp"
77*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonSqrtWorkload.hpp"
78*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonStackWorkload.hpp"
79*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonStridedSliceWorkload.hpp"
80*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonSubtractionWorkload.hpp"
81*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonTransposeConvolution2dWorkload.hpp"
82*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonTransposeWorkload.hpp"
83*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonUnidirectionalSequenceLstmFloatWorkload.hpp"
84*89c4ff92SAndroid Build Coastguard Worker #include "workloads/NeonUnidirectionalSequenceLstmWorkload.hpp"
85*89c4ff92SAndroid Build Coastguard Worker #endif
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker namespace armnn
88*89c4ff92SAndroid Build Coastguard Worker {
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker namespace
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker 
OverrideDataType(const TensorInfo & info,Optional<DataType> type)93*89c4ff92SAndroid Build Coastguard Worker const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
94*89c4ff92SAndroid Build Coastguard Worker {
95*89c4ff92SAndroid Build Coastguard Worker     if (!type)
96*89c4ff92SAndroid Build Coastguard Worker     {
97*89c4ff92SAndroid Build Coastguard Worker         return info;
98*89c4ff92SAndroid Build Coastguard Worker     }
99*89c4ff92SAndroid Build Coastguard Worker     return TensorInfo(info.GetShape(),
100*89c4ff92SAndroid Build Coastguard Worker                       type.value(),
101*89c4ff92SAndroid Build Coastguard Worker                       info.GetQuantizationScale(),
102*89c4ff92SAndroid Build Coastguard Worker                       info.GetQuantizationOffset(),
103*89c4ff92SAndroid Build Coastguard Worker                       info.IsConstant());
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker template< typename ... Args>
IsNeonBackendSupported(Optional<std::string &> reasonIfUnsupported,Args...args)107*89c4ff92SAndroid Build Coastguard Worker bool IsNeonBackendSupported(Optional<std::string&> reasonIfUnsupported, Args... args)
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(reasonIfUnsupported, (args)...);
110*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
111*89c4ff92SAndroid Build Coastguard Worker     return true;
112*89c4ff92SAndroid Build Coastguard Worker #else
113*89c4ff92SAndroid Build Coastguard Worker     SetValueChecked(reasonIfUnsupported, "The armnn library has been built without NEON support");
114*89c4ff92SAndroid Build Coastguard Worker     return false;
115*89c4ff92SAndroid Build Coastguard Worker #endif
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker template<typename FloatFunc, typename Uint8Func, typename ... Params>
IsSupportedForDataTypeNeon(Optional<std::string &> reasonIfUnsupported,DataType dataType,FloatFunc floatFuncPtr,Uint8Func uint8FuncPtr,Params &&...params)119*89c4ff92SAndroid Build Coastguard Worker bool IsSupportedForDataTypeNeon(Optional<std::string&> reasonIfUnsupported,
120*89c4ff92SAndroid Build Coastguard Worker                                 DataType dataType,
121*89c4ff92SAndroid Build Coastguard Worker                                 FloatFunc floatFuncPtr,
122*89c4ff92SAndroid Build Coastguard Worker                                 Uint8Func uint8FuncPtr,
123*89c4ff92SAndroid Build Coastguard Worker                                 Params&&... params)
124*89c4ff92SAndroid Build Coastguard Worker {
125*89c4ff92SAndroid Build Coastguard Worker     return IsNeonBackendSupported(reasonIfUnsupported) &&
126*89c4ff92SAndroid Build Coastguard Worker         IsSupportedForDataTypeGeneric(reasonIfUnsupported,
127*89c4ff92SAndroid Build Coastguard Worker                                          dataType,
128*89c4ff92SAndroid Build Coastguard Worker                                          floatFuncPtr,
129*89c4ff92SAndroid Build Coastguard Worker                                          floatFuncPtr,
130*89c4ff92SAndroid Build Coastguard Worker                                          uint8FuncPtr,
131*89c4ff92SAndroid Build Coastguard Worker                                          &FalseFunc<>,
132*89c4ff92SAndroid Build Coastguard Worker                                          &FalseFunc<>,
133*89c4ff92SAndroid Build Coastguard Worker                                          std::forward<Params>(params)...);
134*89c4ff92SAndroid Build Coastguard Worker }
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
137*89c4ff92SAndroid Build Coastguard Worker template<class FuncType, class... Args>
IsWorkloadSupported(FuncType & func,Optional<std::string &> reasonIfUnsupported,Args &&...args)138*89c4ff92SAndroid Build Coastguard Worker inline bool IsWorkloadSupported(FuncType& func, Optional<std::string&> reasonIfUnsupported, Args&&... args)
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker     arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
141*89c4ff92SAndroid Build Coastguard Worker     const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
142*89c4ff92SAndroid Build Coastguard Worker     if (!supported && reasonIfUnsupported)
143*89c4ff92SAndroid Build Coastguard Worker     {
144*89c4ff92SAndroid Build Coastguard Worker         reasonIfUnsupported.value() = aclStatus.error_description();
145*89c4ff92SAndroid Build Coastguard Worker     }
146*89c4ff92SAndroid Build Coastguard Worker     return supported;
147*89c4ff92SAndroid Build Coastguard Worker }
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
150*89c4ff92SAndroid Build Coastguard Worker     return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
151*89c4ff92SAndroid Build Coastguard Worker #else
152*89c4ff92SAndroid Build Coastguard Worker #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
153*89c4ff92SAndroid Build Coastguard Worker     return IsNeonBackendSupported(reasonIfUnsupported, __VA_ARGS__);
154*89c4ff92SAndroid Build Coastguard Worker #endif
155*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
156*89c4ff92SAndroid Build Coastguard Worker 
NeonLayerSupport(const IBackendInternal::IBackendSpecificModelContextPtr & modelContextPtr)157*89c4ff92SAndroid Build Coastguard Worker NeonLayerSupport::NeonLayerSupport(const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr)
158*89c4ff92SAndroid Build Coastguard Worker     : m_ModelContextPtr(modelContextPtr)
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker 
NeonLayerSupport()162*89c4ff92SAndroid Build Coastguard Worker NeonLayerSupport::NeonLayerSupport()
163*89c4ff92SAndroid Build Coastguard Worker     : m_ModelContextPtr(nullptr)
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker }
166*89c4ff92SAndroid Build Coastguard Worker 
IsLayerTypeSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> & lstmParamsInfo,const Optional<QuantizedLstmInputParamsInfo> & quantizedLstmParamsInfo,Optional<std::string &> reasonIfUnsupported,const NeonLayerSupport & support)167*89c4ff92SAndroid Build Coastguard Worker bool IsLayerTypeSupported(const LayerType& type,
168*89c4ff92SAndroid Build Coastguard Worker                           const std::vector<TensorInfo>& infos,
169*89c4ff92SAndroid Build Coastguard Worker                           const BaseDescriptor& descriptor,
170*89c4ff92SAndroid Build Coastguard Worker                           const Optional<LstmInputParamsInfo>& lstmParamsInfo,
171*89c4ff92SAndroid Build Coastguard Worker                           const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmParamsInfo,
172*89c4ff92SAndroid Build Coastguard Worker                           Optional<std::string&> reasonIfUnsupported,
173*89c4ff92SAndroid Build Coastguard Worker                           const NeonLayerSupport& support)
174*89c4ff92SAndroid Build Coastguard Worker {
175*89c4ff92SAndroid Build Coastguard Worker     switch (type)
176*89c4ff92SAndroid Build Coastguard Worker     {
177*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Activation:
178*89c4ff92SAndroid Build Coastguard Worker             return support.IsActivationSupported(infos[0],
179*89c4ff92SAndroid Build Coastguard Worker                                                  infos[1],
180*89c4ff92SAndroid Build Coastguard Worker                                                  *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
181*89c4ff92SAndroid Build Coastguard Worker                                                  reasonIfUnsupported);
182*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Addition:
183*89c4ff92SAndroid Build Coastguard Worker             return support.IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
184*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ArgMinMax:
185*89c4ff92SAndroid Build Coastguard Worker             return support.IsArgMinMaxSupported(infos[0],
186*89c4ff92SAndroid Build Coastguard Worker                                                 infos[1],
187*89c4ff92SAndroid Build Coastguard Worker                                                 *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
188*89c4ff92SAndroid Build Coastguard Worker                                                 reasonIfUnsupported);
189*89c4ff92SAndroid Build Coastguard Worker         case LayerType::BatchMatMul:
190*89c4ff92SAndroid Build Coastguard Worker             return support.IsBatchMatMulSupported(infos[0],
191*89c4ff92SAndroid Build Coastguard Worker                                                   infos[1],
192*89c4ff92SAndroid Build Coastguard Worker                                                   infos[2],
193*89c4ff92SAndroid Build Coastguard Worker                                                   *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
194*89c4ff92SAndroid Build Coastguard Worker                                                   reasonIfUnsupported);
195*89c4ff92SAndroid Build Coastguard Worker         case LayerType::BatchNormalization:
196*89c4ff92SAndroid Build Coastguard Worker             return support.IsBatchNormalizationSupported(infos[0],
197*89c4ff92SAndroid Build Coastguard Worker                                                          infos[1],
198*89c4ff92SAndroid Build Coastguard Worker                                                          infos[2],
199*89c4ff92SAndroid Build Coastguard Worker                                                          infos[3],
200*89c4ff92SAndroid Build Coastguard Worker                                                          infos[4],
201*89c4ff92SAndroid Build Coastguard Worker                                                          infos[5],
202*89c4ff92SAndroid Build Coastguard Worker                                                          *(PolymorphicDowncast<const
203*89c4ff92SAndroid Build Coastguard Worker                                                              BatchNormalizationDescriptor*>(&descriptor)),
204*89c4ff92SAndroid Build Coastguard Worker                                                          reasonIfUnsupported);
205*89c4ff92SAndroid Build Coastguard Worker         case LayerType::BatchToSpaceNd:
206*89c4ff92SAndroid Build Coastguard Worker             return support.IsBatchToSpaceNdSupported(infos[0],
207*89c4ff92SAndroid Build Coastguard Worker                                                      infos[1],
208*89c4ff92SAndroid Build Coastguard Worker                                                      *(PolymorphicDowncast<const
209*89c4ff92SAndroid Build Coastguard Worker                                                         BatchToSpaceNdDescriptor*>(&descriptor)),
210*89c4ff92SAndroid Build Coastguard Worker                                                      reasonIfUnsupported);
211*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Cast:
212*89c4ff92SAndroid Build Coastguard Worker             return support.IsCastSupported(infos[0], infos[1], reasonIfUnsupported);
213*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ChannelShuffle:
214*89c4ff92SAndroid Build Coastguard Worker             return support.IsChannelShuffleSupported(infos[0],
215*89c4ff92SAndroid Build Coastguard Worker                                                      infos[1],
216*89c4ff92SAndroid Build Coastguard Worker                                                      *(PolymorphicDowncast<const
217*89c4ff92SAndroid Build Coastguard Worker                                                          ChannelShuffleDescriptor*>(&descriptor)),
218*89c4ff92SAndroid Build Coastguard Worker                                                      reasonIfUnsupported);
219*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Comparison:
220*89c4ff92SAndroid Build Coastguard Worker             return support.IsComparisonSupported(infos[0],
221*89c4ff92SAndroid Build Coastguard Worker                                                  infos[1],
222*89c4ff92SAndroid Build Coastguard Worker                                                  infos[2],
223*89c4ff92SAndroid Build Coastguard Worker                                                  *(PolymorphicDowncast<const ComparisonDescriptor*>(&descriptor)),
224*89c4ff92SAndroid Build Coastguard Worker                                                  reasonIfUnsupported);
225*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Concat:
226*89c4ff92SAndroid Build Coastguard Worker         {
227*89c4ff92SAndroid Build Coastguard Worker             std::vector<const TensorInfo*> inputInfos;
228*89c4ff92SAndroid Build Coastguard Worker             for (uint32_t i = 0; i < (infos.size() - 1); i++)
229*89c4ff92SAndroid Build Coastguard Worker             {
230*89c4ff92SAndroid Build Coastguard Worker                 inputInfos.push_back(&infos[i]);
231*89c4ff92SAndroid Build Coastguard Worker             }
232*89c4ff92SAndroid Build Coastguard Worker             return support.IsConcatSupported(inputInfos,
233*89c4ff92SAndroid Build Coastguard Worker                                              infos[infos.size() - 1],
234*89c4ff92SAndroid Build Coastguard Worker                                              *(PolymorphicDowncast<const OriginsDescriptor*>(&descriptor)),
235*89c4ff92SAndroid Build Coastguard Worker                                              reasonIfUnsupported);
236*89c4ff92SAndroid Build Coastguard Worker         }
237*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Constant:
238*89c4ff92SAndroid Build Coastguard Worker             return support.IsConstantSupported(infos[0], reasonIfUnsupported);
239*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ConvertFp16ToFp32:
240*89c4ff92SAndroid Build Coastguard Worker             return support.IsConvertFp16ToFp32Supported(infos[0], infos[1], reasonIfUnsupported);
241*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ConvertFp32ToFp16:
242*89c4ff92SAndroid Build Coastguard Worker             return support.IsConvertFp32ToFp16Supported(infos[0], infos[1], reasonIfUnsupported);
243*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Convolution2d:
244*89c4ff92SAndroid Build Coastguard Worker         {
245*89c4ff92SAndroid Build Coastguard Worker             if (infos.size() != 4)
246*89c4ff92SAndroid Build Coastguard Worker             {
247*89c4ff92SAndroid Build Coastguard Worker                 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
248*89c4ff92SAndroid Build Coastguard Worker                                                "TensorInfos should be of format: {input, output, weights, biases}.");
249*89c4ff92SAndroid Build Coastguard Worker             }
250*89c4ff92SAndroid Build Coastguard Worker 
251*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
252*89c4ff92SAndroid Build Coastguard Worker             if (infos[3] == TensorInfo())
253*89c4ff92SAndroid Build Coastguard Worker             {
254*89c4ff92SAndroid Build Coastguard Worker                 return support.IsConvolution2dSupported(infos[0],
255*89c4ff92SAndroid Build Coastguard Worker                                                         infos[1],
256*89c4ff92SAndroid Build Coastguard Worker                                                         desc,
257*89c4ff92SAndroid Build Coastguard Worker                                                         infos[2],
258*89c4ff92SAndroid Build Coastguard Worker                                                         EmptyOptional(),
259*89c4ff92SAndroid Build Coastguard Worker                                                         reasonIfUnsupported);
260*89c4ff92SAndroid Build Coastguard Worker             }
261*89c4ff92SAndroid Build Coastguard Worker             else
262*89c4ff92SAndroid Build Coastguard Worker             {
263*89c4ff92SAndroid Build Coastguard Worker                 return support.IsConvolution2dSupported(infos[0],
264*89c4ff92SAndroid Build Coastguard Worker                                                         infos[1],
265*89c4ff92SAndroid Build Coastguard Worker                                                         desc,
266*89c4ff92SAndroid Build Coastguard Worker                                                         infos[2],
267*89c4ff92SAndroid Build Coastguard Worker                                                         infos[3],
268*89c4ff92SAndroid Build Coastguard Worker                                                         reasonIfUnsupported);
269*89c4ff92SAndroid Build Coastguard Worker             }
270*89c4ff92SAndroid Build Coastguard Worker         }
271*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Convolution3d:
272*89c4ff92SAndroid Build Coastguard Worker         {
273*89c4ff92SAndroid Build Coastguard Worker             if (infos.size() != 4)
274*89c4ff92SAndroid Build Coastguard Worker             {
275*89c4ff92SAndroid Build Coastguard Worker                 throw InvalidArgumentException("Invalid number of Convolution3d TensorInfos. "
276*89c4ff92SAndroid Build Coastguard Worker                                                "TensorInfos should be of format: {input, output, weights, biases}.");
277*89c4ff92SAndroid Build Coastguard Worker             }
278*89c4ff92SAndroid Build Coastguard Worker 
279*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const Convolution3dDescriptor*>(&descriptor));
280*89c4ff92SAndroid Build Coastguard Worker             if (infos[3] == TensorInfo())
281*89c4ff92SAndroid Build Coastguard Worker             {
282*89c4ff92SAndroid Build Coastguard Worker                 return support.IsConvolution3dSupported(infos[0],
283*89c4ff92SAndroid Build Coastguard Worker                                                         infos[1],
284*89c4ff92SAndroid Build Coastguard Worker                                                         desc,
285*89c4ff92SAndroid Build Coastguard Worker                                                         infos[2],
286*89c4ff92SAndroid Build Coastguard Worker                                                         EmptyOptional(),
287*89c4ff92SAndroid Build Coastguard Worker                                                         reasonIfUnsupported);
288*89c4ff92SAndroid Build Coastguard Worker             }
289*89c4ff92SAndroid Build Coastguard Worker             else
290*89c4ff92SAndroid Build Coastguard Worker             {
291*89c4ff92SAndroid Build Coastguard Worker                 return support.IsConvolution3dSupported(infos[0],
292*89c4ff92SAndroid Build Coastguard Worker                                                         infos[1],
293*89c4ff92SAndroid Build Coastguard Worker                                                         desc,
294*89c4ff92SAndroid Build Coastguard Worker                                                         infos[2],
295*89c4ff92SAndroid Build Coastguard Worker                                                         infos[3],
296*89c4ff92SAndroid Build Coastguard Worker                                                         reasonIfUnsupported);
297*89c4ff92SAndroid Build Coastguard Worker             }
298*89c4ff92SAndroid Build Coastguard Worker         }
299*89c4ff92SAndroid Build Coastguard Worker         case LayerType::DepthToSpace:
300*89c4ff92SAndroid Build Coastguard Worker             return support.IsDepthToSpaceSupported(infos[0],
301*89c4ff92SAndroid Build Coastguard Worker                                                    infos[1],
302*89c4ff92SAndroid Build Coastguard Worker                                                    *(PolymorphicDowncast<const DepthToSpaceDescriptor*>(&descriptor)),
303*89c4ff92SAndroid Build Coastguard Worker                                                    reasonIfUnsupported);
304*89c4ff92SAndroid Build Coastguard Worker         case LayerType::DepthwiseConvolution2d:
305*89c4ff92SAndroid Build Coastguard Worker         {
306*89c4ff92SAndroid Build Coastguard Worker             if (infos.size() != 4)
307*89c4ff92SAndroid Build Coastguard Worker             {
308*89c4ff92SAndroid Build Coastguard Worker                 throw InvalidArgumentException("Invalid number of DepthwiseConvolution2d TensorInfos. "
309*89c4ff92SAndroid Build Coastguard Worker                                                "TensorInfos should be of format: {input, output, weights, biases}.");
310*89c4ff92SAndroid Build Coastguard Worker             }
311*89c4ff92SAndroid Build Coastguard Worker 
312*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
313*89c4ff92SAndroid Build Coastguard Worker             if (infos[3] == TensorInfo())
314*89c4ff92SAndroid Build Coastguard Worker             {
315*89c4ff92SAndroid Build Coastguard Worker                 return support.IsDepthwiseConvolutionSupported(infos[0],
316*89c4ff92SAndroid Build Coastguard Worker                                                                infos[1],
317*89c4ff92SAndroid Build Coastguard Worker                                                                desc,
318*89c4ff92SAndroid Build Coastguard Worker                                                                infos[2],
319*89c4ff92SAndroid Build Coastguard Worker                                                                EmptyOptional(),
320*89c4ff92SAndroid Build Coastguard Worker                                                                reasonIfUnsupported);
321*89c4ff92SAndroid Build Coastguard Worker             }
322*89c4ff92SAndroid Build Coastguard Worker             else
323*89c4ff92SAndroid Build Coastguard Worker             {
324*89c4ff92SAndroid Build Coastguard Worker                 return support.IsDepthwiseConvolutionSupported(infos[0],
325*89c4ff92SAndroid Build Coastguard Worker                                                                infos[1],
326*89c4ff92SAndroid Build Coastguard Worker                                                                desc,
327*89c4ff92SAndroid Build Coastguard Worker                                                                infos[2],
328*89c4ff92SAndroid Build Coastguard Worker                                                                infos[3],
329*89c4ff92SAndroid Build Coastguard Worker                                                                reasonIfUnsupported);
330*89c4ff92SAndroid Build Coastguard Worker             }
331*89c4ff92SAndroid Build Coastguard Worker         }
332*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Dequantize:
333*89c4ff92SAndroid Build Coastguard Worker             return support.IsDequantizeSupported(infos[0], infos[1], reasonIfUnsupported);
334*89c4ff92SAndroid Build Coastguard Worker         case LayerType::DetectionPostProcess:
335*89c4ff92SAndroid Build Coastguard Worker         {
336*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const DetectionPostProcessDescriptor*>(&descriptor));
337*89c4ff92SAndroid Build Coastguard Worker             return support.IsDetectionPostProcessSupported(infos[0],
338*89c4ff92SAndroid Build Coastguard Worker                                                            infos[1],
339*89c4ff92SAndroid Build Coastguard Worker                                                            infos[2],
340*89c4ff92SAndroid Build Coastguard Worker                                                            infos[3],
341*89c4ff92SAndroid Build Coastguard Worker                                                            infos[4],
342*89c4ff92SAndroid Build Coastguard Worker                                                            infos[5],
343*89c4ff92SAndroid Build Coastguard Worker                                                            infos[6],
344*89c4ff92SAndroid Build Coastguard Worker                                                            desc,
345*89c4ff92SAndroid Build Coastguard Worker                                                            reasonIfUnsupported);
346*89c4ff92SAndroid Build Coastguard Worker         }
347*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Division:
348*89c4ff92SAndroid Build Coastguard Worker             return support.IsDivisionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
349*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ElementwiseBinary:
350*89c4ff92SAndroid Build Coastguard Worker         {
351*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const ElementwiseBinaryDescriptor *>(&descriptor));
352*89c4ff92SAndroid Build Coastguard Worker 
353*89c4ff92SAndroid Build Coastguard Worker             switch (desc.m_Operation)
354*89c4ff92SAndroid Build Coastguard Worker             {
355*89c4ff92SAndroid Build Coastguard Worker                 case BinaryOperation::Add:
356*89c4ff92SAndroid Build Coastguard Worker                     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonAdditionWorkloadValidate,
357*89c4ff92SAndroid Build Coastguard Worker                                                    reasonIfUnsupported,
358*89c4ff92SAndroid Build Coastguard Worker                                                    infos[0],
359*89c4ff92SAndroid Build Coastguard Worker                                                    infos[1],
360*89c4ff92SAndroid Build Coastguard Worker                                                    infos[2],
361*89c4ff92SAndroid Build Coastguard Worker                                                    nullptr);
362*89c4ff92SAndroid Build Coastguard Worker                 case BinaryOperation::Div:
363*89c4ff92SAndroid Build Coastguard Worker                     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonDivisionWorkloadValidate,
364*89c4ff92SAndroid Build Coastguard Worker                                                    reasonIfUnsupported,
365*89c4ff92SAndroid Build Coastguard Worker                                                    infos[0],
366*89c4ff92SAndroid Build Coastguard Worker                                                    infos[1],
367*89c4ff92SAndroid Build Coastguard Worker                                                    infos[2],
368*89c4ff92SAndroid Build Coastguard Worker                                                    nullptr);
369*89c4ff92SAndroid Build Coastguard Worker                 case BinaryOperation::Maximum:
370*89c4ff92SAndroid Build Coastguard Worker                     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonMaximumWorkloadValidate,
371*89c4ff92SAndroid Build Coastguard Worker                                                    reasonIfUnsupported,
372*89c4ff92SAndroid Build Coastguard Worker                                                    infos[0],
373*89c4ff92SAndroid Build Coastguard Worker                                                    infos[1],
374*89c4ff92SAndroid Build Coastguard Worker                                                    infos[2]);
375*89c4ff92SAndroid Build Coastguard Worker                 case BinaryOperation::Minimum:
376*89c4ff92SAndroid Build Coastguard Worker                     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonMinimumWorkloadValidate,
377*89c4ff92SAndroid Build Coastguard Worker                                                    reasonIfUnsupported,
378*89c4ff92SAndroid Build Coastguard Worker                                                    infos[0],
379*89c4ff92SAndroid Build Coastguard Worker                                                    infos[1],
380*89c4ff92SAndroid Build Coastguard Worker                                                    infos[2]);
381*89c4ff92SAndroid Build Coastguard Worker                 case BinaryOperation::Mul:
382*89c4ff92SAndroid Build Coastguard Worker                     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonMultiplicationWorkloadValidate,
383*89c4ff92SAndroid Build Coastguard Worker                                                    reasonIfUnsupported,
384*89c4ff92SAndroid Build Coastguard Worker                                                    infos[0],
385*89c4ff92SAndroid Build Coastguard Worker                                                    infos[1],
386*89c4ff92SAndroid Build Coastguard Worker                                                    infos[2],
387*89c4ff92SAndroid Build Coastguard Worker                                                    nullptr);
388*89c4ff92SAndroid Build Coastguard Worker                 case BinaryOperation::Sub:
389*89c4ff92SAndroid Build Coastguard Worker                     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonSubtractionWorkloadValidate,
390*89c4ff92SAndroid Build Coastguard Worker                                                    reasonIfUnsupported,
391*89c4ff92SAndroid Build Coastguard Worker                                                    infos[0],
392*89c4ff92SAndroid Build Coastguard Worker                                                    infos[1],
393*89c4ff92SAndroid Build Coastguard Worker                                                    infos[2],
394*89c4ff92SAndroid Build Coastguard Worker                                                    nullptr);
395*89c4ff92SAndroid Build Coastguard Worker                 default:
396*89c4ff92SAndroid Build Coastguard Worker                     return false;
397*89c4ff92SAndroid Build Coastguard Worker             }
398*89c4ff92SAndroid Build Coastguard Worker         }
399*89c4ff92SAndroid Build Coastguard Worker         case LayerType::ElementwiseUnary:
400*89c4ff92SAndroid Build Coastguard Worker             return support.IsElementwiseUnarySupported(infos[0],
401*89c4ff92SAndroid Build Coastguard Worker                                                        infos[1],
402*89c4ff92SAndroid Build Coastguard Worker                                                        *(PolymorphicDowncast<const
403*89c4ff92SAndroid Build Coastguard Worker                                                            ElementwiseUnaryDescriptor*>(&descriptor)),
404*89c4ff92SAndroid Build Coastguard Worker                                                        reasonIfUnsupported);
405*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Fill:
406*89c4ff92SAndroid Build Coastguard Worker             return support.IsFillSupported(infos[0],
407*89c4ff92SAndroid Build Coastguard Worker                                            infos[1],
408*89c4ff92SAndroid Build Coastguard Worker                                            *(PolymorphicDowncast<const FillDescriptor*>(&descriptor)),
409*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported);
410*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Floor:
411*89c4ff92SAndroid Build Coastguard Worker             return support.IsFloorSupported(infos[0], infos[1], reasonIfUnsupported);
412*89c4ff92SAndroid Build Coastguard Worker         case LayerType::FullyConnected:
413*89c4ff92SAndroid Build Coastguard Worker             return support.IsFullyConnectedSupported(infos[0],
414*89c4ff92SAndroid Build Coastguard Worker                                                      infos[1],
415*89c4ff92SAndroid Build Coastguard Worker                                                      infos[2],
416*89c4ff92SAndroid Build Coastguard Worker                                                      infos[3],
417*89c4ff92SAndroid Build Coastguard Worker                                                      *(PolymorphicDowncast<const
418*89c4ff92SAndroid Build Coastguard Worker                                                          FullyConnectedDescriptor*>(&descriptor)),
419*89c4ff92SAndroid Build Coastguard Worker                                                      reasonIfUnsupported);
420*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Gather:
421*89c4ff92SAndroid Build Coastguard Worker             return support.IsGatherSupported(infos[0],
422*89c4ff92SAndroid Build Coastguard Worker                                              infos[1],
423*89c4ff92SAndroid Build Coastguard Worker                                              infos[2],
424*89c4ff92SAndroid Build Coastguard Worker                                              *(PolymorphicDowncast<const GatherDescriptor*>(&descriptor)),
425*89c4ff92SAndroid Build Coastguard Worker                                              reasonIfUnsupported);
426*89c4ff92SAndroid Build Coastguard Worker         case LayerType::GatherNd:
427*89c4ff92SAndroid Build Coastguard Worker             return support.IsGatherNdSupported(infos[0],
428*89c4ff92SAndroid Build Coastguard Worker                                                infos[1],
429*89c4ff92SAndroid Build Coastguard Worker                                                infos[2],
430*89c4ff92SAndroid Build Coastguard Worker                                                reasonIfUnsupported);
431*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Input:
432*89c4ff92SAndroid Build Coastguard Worker             return support.IsInputSupported(infos[0], reasonIfUnsupported);
433*89c4ff92SAndroid Build Coastguard Worker         case LayerType::InstanceNormalization:
434*89c4ff92SAndroid Build Coastguard Worker             return support.IsInstanceNormalizationSupported(infos[0],
435*89c4ff92SAndroid Build Coastguard Worker                                                             infos[1],
436*89c4ff92SAndroid Build Coastguard Worker                                                             *(PolymorphicDowncast<const
437*89c4ff92SAndroid Build Coastguard Worker                                                                 InstanceNormalizationDescriptor*>(&descriptor)),
438*89c4ff92SAndroid Build Coastguard Worker                                                             reasonIfUnsupported);
439*89c4ff92SAndroid Build Coastguard Worker         case LayerType::L2Normalization:
440*89c4ff92SAndroid Build Coastguard Worker             return support.IsL2NormalizationSupported(infos[0],
441*89c4ff92SAndroid Build Coastguard Worker                                                       infos[1],
442*89c4ff92SAndroid Build Coastguard Worker                                                       *(PolymorphicDowncast<const
443*89c4ff92SAndroid Build Coastguard Worker                                                           L2NormalizationDescriptor*>(&descriptor)),
444*89c4ff92SAndroid Build Coastguard Worker                                                       reasonIfUnsupported);
445*89c4ff92SAndroid Build Coastguard Worker         case LayerType::LogicalBinary:
446*89c4ff92SAndroid Build Coastguard Worker             return support.IsLogicalBinarySupported(infos[0],
447*89c4ff92SAndroid Build Coastguard Worker                                                     infos[1],
448*89c4ff92SAndroid Build Coastguard Worker                                                     infos[2],
449*89c4ff92SAndroid Build Coastguard Worker                                                     *(PolymorphicDowncast<const
450*89c4ff92SAndroid Build Coastguard Worker                                                         LogicalBinaryDescriptor*>(&descriptor)),
451*89c4ff92SAndroid Build Coastguard Worker                                                     reasonIfUnsupported);
452*89c4ff92SAndroid Build Coastguard Worker         case LayerType::LogSoftmax:
453*89c4ff92SAndroid Build Coastguard Worker             return support.IsLogSoftmaxSupported(infos[0],
454*89c4ff92SAndroid Build Coastguard Worker                                                  infos[1],
455*89c4ff92SAndroid Build Coastguard Worker                                                  *(PolymorphicDowncast<const LogSoftmaxDescriptor*>(&descriptor)),
456*89c4ff92SAndroid Build Coastguard Worker                                                  reasonIfUnsupported);
457*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Lstm:
458*89c4ff92SAndroid Build Coastguard Worker             return support.IsLstmSupported(infos[0],
459*89c4ff92SAndroid Build Coastguard Worker                                            infos[1],
460*89c4ff92SAndroid Build Coastguard Worker                                            infos[2],
461*89c4ff92SAndroid Build Coastguard Worker                                            infos[3],
462*89c4ff92SAndroid Build Coastguard Worker                                            infos[4],
463*89c4ff92SAndroid Build Coastguard Worker                                            infos[5],
464*89c4ff92SAndroid Build Coastguard Worker                                            infos[6],
465*89c4ff92SAndroid Build Coastguard Worker                                            *(PolymorphicDowncast<const LstmDescriptor*>(&descriptor)),
466*89c4ff92SAndroid Build Coastguard Worker                                            lstmParamsInfo.value(),
467*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported);
468*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Map:
469*89c4ff92SAndroid Build Coastguard Worker             return true;
470*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Maximum:
471*89c4ff92SAndroid Build Coastguard Worker             return support.IsMaximumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
472*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Mean:
473*89c4ff92SAndroid Build Coastguard Worker             return support.IsMeanSupported(infos[0],
474*89c4ff92SAndroid Build Coastguard Worker                                            infos[1],
475*89c4ff92SAndroid Build Coastguard Worker                                            *(PolymorphicDowncast<const MeanDescriptor*>(&descriptor)),
476*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported);
477*89c4ff92SAndroid Build Coastguard Worker         case LayerType::MemCopy:
478*89c4ff92SAndroid Build Coastguard Worker             return support.IsMemCopySupported(infos[0], infos[1], reasonIfUnsupported);
479*89c4ff92SAndroid Build Coastguard Worker         case LayerType::MemImport:
480*89c4ff92SAndroid Build Coastguard Worker             return support.IsMemImportSupported(infos[0], infos[1], reasonIfUnsupported);
481*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Merge:
482*89c4ff92SAndroid Build Coastguard Worker             return support.IsMergeSupported(infos[0],
483*89c4ff92SAndroid Build Coastguard Worker                                                       infos[1],
484*89c4ff92SAndroid Build Coastguard Worker                                                       infos[2],
485*89c4ff92SAndroid Build Coastguard Worker                                                       reasonIfUnsupported);
486*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Minimum:
487*89c4ff92SAndroid Build Coastguard Worker             return support.IsMinimumSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
488*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Multiplication:
489*89c4ff92SAndroid Build Coastguard Worker             return support.IsMultiplicationSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
490*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Normalization:
491*89c4ff92SAndroid Build Coastguard Worker             return support.IsNormalizationSupported(infos[0],
492*89c4ff92SAndroid Build Coastguard Worker                                                     infos[1],
493*89c4ff92SAndroid Build Coastguard Worker                                                     *(PolymorphicDowncast<const
494*89c4ff92SAndroid Build Coastguard Worker                                                         NormalizationDescriptor*>(&descriptor)),
495*89c4ff92SAndroid Build Coastguard Worker                                                     reasonIfUnsupported);
496*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Output:
497*89c4ff92SAndroid Build Coastguard Worker             return support.IsOutputSupported(infos[0], reasonIfUnsupported);
498*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Pad:
499*89c4ff92SAndroid Build Coastguard Worker             return support.IsPadSupported(infos[0],
500*89c4ff92SAndroid Build Coastguard Worker                                           infos[1],
501*89c4ff92SAndroid Build Coastguard Worker                                           *(PolymorphicDowncast<const PadDescriptor*>(&descriptor)),
502*89c4ff92SAndroid Build Coastguard Worker                                           reasonIfUnsupported);
503*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Permute:
504*89c4ff92SAndroid Build Coastguard Worker             return support.IsPermuteSupported(infos[0],
505*89c4ff92SAndroid Build Coastguard Worker                                               infos[1],
506*89c4ff92SAndroid Build Coastguard Worker                                               *(PolymorphicDowncast<const PermuteDescriptor*>(&descriptor)),
507*89c4ff92SAndroid Build Coastguard Worker                                               reasonIfUnsupported);
508*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Pooling2d:
509*89c4ff92SAndroid Build Coastguard Worker             return support.IsPooling2dSupported(infos[0],
510*89c4ff92SAndroid Build Coastguard Worker                                                 infos[1],
511*89c4ff92SAndroid Build Coastguard Worker                                                 *(PolymorphicDowncast<const Pooling2dDescriptor*>(&descriptor)),
512*89c4ff92SAndroid Build Coastguard Worker                                                 reasonIfUnsupported);
513*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Pooling3d:
514*89c4ff92SAndroid Build Coastguard Worker             return support.IsPooling3dSupported(infos[0],
515*89c4ff92SAndroid Build Coastguard Worker                                                 infos[1],
516*89c4ff92SAndroid Build Coastguard Worker                                                 *(PolymorphicDowncast<const Pooling3dDescriptor*>(&descriptor)),
517*89c4ff92SAndroid Build Coastguard Worker                                                 reasonIfUnsupported);
518*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Prelu:
519*89c4ff92SAndroid Build Coastguard Worker             return support.IsPreluSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
520*89c4ff92SAndroid Build Coastguard Worker         case LayerType::QLstm:
521*89c4ff92SAndroid Build Coastguard Worker             return support.IsQLstmSupported(infos[0],
522*89c4ff92SAndroid Build Coastguard Worker                                             infos[1],
523*89c4ff92SAndroid Build Coastguard Worker                                             infos[2],
524*89c4ff92SAndroid Build Coastguard Worker                                             infos[3],
525*89c4ff92SAndroid Build Coastguard Worker                                             infos[4],
526*89c4ff92SAndroid Build Coastguard Worker                                             infos[5],
527*89c4ff92SAndroid Build Coastguard Worker                                             *(PolymorphicDowncast<const QLstmDescriptor*>(&descriptor)),
528*89c4ff92SAndroid Build Coastguard Worker                                             lstmParamsInfo.value(),
529*89c4ff92SAndroid Build Coastguard Worker                                             reasonIfUnsupported);
530*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Quantize:
531*89c4ff92SAndroid Build Coastguard Worker             return support.IsQuantizeSupported(infos[0], infos[1], reasonIfUnsupported);
532*89c4ff92SAndroid Build Coastguard Worker         case LayerType::QuantizedLstm:
533*89c4ff92SAndroid Build Coastguard Worker             return support.IsQuantizedLstmSupported(infos[0],
534*89c4ff92SAndroid Build Coastguard Worker                                                     infos[1],
535*89c4ff92SAndroid Build Coastguard Worker                                                     infos[2],
536*89c4ff92SAndroid Build Coastguard Worker                                                     infos[3],
537*89c4ff92SAndroid Build Coastguard Worker                                                     infos[4],
538*89c4ff92SAndroid Build Coastguard Worker                                                     quantizedLstmParamsInfo.value(),
539*89c4ff92SAndroid Build Coastguard Worker                                                     reasonIfUnsupported);
540*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Rank:
541*89c4ff92SAndroid Build Coastguard Worker             return true;
542*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Reshape:
543*89c4ff92SAndroid Build Coastguard Worker             return support.IsReshapeSupported(infos[0],
544*89c4ff92SAndroid Build Coastguard Worker                                               infos[1],
545*89c4ff92SAndroid Build Coastguard Worker                                               *(PolymorphicDowncast<const ReshapeDescriptor*>(&descriptor)),
546*89c4ff92SAndroid Build Coastguard Worker                                               reasonIfUnsupported);
547*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Resize:
548*89c4ff92SAndroid Build Coastguard Worker             return support.IsResizeSupported(infos[0],
549*89c4ff92SAndroid Build Coastguard Worker                                              infos[1],
550*89c4ff92SAndroid Build Coastguard Worker                                              *(PolymorphicDowncast<const ResizeDescriptor*>(&descriptor)),
551*89c4ff92SAndroid Build Coastguard Worker                                              reasonIfUnsupported);
552*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Reduce:
553*89c4ff92SAndroid Build Coastguard Worker             return support.IsReduceSupported(infos[0],
554*89c4ff92SAndroid Build Coastguard Worker                                              infos[1],
555*89c4ff92SAndroid Build Coastguard Worker                                              *(PolymorphicDowncast<const ReduceDescriptor*>(&descriptor)),
556*89c4ff92SAndroid Build Coastguard Worker                                              reasonIfUnsupported);
557*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Shape:
558*89c4ff92SAndroid Build Coastguard Worker             return support.IsShapeSupported(infos[0],
559*89c4ff92SAndroid Build Coastguard Worker                                             infos[1],
560*89c4ff92SAndroid Build Coastguard Worker                                             reasonIfUnsupported);
561*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Slice:
562*89c4ff92SAndroid Build Coastguard Worker             return support.IsSliceSupported(infos[0],
563*89c4ff92SAndroid Build Coastguard Worker                                             infos[1],
564*89c4ff92SAndroid Build Coastguard Worker                                             *(PolymorphicDowncast<const SliceDescriptor*>(&descriptor)),
565*89c4ff92SAndroid Build Coastguard Worker                                             reasonIfUnsupported);
566*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Softmax:
567*89c4ff92SAndroid Build Coastguard Worker             return support.IsSoftmaxSupported(infos[0],
568*89c4ff92SAndroid Build Coastguard Worker                                               infos[1],
569*89c4ff92SAndroid Build Coastguard Worker                                               *(PolymorphicDowncast<const SoftmaxDescriptor*>(&descriptor)),
570*89c4ff92SAndroid Build Coastguard Worker                                               reasonIfUnsupported);
571*89c4ff92SAndroid Build Coastguard Worker         case LayerType::SpaceToBatchNd:
572*89c4ff92SAndroid Build Coastguard Worker             return support.IsSpaceToBatchNdSupported(infos[0],
573*89c4ff92SAndroid Build Coastguard Worker                                                      infos[1],
574*89c4ff92SAndroid Build Coastguard Worker                                                      *(PolymorphicDowncast<const
575*89c4ff92SAndroid Build Coastguard Worker                                                         SpaceToBatchNdDescriptor*>(&descriptor)),
576*89c4ff92SAndroid Build Coastguard Worker                                                      reasonIfUnsupported);
577*89c4ff92SAndroid Build Coastguard Worker         case LayerType::SpaceToDepth:
578*89c4ff92SAndroid Build Coastguard Worker             return support.IsSpaceToDepthSupported(infos[0],
579*89c4ff92SAndroid Build Coastguard Worker                                                    infos[1],
580*89c4ff92SAndroid Build Coastguard Worker                                                    *(PolymorphicDowncast<const SpaceToDepthDescriptor*>(&descriptor)),
581*89c4ff92SAndroid Build Coastguard Worker                                                    reasonIfUnsupported);
582*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Splitter:
583*89c4ff92SAndroid Build Coastguard Worker         {
584*89c4ff92SAndroid Build Coastguard Worker             std::vector<TensorInfo> outputInfos;
585*89c4ff92SAndroid Build Coastguard Worker             for (uint32_t i = 1; i < infos.size(); i++)
586*89c4ff92SAndroid Build Coastguard Worker             {
587*89c4ff92SAndroid Build Coastguard Worker                 outputInfos.push_back(infos[i]);
588*89c4ff92SAndroid Build Coastguard Worker             }
589*89c4ff92SAndroid Build Coastguard Worker             return support.IsSplitterSupported(infos[0],
590*89c4ff92SAndroid Build Coastguard Worker                                                {outputInfos.begin(), outputInfos.end()},
591*89c4ff92SAndroid Build Coastguard Worker                                                *(PolymorphicDowncast<const ViewsDescriptor*>(&descriptor)),
592*89c4ff92SAndroid Build Coastguard Worker                                                reasonIfUnsupported);
593*89c4ff92SAndroid Build Coastguard Worker         }
594*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Stack:
595*89c4ff92SAndroid Build Coastguard Worker         {
596*89c4ff92SAndroid Build Coastguard Worker             std::vector<const TensorInfo*> inputInfos;
597*89c4ff92SAndroid Build Coastguard Worker             for (uint32_t i = 0; i < infos.size() - 1; i++)
598*89c4ff92SAndroid Build Coastguard Worker             {
599*89c4ff92SAndroid Build Coastguard Worker                 inputInfos.push_back(&infos[i]);
600*89c4ff92SAndroid Build Coastguard Worker             }
601*89c4ff92SAndroid Build Coastguard Worker             return support.IsStackSupported(inputInfos,
602*89c4ff92SAndroid Build Coastguard Worker                                             infos[infos.size() - 1],
603*89c4ff92SAndroid Build Coastguard Worker                                             *(PolymorphicDowncast<const StackDescriptor*>(&descriptor)),
604*89c4ff92SAndroid Build Coastguard Worker                                             reasonIfUnsupported);
605*89c4ff92SAndroid Build Coastguard Worker         }
606*89c4ff92SAndroid Build Coastguard Worker         case LayerType::StridedSlice:
607*89c4ff92SAndroid Build Coastguard Worker             return support.IsStridedSliceSupported(infos[0],
608*89c4ff92SAndroid Build Coastguard Worker                                                    infos[1],
609*89c4ff92SAndroid Build Coastguard Worker                                                    *(PolymorphicDowncast<const StridedSliceDescriptor*>(&descriptor)),
610*89c4ff92SAndroid Build Coastguard Worker                                                    reasonIfUnsupported);
611*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Subtraction:
612*89c4ff92SAndroid Build Coastguard Worker             return support.IsSubtractionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
613*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Transpose:
614*89c4ff92SAndroid Build Coastguard Worker             return support.IsTransposeSupported(infos[0],
615*89c4ff92SAndroid Build Coastguard Worker                                                 infos[1],
616*89c4ff92SAndroid Build Coastguard Worker                                                 *(PolymorphicDowncast<const TransposeDescriptor*>(&descriptor)),
617*89c4ff92SAndroid Build Coastguard Worker                                                 reasonIfUnsupported);
618*89c4ff92SAndroid Build Coastguard Worker         case LayerType::TransposeConvolution2d:
619*89c4ff92SAndroid Build Coastguard Worker         {
620*89c4ff92SAndroid Build Coastguard Worker             if (infos.size() != 4)
621*89c4ff92SAndroid Build Coastguard Worker             {
622*89c4ff92SAndroid Build Coastguard Worker                 throw InvalidArgumentException("Invalid number of TransposeConvolution2d TensorInfos. "
623*89c4ff92SAndroid Build Coastguard Worker                                                "TensorInfos should be of format: {input, output, weights, biases}.");
624*89c4ff92SAndroid Build Coastguard Worker             }
625*89c4ff92SAndroid Build Coastguard Worker 
626*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const TransposeConvolution2dDescriptor*>(&descriptor));
627*89c4ff92SAndroid Build Coastguard Worker             if (infos[3] == TensorInfo())
628*89c4ff92SAndroid Build Coastguard Worker             {
629*89c4ff92SAndroid Build Coastguard Worker                 return support.IsTransposeConvolution2dSupported(infos[0],
630*89c4ff92SAndroid Build Coastguard Worker                                                                  infos[1],
631*89c4ff92SAndroid Build Coastguard Worker                                                                  desc,
632*89c4ff92SAndroid Build Coastguard Worker                                                                  infos[2],
633*89c4ff92SAndroid Build Coastguard Worker                                                                  EmptyOptional(),
634*89c4ff92SAndroid Build Coastguard Worker                                                                  reasonIfUnsupported);
635*89c4ff92SAndroid Build Coastguard Worker             }
636*89c4ff92SAndroid Build Coastguard Worker             else
637*89c4ff92SAndroid Build Coastguard Worker             {
638*89c4ff92SAndroid Build Coastguard Worker                 return support.IsTransposeConvolution2dSupported(infos[0],
639*89c4ff92SAndroid Build Coastguard Worker                                                                  infos[1],
640*89c4ff92SAndroid Build Coastguard Worker                                                                  desc,
641*89c4ff92SAndroid Build Coastguard Worker                                                                  infos[2],
642*89c4ff92SAndroid Build Coastguard Worker                                                                  infos[3],
643*89c4ff92SAndroid Build Coastguard Worker                                                                  reasonIfUnsupported);
644*89c4ff92SAndroid Build Coastguard Worker             }
645*89c4ff92SAndroid Build Coastguard Worker         }
646*89c4ff92SAndroid Build Coastguard Worker         case LayerType::UnidirectionalSequenceLstm:
647*89c4ff92SAndroid Build Coastguard Worker         {
648*89c4ff92SAndroid Build Coastguard Worker             auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
649*89c4ff92SAndroid Build Coastguard Worker             return support.IsUnidirectionalSequenceLstmSupported(infos[0],
650*89c4ff92SAndroid Build Coastguard Worker                                                                  infos[1],
651*89c4ff92SAndroid Build Coastguard Worker                                                                  infos[2],
652*89c4ff92SAndroid Build Coastguard Worker                                                                  infos[3],
653*89c4ff92SAndroid Build Coastguard Worker                                                                  infos[4],
654*89c4ff92SAndroid Build Coastguard Worker                                                                  infos[5],
655*89c4ff92SAndroid Build Coastguard Worker                                                                  desc,
656*89c4ff92SAndroid Build Coastguard Worker                                                                  lstmParamsInfo.value(),
657*89c4ff92SAndroid Build Coastguard Worker                                                                  reasonIfUnsupported);
658*89c4ff92SAndroid Build Coastguard Worker         }
659*89c4ff92SAndroid Build Coastguard Worker         case LayerType::Unmap:
660*89c4ff92SAndroid Build Coastguard Worker             return true;
661*89c4ff92SAndroid Build Coastguard Worker         default:
662*89c4ff92SAndroid Build Coastguard Worker             // layers not supported in neon by default:
663*89c4ff92SAndroid Build Coastguard Worker             // debug, fakequantization, precompiled,
664*89c4ff92SAndroid Build Coastguard Worker             // standin, switch
665*89c4ff92SAndroid Build Coastguard Worker             return false;
666*89c4ff92SAndroid Build Coastguard Worker     }
667*89c4ff92SAndroid Build Coastguard Worker }
668*89c4ff92SAndroid Build Coastguard Worker 
IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> & lstmParamsInfo,const Optional<QuantizedLstmInputParamsInfo> & quantizedLstmParamsInfo,Optional<std::string &> reasonIfUnsupported) const669*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsLayerSupported(const LayerType& type,
670*89c4ff92SAndroid Build Coastguard Worker                                         const std::vector<TensorInfo>& infos,
671*89c4ff92SAndroid Build Coastguard Worker                                         const BaseDescriptor& descriptor,
672*89c4ff92SAndroid Build Coastguard Worker                                         const Optional<LstmInputParamsInfo>& lstmParamsInfo,
673*89c4ff92SAndroid Build Coastguard Worker                                         const Optional<QuantizedLstmInputParamsInfo>& quantizedLstmParamsInfo,
674*89c4ff92SAndroid Build Coastguard Worker                                         Optional<std::string&> reasonIfUnsupported) const
675*89c4ff92SAndroid Build Coastguard Worker {
676*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = IsLayerTypeSupported(type,
677*89c4ff92SAndroid Build Coastguard Worker                                             infos,
678*89c4ff92SAndroid Build Coastguard Worker                                             descriptor,
679*89c4ff92SAndroid Build Coastguard Worker                                             lstmParamsInfo,
680*89c4ff92SAndroid Build Coastguard Worker                                             quantizedLstmParamsInfo,
681*89c4ff92SAndroid Build Coastguard Worker                                             reasonIfUnsupported,
682*89c4ff92SAndroid Build Coastguard Worker                                             *this);
683*89c4ff92SAndroid Build Coastguard Worker 
684*89c4ff92SAndroid Build Coastguard Worker     // For android-nn-driver and support library, to run FP16 operations on CpuAcc we need at least v8.2
685*89c4ff92SAndroid Build Coastguard Worker     // architecture. If the available architecture is older than v8.2, we can check if the operator is
686*89c4ff92SAndroid Build Coastguard Worker     // supported by changing operator inputs & outputs to be FP32.
687*89c4ff92SAndroid Build Coastguard Worker     // This does not change the operator datatype in the above parsers to be FP32. We are simply reporting
688*89c4ff92SAndroid Build Coastguard Worker     // to the parsers if the operator can supported in ArmNN. We will then re-enter ArmNN (Network.cpp)
689*89c4ff92SAndroid Build Coastguard Worker     // where we will recheck IsLayerSupported() on the FP16 datatype, update the operator to be FP32,
690*89c4ff92SAndroid Build Coastguard Worker     // and, insert convert layers around the FP32 operator.
691*89c4ff92SAndroid Build Coastguard Worker     if (reasonIfUnsupported.has_value())
692*89c4ff92SAndroid Build Coastguard Worker     {
693*89c4ff92SAndroid Build Coastguard Worker         std::string checkStr = "This CPU architecture does not support F16 data type, you need v8.2 or above";
694*89c4ff92SAndroid Build Coastguard Worker         if (!isSupported
695*89c4ff92SAndroid Build Coastguard Worker             && reasonIfUnsupported.value().find(checkStr) != std::string::npos)
696*89c4ff92SAndroid Build Coastguard Worker         {
697*89c4ff92SAndroid Build Coastguard Worker             std::vector<TensorInfo> newInfos;
698*89c4ff92SAndroid Build Coastguard Worker             for (auto               info: infos)
699*89c4ff92SAndroid Build Coastguard Worker             {
700*89c4ff92SAndroid Build Coastguard Worker                 newInfos.emplace_back(OverrideDataType(info, DataType::Float32));
701*89c4ff92SAndroid Build Coastguard Worker             }
702*89c4ff92SAndroid Build Coastguard Worker 
703*89c4ff92SAndroid Build Coastguard Worker             std::string tmpString;
704*89c4ff92SAndroid Build Coastguard Worker             return IsLayerTypeSupported(type,
705*89c4ff92SAndroid Build Coastguard Worker                                         newInfos,
706*89c4ff92SAndroid Build Coastguard Worker                                         descriptor,
707*89c4ff92SAndroid Build Coastguard Worker                                         lstmParamsInfo,
708*89c4ff92SAndroid Build Coastguard Worker                                         quantizedLstmParamsInfo,
709*89c4ff92SAndroid Build Coastguard Worker                                         tmpString,
710*89c4ff92SAndroid Build Coastguard Worker                                         *this);
711*89c4ff92SAndroid Build Coastguard Worker         }
712*89c4ff92SAndroid Build Coastguard Worker     }
713*89c4ff92SAndroid Build Coastguard Worker 
714*89c4ff92SAndroid Build Coastguard Worker     return isSupported;
715*89c4ff92SAndroid Build Coastguard Worker }
716*89c4ff92SAndroid Build Coastguard Worker 
IsActivationSupported(const TensorInfo & input,const TensorInfo & output,const ActivationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const717*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsActivationSupported(const TensorInfo& input,
718*89c4ff92SAndroid Build Coastguard Worker                                              const TensorInfo& output,
719*89c4ff92SAndroid Build Coastguard Worker                                              const ActivationDescriptor& descriptor,
720*89c4ff92SAndroid Build Coastguard Worker                                              Optional<std::string&> reasonIfUnsupported) const
721*89c4ff92SAndroid Build Coastguard Worker {
722*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
723*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonActivationWorkloadValidate,
724*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
725*89c4ff92SAndroid Build Coastguard Worker                                    input,
726*89c4ff92SAndroid Build Coastguard Worker                                    output,
727*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
728*89c4ff92SAndroid Build Coastguard Worker }
729*89c4ff92SAndroid Build Coastguard Worker 
IsAdditionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const730*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsAdditionSupported(const TensorInfo& input0,
731*89c4ff92SAndroid Build Coastguard Worker                                            const TensorInfo& input1,
732*89c4ff92SAndroid Build Coastguard Worker                                            const TensorInfo& output,
733*89c4ff92SAndroid Build Coastguard Worker                                            Optional<std::string&> reasonIfUnsupported) const
734*89c4ff92SAndroid Build Coastguard Worker {
735*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonAdditionWorkloadValidate,
736*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
737*89c4ff92SAndroid Build Coastguard Worker                                    input0,
738*89c4ff92SAndroid Build Coastguard Worker                                    input1,
739*89c4ff92SAndroid Build Coastguard Worker                                    output,
740*89c4ff92SAndroid Build Coastguard Worker                                    nullptr);
741*89c4ff92SAndroid Build Coastguard Worker }
742*89c4ff92SAndroid Build Coastguard Worker 
IsArgMinMaxSupported(const TensorInfo & input,const TensorInfo & output,const ArgMinMaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const743*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsArgMinMaxSupported(const TensorInfo& input,
744*89c4ff92SAndroid Build Coastguard Worker                                             const TensorInfo& output,
745*89c4ff92SAndroid Build Coastguard Worker                                             const ArgMinMaxDescriptor& descriptor,
746*89c4ff92SAndroid Build Coastguard Worker                                             Optional<std::string&> reasonIfUnsupported) const
747*89c4ff92SAndroid Build Coastguard Worker {
748*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonArgMinMaxWorkloadValidate,
749*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
750*89c4ff92SAndroid Build Coastguard Worker                                    input,
751*89c4ff92SAndroid Build Coastguard Worker                                    output,
752*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
753*89c4ff92SAndroid Build Coastguard Worker }
754*89c4ff92SAndroid Build Coastguard Worker 
IsBatchMatMulSupported(const TensorInfo & inputX,const TensorInfo & inputY,const TensorInfo & output,const BatchMatMulDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const755*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
756*89c4ff92SAndroid Build Coastguard Worker                                               const TensorInfo& inputY,
757*89c4ff92SAndroid Build Coastguard Worker                                               const TensorInfo& output,
758*89c4ff92SAndroid Build Coastguard Worker                                               const BatchMatMulDescriptor& descriptor,
759*89c4ff92SAndroid Build Coastguard Worker                                               Optional<std::string&> reasonIfUnsupported) const
760*89c4ff92SAndroid Build Coastguard Worker {
761*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonBatchMatMulValidate,
762*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
763*89c4ff92SAndroid Build Coastguard Worker                                    inputX,
764*89c4ff92SAndroid Build Coastguard Worker                                    inputY,
765*89c4ff92SAndroid Build Coastguard Worker                                    output,
766*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
767*89c4ff92SAndroid Build Coastguard Worker }
768*89c4ff92SAndroid Build Coastguard Worker 
IsBatchNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & mean,const TensorInfo & var,const TensorInfo & beta,const TensorInfo & gamma,const BatchNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const769*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
770*89c4ff92SAndroid Build Coastguard Worker                                                      const TensorInfo& output,
771*89c4ff92SAndroid Build Coastguard Worker                                                      const TensorInfo& mean,
772*89c4ff92SAndroid Build Coastguard Worker                                                      const TensorInfo& var,
773*89c4ff92SAndroid Build Coastguard Worker                                                      const TensorInfo& beta,
774*89c4ff92SAndroid Build Coastguard Worker                                                      const TensorInfo& gamma,
775*89c4ff92SAndroid Build Coastguard Worker                                                      const BatchNormalizationDescriptor& descriptor,
776*89c4ff92SAndroid Build Coastguard Worker                                                      Optional<std::string&> reasonIfUnsupported) const
777*89c4ff92SAndroid Build Coastguard Worker {
778*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonBatchNormalizationValidate,
779*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
780*89c4ff92SAndroid Build Coastguard Worker                                    input,
781*89c4ff92SAndroid Build Coastguard Worker                                    output,
782*89c4ff92SAndroid Build Coastguard Worker                                    mean,
783*89c4ff92SAndroid Build Coastguard Worker                                    var,
784*89c4ff92SAndroid Build Coastguard Worker                                    beta,
785*89c4ff92SAndroid Build Coastguard Worker                                    gamma,
786*89c4ff92SAndroid Build Coastguard Worker                                    descriptor,
787*89c4ff92SAndroid Build Coastguard Worker                                    nullptr);
788*89c4ff92SAndroid Build Coastguard Worker }
789*89c4ff92SAndroid Build Coastguard Worker 
IsBatchToSpaceNdSupported(const TensorInfo & input,const TensorInfo & output,const BatchToSpaceNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const790*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
791*89c4ff92SAndroid Build Coastguard Worker                                                  const TensorInfo& output,
792*89c4ff92SAndroid Build Coastguard Worker                                                  const BatchToSpaceNdDescriptor& descriptor,
793*89c4ff92SAndroid Build Coastguard Worker                                                  Optional<std::string&> reasonIfUnsupported) const
794*89c4ff92SAndroid Build Coastguard Worker {
795*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonBatchToSpaceNdWorkloadValidate,
796*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
797*89c4ff92SAndroid Build Coastguard Worker                                    input,
798*89c4ff92SAndroid Build Coastguard Worker                                    output,
799*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
800*89c4ff92SAndroid Build Coastguard Worker }
801*89c4ff92SAndroid Build Coastguard Worker 
IsCastSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const802*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsCastSupported(const TensorInfo& input,
803*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& output,
804*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
805*89c4ff92SAndroid Build Coastguard Worker {
806*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonCastValidate,
807*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
808*89c4ff92SAndroid Build Coastguard Worker                                    input,
809*89c4ff92SAndroid Build Coastguard Worker                                    output);
810*89c4ff92SAndroid Build Coastguard Worker }
811*89c4ff92SAndroid Build Coastguard Worker 
IsChannelShuffleSupported(const TensorInfo & input,const TensorInfo & output,const ChannelShuffleDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const812*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsChannelShuffleSupported(const TensorInfo& input,
813*89c4ff92SAndroid Build Coastguard Worker                                                  const TensorInfo& output,
814*89c4ff92SAndroid Build Coastguard Worker                                                  const ChannelShuffleDescriptor& descriptor,
815*89c4ff92SAndroid Build Coastguard Worker                                                  Optional<std::string&> reasonIfUnsupported) const
816*89c4ff92SAndroid Build Coastguard Worker {
817*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonChannelShuffleValidate,
818*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
819*89c4ff92SAndroid Build Coastguard Worker                                    input,
820*89c4ff92SAndroid Build Coastguard Worker                                    output,
821*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
822*89c4ff92SAndroid Build Coastguard Worker }
823*89c4ff92SAndroid Build Coastguard Worker 
IsComparisonSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const ComparisonDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const824*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsComparisonSupported(const TensorInfo& input0,
825*89c4ff92SAndroid Build Coastguard Worker                                              const TensorInfo& input1,
826*89c4ff92SAndroid Build Coastguard Worker                                              const TensorInfo& output,
827*89c4ff92SAndroid Build Coastguard Worker                                              const ComparisonDescriptor& descriptor,
828*89c4ff92SAndroid Build Coastguard Worker                                              Optional<std::string&> reasonIfUnsupported) const
829*89c4ff92SAndroid Build Coastguard Worker {
830*89c4ff92SAndroid Build Coastguard Worker 
831*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonComparisonWorkloadValidate,
832*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
833*89c4ff92SAndroid Build Coastguard Worker                                    input0,
834*89c4ff92SAndroid Build Coastguard Worker                                    input1,
835*89c4ff92SAndroid Build Coastguard Worker                                    output,
836*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
837*89c4ff92SAndroid Build Coastguard Worker }
838*89c4ff92SAndroid Build Coastguard Worker 
IsConcatSupported(const std::vector<const TensorInfo * > inputs,const TensorInfo & output,const OriginsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const839*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
840*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& output,
841*89c4ff92SAndroid Build Coastguard Worker                                          const OriginsDescriptor& descriptor,
842*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::string&> reasonIfUnsupported) const
843*89c4ff92SAndroid Build Coastguard Worker {
844*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.GetNumDimensions() <= descriptor.GetConcatAxis())
845*89c4ff92SAndroid Build Coastguard Worker     {
846*89c4ff92SAndroid Build Coastguard Worker         SetValueChecked(reasonIfUnsupported, "Neon Concat: Concat axis > Number of dimensions.");
847*89c4ff92SAndroid Build Coastguard Worker         return false;
848*89c4ff92SAndroid Build Coastguard Worker     }
849*89c4ff92SAndroid Build Coastguard Worker 
850*89c4ff92SAndroid Build Coastguard Worker     unsigned int concatInnerAxis = (descriptor.GetNumDimensions() - descriptor.GetConcatAxis()) - 1;
851*89c4ff92SAndroid Build Coastguard Worker     if(concatInnerAxis < 3) // Width, height, or channels
852*89c4ff92SAndroid Build Coastguard Worker     {
853*89c4ff92SAndroid Build Coastguard Worker         FORWARD_WORKLOAD_VALIDATE_FUNC(NeonConcatWorkloadValidate,
854*89c4ff92SAndroid Build Coastguard Worker                                        reasonIfUnsupported,
855*89c4ff92SAndroid Build Coastguard Worker                                        inputs,
856*89c4ff92SAndroid Build Coastguard Worker                                        output,
857*89c4ff92SAndroid Build Coastguard Worker                                        descriptor);
858*89c4ff92SAndroid Build Coastguard Worker     }
859*89c4ff92SAndroid Build Coastguard Worker     else if (concatInnerAxis == 3)
860*89c4ff92SAndroid Build Coastguard Worker     {
861*89c4ff92SAndroid Build Coastguard Worker         for (auto& input : inputs)
862*89c4ff92SAndroid Build Coastguard Worker         {
863*89c4ff92SAndroid Build Coastguard Worker             if (input && !output.IsTypeSpaceMatch(*input)) // Cannot use sub-tensors if the types are not same space
864*89c4ff92SAndroid Build Coastguard Worker             {
865*89c4ff92SAndroid Build Coastguard Worker                 SetValueChecked(reasonIfUnsupported, "Neon Concat: Types and quantization parameters must match.");
866*89c4ff92SAndroid Build Coastguard Worker                 return false;
867*89c4ff92SAndroid Build Coastguard Worker             }
868*89c4ff92SAndroid Build Coastguard Worker         }
869*89c4ff92SAndroid Build Coastguard Worker         return true; // Sub-tensors support concat along batch
870*89c4ff92SAndroid Build Coastguard Worker     }
871*89c4ff92SAndroid Build Coastguard Worker     else // > 4 dimensions not supported.
872*89c4ff92SAndroid Build Coastguard Worker     {
873*89c4ff92SAndroid Build Coastguard Worker         SetValueChecked(reasonIfUnsupported, "Neon Concat: Maximum of 4 dimensions supported.");
874*89c4ff92SAndroid Build Coastguard Worker         return false;
875*89c4ff92SAndroid Build Coastguard Worker     }
876*89c4ff92SAndroid Build Coastguard Worker }
877*89c4ff92SAndroid Build Coastguard Worker 
IsConstantSupported(const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const878*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsConstantSupported(const TensorInfo& output,
879*89c4ff92SAndroid Build Coastguard Worker                                            Optional<std::string&> reasonIfUnsupported) const
880*89c4ff92SAndroid Build Coastguard Worker {
881*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonConstantWorkloadValidate,
882*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
883*89c4ff92SAndroid Build Coastguard Worker                                    output);
884*89c4ff92SAndroid Build Coastguard Worker }
885*89c4ff92SAndroid Build Coastguard Worker 
IsConvertFp16ToFp32Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const886*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsConvertFp16ToFp32Supported(const TensorInfo& input,
887*89c4ff92SAndroid Build Coastguard Worker                                                     const TensorInfo& output,
888*89c4ff92SAndroid Build Coastguard Worker                                                     Optional<std::string&> reasonIfUnsupported) const
889*89c4ff92SAndroid Build Coastguard Worker {
890*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(input);
891*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(output);
892*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(reasonIfUnsupported);
893*89c4ff92SAndroid Build Coastguard Worker     return true;
894*89c4ff92SAndroid Build Coastguard Worker }
895*89c4ff92SAndroid Build Coastguard Worker 
IsConvertFp32ToFp16Supported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const896*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsConvertFp32ToFp16Supported(const TensorInfo& input,
897*89c4ff92SAndroid Build Coastguard Worker                                                     const TensorInfo& output,
898*89c4ff92SAndroid Build Coastguard Worker                                                     Optional<std::string&> reasonIfUnsupported) const
899*89c4ff92SAndroid Build Coastguard Worker {
900*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(input);
901*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(output);
902*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(reasonIfUnsupported);
903*89c4ff92SAndroid Build Coastguard Worker     return true;
904*89c4ff92SAndroid Build Coastguard Worker }
905*89c4ff92SAndroid Build Coastguard Worker 
IsConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const906*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
907*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& output,
908*89c4ff92SAndroid Build Coastguard Worker                                                 const Convolution2dDescriptor& descriptor,
909*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& weights,
910*89c4ff92SAndroid Build Coastguard Worker                                                 const Optional<TensorInfo>& biases,
911*89c4ff92SAndroid Build Coastguard Worker                                                 Optional<std::string&> reasonIfUnsupported) const
912*89c4ff92SAndroid Build Coastguard Worker {
913*89c4ff92SAndroid Build Coastguard Worker     bool isFastMathEnabled = false;
914*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
915*89c4ff92SAndroid Build Coastguard Worker     if (m_ModelContextPtr)
916*89c4ff92SAndroid Build Coastguard Worker     {
917*89c4ff92SAndroid Build Coastguard Worker         if (m_ModelContextPtr.get() != nullptr)
918*89c4ff92SAndroid Build Coastguard Worker         {
919*89c4ff92SAndroid Build Coastguard Worker             auto modelOptions = dynamic_cast<NeonBackendModelContext*>(m_ModelContextPtr.get());
920*89c4ff92SAndroid Build Coastguard Worker             if (modelOptions)
921*89c4ff92SAndroid Build Coastguard Worker             {
922*89c4ff92SAndroid Build Coastguard Worker                 isFastMathEnabled = modelOptions->IsFastMathEnabled();
923*89c4ff92SAndroid Build Coastguard Worker             }
924*89c4ff92SAndroid Build Coastguard Worker         }
925*89c4ff92SAndroid Build Coastguard Worker     }
926*89c4ff92SAndroid Build Coastguard Worker #endif
927*89c4ff92SAndroid Build Coastguard Worker 
928*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonConvolution2dWorkloadValidate,
929*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
930*89c4ff92SAndroid Build Coastguard Worker                                    input,
931*89c4ff92SAndroid Build Coastguard Worker                                    output,
932*89c4ff92SAndroid Build Coastguard Worker                                    descriptor,
933*89c4ff92SAndroid Build Coastguard Worker                                    weights,
934*89c4ff92SAndroid Build Coastguard Worker                                    biases,
935*89c4ff92SAndroid Build Coastguard Worker                                    isFastMathEnabled,
936*89c4ff92SAndroid Build Coastguard Worker                                    nullptr);
937*89c4ff92SAndroid Build Coastguard Worker }
938*89c4ff92SAndroid Build Coastguard Worker 
IsConvolution3dSupported(const TensorInfo & input,const TensorInfo & output,const Convolution3dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const939*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsConvolution3dSupported(const TensorInfo& input,
940*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& output,
941*89c4ff92SAndroid Build Coastguard Worker                                                 const Convolution3dDescriptor& descriptor,
942*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& weights,
943*89c4ff92SAndroid Build Coastguard Worker                                                 const Optional<TensorInfo>& biases,
944*89c4ff92SAndroid Build Coastguard Worker                                                 Optional<std::string&> reasonIfUnsupported) const
945*89c4ff92SAndroid Build Coastguard Worker {
946*89c4ff92SAndroid Build Coastguard Worker     bool isFastMathEnabled = false;
947*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
948*89c4ff92SAndroid Build Coastguard Worker     if (m_ModelContextPtr)
949*89c4ff92SAndroid Build Coastguard Worker     {
950*89c4ff92SAndroid Build Coastguard Worker         if (m_ModelContextPtr.get() != nullptr)
951*89c4ff92SAndroid Build Coastguard Worker         {
952*89c4ff92SAndroid Build Coastguard Worker             auto modelOptions = dynamic_cast<NeonBackendModelContext*>(m_ModelContextPtr.get());
953*89c4ff92SAndroid Build Coastguard Worker             if (modelOptions)
954*89c4ff92SAndroid Build Coastguard Worker             {
955*89c4ff92SAndroid Build Coastguard Worker                 isFastMathEnabled = modelOptions->IsFastMathEnabled();
956*89c4ff92SAndroid Build Coastguard Worker             }
957*89c4ff92SAndroid Build Coastguard Worker         }
958*89c4ff92SAndroid Build Coastguard Worker     }
959*89c4ff92SAndroid Build Coastguard Worker #endif
960*89c4ff92SAndroid Build Coastguard Worker 
961*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonConvolution3dWorkloadValidate,
962*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
963*89c4ff92SAndroid Build Coastguard Worker                                    input,
964*89c4ff92SAndroid Build Coastguard Worker                                    output,
965*89c4ff92SAndroid Build Coastguard Worker                                    descriptor,
966*89c4ff92SAndroid Build Coastguard Worker                                    weights,
967*89c4ff92SAndroid Build Coastguard Worker                                    biases,
968*89c4ff92SAndroid Build Coastguard Worker                                    isFastMathEnabled,
969*89c4ff92SAndroid Build Coastguard Worker                                    nullptr);
970*89c4ff92SAndroid Build Coastguard Worker }
971*89c4ff92SAndroid Build Coastguard Worker 
IsDepthToSpaceSupported(const TensorInfo & input,const TensorInfo & output,const DepthToSpaceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const972*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsDepthToSpaceSupported(const TensorInfo& input,
973*89c4ff92SAndroid Build Coastguard Worker                                                const TensorInfo& output,
974*89c4ff92SAndroid Build Coastguard Worker                                                const DepthToSpaceDescriptor& descriptor,
975*89c4ff92SAndroid Build Coastguard Worker                                                Optional<std::string&> reasonIfUnsupported) const
976*89c4ff92SAndroid Build Coastguard Worker {
977*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonDepthToSpaceWorkloadValidate,
978*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
979*89c4ff92SAndroid Build Coastguard Worker                                    input,
980*89c4ff92SAndroid Build Coastguard Worker                                    output,
981*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
982*89c4ff92SAndroid Build Coastguard Worker }
983*89c4ff92SAndroid Build Coastguard Worker 
IsDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const984*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsDepthwiseConvolutionSupported(const TensorInfo& input,
985*89c4ff92SAndroid Build Coastguard Worker                                                        const TensorInfo& output,
986*89c4ff92SAndroid Build Coastguard Worker                                                        const DepthwiseConvolution2dDescriptor& descriptor,
987*89c4ff92SAndroid Build Coastguard Worker                                                        const TensorInfo& weights,
988*89c4ff92SAndroid Build Coastguard Worker                                                        const Optional<TensorInfo>& biases,
989*89c4ff92SAndroid Build Coastguard Worker                                                        Optional<std::string&> reasonIfUnsupported) const
990*89c4ff92SAndroid Build Coastguard Worker {
991*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonDepthwiseConvolutionWorkloadValidate,
992*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
993*89c4ff92SAndroid Build Coastguard Worker                                    input,
994*89c4ff92SAndroid Build Coastguard Worker                                    output,
995*89c4ff92SAndroid Build Coastguard Worker                                    descriptor,
996*89c4ff92SAndroid Build Coastguard Worker                                    weights,
997*89c4ff92SAndroid Build Coastguard Worker                                    biases,
998*89c4ff92SAndroid Build Coastguard Worker                                    nullptr);
999*89c4ff92SAndroid Build Coastguard Worker }
1000*89c4ff92SAndroid Build Coastguard Worker 
IsDequantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1001*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsDequantizeSupported(const TensorInfo& input,
1002*89c4ff92SAndroid Build Coastguard Worker                                              const TensorInfo& output,
1003*89c4ff92SAndroid Build Coastguard Worker                                              Optional<std::string&> reasonIfUnsupported) const
1004*89c4ff92SAndroid Build Coastguard Worker {
1005*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonDequantizeWorkloadValidate,
1006*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1007*89c4ff92SAndroid Build Coastguard Worker                                    input,
1008*89c4ff92SAndroid Build Coastguard Worker                                    output);
1009*89c4ff92SAndroid Build Coastguard Worker }
1010*89c4ff92SAndroid Build Coastguard Worker 
IsDilatedDepthwiseConvolutionSupported(const TensorInfo & input,const TensorInfo & output,const DepthwiseConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1011*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsDilatedDepthwiseConvolutionSupported(const TensorInfo& input,
1012*89c4ff92SAndroid Build Coastguard Worker                                                               const TensorInfo& output,
1013*89c4ff92SAndroid Build Coastguard Worker                                                               const DepthwiseConvolution2dDescriptor& descriptor,
1014*89c4ff92SAndroid Build Coastguard Worker                                                               const TensorInfo& weights,
1015*89c4ff92SAndroid Build Coastguard Worker                                                               const Optional<TensorInfo>& biases,
1016*89c4ff92SAndroid Build Coastguard Worker                                                               Optional<std::string&> reasonIfUnsupported) const
1017*89c4ff92SAndroid Build Coastguard Worker {
1018*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonDepthwiseConvolutionWorkloadValidate,
1019*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1020*89c4ff92SAndroid Build Coastguard Worker                                    input,
1021*89c4ff92SAndroid Build Coastguard Worker                                    output,
1022*89c4ff92SAndroid Build Coastguard Worker                                    descriptor,
1023*89c4ff92SAndroid Build Coastguard Worker                                    weights,
1024*89c4ff92SAndroid Build Coastguard Worker                                    biases,
1025*89c4ff92SAndroid Build Coastguard Worker                                    nullptr);
1026*89c4ff92SAndroid Build Coastguard Worker }
1027*89c4ff92SAndroid Build Coastguard Worker 
IsElementwiseUnarySupported(const TensorInfo & input,const TensorInfo & output,const ElementwiseUnaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1028*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsElementwiseUnarySupported(const TensorInfo& input,
1029*89c4ff92SAndroid Build Coastguard Worker                                                    const TensorInfo& output,
1030*89c4ff92SAndroid Build Coastguard Worker                                                    const ElementwiseUnaryDescriptor& descriptor,
1031*89c4ff92SAndroid Build Coastguard Worker                                                    Optional<std::string&> reasonIfUnsupported) const
1032*89c4ff92SAndroid Build Coastguard Worker {
1033*89c4ff92SAndroid Build Coastguard Worker     switch(descriptor.m_Operation)
1034*89c4ff92SAndroid Build Coastguard Worker     {
1035*89c4ff92SAndroid Build Coastguard Worker         case UnaryOperation::Abs:
1036*89c4ff92SAndroid Build Coastguard Worker             FORWARD_WORKLOAD_VALIDATE_FUNC(NeonAbsWorkloadValidate,
1037*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported,
1038*89c4ff92SAndroid Build Coastguard Worker                                            input,
1039*89c4ff92SAndroid Build Coastguard Worker                                            output);
1040*89c4ff92SAndroid Build Coastguard Worker         case UnaryOperation::Exp:
1041*89c4ff92SAndroid Build Coastguard Worker             FORWARD_WORKLOAD_VALIDATE_FUNC(NeonExpWorkloadValidate,
1042*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported,
1043*89c4ff92SAndroid Build Coastguard Worker                                            input,
1044*89c4ff92SAndroid Build Coastguard Worker                                            output);
1045*89c4ff92SAndroid Build Coastguard Worker         case UnaryOperation::LogicalNot:
1046*89c4ff92SAndroid Build Coastguard Worker             FORWARD_WORKLOAD_VALIDATE_FUNC(NeonLogicalNotWorkloadValidate,
1047*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported,
1048*89c4ff92SAndroid Build Coastguard Worker                                            input,
1049*89c4ff92SAndroid Build Coastguard Worker                                            output);
1050*89c4ff92SAndroid Build Coastguard Worker        case UnaryOperation::Log:
1051*89c4ff92SAndroid Build Coastguard Worker             FORWARD_WORKLOAD_VALIDATE_FUNC(NeonLogWorkloadValidate,
1052*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported,
1053*89c4ff92SAndroid Build Coastguard Worker                                            input,
1054*89c4ff92SAndroid Build Coastguard Worker                                            output);
1055*89c4ff92SAndroid Build Coastguard Worker         case UnaryOperation::Neg:
1056*89c4ff92SAndroid Build Coastguard Worker             FORWARD_WORKLOAD_VALIDATE_FUNC(NeonNegWorkloadValidate,
1057*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported,
1058*89c4ff92SAndroid Build Coastguard Worker                                            input,
1059*89c4ff92SAndroid Build Coastguard Worker                                            output);
1060*89c4ff92SAndroid Build Coastguard Worker         case UnaryOperation::Rsqrt:
1061*89c4ff92SAndroid Build Coastguard Worker             FORWARD_WORKLOAD_VALIDATE_FUNC(NeonRsqrtWorkloadValidate,
1062*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported,
1063*89c4ff92SAndroid Build Coastguard Worker                                            input,
1064*89c4ff92SAndroid Build Coastguard Worker                                            output);
1065*89c4ff92SAndroid Build Coastguard Worker         case UnaryOperation::Sin:
1066*89c4ff92SAndroid Build Coastguard Worker             FORWARD_WORKLOAD_VALIDATE_FUNC(NeonSinWorkloadValidate,
1067*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported,
1068*89c4ff92SAndroid Build Coastguard Worker                                            input,
1069*89c4ff92SAndroid Build Coastguard Worker                                            output);
1070*89c4ff92SAndroid Build Coastguard Worker         case UnaryOperation::Sqrt:
1071*89c4ff92SAndroid Build Coastguard Worker             FORWARD_WORKLOAD_VALIDATE_FUNC(NeonSqrtWorkloadValidate,
1072*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported,
1073*89c4ff92SAndroid Build Coastguard Worker                                            input,
1074*89c4ff92SAndroid Build Coastguard Worker                                            output);
1075*89c4ff92SAndroid Build Coastguard Worker         default:
1076*89c4ff92SAndroid Build Coastguard Worker             return false;
1077*89c4ff92SAndroid Build Coastguard Worker     }
1078*89c4ff92SAndroid Build Coastguard Worker }
1079*89c4ff92SAndroid Build Coastguard Worker 
IsFillSupported(const TensorInfo & input,const TensorInfo & output,const FillDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1080*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsFillSupported(const TensorInfo& input,
1081*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& output,
1082*89c4ff92SAndroid Build Coastguard Worker                                        const FillDescriptor& descriptor,
1083*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
1084*89c4ff92SAndroid Build Coastguard Worker {
1085*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(input);
1086*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(output);
1087*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(descriptor);
1088*89c4ff92SAndroid Build Coastguard Worker 
1089*89c4ff92SAndroid Build Coastguard Worker     return IsNeonBackendSupported(reasonIfUnsupported);
1090*89c4ff92SAndroid Build Coastguard Worker }
1091*89c4ff92SAndroid Build Coastguard Worker 
IsFloorSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1092*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsFloorSupported(const TensorInfo& input,
1093*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& output,
1094*89c4ff92SAndroid Build Coastguard Worker                                         Optional<std::string&> reasonIfUnsupported) const
1095*89c4ff92SAndroid Build Coastguard Worker {
1096*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(output);
1097*89c4ff92SAndroid Build Coastguard Worker     return IsNeonBackendSupported(reasonIfUnsupported) &&
1098*89c4ff92SAndroid Build Coastguard Worker            IsSupportedForDataTypeGeneric(reasonIfUnsupported,
1099*89c4ff92SAndroid Build Coastguard Worker                                          input.GetDataType(),
1100*89c4ff92SAndroid Build Coastguard Worker                                          &FalseFuncF16<>,
1101*89c4ff92SAndroid Build Coastguard Worker                                          &TrueFunc<>,
1102*89c4ff92SAndroid Build Coastguard Worker                                          &FalseFuncU8<>,
1103*89c4ff92SAndroid Build Coastguard Worker                                          &FalseFuncI32<>,
1104*89c4ff92SAndroid Build Coastguard Worker                                          &FalseFuncU8<>);
1105*89c4ff92SAndroid Build Coastguard Worker }
1106*89c4ff92SAndroid Build Coastguard Worker 
IsFullyConnectedSupported(const TensorInfo & input,const TensorInfo & output,const TensorInfo & weights,const TensorInfo & biases,const FullyConnectedDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1107*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
1108*89c4ff92SAndroid Build Coastguard Worker                                                  const TensorInfo& output,
1109*89c4ff92SAndroid Build Coastguard Worker                                                  const TensorInfo& weights,
1110*89c4ff92SAndroid Build Coastguard Worker                                                  const TensorInfo& biases,
1111*89c4ff92SAndroid Build Coastguard Worker                                                  const FullyConnectedDescriptor& descriptor,
1112*89c4ff92SAndroid Build Coastguard Worker                                                  Optional<std::string&> reasonIfUnsupported) const
1113*89c4ff92SAndroid Build Coastguard Worker {
1114*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonFullyConnectedWorkloadValidate,
1115*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1116*89c4ff92SAndroid Build Coastguard Worker                                    input,
1117*89c4ff92SAndroid Build Coastguard Worker                                    output,
1118*89c4ff92SAndroid Build Coastguard Worker                                    weights,
1119*89c4ff92SAndroid Build Coastguard Worker                                    biases,
1120*89c4ff92SAndroid Build Coastguard Worker                                    descriptor,
1121*89c4ff92SAndroid Build Coastguard Worker                                    nullptr);
1122*89c4ff92SAndroid Build Coastguard Worker }
1123*89c4ff92SAndroid Build Coastguard Worker 
IsGatherSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const GatherDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1124*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsGatherSupported(const TensorInfo& input0,
1125*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& input1,
1126*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& output,
1127*89c4ff92SAndroid Build Coastguard Worker                                          const GatherDescriptor& descriptor,
1128*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::string&> reasonIfUnsupported) const
1129*89c4ff92SAndroid Build Coastguard Worker {
1130*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonGatherWorkloadValidate,
1131*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1132*89c4ff92SAndroid Build Coastguard Worker                                    input0,
1133*89c4ff92SAndroid Build Coastguard Worker                                    input1,
1134*89c4ff92SAndroid Build Coastguard Worker                                    output,
1135*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1136*89c4ff92SAndroid Build Coastguard Worker }
1137*89c4ff92SAndroid Build Coastguard Worker 
IsGatherNdSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1138*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsGatherNdSupported(const TensorInfo& input0,
1139*89c4ff92SAndroid Build Coastguard Worker                                            const TensorInfo& input1,
1140*89c4ff92SAndroid Build Coastguard Worker                                            const TensorInfo& output,
1141*89c4ff92SAndroid Build Coastguard Worker                                            Optional<std::string&> reasonIfUnsupported) const
1142*89c4ff92SAndroid Build Coastguard Worker {
1143*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonGatherNdWorkloadValidate,
1144*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1145*89c4ff92SAndroid Build Coastguard Worker                                    input0,
1146*89c4ff92SAndroid Build Coastguard Worker                                    input1,
1147*89c4ff92SAndroid Build Coastguard Worker                                    output);
1148*89c4ff92SAndroid Build Coastguard Worker }
1149*89c4ff92SAndroid Build Coastguard Worker 
IsInputSupported(const TensorInfo & input,Optional<std::string &> reasonIfUnsupported) const1150*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsInputSupported(const TensorInfo& input,
1151*89c4ff92SAndroid Build Coastguard Worker                                         Optional<std::string&> reasonIfUnsupported) const
1152*89c4ff92SAndroid Build Coastguard Worker {
1153*89c4ff92SAndroid Build Coastguard Worker     return IsNeonBackendSupported(reasonIfUnsupported, input);
1154*89c4ff92SAndroid Build Coastguard Worker }
1155*89c4ff92SAndroid Build Coastguard Worker 
IsInstanceNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const InstanceNormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1156*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
1157*89c4ff92SAndroid Build Coastguard Worker                                                         const TensorInfo& output,
1158*89c4ff92SAndroid Build Coastguard Worker                                                         const InstanceNormalizationDescriptor& descriptor,
1159*89c4ff92SAndroid Build Coastguard Worker                                                         Optional<std::string&> reasonIfUnsupported) const
1160*89c4ff92SAndroid Build Coastguard Worker {
1161*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonInstanceNormalizationWorkloadValidate,
1162*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1163*89c4ff92SAndroid Build Coastguard Worker                                    input,
1164*89c4ff92SAndroid Build Coastguard Worker                                    output,
1165*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1166*89c4ff92SAndroid Build Coastguard Worker }
1167*89c4ff92SAndroid Build Coastguard Worker 
IsL2NormalizationSupported(const TensorInfo & input,const TensorInfo & output,const L2NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1168*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
1169*89c4ff92SAndroid Build Coastguard Worker                                                   const TensorInfo& output,
1170*89c4ff92SAndroid Build Coastguard Worker                                                   const L2NormalizationDescriptor& descriptor,
1171*89c4ff92SAndroid Build Coastguard Worker                                                   Optional<std::string&> reasonIfUnsupported) const
1172*89c4ff92SAndroid Build Coastguard Worker {
1173*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1174*89c4ff92SAndroid Build Coastguard Worker }
1175*89c4ff92SAndroid Build Coastguard Worker 
IsLogicalBinarySupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,const LogicalBinaryDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1176*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
1177*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& input1,
1178*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& output,
1179*89c4ff92SAndroid Build Coastguard Worker                                                 const LogicalBinaryDescriptor& descriptor,
1180*89c4ff92SAndroid Build Coastguard Worker                                                 Optional<std::string&> reasonIfUnsupported) const
1181*89c4ff92SAndroid Build Coastguard Worker {
1182*89c4ff92SAndroid Build Coastguard Worker     switch(descriptor.m_Operation)
1183*89c4ff92SAndroid Build Coastguard Worker     {
1184*89c4ff92SAndroid Build Coastguard Worker         case LogicalBinaryOperation::LogicalAnd:
1185*89c4ff92SAndroid Build Coastguard Worker             FORWARD_WORKLOAD_VALIDATE_FUNC(NeonLogicalAndWorkloadValidate,
1186*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported,
1187*89c4ff92SAndroid Build Coastguard Worker                                            input0,
1188*89c4ff92SAndroid Build Coastguard Worker                                            input1,
1189*89c4ff92SAndroid Build Coastguard Worker                                            output);
1190*89c4ff92SAndroid Build Coastguard Worker         case LogicalBinaryOperation::LogicalOr:
1191*89c4ff92SAndroid Build Coastguard Worker             FORWARD_WORKLOAD_VALIDATE_FUNC(NeonLogicalOrWorkloadValidate,
1192*89c4ff92SAndroid Build Coastguard Worker                                            reasonIfUnsupported,
1193*89c4ff92SAndroid Build Coastguard Worker                                            input0,
1194*89c4ff92SAndroid Build Coastguard Worker                                            input1,
1195*89c4ff92SAndroid Build Coastguard Worker                                            output);
1196*89c4ff92SAndroid Build Coastguard Worker         default:
1197*89c4ff92SAndroid Build Coastguard Worker             return false;
1198*89c4ff92SAndroid Build Coastguard Worker     }
1199*89c4ff92SAndroid Build Coastguard Worker }
1200*89c4ff92SAndroid Build Coastguard Worker 
IsLogSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const LogSoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1201*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
1202*89c4ff92SAndroid Build Coastguard Worker                                              const TensorInfo& output,
1203*89c4ff92SAndroid Build Coastguard Worker                                              const LogSoftmaxDescriptor& descriptor,
1204*89c4ff92SAndroid Build Coastguard Worker                                              Optional<std::string&> reasonIfUnsupported) const
1205*89c4ff92SAndroid Build Coastguard Worker {
1206*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonLogSoftmaxWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1207*89c4ff92SAndroid Build Coastguard Worker }
1208*89c4ff92SAndroid Build Coastguard Worker 
IsLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & scratchBuffer,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const LstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1209*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsLstmSupported(const TensorInfo& input,
1210*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& outputStateIn,
1211*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& cellStateIn,
1212*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& scratchBuffer,
1213*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& outputStateOut,
1214*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& cellStateOut,
1215*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& output,
1216*89c4ff92SAndroid Build Coastguard Worker                                        const LstmDescriptor& descriptor,
1217*89c4ff92SAndroid Build Coastguard Worker                                        const LstmInputParamsInfo& paramsInfo,
1218*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
1219*89c4ff92SAndroid Build Coastguard Worker {
1220*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonLstmFloatWorkloadValidate,
1221*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1222*89c4ff92SAndroid Build Coastguard Worker                                    input,
1223*89c4ff92SAndroid Build Coastguard Worker                                    outputStateIn,
1224*89c4ff92SAndroid Build Coastguard Worker                                    cellStateIn,
1225*89c4ff92SAndroid Build Coastguard Worker                                    scratchBuffer,
1226*89c4ff92SAndroid Build Coastguard Worker                                    outputStateOut,
1227*89c4ff92SAndroid Build Coastguard Worker                                    cellStateOut,
1228*89c4ff92SAndroid Build Coastguard Worker                                    output,
1229*89c4ff92SAndroid Build Coastguard Worker                                    descriptor,
1230*89c4ff92SAndroid Build Coastguard Worker                                    paramsInfo);
1231*89c4ff92SAndroid Build Coastguard Worker }
1232*89c4ff92SAndroid Build Coastguard Worker 
IsMaximumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1233*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsMaximumSupported(const TensorInfo& input0,
1234*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& input1,
1235*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& output,
1236*89c4ff92SAndroid Build Coastguard Worker                                           Optional<std::string&> reasonIfUnsupported) const
1237*89c4ff92SAndroid Build Coastguard Worker {
1238*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonMaximumWorkloadValidate,
1239*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1240*89c4ff92SAndroid Build Coastguard Worker                                    input0,
1241*89c4ff92SAndroid Build Coastguard Worker                                    input1,
1242*89c4ff92SAndroid Build Coastguard Worker                                    output);
1243*89c4ff92SAndroid Build Coastguard Worker }
1244*89c4ff92SAndroid Build Coastguard Worker 
IsMeanSupported(const TensorInfo & input,const TensorInfo & output,const MeanDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1245*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsMeanSupported(const TensorInfo& input,
1246*89c4ff92SAndroid Build Coastguard Worker                                        const TensorInfo& output,
1247*89c4ff92SAndroid Build Coastguard Worker                                        const MeanDescriptor& descriptor,
1248*89c4ff92SAndroid Build Coastguard Worker                                        Optional<std::string&> reasonIfUnsupported) const
1249*89c4ff92SAndroid Build Coastguard Worker {
1250*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonMeanWorkloadValidate,
1251*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1252*89c4ff92SAndroid Build Coastguard Worker                                    input,
1253*89c4ff92SAndroid Build Coastguard Worker                                    output,
1254*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1255*89c4ff92SAndroid Build Coastguard Worker }
1256*89c4ff92SAndroid Build Coastguard Worker 
IsMinimumSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1257*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsMinimumSupported(const TensorInfo& input0,
1258*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& input1,
1259*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& output,
1260*89c4ff92SAndroid Build Coastguard Worker                                           Optional<std::string&> reasonIfUnsupported) const
1261*89c4ff92SAndroid Build Coastguard Worker {
1262*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonMinimumWorkloadValidate,
1263*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1264*89c4ff92SAndroid Build Coastguard Worker                                    input0,
1265*89c4ff92SAndroid Build Coastguard Worker                                    input1,
1266*89c4ff92SAndroid Build Coastguard Worker                                    output);
1267*89c4ff92SAndroid Build Coastguard Worker }
1268*89c4ff92SAndroid Build Coastguard Worker 
IsMultiplicationSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1269*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsMultiplicationSupported(const TensorInfo& input0,
1270*89c4ff92SAndroid Build Coastguard Worker                                                  const TensorInfo& input1,
1271*89c4ff92SAndroid Build Coastguard Worker                                                  const TensorInfo& output,
1272*89c4ff92SAndroid Build Coastguard Worker                                                  Optional<std::string&> reasonIfUnsupported) const
1273*89c4ff92SAndroid Build Coastguard Worker {
1274*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonMultiplicationWorkloadValidate,
1275*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1276*89c4ff92SAndroid Build Coastguard Worker                                    input0,
1277*89c4ff92SAndroid Build Coastguard Worker                                    input1,
1278*89c4ff92SAndroid Build Coastguard Worker                                    output,
1279*89c4ff92SAndroid Build Coastguard Worker                                    nullptr);
1280*89c4ff92SAndroid Build Coastguard Worker }
1281*89c4ff92SAndroid Build Coastguard Worker 
IsDivisionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1282*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsDivisionSupported(const TensorInfo& input0,
1283*89c4ff92SAndroid Build Coastguard Worker                                            const TensorInfo& input1,
1284*89c4ff92SAndroid Build Coastguard Worker                                            const TensorInfo& output,
1285*89c4ff92SAndroid Build Coastguard Worker                                            Optional<std::string&> reasonIfUnsupported) const
1286*89c4ff92SAndroid Build Coastguard Worker {
1287*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonDivisionWorkloadValidate,
1288*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1289*89c4ff92SAndroid Build Coastguard Worker                                    input0,
1290*89c4ff92SAndroid Build Coastguard Worker                                    input1,
1291*89c4ff92SAndroid Build Coastguard Worker                                    output,
1292*89c4ff92SAndroid Build Coastguard Worker                                    nullptr);
1293*89c4ff92SAndroid Build Coastguard Worker }
1294*89c4ff92SAndroid Build Coastguard Worker 
IsNormalizationSupported(const TensorInfo & input,const TensorInfo & output,const NormalizationDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1295*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsNormalizationSupported(const TensorInfo& input,
1296*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& output,
1297*89c4ff92SAndroid Build Coastguard Worker                                                 const NormalizationDescriptor& descriptor,
1298*89c4ff92SAndroid Build Coastguard Worker                                                 Optional<std::string&> reasonIfUnsupported) const
1299*89c4ff92SAndroid Build Coastguard Worker {
1300*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonNormalizationWorkloadValidate,
1301*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1302*89c4ff92SAndroid Build Coastguard Worker                                    input,
1303*89c4ff92SAndroid Build Coastguard Worker                                    output,
1304*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1305*89c4ff92SAndroid Build Coastguard Worker }
1306*89c4ff92SAndroid Build Coastguard Worker 
IsOutputSupported(const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1307*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsOutputSupported(const TensorInfo& output,
1308*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::string&> reasonIfUnsupported) const
1309*89c4ff92SAndroid Build Coastguard Worker {
1310*89c4ff92SAndroid Build Coastguard Worker     return IsNeonBackendSupported(reasonIfUnsupported, output);
1311*89c4ff92SAndroid Build Coastguard Worker }
1312*89c4ff92SAndroid Build Coastguard Worker 
IsPadSupported(const TensorInfo & input,const TensorInfo & output,const PadDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1313*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsPadSupported(const TensorInfo& input,
1314*89c4ff92SAndroid Build Coastguard Worker                                       const TensorInfo& output,
1315*89c4ff92SAndroid Build Coastguard Worker                                       const PadDescriptor& descriptor,
1316*89c4ff92SAndroid Build Coastguard Worker                                       Optional<std::string&> reasonIfUnsupported) const
1317*89c4ff92SAndroid Build Coastguard Worker {
1318*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonPadWorkloadValidate,
1319*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1320*89c4ff92SAndroid Build Coastguard Worker                                    input,
1321*89c4ff92SAndroid Build Coastguard Worker                                    output,
1322*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1323*89c4ff92SAndroid Build Coastguard Worker }
1324*89c4ff92SAndroid Build Coastguard Worker 
IsPermuteSupported(const TensorInfo & input,const TensorInfo & output,const PermuteDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1325*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsPermuteSupported(const TensorInfo& input,
1326*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& output,
1327*89c4ff92SAndroid Build Coastguard Worker                                           const PermuteDescriptor& descriptor,
1328*89c4ff92SAndroid Build Coastguard Worker                                           Optional<std::string&> reasonIfUnsupported) const
1329*89c4ff92SAndroid Build Coastguard Worker {
1330*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonPermuteWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1331*89c4ff92SAndroid Build Coastguard Worker }
1332*89c4ff92SAndroid Build Coastguard Worker 
IsPooling2dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling2dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1333*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsPooling2dSupported(const TensorInfo& input,
1334*89c4ff92SAndroid Build Coastguard Worker                                             const TensorInfo& output,
1335*89c4ff92SAndroid Build Coastguard Worker                                             const Pooling2dDescriptor& descriptor,
1336*89c4ff92SAndroid Build Coastguard Worker                                             Optional<std::string&> reasonIfUnsupported) const
1337*89c4ff92SAndroid Build Coastguard Worker {
1338*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1339*89c4ff92SAndroid Build Coastguard Worker }
1340*89c4ff92SAndroid Build Coastguard Worker 
IsPooling3dSupported(const TensorInfo & input,const TensorInfo & output,const Pooling3dDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1341*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsPooling3dSupported(const TensorInfo& input,
1342*89c4ff92SAndroid Build Coastguard Worker                                             const TensorInfo& output,
1343*89c4ff92SAndroid Build Coastguard Worker                                             const Pooling3dDescriptor& descriptor,
1344*89c4ff92SAndroid Build Coastguard Worker                                             Optional<std::string&> reasonIfUnsupported) const
1345*89c4ff92SAndroid Build Coastguard Worker {
1346*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonPooling3dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1347*89c4ff92SAndroid Build Coastguard Worker }
1348*89c4ff92SAndroid Build Coastguard Worker 
IsPreluSupported(const armnn::TensorInfo & input,const armnn::TensorInfo & alpha,const armnn::TensorInfo & output,armnn::Optional<std::string &> reasonIfUnsupported) const1349*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsPreluSupported(const armnn::TensorInfo &input,
1350*89c4ff92SAndroid Build Coastguard Worker                                         const armnn::TensorInfo &alpha,
1351*89c4ff92SAndroid Build Coastguard Worker                                         const armnn::TensorInfo &output,
1352*89c4ff92SAndroid Build Coastguard Worker                                         armnn::Optional<std::string &> reasonIfUnsupported) const
1353*89c4ff92SAndroid Build Coastguard Worker {
1354*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonPreluWorkloadValidate, reasonIfUnsupported, input, alpha, output);
1355*89c4ff92SAndroid Build Coastguard Worker }
1356*89c4ff92SAndroid Build Coastguard Worker 
IsQLstmSupported(const TensorInfo & input,const TensorInfo & previousOutputIn,const TensorInfo & previousCellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const QLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1357*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsQLstmSupported(const TensorInfo& input,
1358*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& previousOutputIn,
1359*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& previousCellStateIn,
1360*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& outputStateOut,
1361*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& cellStateOut,
1362*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& output,
1363*89c4ff92SAndroid Build Coastguard Worker                                         const QLstmDescriptor& descriptor,
1364*89c4ff92SAndroid Build Coastguard Worker                                         const LstmInputParamsInfo& paramsInfo,
1365*89c4ff92SAndroid Build Coastguard Worker                                         Optional<std::string&> reasonIfUnsupported) const
1366*89c4ff92SAndroid Build Coastguard Worker {
1367*89c4ff92SAndroid Build Coastguard Worker     // Check required here in order to pass IsLayerSupported for datatypes tests
1368*89c4ff92SAndroid Build Coastguard Worker     if (input.GetDataType()               == armnn::DataType::QAsymmS8 &&
1369*89c4ff92SAndroid Build Coastguard Worker         previousOutputIn.GetDataType()    == armnn::DataType::QAsymmS8 &&
1370*89c4ff92SAndroid Build Coastguard Worker         previousCellStateIn.GetDataType() == armnn::DataType::QSymmS16 &&
1371*89c4ff92SAndroid Build Coastguard Worker         outputStateOut.GetDataType()      == armnn::DataType::QAsymmS8 &&
1372*89c4ff92SAndroid Build Coastguard Worker         cellStateOut.GetDataType()        == armnn::DataType::QSymmS16 &&
1373*89c4ff92SAndroid Build Coastguard Worker         output.GetDataType()              == armnn::DataType::QAsymmS8)
1374*89c4ff92SAndroid Build Coastguard Worker     {
1375*89c4ff92SAndroid Build Coastguard Worker         FORWARD_WORKLOAD_VALIDATE_FUNC(NeonQLstmWorkloadValidate,
1376*89c4ff92SAndroid Build Coastguard Worker                                        reasonIfUnsupported,
1377*89c4ff92SAndroid Build Coastguard Worker                                        input,
1378*89c4ff92SAndroid Build Coastguard Worker                                        previousCellStateIn,
1379*89c4ff92SAndroid Build Coastguard Worker                                        previousOutputIn,
1380*89c4ff92SAndroid Build Coastguard Worker                                        cellStateOut,
1381*89c4ff92SAndroid Build Coastguard Worker                                        outputStateOut,
1382*89c4ff92SAndroid Build Coastguard Worker                                        output,
1383*89c4ff92SAndroid Build Coastguard Worker                                        descriptor,
1384*89c4ff92SAndroid Build Coastguard Worker                                        paramsInfo);
1385*89c4ff92SAndroid Build Coastguard Worker     }
1386*89c4ff92SAndroid Build Coastguard Worker     else
1387*89c4ff92SAndroid Build Coastguard Worker     {
1388*89c4ff92SAndroid Build Coastguard Worker         return false;
1389*89c4ff92SAndroid Build Coastguard Worker     }
1390*89c4ff92SAndroid Build Coastguard Worker }
1391*89c4ff92SAndroid Build Coastguard Worker 
IsQuantizeSupported(const TensorInfo & input,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1392*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsQuantizeSupported(const TensorInfo& input,
1393*89c4ff92SAndroid Build Coastguard Worker                                            const TensorInfo& output,
1394*89c4ff92SAndroid Build Coastguard Worker                                            Optional<std::string&> reasonIfUnsupported) const
1395*89c4ff92SAndroid Build Coastguard Worker {
1396*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonQuantizeWorkloadValidate,
1397*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1398*89c4ff92SAndroid Build Coastguard Worker                                    input,
1399*89c4ff92SAndroid Build Coastguard Worker                                    output);
1400*89c4ff92SAndroid Build Coastguard Worker }
1401*89c4ff92SAndroid Build Coastguard Worker 
IsQuantizedLstmSupported(const TensorInfo & input,const TensorInfo & cellStateIn,const TensorInfo & outputStateIn,const TensorInfo & cellStateOut,const TensorInfo & outputStateOut,const QuantizedLstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1402*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsQuantizedLstmSupported(const TensorInfo& input,
1403*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& cellStateIn,
1404*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& outputStateIn,
1405*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& cellStateOut,
1406*89c4ff92SAndroid Build Coastguard Worker                                                 const TensorInfo& outputStateOut,
1407*89c4ff92SAndroid Build Coastguard Worker                                                 const QuantizedLstmInputParamsInfo& paramsInfo,
1408*89c4ff92SAndroid Build Coastguard Worker                                                 Optional<std::string&> reasonIfUnsupported) const
1409*89c4ff92SAndroid Build Coastguard Worker {
1410*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonQuantizedLstmWorkloadValidate,
1411*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1412*89c4ff92SAndroid Build Coastguard Worker                                    input,
1413*89c4ff92SAndroid Build Coastguard Worker                                    cellStateIn,
1414*89c4ff92SAndroid Build Coastguard Worker                                    outputStateIn,
1415*89c4ff92SAndroid Build Coastguard Worker                                    cellStateOut,
1416*89c4ff92SAndroid Build Coastguard Worker                                    outputStateOut,
1417*89c4ff92SAndroid Build Coastguard Worker                                    paramsInfo);
1418*89c4ff92SAndroid Build Coastguard Worker }
1419*89c4ff92SAndroid Build Coastguard Worker 
IsReduceSupported(const TensorInfo & input,const TensorInfo & output,const ReduceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1420*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsReduceSupported(const TensorInfo& input,
1421*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& output,
1422*89c4ff92SAndroid Build Coastguard Worker                                          const ReduceDescriptor& descriptor,
1423*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::string&> reasonIfUnsupported) const
1424*89c4ff92SAndroid Build Coastguard Worker {
1425*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonReduceWorkloadValidate,
1426*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1427*89c4ff92SAndroid Build Coastguard Worker                                    input,
1428*89c4ff92SAndroid Build Coastguard Worker                                    output,
1429*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1430*89c4ff92SAndroid Build Coastguard Worker }
1431*89c4ff92SAndroid Build Coastguard Worker 
IsReshapeSupported(const TensorInfo & input,const TensorInfo & output,const ReshapeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1432*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsReshapeSupported(const TensorInfo& input,
1433*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& output,
1434*89c4ff92SAndroid Build Coastguard Worker                                           const ReshapeDescriptor& descriptor,
1435*89c4ff92SAndroid Build Coastguard Worker                                           Optional<std::string&> reasonIfUnsupported) const
1436*89c4ff92SAndroid Build Coastguard Worker {
1437*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(descriptor);
1438*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonReshapeWorkloadValidate,
1439*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1440*89c4ff92SAndroid Build Coastguard Worker                                    input,
1441*89c4ff92SAndroid Build Coastguard Worker                                    output);
1442*89c4ff92SAndroid Build Coastguard Worker }
1443*89c4ff92SAndroid Build Coastguard Worker 
IsResizeSupported(const TensorInfo & input,const TensorInfo & output,const ResizeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1444*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsResizeSupported(const TensorInfo& input,
1445*89c4ff92SAndroid Build Coastguard Worker                                          const TensorInfo& output,
1446*89c4ff92SAndroid Build Coastguard Worker                                          const ResizeDescriptor& descriptor,
1447*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::string&> reasonIfUnsupported) const
1448*89c4ff92SAndroid Build Coastguard Worker {
1449*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonResizeWorkloadValidate,
1450*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1451*89c4ff92SAndroid Build Coastguard Worker                                    input,
1452*89c4ff92SAndroid Build Coastguard Worker                                    output,
1453*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1454*89c4ff92SAndroid Build Coastguard Worker }
1455*89c4ff92SAndroid Build Coastguard Worker 
IsSliceSupported(const TensorInfo & input,const TensorInfo & output,const SliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1456*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsSliceSupported(const TensorInfo& input,
1457*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& output,
1458*89c4ff92SAndroid Build Coastguard Worker                                         const SliceDescriptor& descriptor,
1459*89c4ff92SAndroid Build Coastguard Worker                                         Optional<std::string&> reasonIfUnsupported) const
1460*89c4ff92SAndroid Build Coastguard Worker {
1461*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonSliceWorkloadValidate,
1462*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1463*89c4ff92SAndroid Build Coastguard Worker                                    input,
1464*89c4ff92SAndroid Build Coastguard Worker                                    output,
1465*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1466*89c4ff92SAndroid Build Coastguard Worker }
1467*89c4ff92SAndroid Build Coastguard Worker 
IsSoftmaxSupported(const TensorInfo & input,const TensorInfo & output,const SoftmaxDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1468*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
1469*89c4ff92SAndroid Build Coastguard Worker                                           const TensorInfo& output,
1470*89c4ff92SAndroid Build Coastguard Worker                                           const SoftmaxDescriptor& descriptor,
1471*89c4ff92SAndroid Build Coastguard Worker                                           Optional<std::string&> reasonIfUnsupported) const
1472*89c4ff92SAndroid Build Coastguard Worker {
1473*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonSoftmaxWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1474*89c4ff92SAndroid Build Coastguard Worker }
1475*89c4ff92SAndroid Build Coastguard Worker 
IsSpaceToBatchNdSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToBatchNdDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1476*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
1477*89c4ff92SAndroid Build Coastguard Worker                                                  const TensorInfo& output,
1478*89c4ff92SAndroid Build Coastguard Worker                                                  const SpaceToBatchNdDescriptor& descriptor,
1479*89c4ff92SAndroid Build Coastguard Worker                                                  Optional<std::string&> reasonIfUnsupported) const
1480*89c4ff92SAndroid Build Coastguard Worker {
1481*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonSpaceToBatchNdWorkloadValidate,
1482*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1483*89c4ff92SAndroid Build Coastguard Worker                                    input,
1484*89c4ff92SAndroid Build Coastguard Worker                                    output,
1485*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1486*89c4ff92SAndroid Build Coastguard Worker }
1487*89c4ff92SAndroid Build Coastguard Worker 
IsSpaceToDepthSupported(const TensorInfo & input,const TensorInfo & output,const SpaceToDepthDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1488*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
1489*89c4ff92SAndroid Build Coastguard Worker                                                const TensorInfo& output,
1490*89c4ff92SAndroid Build Coastguard Worker                                                const SpaceToDepthDescriptor& descriptor,
1491*89c4ff92SAndroid Build Coastguard Worker                                                Optional<std::string&> reasonIfUnsupported) const
1492*89c4ff92SAndroid Build Coastguard Worker {
1493*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonSpaceToDepthWorkloadValidate,
1494*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1495*89c4ff92SAndroid Build Coastguard Worker                                    input,
1496*89c4ff92SAndroid Build Coastguard Worker                                    output,
1497*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1498*89c4ff92SAndroid Build Coastguard Worker }
1499*89c4ff92SAndroid Build Coastguard Worker 
IsSplitterSupported(const TensorInfo & input,const std::vector<std::reference_wrapper<TensorInfo>> & outputs,const ViewsDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1500*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsSplitterSupported(const TensorInfo& input,
1501*89c4ff92SAndroid Build Coastguard Worker                                            const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
1502*89c4ff92SAndroid Build Coastguard Worker                                            const ViewsDescriptor& descriptor,
1503*89c4ff92SAndroid Build Coastguard Worker                                            Optional<std::string&> reasonIfUnsupported) const
1504*89c4ff92SAndroid Build Coastguard Worker {
1505*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTENEON_ENABLED)
1506*89c4ff92SAndroid Build Coastguard Worker     // Split along the last dimension, cannot use sub-tensors
1507*89c4ff92SAndroid Build Coastguard Worker     // as width and height of the sub-tensors do not match
1508*89c4ff92SAndroid Build Coastguard Worker     // the width and height of the parent tensor
1509*89c4ff92SAndroid Build Coastguard Worker     // in case of input with more than 2D.
1510*89c4ff92SAndroid Build Coastguard Worker     std::set<unsigned int> splitAxis = ComputeSplitAxis(descriptor, input.GetShape());
1511*89c4ff92SAndroid Build Coastguard Worker     if (descriptor.GetNumDimensions() > 2 && splitAxis.size() == 1 &&
1512*89c4ff92SAndroid Build Coastguard Worker         *splitAxis.begin() == descriptor.GetNumDimensions() - 1 )
1513*89c4ff92SAndroid Build Coastguard Worker     {
1514*89c4ff92SAndroid Build Coastguard Worker         FORWARD_WORKLOAD_VALIDATE_FUNC(NeonSplitterWorkloadValidate,
1515*89c4ff92SAndroid Build Coastguard Worker                                        reasonIfUnsupported,
1516*89c4ff92SAndroid Build Coastguard Worker                                        input,
1517*89c4ff92SAndroid Build Coastguard Worker                                        outputs,
1518*89c4ff92SAndroid Build Coastguard Worker                                        *splitAxis.begin());
1519*89c4ff92SAndroid Build Coastguard Worker     }
1520*89c4ff92SAndroid Build Coastguard Worker #endif
1521*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(descriptor);
1522*89c4ff92SAndroid Build Coastguard Worker     for (auto output : outputs)
1523*89c4ff92SAndroid Build Coastguard Worker     {
1524*89c4ff92SAndroid Build Coastguard Worker         if (!input.IsTypeSpaceMatch(output)) // Cannot use sub-tensors if the types are not same space
1525*89c4ff92SAndroid Build Coastguard Worker         {
1526*89c4ff92SAndroid Build Coastguard Worker             SetValueChecked(reasonIfUnsupported, "Neon Splitter: Types and quantization parameters must match.");
1527*89c4ff92SAndroid Build Coastguard Worker             return false;
1528*89c4ff92SAndroid Build Coastguard Worker         }
1529*89c4ff92SAndroid Build Coastguard Worker     }
1530*89c4ff92SAndroid Build Coastguard Worker     return true;
1531*89c4ff92SAndroid Build Coastguard Worker }
1532*89c4ff92SAndroid Build Coastguard Worker 
IsStackSupported(const std::vector<const TensorInfo * > & inputs,const TensorInfo & output,const StackDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1533*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsStackSupported(const std::vector<const TensorInfo*>& inputs,
1534*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& output,
1535*89c4ff92SAndroid Build Coastguard Worker                                         const StackDescriptor& descriptor,
1536*89c4ff92SAndroid Build Coastguard Worker                                         Optional<std::string&> reasonIfUnsupported) const
1537*89c4ff92SAndroid Build Coastguard Worker {
1538*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonStackWorkloadValidate,
1539*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1540*89c4ff92SAndroid Build Coastguard Worker                                    inputs,
1541*89c4ff92SAndroid Build Coastguard Worker                                    output,
1542*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1543*89c4ff92SAndroid Build Coastguard Worker }
1544*89c4ff92SAndroid Build Coastguard Worker 
IsStridedSliceSupported(const TensorInfo & input,const TensorInfo & output,const StridedSliceDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1545*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsStridedSliceSupported(const TensorInfo& input,
1546*89c4ff92SAndroid Build Coastguard Worker                                                const TensorInfo& output,
1547*89c4ff92SAndroid Build Coastguard Worker                                                const StridedSliceDescriptor& descriptor,
1548*89c4ff92SAndroid Build Coastguard Worker                                                Optional<std::string&> reasonIfUnsupported) const
1549*89c4ff92SAndroid Build Coastguard Worker {
1550*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonStridedSliceWorkloadValidate,
1551*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1552*89c4ff92SAndroid Build Coastguard Worker                                    input,
1553*89c4ff92SAndroid Build Coastguard Worker                                    output,
1554*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
1555*89c4ff92SAndroid Build Coastguard Worker }
1556*89c4ff92SAndroid Build Coastguard Worker 
IsSubtractionSupported(const TensorInfo & input0,const TensorInfo & input1,const TensorInfo & output,Optional<std::string &> reasonIfUnsupported) const1557*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsSubtractionSupported(const TensorInfo& input0,
1558*89c4ff92SAndroid Build Coastguard Worker                                               const TensorInfo& input1,
1559*89c4ff92SAndroid Build Coastguard Worker                                               const TensorInfo& output,
1560*89c4ff92SAndroid Build Coastguard Worker                                               Optional<std::string&> reasonIfUnsupported) const
1561*89c4ff92SAndroid Build Coastguard Worker {
1562*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonSubtractionWorkloadValidate,
1563*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1564*89c4ff92SAndroid Build Coastguard Worker                                    input0,
1565*89c4ff92SAndroid Build Coastguard Worker                                    input1,
1566*89c4ff92SAndroid Build Coastguard Worker                                    output,
1567*89c4ff92SAndroid Build Coastguard Worker                                    nullptr);
1568*89c4ff92SAndroid Build Coastguard Worker }
1569*89c4ff92SAndroid Build Coastguard Worker 
IsTransposeConvolution2dSupported(const TensorInfo & input,const TensorInfo & output,const TransposeConvolution2dDescriptor & descriptor,const TensorInfo & weights,const Optional<TensorInfo> & biases,Optional<std::string &> reasonIfUnsupported) const1570*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input,
1571*89c4ff92SAndroid Build Coastguard Worker                                                          const TensorInfo& output,
1572*89c4ff92SAndroid Build Coastguard Worker                                                          const TransposeConvolution2dDescriptor& descriptor,
1573*89c4ff92SAndroid Build Coastguard Worker                                                          const TensorInfo& weights,
1574*89c4ff92SAndroid Build Coastguard Worker                                                          const Optional<TensorInfo>& biases,
1575*89c4ff92SAndroid Build Coastguard Worker                                                          Optional<std::string&> reasonIfUnsupported) const
1576*89c4ff92SAndroid Build Coastguard Worker {
1577*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonTransposeConvolution2dWorkloadValidate,
1578*89c4ff92SAndroid Build Coastguard Worker                                    reasonIfUnsupported,
1579*89c4ff92SAndroid Build Coastguard Worker                                    input,
1580*89c4ff92SAndroid Build Coastguard Worker                                    output,
1581*89c4ff92SAndroid Build Coastguard Worker                                    descriptor,
1582*89c4ff92SAndroid Build Coastguard Worker                                    weights,
1583*89c4ff92SAndroid Build Coastguard Worker                                    biases);
1584*89c4ff92SAndroid Build Coastguard Worker }
1585*89c4ff92SAndroid Build Coastguard Worker 
IsTransposeSupported(const TensorInfo & input,const TensorInfo & output,const TransposeDescriptor & descriptor,Optional<std::string &> reasonIfUnsupported) const1586*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsTransposeSupported(const TensorInfo& input,
1587*89c4ff92SAndroid Build Coastguard Worker                                             const TensorInfo& output,
1588*89c4ff92SAndroid Build Coastguard Worker                                             const TransposeDescriptor& descriptor,
1589*89c4ff92SAndroid Build Coastguard Worker                                             Optional<std::string&> reasonIfUnsupported) const
1590*89c4ff92SAndroid Build Coastguard Worker {
1591*89c4ff92SAndroid Build Coastguard Worker     FORWARD_WORKLOAD_VALIDATE_FUNC(NeonTransposeWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
1592*89c4ff92SAndroid Build Coastguard Worker }
1593*89c4ff92SAndroid Build Coastguard Worker 
IsUnidirectionalSequenceLstmSupported(const TensorInfo & input,const TensorInfo & outputStateIn,const TensorInfo & cellStateIn,const TensorInfo & outputStateOut,const TensorInfo & cellStateOut,const TensorInfo & output,const UnidirectionalSequenceLstmDescriptor & descriptor,const LstmInputParamsInfo & paramsInfo,Optional<std::string &> reasonIfUnsupported) const1594*89c4ff92SAndroid Build Coastguard Worker bool NeonLayerSupport::IsUnidirectionalSequenceLstmSupported(const TensorInfo& input,
1595*89c4ff92SAndroid Build Coastguard Worker                                                              const TensorInfo& outputStateIn,
1596*89c4ff92SAndroid Build Coastguard Worker                                                              const TensorInfo& cellStateIn,
1597*89c4ff92SAndroid Build Coastguard Worker                                                              const TensorInfo& outputStateOut,
1598*89c4ff92SAndroid Build Coastguard Worker                                                              const TensorInfo& cellStateOut,
1599*89c4ff92SAndroid Build Coastguard Worker                                                              const TensorInfo& output,
1600*89c4ff92SAndroid Build Coastguard Worker                                                              const UnidirectionalSequenceLstmDescriptor& descriptor,
1601*89c4ff92SAndroid Build Coastguard Worker                                                              const LstmInputParamsInfo& paramsInfo,
1602*89c4ff92SAndroid Build Coastguard Worker                                                              Optional<std::string&> reasonIfUnsupported) const
1603*89c4ff92SAndroid Build Coastguard Worker {
1604*89c4ff92SAndroid Build Coastguard Worker     if (input.GetDataType() == armnn::DataType::QAsymmS8 &&
1605*89c4ff92SAndroid Build Coastguard Worker         outputStateIn.GetDataType() == armnn::DataType::QAsymmS8 &&
1606*89c4ff92SAndroid Build Coastguard Worker         cellStateIn.GetDataType() == armnn::DataType::QSymmS16 &&
1607*89c4ff92SAndroid Build Coastguard Worker         outputStateOut.GetDataType() == armnn::DataType::QAsymmS8 &&
1608*89c4ff92SAndroid Build Coastguard Worker         cellStateOut.GetDataType() == armnn::DataType::QSymmS16 &&
1609*89c4ff92SAndroid Build Coastguard Worker         output.GetDataType() == armnn::DataType::QAsymmS8)
1610*89c4ff92SAndroid Build Coastguard Worker     {
1611*89c4ff92SAndroid Build Coastguard Worker         FORWARD_WORKLOAD_VALIDATE_FUNC(NeonUnidirectionalSequenceLstmWorkloadValidate,
1612*89c4ff92SAndroid Build Coastguard Worker                                        reasonIfUnsupported,
1613*89c4ff92SAndroid Build Coastguard Worker                                        input,
1614*89c4ff92SAndroid Build Coastguard Worker                                        outputStateIn,
1615*89c4ff92SAndroid Build Coastguard Worker                                        cellStateIn,
1616*89c4ff92SAndroid Build Coastguard Worker                                        outputStateOut,
1617*89c4ff92SAndroid Build Coastguard Worker                                        cellStateOut,
1618*89c4ff92SAndroid Build Coastguard Worker                                        output,
1619*89c4ff92SAndroid Build Coastguard Worker                                        descriptor,
1620*89c4ff92SAndroid Build Coastguard Worker                                        paramsInfo);
1621*89c4ff92SAndroid Build Coastguard Worker     }
1622*89c4ff92SAndroid Build Coastguard Worker     else
1623*89c4ff92SAndroid Build Coastguard Worker     {
1624*89c4ff92SAndroid Build Coastguard Worker         FORWARD_WORKLOAD_VALIDATE_FUNC(NeonUnidirectionalSequenceLstmFloatWorkloadValidate,
1625*89c4ff92SAndroid Build Coastguard Worker                                        reasonIfUnsupported,
1626*89c4ff92SAndroid Build Coastguard Worker                                        input,
1627*89c4ff92SAndroid Build Coastguard Worker                                        outputStateIn,
1628*89c4ff92SAndroid Build Coastguard Worker                                        cellStateIn,
1629*89c4ff92SAndroid Build Coastguard Worker                                        outputStateOut,
1630*89c4ff92SAndroid Build Coastguard Worker                                        cellStateOut,
1631*89c4ff92SAndroid Build Coastguard Worker                                        output,
1632*89c4ff92SAndroid Build Coastguard Worker                                        descriptor,
1633*89c4ff92SAndroid Build Coastguard Worker                                        paramsInfo);
1634*89c4ff92SAndroid Build Coastguard Worker     }
1635*89c4ff92SAndroid Build Coastguard Worker }
1636*89c4ff92SAndroid Build Coastguard Worker 
1637*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
1638