xref: /aosp_15_r20/external/armnn/shim/sl/canonical/ConversionUtils.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ConversionUtils.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker ///
10*89c4ff92SAndroid Build Coastguard Worker /// Helper classes
11*89c4ff92SAndroid Build Coastguard Worker ///
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker 
LayerInputHandle()16*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle::LayerInputHandle()
17*89c4ff92SAndroid Build Coastguard Worker     : m_OutputSlot(nullptr)
18*89c4ff92SAndroid Build Coastguard Worker     , m_Valid(false)
19*89c4ff92SAndroid Build Coastguard Worker {}
20*89c4ff92SAndroid Build Coastguard Worker 
LayerInputHandle(bool valid,armnn::IOutputSlot * outputSlot,armnn::TensorInfo tensorInfo)21*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle::LayerInputHandle(bool valid, armnn::IOutputSlot* outputSlot, armnn::TensorInfo tensorInfo)
22*89c4ff92SAndroid Build Coastguard Worker     : m_OutputSlot(outputSlot)
23*89c4ff92SAndroid Build Coastguard Worker     , m_Valid(valid)
24*89c4ff92SAndroid Build Coastguard Worker     , m_TensorInfo(tensorInfo)
25*89c4ff92SAndroid Build Coastguard Worker {}
26*89c4ff92SAndroid Build Coastguard Worker 
IsValid() const27*89c4ff92SAndroid Build Coastguard Worker bool LayerInputHandle::IsValid() const
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker     return m_Valid;
30*89c4ff92SAndroid Build Coastguard Worker }
31*89c4ff92SAndroid Build Coastguard Worker 
Connect(armnn::IInputSlot & inputSlot)32*89c4ff92SAndroid Build Coastguard Worker void LayerInputHandle::Connect(armnn::IInputSlot& inputSlot)
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(IsValid());
35*89c4ff92SAndroid Build Coastguard Worker     if (m_OutputSlot)
36*89c4ff92SAndroid Build Coastguard Worker     {
37*89c4ff92SAndroid Build Coastguard Worker         m_OutputSlot->Connect(inputSlot);
38*89c4ff92SAndroid Build Coastguard Worker     }
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker 
Disconnect(armnn::IInputSlot & inputSlot)41*89c4ff92SAndroid Build Coastguard Worker void LayerInputHandle::Disconnect(armnn::IInputSlot& inputSlot)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(IsValid());
44*89c4ff92SAndroid Build Coastguard Worker     if (m_OutputSlot)
45*89c4ff92SAndroid Build Coastguard Worker     {
46*89c4ff92SAndroid Build Coastguard Worker         m_OutputSlot->Disconnect(inputSlot);
47*89c4ff92SAndroid Build Coastguard Worker     }
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker 
GetTensorInfo() const50*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& LayerInputHandle::GetTensorInfo() const
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker     return m_TensorInfo;
53*89c4ff92SAndroid Build Coastguard Worker }
54*89c4ff92SAndroid Build Coastguard Worker 
SanitizeQuantizationScale(LayerInputHandle & weight,LayerInputHandle & input)55*89c4ff92SAndroid Build Coastguard Worker void LayerInputHandle::SanitizeQuantizationScale(LayerInputHandle& weight, LayerInputHandle& input)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker     if (m_OutputSlot)
58*89c4ff92SAndroid Build Coastguard Worker     {
59*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo weightInfo = weight.GetTensorInfo();
60*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo inputInfo = input.GetTensorInfo();
61*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo biasInfo = GetTensorInfo();
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker         SanitizeBiasQuantizationScale(biasInfo, weightInfo, inputInfo);
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker         m_TensorInfo = biasInfo;
66*89c4ff92SAndroid Build Coastguard Worker         m_OutputSlot->SetTensorInfo(biasInfo);
67*89c4ff92SAndroid Build Coastguard Worker     }
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker 
GetOutputSlot() const70*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot* LayerInputHandle::GetOutputSlot() const
71*89c4ff92SAndroid Build Coastguard Worker {
72*89c4ff92SAndroid Build Coastguard Worker     return m_OutputSlot;
73*89c4ff92SAndroid Build Coastguard Worker }
74*89c4ff92SAndroid Build Coastguard Worker 
ConstTensorPin(bool optional)75*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin::ConstTensorPin(bool optional)
76*89c4ff92SAndroid Build Coastguard Worker     : m_Optional(optional)
77*89c4ff92SAndroid Build Coastguard Worker {}
78*89c4ff92SAndroid Build Coastguard Worker 
ConstTensorPin(armnn::TensorInfo & tensorInfo,const void * valueStart,uint32_t numBytes,const armnn::PermutationVector & mappings)79*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin::ConstTensorPin(armnn::TensorInfo& tensorInfo,
80*89c4ff92SAndroid Build Coastguard Worker                                const void* valueStart,
81*89c4ff92SAndroid Build Coastguard Worker                                uint32_t numBytes,
82*89c4ff92SAndroid Build Coastguard Worker                                const armnn::PermutationVector& mappings)
83*89c4ff92SAndroid Build Coastguard Worker     : m_Optional(false)
84*89c4ff92SAndroid Build Coastguard Worker {
85*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(numBytes);
86*89c4ff92SAndroid Build Coastguard Worker     if (tensorInfo.GetNumBytes() != numBytes)
87*89c4ff92SAndroid Build Coastguard Worker     {
88*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << "The size of ConstTensor does not match its TensorInfo.";
89*89c4ff92SAndroid Build Coastguard Worker     }
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     const bool needsSwizzling = (mappings.GetSize() > 0);
92*89c4ff92SAndroid Build Coastguard Worker     if (needsSwizzling)
93*89c4ff92SAndroid Build Coastguard Worker     {
94*89c4ff92SAndroid Build Coastguard Worker         m_SwizzledTensorData.resize(tensorInfo.GetNumBytes());
95*89c4ff92SAndroid Build Coastguard Worker         SwizzleAndroidNn4dTensorToArmNn(tensorInfo, valueStart, m_SwizzledTensorData.data(), mappings);
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker         m_ConstTensor = armnn::ConstTensor(tensorInfo, m_SwizzledTensorData.data());
98*89c4ff92SAndroid Build Coastguard Worker     }
99*89c4ff92SAndroid Build Coastguard Worker     else
100*89c4ff92SAndroid Build Coastguard Worker     {
101*89c4ff92SAndroid Build Coastguard Worker         m_ConstTensor = armnn::ConstTensor(tensorInfo, valueStart);
102*89c4ff92SAndroid Build Coastguard Worker     }
103*89c4ff92SAndroid Build Coastguard Worker }
104*89c4ff92SAndroid Build Coastguard Worker 
IsValid() const105*89c4ff92SAndroid Build Coastguard Worker bool ConstTensorPin::IsValid() const
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker     return m_ConstTensor.GetMemoryArea() != nullptr;
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker 
IsOptional() const110*89c4ff92SAndroid Build Coastguard Worker bool ConstTensorPin::IsOptional() const
111*89c4ff92SAndroid Build Coastguard Worker {
112*89c4ff92SAndroid Build Coastguard Worker     return m_Optional;
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker 
GetConstTensor() const115*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor& ConstTensorPin::GetConstTensor() const
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker     return m_ConstTensor;
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker 
GetConstTensorPtr() const120*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor* ConstTensorPin::GetConstTensorPtr() const
121*89c4ff92SAndroid Build Coastguard Worker {
122*89c4ff92SAndroid Build Coastguard Worker     if (IsValid() && m_ConstTensor.GetNumElements() > 0)
123*89c4ff92SAndroid Build Coastguard Worker     {
124*89c4ff92SAndroid Build Coastguard Worker         return &m_ConstTensor;
125*89c4ff92SAndroid Build Coastguard Worker     }
126*89c4ff92SAndroid Build Coastguard Worker     // tensor is either invalid, or has no elements (indicating an optional tensor that was not provided)
127*89c4ff92SAndroid Build Coastguard Worker     return nullptr;
128*89c4ff92SAndroid Build Coastguard Worker }
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker ///
131*89c4ff92SAndroid Build Coastguard Worker /// Utility functions
132*89c4ff92SAndroid Build Coastguard Worker ///
133*89c4ff92SAndroid Build Coastguard Worker 
IsWeightsValid(const Operation & operation,uint32_t inputIndex,const Model & model)134*89c4ff92SAndroid Build Coastguard Worker bool IsWeightsValid(const Operation& operation,
135*89c4ff92SAndroid Build Coastguard Worker                     uint32_t inputIndex,
136*89c4ff92SAndroid Build Coastguard Worker                     const Model& model)
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker     const Operand* operand = GetInputOperand(operation, inputIndex, model);
139*89c4ff92SAndroid Build Coastguard Worker     if (!operand)
140*89c4ff92SAndroid Build Coastguard Worker     {
141*89c4ff92SAndroid Build Coastguard Worker         Fail("%s: failed to get input operand %i", __func__, inputIndex);
142*89c4ff92SAndroid Build Coastguard Worker         return false;
143*89c4ff92SAndroid Build Coastguard Worker     }
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker     if (operand->lifetime    != OperandLifeTime::CONSTANT_COPY
146*89c4ff92SAndroid Build Coastguard Worker         && operand->lifetime != OperandLifeTime::CONSTANT_REFERENCE
147*89c4ff92SAndroid Build Coastguard Worker         && operand->lifetime != OperandLifeTime::NO_VALUE)
148*89c4ff92SAndroid Build Coastguard Worker     {
149*89c4ff92SAndroid Build Coastguard Worker         return false;
150*89c4ff92SAndroid Build Coastguard Worker     }
151*89c4ff92SAndroid Build Coastguard Worker     return true;
152*89c4ff92SAndroid Build Coastguard Worker }
153*89c4ff92SAndroid Build Coastguard Worker 
ConvertOperandToConstTensorPin(const Operand & operand,const Model & model,const ConversionData & data,const armnn::PermutationVector & dimensionMappings,const armnn::TensorShape * overrideTensorShape,bool optional,const armnn::DataType * overrideDataType)154*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin ConvertOperandToConstTensorPin(const Operand& operand,
155*89c4ff92SAndroid Build Coastguard Worker                                               const Model& model,
156*89c4ff92SAndroid Build Coastguard Worker                                               const ConversionData& data,
157*89c4ff92SAndroid Build Coastguard Worker                                               const armnn::PermutationVector& dimensionMappings,
158*89c4ff92SAndroid Build Coastguard Worker                                               const armnn::TensorShape* overrideTensorShape,
159*89c4ff92SAndroid Build Coastguard Worker                                               bool optional,
160*89c4ff92SAndroid Build Coastguard Worker                                               const armnn::DataType* overrideDataType)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker     if (!IsOperandTypeSupportedForTensors(operand.type))
163*89c4ff92SAndroid Build Coastguard Worker     {
164*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << __func__ << ": unsupported operand type for tensor" << operand.type;
165*89c4ff92SAndroid Build Coastguard Worker         return ConstTensorPin();
166*89c4ff92SAndroid Build Coastguard Worker     }
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker     if (!optional && !IsOperandConstant(operand))
169*89c4ff92SAndroid Build Coastguard Worker     {
170*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << __func__ << ": lifetime for input tensor: r" << operand.lifetime;
171*89c4ff92SAndroid Build Coastguard Worker         return ConstTensorPin();
172*89c4ff92SAndroid Build Coastguard Worker     }
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker     const void* const valueStart = GetOperandValueReadOnlyAddress(operand, model, data, optional);
175*89c4ff92SAndroid Build Coastguard Worker     if (!valueStart)
176*89c4ff92SAndroid Build Coastguard Worker     {
177*89c4ff92SAndroid Build Coastguard Worker         if (optional)
178*89c4ff92SAndroid Build Coastguard Worker         {
179*89c4ff92SAndroid Build Coastguard Worker             // optional tensor with no values is not really an error; return it as invalid, but marked as optional
180*89c4ff92SAndroid Build Coastguard Worker             return ConstTensorPin(true);
181*89c4ff92SAndroid Build Coastguard Worker         }
182*89c4ff92SAndroid Build Coastguard Worker         // mandatory tensor with no values
183*89c4ff92SAndroid Build Coastguard Worker         Fail("%s: failed to get operand address", __func__);
184*89c4ff92SAndroid Build Coastguard Worker         return ConstTensorPin();
185*89c4ff92SAndroid Build Coastguard Worker     }
186*89c4ff92SAndroid Build Coastguard Worker 
187*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo tensorInfo = GetTensorInfoForOperand(operand);
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker     if (overrideTensorShape)
190*89c4ff92SAndroid Build Coastguard Worker     {
191*89c4ff92SAndroid Build Coastguard Worker         tensorInfo.SetShape(*overrideTensorShape);
192*89c4ff92SAndroid Build Coastguard Worker     }
193*89c4ff92SAndroid Build Coastguard Worker 
194*89c4ff92SAndroid Build Coastguard Worker     if (overrideDataType)
195*89c4ff92SAndroid Build Coastguard Worker     {
196*89c4ff92SAndroid Build Coastguard Worker         tensorInfo.SetDataType(*overrideDataType);
197*89c4ff92SAndroid Build Coastguard Worker     }
198*89c4ff92SAndroid Build Coastguard Worker 
199*89c4ff92SAndroid Build Coastguard Worker     // Make sure isConstant flag is set.
200*89c4ff92SAndroid Build Coastguard Worker     tensorInfo.SetConstant();
201*89c4ff92SAndroid Build Coastguard Worker     return ConstTensorPin(tensorInfo, valueStart, operand.location.length, dimensionMappings);
202*89c4ff92SAndroid Build Coastguard Worker }
203*89c4ff92SAndroid Build Coastguard Worker 
ConvertToLayerInputHandle(const Operation & operation,uint32_t inputIndex,const Model & model,ConversionData & data,const armnn::PermutationVector & dimensionMappings,const LayerInputHandle * inputHandle)204*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle ConvertToLayerInputHandle(const Operation& operation,
205*89c4ff92SAndroid Build Coastguard Worker                                            uint32_t inputIndex,
206*89c4ff92SAndroid Build Coastguard Worker                                            const Model& model,
207*89c4ff92SAndroid Build Coastguard Worker                                            ConversionData& data,
208*89c4ff92SAndroid Build Coastguard Worker                                            const armnn::PermutationVector& dimensionMappings,
209*89c4ff92SAndroid Build Coastguard Worker                                            const LayerInputHandle* inputHandle)
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker     const Operand* operand = GetInputOperand(operation, inputIndex, model);
213*89c4ff92SAndroid Build Coastguard Worker     if (!operand)
214*89c4ff92SAndroid Build Coastguard Worker     {
215*89c4ff92SAndroid Build Coastguard Worker         Fail("%s: failed to get input operand %i", __func__, inputIndex);
216*89c4ff92SAndroid Build Coastguard Worker         return LayerInputHandle();
217*89c4ff92SAndroid Build Coastguard Worker     }
218*89c4ff92SAndroid Build Coastguard Worker 
219*89c4ff92SAndroid Build Coastguard Worker     if (!IsOperandTypeSupportedForTensors(operand->type))
220*89c4ff92SAndroid Build Coastguard Worker     {
221*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << __func__ << ": unsupported operand type for tensor: " << operand->type;
222*89c4ff92SAndroid Build Coastguard Worker         return LayerInputHandle();
223*89c4ff92SAndroid Build Coastguard Worker     }
224*89c4ff92SAndroid Build Coastguard Worker 
225*89c4ff92SAndroid Build Coastguard Worker     try
226*89c4ff92SAndroid Build Coastguard Worker     {
227*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo operandTensorInfo = GetTensorInfoForOperand(*operand);
228*89c4ff92SAndroid Build Coastguard Worker 
229*89c4ff92SAndroid Build Coastguard Worker         if (IsDynamicTensor(operandTensorInfo))
230*89c4ff92SAndroid Build Coastguard Worker         {
231*89c4ff92SAndroid Build Coastguard Worker             data.m_DynamicInputsEncountered = true;
232*89c4ff92SAndroid Build Coastguard Worker 
233*89c4ff92SAndroid Build Coastguard Worker             const uint32_t operandIndex = operation.inputs[inputIndex];
234*89c4ff92SAndroid Build Coastguard Worker 
235*89c4ff92SAndroid Build Coastguard Worker             // Check if the dynamic input tensors have been inferred by one of the previous layers
236*89c4ff92SAndroid Build Coastguard Worker             // If not we can't support them
237*89c4ff92SAndroid Build Coastguard Worker             if (data.m_OutputSlotForOperand.size() >= operandIndex && data.m_OutputSlotForOperand[operandIndex])
238*89c4ff92SAndroid Build Coastguard Worker             {
239*89c4ff92SAndroid Build Coastguard Worker                 operandTensorInfo = data.m_OutputSlotForOperand[operandIndex]->GetTensorInfo();
240*89c4ff92SAndroid Build Coastguard Worker             }
241*89c4ff92SAndroid Build Coastguard Worker             else
242*89c4ff92SAndroid Build Coastguard Worker             {
243*89c4ff92SAndroid Build Coastguard Worker                 Fail("%s: Type 2 dynamic input tensors are not supported", __func__);
244*89c4ff92SAndroid Build Coastguard Worker                 return LayerInputHandle();
245*89c4ff92SAndroid Build Coastguard Worker             }
246*89c4ff92SAndroid Build Coastguard Worker         }
247*89c4ff92SAndroid Build Coastguard Worker 
248*89c4ff92SAndroid Build Coastguard Worker         switch (operand->lifetime)
249*89c4ff92SAndroid Build Coastguard Worker         {
250*89c4ff92SAndroid Build Coastguard Worker             case OperandLifeTime::SUBGRAPH_INPUT:
251*89c4ff92SAndroid Build Coastguard Worker             {
252*89c4ff92SAndroid Build Coastguard Worker                 // NOTE: We must check whether we can support the input tensor on at least one
253*89c4ff92SAndroid Build Coastguard Worker                 // of the provided backends; otherwise we cannot convert the operation
254*89c4ff92SAndroid Build Coastguard Worker                 bool isInputSupported = false;
255*89c4ff92SAndroid Build Coastguard Worker                 FORWARD_LAYER_SUPPORT_FUNC(__func__,
256*89c4ff92SAndroid Build Coastguard Worker                                            IsInputSupported,
257*89c4ff92SAndroid Build Coastguard Worker                                            data.m_Backends,
258*89c4ff92SAndroid Build Coastguard Worker                                            isInputSupported,
259*89c4ff92SAndroid Build Coastguard Worker                                            armnn::BackendId(),
260*89c4ff92SAndroid Build Coastguard Worker                                            operandTensorInfo);
261*89c4ff92SAndroid Build Coastguard Worker 
262*89c4ff92SAndroid Build Coastguard Worker                 if (!isInputSupported)
263*89c4ff92SAndroid Build Coastguard Worker                 {
264*89c4ff92SAndroid Build Coastguard Worker                     Fail("%s: unsupported input tensor", __func__);
265*89c4ff92SAndroid Build Coastguard Worker                     return LayerInputHandle();
266*89c4ff92SAndroid Build Coastguard Worker                 }
267*89c4ff92SAndroid Build Coastguard Worker 
268*89c4ff92SAndroid Build Coastguard Worker                 [[clang::fallthrough]]; // intentional fallthrough
269*89c4ff92SAndroid Build Coastguard Worker             }
270*89c4ff92SAndroid Build Coastguard Worker             case OperandLifeTime::TEMPORARY_VARIABLE: // intentional fallthrough
271*89c4ff92SAndroid Build Coastguard Worker             case OperandLifeTime::SUBGRAPH_OUTPUT:
272*89c4ff92SAndroid Build Coastguard Worker             {
273*89c4ff92SAndroid Build Coastguard Worker                 // The tensor is either an operand internal to the model, or a model input.
274*89c4ff92SAndroid Build Coastguard Worker                 // It can be associated with an ArmNN output slot for an existing layer.
275*89c4ff92SAndroid Build Coastguard Worker 
276*89c4ff92SAndroid Build Coastguard Worker                 // m_OutputSlotForOperand[...] can be nullptr if the previous layer could not be converted
277*89c4ff92SAndroid Build Coastguard Worker                 const uint32_t operandIndex = operation.inputs[inputIndex];
278*89c4ff92SAndroid Build Coastguard Worker                 return LayerInputHandle(true, data.m_OutputSlotForOperand[operandIndex], operandTensorInfo);
279*89c4ff92SAndroid Build Coastguard Worker             }
280*89c4ff92SAndroid Build Coastguard Worker             case OperandLifeTime::CONSTANT_COPY: // intentional fallthrough
281*89c4ff92SAndroid Build Coastguard Worker             case OperandLifeTime::POINTER:
282*89c4ff92SAndroid Build Coastguard Worker             case OperandLifeTime::CONSTANT_REFERENCE:
283*89c4ff92SAndroid Build Coastguard Worker             {
284*89c4ff92SAndroid Build Coastguard Worker                 auto constantTensorDataType = operandTensorInfo.GetDataType();
285*89c4ff92SAndroid Build Coastguard Worker                 // The tensor has an already known constant value, and can be converted into an ArmNN Constant layer.
286*89c4ff92SAndroid Build Coastguard Worker                 ConstTensorPin tensorPin = ConvertOperandToConstTensorPin(*operand,
287*89c4ff92SAndroid Build Coastguard Worker                                                                           model,
288*89c4ff92SAndroid Build Coastguard Worker                                                                           data,
289*89c4ff92SAndroid Build Coastguard Worker                                                                           dimensionMappings,
290*89c4ff92SAndroid Build Coastguard Worker                                                                           nullptr,
291*89c4ff92SAndroid Build Coastguard Worker                                                                           false,
292*89c4ff92SAndroid Build Coastguard Worker                                                                           &constantTensorDataType);
293*89c4ff92SAndroid Build Coastguard Worker                 if (tensorPin.IsValid())
294*89c4ff92SAndroid Build Coastguard Worker                 {
295*89c4ff92SAndroid Build Coastguard Worker                     bool isSupported = false;
296*89c4ff92SAndroid Build Coastguard Worker                     armnn::BackendId setBackend;
297*89c4ff92SAndroid Build Coastguard Worker                     FORWARD_LAYER_SUPPORT_FUNC(__func__,
298*89c4ff92SAndroid Build Coastguard Worker                                                IsConstantSupported,
299*89c4ff92SAndroid Build Coastguard Worker                                                data.m_Backends,
300*89c4ff92SAndroid Build Coastguard Worker                                                isSupported,
301*89c4ff92SAndroid Build Coastguard Worker                                                setBackend,
302*89c4ff92SAndroid Build Coastguard Worker                                                tensorPin.GetConstTensor().GetInfo());
303*89c4ff92SAndroid Build Coastguard Worker                     if (!isSupported)
304*89c4ff92SAndroid Build Coastguard Worker                     {
305*89c4ff92SAndroid Build Coastguard Worker                         return LayerInputHandle();
306*89c4ff92SAndroid Build Coastguard Worker                     }
307*89c4ff92SAndroid Build Coastguard Worker 
308*89c4ff92SAndroid Build Coastguard Worker                     armnn::IConnectableLayer* constantLayer =
309*89c4ff92SAndroid Build Coastguard Worker                         data.m_Network->AddConstantLayer(tensorPin.GetConstTensor());
310*89c4ff92SAndroid Build Coastguard Worker                     constantLayer->SetBackendId(setBackend);
311*89c4ff92SAndroid Build Coastguard Worker                     armnn::IOutputSlot& outputSlot = constantLayer->GetOutputSlot(0);
312*89c4ff92SAndroid Build Coastguard Worker                     armnn::TensorInfo constantTensorInfo = tensorPin.GetConstTensor().GetInfo();
313*89c4ff92SAndroid Build Coastguard Worker                     outputSlot.SetTensorInfo(constantTensorInfo);
314*89c4ff92SAndroid Build Coastguard Worker 
315*89c4ff92SAndroid Build Coastguard Worker                     return LayerInputHandle(true, &outputSlot, constantTensorInfo);
316*89c4ff92SAndroid Build Coastguard Worker                 }
317*89c4ff92SAndroid Build Coastguard Worker                 else
318*89c4ff92SAndroid Build Coastguard Worker                 {
319*89c4ff92SAndroid Build Coastguard Worker                     Fail("%s: invalid operand tensor", __func__);
320*89c4ff92SAndroid Build Coastguard Worker                     return LayerInputHandle();
321*89c4ff92SAndroid Build Coastguard Worker                 }
322*89c4ff92SAndroid Build Coastguard Worker                 break;
323*89c4ff92SAndroid Build Coastguard Worker             }
324*89c4ff92SAndroid Build Coastguard Worker             default:
325*89c4ff92SAndroid Build Coastguard Worker             {
326*89c4ff92SAndroid Build Coastguard Worker                 VLOG(DRIVER) << __func__ << ": unsupported lifetime for input tensor: " << operand->lifetime;
327*89c4ff92SAndroid Build Coastguard Worker                 return LayerInputHandle();
328*89c4ff92SAndroid Build Coastguard Worker             }
329*89c4ff92SAndroid Build Coastguard Worker         }
330*89c4ff92SAndroid Build Coastguard Worker     }
331*89c4ff92SAndroid Build Coastguard Worker     catch (UnsupportedOperand<OperandType>& e)
332*89c4ff92SAndroid Build Coastguard Worker     {
333*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << __func__ << ": Operand type: " << e.m_type << " not supported in ArmnnDriver";
334*89c4ff92SAndroid Build Coastguard Worker         return LayerInputHandle();
335*89c4ff92SAndroid Build Coastguard Worker     }
336*89c4ff92SAndroid Build Coastguard Worker }
337*89c4ff92SAndroid Build Coastguard Worker 
ConvertPaddings(const Operation & operation,const Model & model,ConversionData & data,unsigned int rank,armnn::PadDescriptor & padDescriptor)338*89c4ff92SAndroid Build Coastguard Worker bool ConvertPaddings(const Operation& operation,
339*89c4ff92SAndroid Build Coastguard Worker                      const Model& model,
340*89c4ff92SAndroid Build Coastguard Worker                      ConversionData& data,
341*89c4ff92SAndroid Build Coastguard Worker                      unsigned int rank,
342*89c4ff92SAndroid Build Coastguard Worker                      armnn::PadDescriptor& padDescriptor)
343*89c4ff92SAndroid Build Coastguard Worker {
344*89c4ff92SAndroid Build Coastguard Worker     const Operand* paddingsOperand = GetInputOperand(operation, 1, model);
345*89c4ff92SAndroid Build Coastguard Worker     if (!paddingsOperand)
346*89c4ff92SAndroid Build Coastguard Worker     {
347*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Could not read paddings operand", __func__);
348*89c4ff92SAndroid Build Coastguard Worker     }
349*89c4ff92SAndroid Build Coastguard Worker 
350*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape paddingsOperandShape = GetTensorShapeForOperand(*paddingsOperand);
351*89c4ff92SAndroid Build Coastguard Worker     if (paddingsOperandShape.GetNumDimensions() != 2 || paddingsOperandShape.GetNumElements() != rank * 2)
352*89c4ff92SAndroid Build Coastguard Worker     {
353*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Operation has invalid paddings operand: expected shape [%d, 2]",  __func__, rank);
354*89c4ff92SAndroid Build Coastguard Worker     }
355*89c4ff92SAndroid Build Coastguard Worker 
356*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> paddings;
357*89c4ff92SAndroid Build Coastguard Worker     if (!GetTensorInt32Values(*paddingsOperand, paddings, model, data))
358*89c4ff92SAndroid Build Coastguard Worker     {
359*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Operation has invalid or unsupported paddings operand", __func__);
360*89c4ff92SAndroid Build Coastguard Worker     }
361*89c4ff92SAndroid Build Coastguard Worker 
362*89c4ff92SAndroid Build Coastguard Worker     // add padding for each dimension of input tensor.
363*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < paddings.size() - 1; i += 2)
364*89c4ff92SAndroid Build Coastguard Worker     {
365*89c4ff92SAndroid Build Coastguard Worker         int paddingBeforeInput = paddings[i];
366*89c4ff92SAndroid Build Coastguard Worker         int paddingAfterInput  = paddings[i + 1];
367*89c4ff92SAndroid Build Coastguard Worker 
368*89c4ff92SAndroid Build Coastguard Worker         if (paddingBeforeInput < 0 || paddingAfterInput < 0)
369*89c4ff92SAndroid Build Coastguard Worker         {
370*89c4ff92SAndroid Build Coastguard Worker             return Fail("%s: Operation has invalid paddings operand, invalid padding values.",  __func__);
371*89c4ff92SAndroid Build Coastguard Worker         }
372*89c4ff92SAndroid Build Coastguard Worker 
373*89c4ff92SAndroid Build Coastguard Worker         padDescriptor.m_PadList.emplace_back((unsigned int) paddingBeforeInput, (unsigned int) paddingAfterInput);
374*89c4ff92SAndroid Build Coastguard Worker     }
375*89c4ff92SAndroid Build Coastguard Worker 
376*89c4ff92SAndroid Build Coastguard Worker     return true;
377*89c4ff92SAndroid Build Coastguard Worker }
378*89c4ff92SAndroid Build Coastguard Worker 
379*89c4ff92SAndroid Build Coastguard Worker 
ConvertPooling2d(const Operation & operation,const char * operationName,armnn::PoolingAlgorithm poolType,const Model & model,ConversionData & data)380*89c4ff92SAndroid Build Coastguard Worker bool ConvertPooling2d(const Operation& operation,
381*89c4ff92SAndroid Build Coastguard Worker                       const char* operationName,
382*89c4ff92SAndroid Build Coastguard Worker                       armnn::PoolingAlgorithm poolType,
383*89c4ff92SAndroid Build Coastguard Worker                       const Model& model,
384*89c4ff92SAndroid Build Coastguard Worker                       ConversionData& data)
385*89c4ff92SAndroid Build Coastguard Worker {
386*89c4ff92SAndroid Build Coastguard Worker 
387*89c4ff92SAndroid Build Coastguard Worker     VLOG(DRIVER) << "Converter::ConvertL2Pool2d()";
388*89c4ff92SAndroid Build Coastguard Worker 
389*89c4ff92SAndroid Build Coastguard Worker     LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data);
390*89c4ff92SAndroid Build Coastguard Worker     if (!input.IsValid())
391*89c4ff92SAndroid Build Coastguard Worker     {
392*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Operation Could not read input 0", operationName);
393*89c4ff92SAndroid Build Coastguard Worker     }
394*89c4ff92SAndroid Build Coastguard Worker 
395*89c4ff92SAndroid Build Coastguard Worker     const Operand* output = GetOutputOperand(operation, 0, model);
396*89c4ff92SAndroid Build Coastguard Worker     if (!output)
397*89c4ff92SAndroid Build Coastguard Worker     {
398*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Could not read output 0", __func__);
399*89c4ff92SAndroid Build Coastguard Worker     }
400*89c4ff92SAndroid Build Coastguard Worker 
401*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& inputInfo  = input.GetTensorInfo();
402*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
403*89c4ff92SAndroid Build Coastguard Worker 
404*89c4ff92SAndroid Build Coastguard Worker     armnn::Pooling2dDescriptor desc;
405*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolType = poolType;
406*89c4ff92SAndroid Build Coastguard Worker     desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Floor;
407*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NHWC;
408*89c4ff92SAndroid Build Coastguard Worker 
409*89c4ff92SAndroid Build Coastguard Worker     ActivationFn activation;
410*89c4ff92SAndroid Build Coastguard Worker 
411*89c4ff92SAndroid Build Coastguard Worker     auto inputSize = operation.inputs.size();
412*89c4ff92SAndroid Build Coastguard Worker 
413*89c4ff92SAndroid Build Coastguard Worker     if (inputSize >= 10)
414*89c4ff92SAndroid Build Coastguard Worker     {
415*89c4ff92SAndroid Build Coastguard Worker         // one input, 9 parameters (padding l r t b, stridex, stridey, width, height, activation type)
416*89c4ff92SAndroid Build Coastguard Worker         if (!GetInputScalar(operation, 1, OperandType::INT32, desc.m_PadLeft, model, data) ||
417*89c4ff92SAndroid Build Coastguard Worker             !GetInputScalar(operation, 2, OperandType::INT32, desc.m_PadRight, model, data) ||
418*89c4ff92SAndroid Build Coastguard Worker             !GetInputScalar(operation, 3, OperandType::INT32, desc.m_PadTop, model, data) ||
419*89c4ff92SAndroid Build Coastguard Worker             !GetInputScalar(operation, 4, OperandType::INT32, desc.m_PadBottom, model, data) ||
420*89c4ff92SAndroid Build Coastguard Worker             !GetInputScalar(operation, 5, OperandType::INT32, desc.m_StrideX, model, data) ||
421*89c4ff92SAndroid Build Coastguard Worker             !GetInputScalar(operation, 6, OperandType::INT32, desc.m_StrideY, model, data) ||
422*89c4ff92SAndroid Build Coastguard Worker             !GetInputScalar(operation, 7, OperandType::INT32, desc.m_PoolWidth, model, data) ||
423*89c4ff92SAndroid Build Coastguard Worker             !GetInputScalar(operation, 8, OperandType::INT32, desc.m_PoolHeight, model, data) ||
424*89c4ff92SAndroid Build Coastguard Worker             !GetInputActivationFunction(operation, 9, activation, model, data))
425*89c4ff92SAndroid Build Coastguard Worker         {
426*89c4ff92SAndroid Build Coastguard Worker             return Fail("%s: Operation has invalid inputs", operationName);
427*89c4ff92SAndroid Build Coastguard Worker         }
428*89c4ff92SAndroid Build Coastguard Worker 
429*89c4ff92SAndroid Build Coastguard Worker         if (Is12OrLaterOperand(*output))
430*89c4ff92SAndroid Build Coastguard Worker         {
431*89c4ff92SAndroid Build Coastguard Worker             desc.m_DataLayout = OptionalDataLayout(operation, 10, model, data);
432*89c4ff92SAndroid Build Coastguard Worker         }
433*89c4ff92SAndroid Build Coastguard Worker     }
434*89c4ff92SAndroid Build Coastguard Worker     else
435*89c4ff92SAndroid Build Coastguard Worker     {
436*89c4ff92SAndroid Build Coastguard Worker         // one input, 6 parameters (padding, stridex, stridey, width, height, activation type)
437*89c4ff92SAndroid Build Coastguard Worker         ::android::nn::PaddingScheme scheme;
438*89c4ff92SAndroid Build Coastguard Worker         if (!GetInputPaddingScheme(operation, 1, scheme, model, data) ||
439*89c4ff92SAndroid Build Coastguard Worker             !GetInputScalar(operation, 2, OperandType::INT32, desc.m_StrideX, model, data) ||
440*89c4ff92SAndroid Build Coastguard Worker             !GetInputScalar(operation, 3, OperandType::INT32, desc.m_StrideY, model, data) ||
441*89c4ff92SAndroid Build Coastguard Worker             !GetInputScalar(operation, 4, OperandType::INT32, desc.m_PoolWidth, model, data) ||
442*89c4ff92SAndroid Build Coastguard Worker             !GetInputScalar(operation, 5, OperandType::INT32, desc.m_PoolHeight, model, data) ||
443*89c4ff92SAndroid Build Coastguard Worker             !GetInputActivationFunction(operation, 6, activation, model, data))
444*89c4ff92SAndroid Build Coastguard Worker         {
445*89c4ff92SAndroid Build Coastguard Worker             return Fail("%s: Operation has invalid inputs", operationName);
446*89c4ff92SAndroid Build Coastguard Worker         }
447*89c4ff92SAndroid Build Coastguard Worker 
448*89c4ff92SAndroid Build Coastguard Worker         if (Is12OrLaterOperand(*output))
449*89c4ff92SAndroid Build Coastguard Worker         {
450*89c4ff92SAndroid Build Coastguard Worker             desc.m_DataLayout = OptionalDataLayout(operation, 7, model, data);
451*89c4ff92SAndroid Build Coastguard Worker         }
452*89c4ff92SAndroid Build Coastguard Worker 
453*89c4ff92SAndroid Build Coastguard Worker         const armnnUtils::DataLayoutIndexed dataLayout(desc.m_DataLayout);
454*89c4ff92SAndroid Build Coastguard Worker         const unsigned int inputWidth  = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
455*89c4ff92SAndroid Build Coastguard Worker         const unsigned int inputHeight = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
456*89c4ff92SAndroid Build Coastguard Worker 
457*89c4ff92SAndroid Build Coastguard Worker         CalcPadding(inputWidth, desc.m_PoolWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, scheme);
458*89c4ff92SAndroid Build Coastguard Worker         CalcPadding(inputHeight, desc.m_PoolHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, scheme);
459*89c4ff92SAndroid Build Coastguard Worker     }
460*89c4ff92SAndroid Build Coastguard Worker 
461*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
462*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
463*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
464*89c4ff92SAndroid Build Coastguard Worker     {
465*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC(__func__,
466*89c4ff92SAndroid Build Coastguard Worker                                    IsPooling2dSupported,
467*89c4ff92SAndroid Build Coastguard Worker                                    data.m_Backends,
468*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
469*89c4ff92SAndroid Build Coastguard Worker                                    setBackend,
470*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo,
471*89c4ff92SAndroid Build Coastguard Worker                                    outputInfo,
472*89c4ff92SAndroid Build Coastguard Worker                                    desc);
473*89c4ff92SAndroid Build Coastguard Worker 
474*89c4ff92SAndroid Build Coastguard Worker     };
475*89c4ff92SAndroid Build Coastguard Worker 
476*89c4ff92SAndroid Build Coastguard Worker     if(IsDynamicTensor(outputInfo))
477*89c4ff92SAndroid Build Coastguard Worker     {
478*89c4ff92SAndroid Build Coastguard Worker         isSupported = AreDynamicTensorsSupported();
479*89c4ff92SAndroid Build Coastguard Worker     }
480*89c4ff92SAndroid Build Coastguard Worker     else
481*89c4ff92SAndroid Build Coastguard Worker     {
482*89c4ff92SAndroid Build Coastguard Worker         validateFunc(outputInfo, isSupported);
483*89c4ff92SAndroid Build Coastguard Worker     }
484*89c4ff92SAndroid Build Coastguard Worker 
485*89c4ff92SAndroid Build Coastguard Worker     if (!isSupported)
486*89c4ff92SAndroid Build Coastguard Worker     {
487*89c4ff92SAndroid Build Coastguard Worker         return false;
488*89c4ff92SAndroid Build Coastguard Worker     }
489*89c4ff92SAndroid Build Coastguard Worker 
490*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* pooling2dLayer = data.m_Network->AddPooling2dLayer(desc);
491*89c4ff92SAndroid Build Coastguard Worker     pooling2dLayer->SetBackendId(setBackend);
492*89c4ff92SAndroid Build Coastguard Worker     if (!pooling2dLayer)
493*89c4ff92SAndroid Build Coastguard Worker     {
494*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: AddPooling2dLayer failed", __func__);
495*89c4ff92SAndroid Build Coastguard Worker     }
496*89c4ff92SAndroid Build Coastguard Worker 
497*89c4ff92SAndroid Build Coastguard Worker     input.Connect(pooling2dLayer->GetInputSlot(0));
498*89c4ff92SAndroid Build Coastguard Worker 
499*89c4ff92SAndroid Build Coastguard Worker     if (!isSupported)
500*89c4ff92SAndroid Build Coastguard Worker     {
501*89c4ff92SAndroid Build Coastguard Worker         return false;
502*89c4ff92SAndroid Build Coastguard Worker     }
503*89c4ff92SAndroid Build Coastguard Worker 
504*89c4ff92SAndroid Build Coastguard Worker     return SetupAndTrackLayerOutputSlot(operation, 0, *pooling2dLayer, model,
505*89c4ff92SAndroid Build Coastguard Worker                                         data, nullptr, validateFunc, activation);
506*89c4ff92SAndroid Build Coastguard Worker }
507*89c4ff92SAndroid Build Coastguard Worker 
ConvertReduce(const Operation & operation,const Model & model,ConversionData & data,armnn::ReduceOperation reduceOperation)508*89c4ff92SAndroid Build Coastguard Worker bool ConvertReduce(const Operation& operation,
509*89c4ff92SAndroid Build Coastguard Worker                    const Model& model,
510*89c4ff92SAndroid Build Coastguard Worker                    ConversionData& data,
511*89c4ff92SAndroid Build Coastguard Worker                    armnn::ReduceOperation reduceOperation)
512*89c4ff92SAndroid Build Coastguard Worker {
513*89c4ff92SAndroid Build Coastguard Worker     armnn::ReduceDescriptor descriptor;
514*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_ReduceOperation = reduceOperation;
515*89c4ff92SAndroid Build Coastguard Worker 
516*89c4ff92SAndroid Build Coastguard Worker     LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data);
517*89c4ff92SAndroid Build Coastguard Worker     if (!input.IsValid())
518*89c4ff92SAndroid Build Coastguard Worker     {
519*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Operation has invalid inputs", __func__);
520*89c4ff92SAndroid Build Coastguard Worker     }
521*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
522*89c4ff92SAndroid Build Coastguard Worker 
523*89c4ff92SAndroid Build Coastguard Worker     const Operand* output = GetOutputOperand(operation, 0, model);
524*89c4ff92SAndroid Build Coastguard Worker     if (!output)
525*89c4ff92SAndroid Build Coastguard Worker     {
526*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Could not read output 0", __func__);
527*89c4ff92SAndroid Build Coastguard Worker     }
528*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
529*89c4ff92SAndroid Build Coastguard Worker 
530*89c4ff92SAndroid Build Coastguard Worker     const Operand* axisOperand = GetInputOperand(operation, 1, model);
531*89c4ff92SAndroid Build Coastguard Worker     if (!axisOperand)
532*89c4ff92SAndroid Build Coastguard Worker     {
533*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Could not read input 1", __func__);
534*89c4ff92SAndroid Build Coastguard Worker     }
535*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> axis;
536*89c4ff92SAndroid Build Coastguard Worker     if (!GetTensorInt32Values(*axisOperand, axis, model, data))
537*89c4ff92SAndroid Build Coastguard Worker     {
538*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Input 1 has invalid values", __func__);
539*89c4ff92SAndroid Build Coastguard Worker     }
540*89c4ff92SAndroid Build Coastguard Worker 
541*89c4ff92SAndroid Build Coastguard Worker     // Convert the axis to unsigned int and remove duplicates.
542*89c4ff92SAndroid Build Coastguard Worker     unsigned int rank = inputInfo.GetNumDimensions();
543*89c4ff92SAndroid Build Coastguard Worker     std::set<unsigned int> uniqueAxis;
544*89c4ff92SAndroid Build Coastguard Worker     std::transform(axis.begin(), axis.end(),
545*89c4ff92SAndroid Build Coastguard Worker                    std::inserter(uniqueAxis, uniqueAxis.begin()),
546*89c4ff92SAndroid Build Coastguard Worker                    [rank](int i) -> unsigned int { return (i + rank) % rank; });
547*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_vAxis.assign(uniqueAxis.begin(), uniqueAxis.end());
548*89c4ff92SAndroid Build Coastguard Worker 
549*89c4ff92SAndroid Build Coastguard Worker     // Get the "keep dims" flag.
550*89c4ff92SAndroid Build Coastguard Worker     if (!GetInputScalar(operation, 2, OperandType::BOOL, descriptor.m_KeepDims, model, data))
551*89c4ff92SAndroid Build Coastguard Worker     {
552*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Could not read input 2", __func__);
553*89c4ff92SAndroid Build Coastguard Worker     }
554*89c4ff92SAndroid Build Coastguard Worker 
555*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
556*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
557*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
558*89c4ff92SAndroid Build Coastguard Worker     {
559*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC(__func__,
560*89c4ff92SAndroid Build Coastguard Worker                                    IsReduceSupported,
561*89c4ff92SAndroid Build Coastguard Worker                                    data.m_Backends,
562*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
563*89c4ff92SAndroid Build Coastguard Worker                                    setBackend,
564*89c4ff92SAndroid Build Coastguard Worker                                    inputInfo,
565*89c4ff92SAndroid Build Coastguard Worker                                    outputInfo,
566*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
567*89c4ff92SAndroid Build Coastguard Worker     };
568*89c4ff92SAndroid Build Coastguard Worker 
569*89c4ff92SAndroid Build Coastguard Worker     if(!IsDynamicTensor(outputInfo))
570*89c4ff92SAndroid Build Coastguard Worker     {
571*89c4ff92SAndroid Build Coastguard Worker         validateFunc(outputInfo, isSupported);
572*89c4ff92SAndroid Build Coastguard Worker     }
573*89c4ff92SAndroid Build Coastguard Worker     else
574*89c4ff92SAndroid Build Coastguard Worker     {
575*89c4ff92SAndroid Build Coastguard Worker         isSupported = AreDynamicTensorsSupported();
576*89c4ff92SAndroid Build Coastguard Worker     }
577*89c4ff92SAndroid Build Coastguard Worker 
578*89c4ff92SAndroid Build Coastguard Worker     if (!isSupported)
579*89c4ff92SAndroid Build Coastguard Worker     {
580*89c4ff92SAndroid Build Coastguard Worker         return false;
581*89c4ff92SAndroid Build Coastguard Worker     }
582*89c4ff92SAndroid Build Coastguard Worker 
583*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const layer = data.m_Network->AddReduceLayer(descriptor);
584*89c4ff92SAndroid Build Coastguard Worker     layer->SetBackendId(setBackend);
585*89c4ff92SAndroid Build Coastguard Worker     assert(layer != nullptr);
586*89c4ff92SAndroid Build Coastguard Worker     input.Connect(layer->GetInputSlot(0));
587*89c4ff92SAndroid Build Coastguard Worker 
588*89c4ff92SAndroid Build Coastguard Worker     return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data, nullptr, validateFunc);
589*89c4ff92SAndroid Build Coastguard Worker }
590*89c4ff92SAndroid Build Coastguard Worker 
591*89c4ff92SAndroid Build Coastguard Worker 
ConvertToActivation(const Operation & operation,const char * operationName,const armnn::ActivationDescriptor & activationDesc,const Model & model,ConversionData & data)592*89c4ff92SAndroid Build Coastguard Worker bool ConvertToActivation(const Operation& operation,
593*89c4ff92SAndroid Build Coastguard Worker                          const char* operationName,
594*89c4ff92SAndroid Build Coastguard Worker                          const armnn::ActivationDescriptor& activationDesc,
595*89c4ff92SAndroid Build Coastguard Worker                          const Model& model,
596*89c4ff92SAndroid Build Coastguard Worker                          ConversionData& data)
597*89c4ff92SAndroid Build Coastguard Worker {
598*89c4ff92SAndroid Build Coastguard Worker     LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data);
599*89c4ff92SAndroid Build Coastguard Worker     if (!input.IsValid())
600*89c4ff92SAndroid Build Coastguard Worker     {
601*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Input 0 is invalid", operationName);
602*89c4ff92SAndroid Build Coastguard Worker     }
603*89c4ff92SAndroid Build Coastguard Worker 
604*89c4ff92SAndroid Build Coastguard Worker     const Operand* outputOperand = GetOutputOperand(operation, 0, model);
605*89c4ff92SAndroid Build Coastguard Worker     if (!outputOperand)
606*89c4ff92SAndroid Build Coastguard Worker     {
607*89c4ff92SAndroid Build Coastguard Worker         return false;
608*89c4ff92SAndroid Build Coastguard Worker     }
609*89c4ff92SAndroid Build Coastguard Worker 
610*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& outInfo = GetTensorInfoForOperand(*outputOperand);
611*89c4ff92SAndroid Build Coastguard Worker 
612*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
613*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
614*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
615*89c4ff92SAndroid Build Coastguard Worker     {
616*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC(__func__,
617*89c4ff92SAndroid Build Coastguard Worker                                    IsActivationSupported,
618*89c4ff92SAndroid Build Coastguard Worker                                    data.m_Backends,
619*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
620*89c4ff92SAndroid Build Coastguard Worker                                    setBackend,
621*89c4ff92SAndroid Build Coastguard Worker                                    input.GetTensorInfo(),
622*89c4ff92SAndroid Build Coastguard Worker                                    outInfo,
623*89c4ff92SAndroid Build Coastguard Worker                                    activationDesc);
624*89c4ff92SAndroid Build Coastguard Worker     };
625*89c4ff92SAndroid Build Coastguard Worker 
626*89c4ff92SAndroid Build Coastguard Worker     if(IsDynamicTensor(outInfo))
627*89c4ff92SAndroid Build Coastguard Worker     {
628*89c4ff92SAndroid Build Coastguard Worker         isSupported = AreDynamicTensorsSupported();
629*89c4ff92SAndroid Build Coastguard Worker     }
630*89c4ff92SAndroid Build Coastguard Worker     else
631*89c4ff92SAndroid Build Coastguard Worker     {
632*89c4ff92SAndroid Build Coastguard Worker         validateFunc(outInfo, isSupported);
633*89c4ff92SAndroid Build Coastguard Worker     }
634*89c4ff92SAndroid Build Coastguard Worker 
635*89c4ff92SAndroid Build Coastguard Worker     if (!isSupported)
636*89c4ff92SAndroid Build Coastguard Worker     {
637*89c4ff92SAndroid Build Coastguard Worker         return false;
638*89c4ff92SAndroid Build Coastguard Worker     }
639*89c4ff92SAndroid Build Coastguard Worker 
640*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* layer = data.m_Network->AddActivationLayer(activationDesc);
641*89c4ff92SAndroid Build Coastguard Worker     layer->SetBackendId(setBackend);
642*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
643*89c4ff92SAndroid Build Coastguard Worker     input.Connect(layer->GetInputSlot(0));
644*89c4ff92SAndroid Build Coastguard Worker 
645*89c4ff92SAndroid Build Coastguard Worker     return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data, nullptr, validateFunc);
646*89c4ff92SAndroid Build Coastguard Worker }
647*89c4ff92SAndroid Build Coastguard Worker 
DequantizeIfRequired(size_t operand_index,const Operation & operation,const Model & model,const ConversionData & data)648*89c4ff92SAndroid Build Coastguard Worker DequantizeResult DequantizeIfRequired(size_t operand_index,
649*89c4ff92SAndroid Build Coastguard Worker                                       const Operation& operation,
650*89c4ff92SAndroid Build Coastguard Worker                                       const Model& model,
651*89c4ff92SAndroid Build Coastguard Worker                                       const ConversionData& data)
652*89c4ff92SAndroid Build Coastguard Worker {
653*89c4ff92SAndroid Build Coastguard Worker     const Operand* weightsOperand = GetInputOperand(operation, operand_index, model);
654*89c4ff92SAndroid Build Coastguard Worker     if (!weightsOperand)
655*89c4ff92SAndroid Build Coastguard Worker     {
656*89c4ff92SAndroid Build Coastguard Worker         return { nullptr, 0, armnn::TensorInfo(), DequantizeStatus::INVALID_OPERAND };
657*89c4ff92SAndroid Build Coastguard Worker     }
658*89c4ff92SAndroid Build Coastguard Worker 
659*89c4ff92SAndroid Build Coastguard Worker     if (IsOperandConstant(*weightsOperand))
660*89c4ff92SAndroid Build Coastguard Worker     {
661*89c4ff92SAndroid Build Coastguard Worker         // Weights are already constant
662*89c4ff92SAndroid Build Coastguard Worker         return { nullptr, 0, armnn::TensorInfo(), DequantizeStatus::NOT_REQUIRED };
663*89c4ff92SAndroid Build Coastguard Worker     }
664*89c4ff92SAndroid Build Coastguard Worker 
665*89c4ff92SAndroid Build Coastguard Worker     const size_t weightsInputIndex = operation.inputs[operand_index];
666*89c4ff92SAndroid Build Coastguard Worker 
667*89c4ff92SAndroid Build Coastguard Worker     // The weights are a non const tensor, this indicates they might be the output of a dequantize op.
668*89c4ff92SAndroid Build Coastguard Worker     // Iterate over the nodes and find the previous operation which should be DEQUANTIZE
669*89c4ff92SAndroid Build Coastguard Worker     for (uint32_t operationIdx = 0; operationIdx < getMainModel(model).operations.size(); ++operationIdx)
670*89c4ff92SAndroid Build Coastguard Worker     {
671*89c4ff92SAndroid Build Coastguard Worker         // Search for the DEQUANTIZE op which has the operand with index equal to operandIndex
672*89c4ff92SAndroid Build Coastguard Worker         const auto& operationIt = getMainModel(model).operations[operationIdx];
673*89c4ff92SAndroid Build Coastguard Worker         if (operationIt.type != OperationType::DEQUANTIZE)
674*89c4ff92SAndroid Build Coastguard Worker         {
675*89c4ff92SAndroid Build Coastguard Worker             continue;
676*89c4ff92SAndroid Build Coastguard Worker         }
677*89c4ff92SAndroid Build Coastguard Worker 
678*89c4ff92SAndroid Build Coastguard Worker         size_t outOpIndex = weightsInputIndex + 1;
679*89c4ff92SAndroid Build Coastguard Worker         for (size_t i = 0; outOpIndex != weightsInputIndex && i < operationIt.outputs.size(); ++i)
680*89c4ff92SAndroid Build Coastguard Worker         {
681*89c4ff92SAndroid Build Coastguard Worker             outOpIndex = operationIt.outputs[i];
682*89c4ff92SAndroid Build Coastguard Worker         }
683*89c4ff92SAndroid Build Coastguard Worker 
684*89c4ff92SAndroid Build Coastguard Worker         if (outOpIndex != weightsInputIndex)
685*89c4ff92SAndroid Build Coastguard Worker         {
686*89c4ff92SAndroid Build Coastguard Worker             continue;
687*89c4ff92SAndroid Build Coastguard Worker         }
688*89c4ff92SAndroid Build Coastguard Worker 
689*89c4ff92SAndroid Build Coastguard Worker         const Operand* operand = GetInputOperand(operationIt, 0, model);
690*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(operand);
691*89c4ff92SAndroid Build Coastguard Worker 
692*89c4ff92SAndroid Build Coastguard Worker         if (!IsQSymm8(*operand))
693*89c4ff92SAndroid Build Coastguard Worker         {
694*89c4ff92SAndroid Build Coastguard Worker             // Only supporting dequantize from QSYMM8 to FLOAT
695*89c4ff92SAndroid Build Coastguard Worker             break;
696*89c4ff92SAndroid Build Coastguard Worker         }
697*89c4ff92SAndroid Build Coastguard Worker 
698*89c4ff92SAndroid Build Coastguard Worker         // Allocate a new buffer for the dequantized data and manually dequantize
699*89c4ff92SAndroid Build Coastguard Worker         const void* startValue = GetOperandValueReadOnlyAddress(*operand, model, data);
700*89c4ff92SAndroid Build Coastguard Worker         if (!startValue)
701*89c4ff92SAndroid Build Coastguard Worker         {
702*89c4ff92SAndroid Build Coastguard Worker             // Failed to get the operand address
703*89c4ff92SAndroid Build Coastguard Worker             break;
704*89c4ff92SAndroid Build Coastguard Worker         }
705*89c4ff92SAndroid Build Coastguard Worker 
706*89c4ff92SAndroid Build Coastguard Worker         const uint8_t* quantizedBuffer = reinterpret_cast<const uint8_t*>(startValue);
707*89c4ff92SAndroid Build Coastguard Worker         size_t dequantizedBufferLength = operand->location.length;
708*89c4ff92SAndroid Build Coastguard Worker         const float quantizationScale  = operand->scale;
709*89c4ff92SAndroid Build Coastguard Worker 
710*89c4ff92SAndroid Build Coastguard Worker         auto dequantizedBuffer = std::make_unique<float[]>(dequantizedBufferLength + 1);
711*89c4ff92SAndroid Build Coastguard Worker         for (size_t i = 0; i < dequantizedBufferLength; ++i)
712*89c4ff92SAndroid Build Coastguard Worker         {
713*89c4ff92SAndroid Build Coastguard Worker             float* dstPtr = dequantizedBuffer.get();
714*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(dstPtr);
715*89c4ff92SAndroid Build Coastguard Worker             *dstPtr++ = quantizedBuffer[i] * quantizationScale;
716*89c4ff92SAndroid Build Coastguard Worker         }
717*89c4ff92SAndroid Build Coastguard Worker 
718*89c4ff92SAndroid Build Coastguard Worker         // Construct tensor info for dequantized ConstTensor
719*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo tensorInfo(operand->dimensions.size(),
720*89c4ff92SAndroid Build Coastguard Worker                                      operand->dimensions.data(),
721*89c4ff92SAndroid Build Coastguard Worker                                      armnn::DataType::Float32);
722*89c4ff92SAndroid Build Coastguard Worker 
723*89c4ff92SAndroid Build Coastguard Worker         return { std::move(dequantizedBuffer), dequantizedBufferLength * sizeof(float),
724*89c4ff92SAndroid Build Coastguard Worker                  std::move(tensorInfo),
725*89c4ff92SAndroid Build Coastguard Worker                  DequantizeStatus::SUCCESS };
726*89c4ff92SAndroid Build Coastguard Worker     }
727*89c4ff92SAndroid Build Coastguard Worker 
728*89c4ff92SAndroid Build Coastguard Worker     return { nullptr, 0, armnn::TensorInfo() , DequantizeStatus::NOT_REQUIRED};
729*89c4ff92SAndroid Build Coastguard Worker }
730*89c4ff92SAndroid Build Coastguard Worker 
DequantizeAndMakeConstTensorPin(const Operation & operation,const Model & model,const ConversionData & data,size_t operandIndex,bool optional)731*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin DequantizeAndMakeConstTensorPin(const Operation& operation,
732*89c4ff92SAndroid Build Coastguard Worker                                                const Model& model,
733*89c4ff92SAndroid Build Coastguard Worker                                                const ConversionData& data,
734*89c4ff92SAndroid Build Coastguard Worker                                                size_t operandIndex,
735*89c4ff92SAndroid Build Coastguard Worker                                                bool optional)
736*89c4ff92SAndroid Build Coastguard Worker {
737*89c4ff92SAndroid Build Coastguard Worker     DequantizeResult dequantized = DequantizeIfRequired(operandIndex,operation, model, data);
738*89c4ff92SAndroid Build Coastguard Worker 
739*89c4ff92SAndroid Build Coastguard Worker     DequantizeStatus status = std::get<3>(dequantized);
740*89c4ff92SAndroid Build Coastguard Worker     switch (status)
741*89c4ff92SAndroid Build Coastguard Worker     {
742*89c4ff92SAndroid Build Coastguard Worker         case DequantizeStatus::INVALID_OPERAND:
743*89c4ff92SAndroid Build Coastguard Worker         {
744*89c4ff92SAndroid Build Coastguard Worker             // return invalid const tensor pin
745*89c4ff92SAndroid Build Coastguard Worker             return ConstTensorPin();
746*89c4ff92SAndroid Build Coastguard Worker         }
747*89c4ff92SAndroid Build Coastguard Worker         case DequantizeStatus::NOT_REQUIRED:
748*89c4ff92SAndroid Build Coastguard Worker         {
749*89c4ff92SAndroid Build Coastguard Worker             return ConvertOperationInputToConstTensorPin(
750*89c4ff92SAndroid Build Coastguard Worker                 operation, operandIndex, model, data, g_DontPermute, nullptr, optional);
751*89c4ff92SAndroid Build Coastguard Worker         }
752*89c4ff92SAndroid Build Coastguard Worker         case DequantizeStatus::SUCCESS:
753*89c4ff92SAndroid Build Coastguard Worker         default:
754*89c4ff92SAndroid Build Coastguard Worker         {
755*89c4ff92SAndroid Build Coastguard Worker             return ConstTensorPin(
756*89c4ff92SAndroid Build Coastguard Worker                 std::get<2>(dequantized), std::get<0>(dequantized).get(), std::get<1>(dequantized), g_DontPermute);
757*89c4ff92SAndroid Build Coastguard Worker         }
758*89c4ff92SAndroid Build Coastguard Worker     }
759*89c4ff92SAndroid Build Coastguard Worker }
760*89c4ff92SAndroid Build Coastguard Worker 
GetInputPaddingScheme(const Operation & operation,uint32_t inputIndex,PaddingScheme & outPaddingScheme,const Model & model,const ConversionData & data)761*89c4ff92SAndroid Build Coastguard Worker bool GetInputPaddingScheme(const Operation& operation,
762*89c4ff92SAndroid Build Coastguard Worker                            uint32_t inputIndex,
763*89c4ff92SAndroid Build Coastguard Worker                            PaddingScheme& outPaddingScheme,
764*89c4ff92SAndroid Build Coastguard Worker                            const Model& model,
765*89c4ff92SAndroid Build Coastguard Worker                            const ConversionData& data)
766*89c4ff92SAndroid Build Coastguard Worker {
767*89c4ff92SAndroid Build Coastguard Worker     int32_t paddingSchemeAsInt;
768*89c4ff92SAndroid Build Coastguard Worker     if (!GetInputInt32(operation, inputIndex, paddingSchemeAsInt, model, data))
769*89c4ff92SAndroid Build Coastguard Worker     {
770*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: failed to get padding scheme input value", __func__);
771*89c4ff92SAndroid Build Coastguard Worker     }
772*89c4ff92SAndroid Build Coastguard Worker 
773*89c4ff92SAndroid Build Coastguard Worker     outPaddingScheme = static_cast<::android::nn::PaddingScheme>(paddingSchemeAsInt);
774*89c4ff92SAndroid Build Coastguard Worker     return true;
775*89c4ff92SAndroid Build Coastguard Worker }
776*89c4ff92SAndroid Build Coastguard Worker 
GetOperandValueReadOnlyAddress(const Operand & operand,const Model & model,const ConversionData & data,bool optional)777*89c4ff92SAndroid Build Coastguard Worker const void* GetOperandValueReadOnlyAddress(const Operand& operand,
778*89c4ff92SAndroid Build Coastguard Worker                                            const Model& model,
779*89c4ff92SAndroid Build Coastguard Worker                                            const ConversionData& data,
780*89c4ff92SAndroid Build Coastguard Worker                                            bool optional)
781*89c4ff92SAndroid Build Coastguard Worker {
782*89c4ff92SAndroid Build Coastguard Worker     const void* valueStart = nullptr;
783*89c4ff92SAndroid Build Coastguard Worker     switch (operand.lifetime)
784*89c4ff92SAndroid Build Coastguard Worker     {
785*89c4ff92SAndroid Build Coastguard Worker         case OperandLifeTime::CONSTANT_COPY:
786*89c4ff92SAndroid Build Coastguard Worker         {
787*89c4ff92SAndroid Build Coastguard Worker             valueStart = model.operandValues.data() + operand.location.offset;
788*89c4ff92SAndroid Build Coastguard Worker             break;
789*89c4ff92SAndroid Build Coastguard Worker         }
790*89c4ff92SAndroid Build Coastguard Worker         case OperandLifeTime::POINTER:
791*89c4ff92SAndroid Build Coastguard Worker         {
792*89c4ff92SAndroid Build Coastguard Worker             // Pointer specified in the model
793*89c4ff92SAndroid Build Coastguard Worker             valueStart = std::get<const void*>(operand.location.pointer);
794*89c4ff92SAndroid Build Coastguard Worker             break;
795*89c4ff92SAndroid Build Coastguard Worker         }
796*89c4ff92SAndroid Build Coastguard Worker         case OperandLifeTime::CONSTANT_REFERENCE:
797*89c4ff92SAndroid Build Coastguard Worker         {
798*89c4ff92SAndroid Build Coastguard Worker             // Constant specified via a Memory object
799*89c4ff92SAndroid Build Coastguard Worker             valueStart = GetMemoryFromPool(operand.location, data.m_MemPools);
800*89c4ff92SAndroid Build Coastguard Worker             break;
801*89c4ff92SAndroid Build Coastguard Worker         }
802*89c4ff92SAndroid Build Coastguard Worker         case OperandLifeTime::NO_VALUE:
803*89c4ff92SAndroid Build Coastguard Worker         {
804*89c4ff92SAndroid Build Coastguard Worker             // An optional input tensor with no values is not an error so should not register as a fail
805*89c4ff92SAndroid Build Coastguard Worker             if (optional)
806*89c4ff92SAndroid Build Coastguard Worker             {
807*89c4ff92SAndroid Build Coastguard Worker                 valueStart = nullptr;
808*89c4ff92SAndroid Build Coastguard Worker                 break;
809*89c4ff92SAndroid Build Coastguard Worker             }
810*89c4ff92SAndroid Build Coastguard Worker             [[fallthrough]];
811*89c4ff92SAndroid Build Coastguard Worker         }
812*89c4ff92SAndroid Build Coastguard Worker         default:
813*89c4ff92SAndroid Build Coastguard Worker         {
814*89c4ff92SAndroid Build Coastguard Worker             VLOG(DRIVER) << __func__ << ": unsupported/invalid operand lifetime:: " << operand.lifetime;
815*89c4ff92SAndroid Build Coastguard Worker             valueStart = nullptr;
816*89c4ff92SAndroid Build Coastguard Worker         }
817*89c4ff92SAndroid Build Coastguard Worker     }
818*89c4ff92SAndroid Build Coastguard Worker 
819*89c4ff92SAndroid Build Coastguard Worker     return valueStart;
820*89c4ff92SAndroid Build Coastguard Worker }
821*89c4ff92SAndroid Build Coastguard Worker 
GetTensorInt32Values(const Operand & operand,std::vector<int32_t> & outValues,const Model & model,const ConversionData & data)822*89c4ff92SAndroid Build Coastguard Worker bool GetTensorInt32Values(const Operand& operand,
823*89c4ff92SAndroid Build Coastguard Worker                                  std::vector<int32_t>& outValues,
824*89c4ff92SAndroid Build Coastguard Worker                                  const Model& model,
825*89c4ff92SAndroid Build Coastguard Worker                                  const ConversionData& data)
826*89c4ff92SAndroid Build Coastguard Worker {
827*89c4ff92SAndroid Build Coastguard Worker     if (operand.type != OperandType::TENSOR_INT32)
828*89c4ff92SAndroid Build Coastguard Worker     {
829*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << __func__ << ": invalid operand type: " << operand.type;
830*89c4ff92SAndroid Build Coastguard Worker         return false;
831*89c4ff92SAndroid Build Coastguard Worker     }
832*89c4ff92SAndroid Build Coastguard Worker 
833*89c4ff92SAndroid Build Coastguard Worker     const void* startAddress = GetOperandValueReadOnlyAddress(operand, model, data);
834*89c4ff92SAndroid Build Coastguard Worker     if (!startAddress)
835*89c4ff92SAndroid Build Coastguard Worker     {
836*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << __func__ << ": failed to get operand address " << operand.type;
837*89c4ff92SAndroid Build Coastguard Worker         return false;
838*89c4ff92SAndroid Build Coastguard Worker     }
839*89c4ff92SAndroid Build Coastguard Worker 
840*89c4ff92SAndroid Build Coastguard Worker     // Check number of bytes is sensible
841*89c4ff92SAndroid Build Coastguard Worker     const uint32_t numBytes = operand.location.length;
842*89c4ff92SAndroid Build Coastguard Worker     if (numBytes % sizeof(int32_t) != 0)
843*89c4ff92SAndroid Build Coastguard Worker     {
844*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: invalid number of bytes: %i, expected to be a multiple of %i",
845*89c4ff92SAndroid Build Coastguard Worker                     __func__, numBytes, sizeof(int32_t));
846*89c4ff92SAndroid Build Coastguard Worker     }
847*89c4ff92SAndroid Build Coastguard Worker 
848*89c4ff92SAndroid Build Coastguard Worker     outValues.resize(numBytes / sizeof(int32_t));
849*89c4ff92SAndroid Build Coastguard Worker     memcpy(outValues.data(), startAddress, numBytes);
850*89c4ff92SAndroid Build Coastguard Worker     return true;
851*89c4ff92SAndroid Build Coastguard Worker }
852*89c4ff92SAndroid Build Coastguard Worker 
OptionalDataLayout(const Operation & operation,uint32_t inputIndex,const Model & model,ConversionData & data)853*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout OptionalDataLayout(const Operation& operation,
854*89c4ff92SAndroid Build Coastguard Worker                                      uint32_t inputIndex,
855*89c4ff92SAndroid Build Coastguard Worker                                      const Model& model,
856*89c4ff92SAndroid Build Coastguard Worker                                      ConversionData& data)
857*89c4ff92SAndroid Build Coastguard Worker {
858*89c4ff92SAndroid Build Coastguard Worker     const Operand* operand = GetInputOperand(operation, inputIndex, model);
859*89c4ff92SAndroid Build Coastguard Worker     if (!operand)
860*89c4ff92SAndroid Build Coastguard Worker     {
861*89c4ff92SAndroid Build Coastguard Worker         return armnn::DataLayout::NHWC;
862*89c4ff92SAndroid Build Coastguard Worker     }
863*89c4ff92SAndroid Build Coastguard Worker 
864*89c4ff92SAndroid Build Coastguard Worker     if (!IsBool(*operand))
865*89c4ff92SAndroid Build Coastguard Worker     {
866*89c4ff92SAndroid Build Coastguard Worker         return armnn::DataLayout::NHWC;
867*89c4ff92SAndroid Build Coastguard Worker     }
868*89c4ff92SAndroid Build Coastguard Worker 
869*89c4ff92SAndroid Build Coastguard Worker     const void* valueAddress = GetOperandValueReadOnlyAddress(*operand, model, data);
870*89c4ff92SAndroid Build Coastguard Worker     if (!valueAddress)
871*89c4ff92SAndroid Build Coastguard Worker     {
872*89c4ff92SAndroid Build Coastguard Worker         return armnn::DataLayout::NHWC;
873*89c4ff92SAndroid Build Coastguard Worker     }
874*89c4ff92SAndroid Build Coastguard Worker 
875*89c4ff92SAndroid Build Coastguard Worker     if (*(static_cast<const bool*>(valueAddress)))
876*89c4ff92SAndroid Build Coastguard Worker     {
877*89c4ff92SAndroid Build Coastguard Worker         return armnn::DataLayout::NCHW;
878*89c4ff92SAndroid Build Coastguard Worker     }
879*89c4ff92SAndroid Build Coastguard Worker     else
880*89c4ff92SAndroid Build Coastguard Worker     {
881*89c4ff92SAndroid Build Coastguard Worker         return armnn::DataLayout::NHWC;
882*89c4ff92SAndroid Build Coastguard Worker     }
883*89c4ff92SAndroid Build Coastguard Worker }
884*89c4ff92SAndroid Build Coastguard Worker 
ProcessActivation(const armnn::TensorInfo & tensorInfo,ActivationFn activation,armnn::IConnectableLayer * prevLayer,ConversionData & data)885*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* ProcessActivation(const armnn::TensorInfo& tensorInfo,
886*89c4ff92SAndroid Build Coastguard Worker                                             ActivationFn activation,
887*89c4ff92SAndroid Build Coastguard Worker                                             armnn::IConnectableLayer* prevLayer,
888*89c4ff92SAndroid Build Coastguard Worker                                             ConversionData& data)
889*89c4ff92SAndroid Build Coastguard Worker {
890*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(prevLayer->GetNumOutputSlots() == 1);
891*89c4ff92SAndroid Build Coastguard Worker 
892*89c4ff92SAndroid Build Coastguard Worker     prevLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
893*89c4ff92SAndroid Build Coastguard Worker 
894*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* activationLayer = prevLayer;
895*89c4ff92SAndroid Build Coastguard Worker 
896*89c4ff92SAndroid Build Coastguard Worker     if (activation != ActivationFn::kActivationNone)
897*89c4ff92SAndroid Build Coastguard Worker     {
898*89c4ff92SAndroid Build Coastguard Worker         armnn::ActivationDescriptor activationDesc;
899*89c4ff92SAndroid Build Coastguard Worker         switch (activation)
900*89c4ff92SAndroid Build Coastguard Worker         {
901*89c4ff92SAndroid Build Coastguard Worker             case ActivationFn::kActivationRelu:
902*89c4ff92SAndroid Build Coastguard Worker             {
903*89c4ff92SAndroid Build Coastguard Worker                 activationDesc.m_Function = armnn::ActivationFunction::ReLu;
904*89c4ff92SAndroid Build Coastguard Worker                 break;
905*89c4ff92SAndroid Build Coastguard Worker             }
906*89c4ff92SAndroid Build Coastguard Worker             case ActivationFn::kActivationRelu1:
907*89c4ff92SAndroid Build Coastguard Worker             {
908*89c4ff92SAndroid Build Coastguard Worker                 activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
909*89c4ff92SAndroid Build Coastguard Worker                 activationDesc.m_A = 1.0f;
910*89c4ff92SAndroid Build Coastguard Worker                 activationDesc.m_B = -1.0f;
911*89c4ff92SAndroid Build Coastguard Worker                 break;
912*89c4ff92SAndroid Build Coastguard Worker             }
913*89c4ff92SAndroid Build Coastguard Worker             case ActivationFn::kActivationRelu6:
914*89c4ff92SAndroid Build Coastguard Worker             {
915*89c4ff92SAndroid Build Coastguard Worker                 activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
916*89c4ff92SAndroid Build Coastguard Worker                 activationDesc.m_A = 6.0f;
917*89c4ff92SAndroid Build Coastguard Worker                 break;
918*89c4ff92SAndroid Build Coastguard Worker             }
919*89c4ff92SAndroid Build Coastguard Worker             case ActivationFn::kActivationSigmoid:
920*89c4ff92SAndroid Build Coastguard Worker             {
921*89c4ff92SAndroid Build Coastguard Worker                 activationDesc.m_Function = armnn::ActivationFunction::Sigmoid;
922*89c4ff92SAndroid Build Coastguard Worker                 break;
923*89c4ff92SAndroid Build Coastguard Worker             }
924*89c4ff92SAndroid Build Coastguard Worker             case ActivationFn::kActivationTanh:
925*89c4ff92SAndroid Build Coastguard Worker             {
926*89c4ff92SAndroid Build Coastguard Worker                 activationDesc.m_Function = armnn::ActivationFunction::TanH;
927*89c4ff92SAndroid Build Coastguard Worker                 activationDesc.m_A = 1.0f;
928*89c4ff92SAndroid Build Coastguard Worker                 activationDesc.m_B = 1.0f;
929*89c4ff92SAndroid Build Coastguard Worker                 break;
930*89c4ff92SAndroid Build Coastguard Worker             }
931*89c4ff92SAndroid Build Coastguard Worker             default:
932*89c4ff92SAndroid Build Coastguard Worker             {
933*89c4ff92SAndroid Build Coastguard Worker                 Fail("%s: Invalid activation enum value %i", __func__, activation);
934*89c4ff92SAndroid Build Coastguard Worker                 return nullptr;
935*89c4ff92SAndroid Build Coastguard Worker             }
936*89c4ff92SAndroid Build Coastguard Worker         }
937*89c4ff92SAndroid Build Coastguard Worker 
938*89c4ff92SAndroid Build Coastguard Worker         bool isSupported = false;
939*89c4ff92SAndroid Build Coastguard Worker         armnn::BackendId setBackend;
940*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC(__func__,
941*89c4ff92SAndroid Build Coastguard Worker                                    IsActivationSupported,
942*89c4ff92SAndroid Build Coastguard Worker                                    data.m_Backends,
943*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
944*89c4ff92SAndroid Build Coastguard Worker                                    setBackend,
945*89c4ff92SAndroid Build Coastguard Worker                                    prevLayer->GetOutputSlot(0).GetTensorInfo(),
946*89c4ff92SAndroid Build Coastguard Worker                                    tensorInfo,
947*89c4ff92SAndroid Build Coastguard Worker                                    activationDesc);
948*89c4ff92SAndroid Build Coastguard Worker         if (!isSupported)
949*89c4ff92SAndroid Build Coastguard Worker         {
950*89c4ff92SAndroid Build Coastguard Worker             return nullptr;
951*89c4ff92SAndroid Build Coastguard Worker         }
952*89c4ff92SAndroid Build Coastguard Worker 
953*89c4ff92SAndroid Build Coastguard Worker         activationLayer = data.m_Network->AddActivationLayer(activationDesc);
954*89c4ff92SAndroid Build Coastguard Worker         activationLayer->SetBackendId(setBackend);
955*89c4ff92SAndroid Build Coastguard Worker 
956*89c4ff92SAndroid Build Coastguard Worker         prevLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
957*89c4ff92SAndroid Build Coastguard Worker         activationLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
958*89c4ff92SAndroid Build Coastguard Worker     }
959*89c4ff92SAndroid Build Coastguard Worker 
960*89c4ff92SAndroid Build Coastguard Worker     return activationLayer;
961*89c4ff92SAndroid Build Coastguard Worker }
962*89c4ff92SAndroid Build Coastguard Worker 
SetupAndTrackLayerOutputSlot(const Operation & operation,uint32_t operationOutputIndex,armnn::IConnectableLayer & layer,uint32_t layerOutputIndex,const Model & model,ConversionData & data,const armnn::TensorInfo * overrideOutputInfo,const std::function<void (const armnn::TensorInfo &,bool &)> & validateFunc,const ActivationFn & activationFunction,bool inferOutputShapes)963*89c4ff92SAndroid Build Coastguard Worker bool SetupAndTrackLayerOutputSlot(const Operation& operation,
964*89c4ff92SAndroid Build Coastguard Worker                                   uint32_t operationOutputIndex,
965*89c4ff92SAndroid Build Coastguard Worker                                   armnn::IConnectableLayer& layer,
966*89c4ff92SAndroid Build Coastguard Worker                                   uint32_t layerOutputIndex,
967*89c4ff92SAndroid Build Coastguard Worker                                   const Model& model,
968*89c4ff92SAndroid Build Coastguard Worker                                   ConversionData& data,
969*89c4ff92SAndroid Build Coastguard Worker                                   const armnn::TensorInfo* overrideOutputInfo,
970*89c4ff92SAndroid Build Coastguard Worker                                   const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc,
971*89c4ff92SAndroid Build Coastguard Worker                                   const ActivationFn& activationFunction,
972*89c4ff92SAndroid Build Coastguard Worker                                   bool inferOutputShapes)
973*89c4ff92SAndroid Build Coastguard Worker {
974*89c4ff92SAndroid Build Coastguard Worker     const Operand* outputOperand = GetOutputOperand(operation, operationOutputIndex, model);
975*89c4ff92SAndroid Build Coastguard Worker     if ((outputOperand == nullptr) || (operationOutputIndex >= layer.GetNumOutputSlots()))
976*89c4ff92SAndroid Build Coastguard Worker     {
977*89c4ff92SAndroid Build Coastguard Worker         return false;
978*89c4ff92SAndroid Build Coastguard Worker     }
979*89c4ff92SAndroid Build Coastguard Worker 
980*89c4ff92SAndroid Build Coastguard Worker     armnn::IOutputSlot& outputSlot = layer.GetOutputSlot(layerOutputIndex);
981*89c4ff92SAndroid Build Coastguard Worker     if (overrideOutputInfo == nullptr)
982*89c4ff92SAndroid Build Coastguard Worker     {
983*89c4ff92SAndroid Build Coastguard Worker         outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand));
984*89c4ff92SAndroid Build Coastguard Worker     }
985*89c4ff92SAndroid Build Coastguard Worker     else
986*89c4ff92SAndroid Build Coastguard Worker     {
987*89c4ff92SAndroid Build Coastguard Worker         outputSlot.SetTensorInfo(*overrideOutputInfo);
988*89c4ff92SAndroid Build Coastguard Worker     }
989*89c4ff92SAndroid Build Coastguard Worker 
990*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
991*89c4ff92SAndroid Build Coastguard Worker     if (validateFunc && (IsDynamicTensor(outputSlot.GetTensorInfo()) || inferOutputShapes))
992*89c4ff92SAndroid Build Coastguard Worker     {
993*89c4ff92SAndroid Build Coastguard Worker         // Type one dynamic tensors require the previous layer's output shape for inference
994*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex)
995*89c4ff92SAndroid Build Coastguard Worker         {
996*89c4ff92SAndroid Build Coastguard Worker             if(!layer.GetInputSlot(inputSlotIndex).GetConnection())
997*89c4ff92SAndroid Build Coastguard Worker             {
998*89c4ff92SAndroid Build Coastguard Worker                 return false;
999*89c4ff92SAndroid Build Coastguard Worker             }
1000*89c4ff92SAndroid Build Coastguard Worker         }
1001*89c4ff92SAndroid Build Coastguard Worker         // IsTensorInfoSet will infer the dynamic output shape
1002*89c4ff92SAndroid Build Coastguard Worker         outputSlot.IsTensorInfoSet();
1003*89c4ff92SAndroid Build Coastguard Worker         // Once the shape is inferred we can validate it
1004*89c4ff92SAndroid Build Coastguard Worker         validateFunc(outputSlot.GetTensorInfo(), isSupported);
1005*89c4ff92SAndroid Build Coastguard Worker 
1006*89c4ff92SAndroid Build Coastguard Worker         if(!isSupported)
1007*89c4ff92SAndroid Build Coastguard Worker         {
1008*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex)
1009*89c4ff92SAndroid Build Coastguard Worker             {
1010*89c4ff92SAndroid Build Coastguard Worker                 layer.GetInputSlot(inputSlotIndex).GetConnection()->Disconnect(layer.GetInputSlot(inputSlotIndex));
1011*89c4ff92SAndroid Build Coastguard Worker             }
1012*89c4ff92SAndroid Build Coastguard Worker             return false;
1013*89c4ff92SAndroid Build Coastguard Worker         }
1014*89c4ff92SAndroid Build Coastguard Worker     }
1015*89c4ff92SAndroid Build Coastguard Worker 
1016*89c4ff92SAndroid Build Coastguard Worker     const uint32_t operandIndex = operation.outputs[operationOutputIndex];
1017*89c4ff92SAndroid Build Coastguard Worker 
1018*89c4ff92SAndroid Build Coastguard Worker     if (activationFunction != ActivationFn::kActivationNone)
1019*89c4ff92SAndroid Build Coastguard Worker     {
1020*89c4ff92SAndroid Build Coastguard Worker         const armnn::TensorInfo& activationOutputInfo = outputSlot.GetTensorInfo();
1021*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* const endLayer = ProcessActivation(activationOutputInfo, activationFunction,
1022*89c4ff92SAndroid Build Coastguard Worker                                                                      &layer, data);
1023*89c4ff92SAndroid Build Coastguard Worker 
1024*89c4ff92SAndroid Build Coastguard Worker         if (!endLayer)
1025*89c4ff92SAndroid Build Coastguard Worker         {
1026*89c4ff92SAndroid Build Coastguard Worker             return Fail("%s: ProcessActivation failed", __func__);
1027*89c4ff92SAndroid Build Coastguard Worker         }
1028*89c4ff92SAndroid Build Coastguard Worker 
1029*89c4ff92SAndroid Build Coastguard Worker         armnn::IOutputSlot& activationOutputSlot = endLayer->GetOutputSlot(layerOutputIndex);
1030*89c4ff92SAndroid Build Coastguard Worker         data.m_OutputSlotForOperand[operandIndex] = &activationOutputSlot;
1031*89c4ff92SAndroid Build Coastguard Worker     }
1032*89c4ff92SAndroid Build Coastguard Worker     else
1033*89c4ff92SAndroid Build Coastguard Worker     {
1034*89c4ff92SAndroid Build Coastguard Worker         data.m_OutputSlotForOperand[operandIndex] = &outputSlot;
1035*89c4ff92SAndroid Build Coastguard Worker     }
1036*89c4ff92SAndroid Build Coastguard Worker 
1037*89c4ff92SAndroid Build Coastguard Worker     return true;
1038*89c4ff92SAndroid Build Coastguard Worker }
1039*89c4ff92SAndroid Build Coastguard Worker 
IsConnectedToDequantize(armnn::IOutputSlot * ioutputSlot)1040*89c4ff92SAndroid Build Coastguard Worker bool IsConnectedToDequantize(armnn::IOutputSlot* ioutputSlot)
1041*89c4ff92SAndroid Build Coastguard Worker {
1042*89c4ff92SAndroid Build Coastguard Worker     VLOG(DRIVER) << "ConversionUtils::IsConnectedToDequantize()";
1043*89c4ff92SAndroid Build Coastguard Worker     if (!ioutputSlot)
1044*89c4ff92SAndroid Build Coastguard Worker     {
1045*89c4ff92SAndroid Build Coastguard Worker         return false;
1046*89c4ff92SAndroid Build Coastguard Worker     }
1047*89c4ff92SAndroid Build Coastguard Worker     VLOG(DRIVER) << "ConversionUtils::IsConnectedToDequantize() ioutputSlot is valid.";
1048*89c4ff92SAndroid Build Coastguard Worker     // Find the connections and layers..
1049*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer& owningLayer = ioutputSlot->GetOwningIConnectableLayer();
1050*89c4ff92SAndroid Build Coastguard Worker     if (owningLayer.GetType() == armnn::LayerType::Dequantize)
1051*89c4ff92SAndroid Build Coastguard Worker     {
1052*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << "ConversionUtils::IsConnectedToDequantize() connected to Dequantize Layer.";
1053*89c4ff92SAndroid Build Coastguard Worker         armnn::IInputSlot& inputSlot = owningLayer.GetInputSlot(0);
1054*89c4ff92SAndroid Build Coastguard Worker         armnn::IOutputSlot* connection = inputSlot.GetConnection();
1055*89c4ff92SAndroid Build Coastguard Worker         if (connection)
1056*89c4ff92SAndroid Build Coastguard Worker         {
1057*89c4ff92SAndroid Build Coastguard Worker             VLOG(DRIVER) << "ConversionUtils::IsConnectedToDequantize() Dequantize Layer has a connection.";
1058*89c4ff92SAndroid Build Coastguard Worker             armnn::IConnectableLayer& connectedLayer =
1059*89c4ff92SAndroid Build Coastguard Worker                     connection->GetOwningIConnectableLayer();
1060*89c4ff92SAndroid Build Coastguard Worker             if (connectedLayer.GetType() == armnn::LayerType::Constant)
1061*89c4ff92SAndroid Build Coastguard Worker             {
1062*89c4ff92SAndroid Build Coastguard Worker                 VLOG(DRIVER) << "ConversionUtils::IsConnectedToDequantize() Dequantize Layer connected to Constant";
1063*89c4ff92SAndroid Build Coastguard Worker                 return true;
1064*89c4ff92SAndroid Build Coastguard Worker             }
1065*89c4ff92SAndroid Build Coastguard Worker         }
1066*89c4ff92SAndroid Build Coastguard Worker     }
1067*89c4ff92SAndroid Build Coastguard Worker     return false;
1068*89c4ff92SAndroid Build Coastguard Worker }
1069*89c4ff92SAndroid Build Coastguard Worker 
1070*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn_driver
1071