1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "ConversionUtils.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker ///
10*89c4ff92SAndroid Build Coastguard Worker /// Helper classes
11*89c4ff92SAndroid Build Coastguard Worker ///
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker
LayerInputHandle()16*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle::LayerInputHandle()
17*89c4ff92SAndroid Build Coastguard Worker : m_OutputSlot(nullptr)
18*89c4ff92SAndroid Build Coastguard Worker , m_Valid(false)
19*89c4ff92SAndroid Build Coastguard Worker {}
20*89c4ff92SAndroid Build Coastguard Worker
LayerInputHandle(bool valid,armnn::IOutputSlot * outputSlot,armnn::TensorInfo tensorInfo)21*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle::LayerInputHandle(bool valid, armnn::IOutputSlot* outputSlot, armnn::TensorInfo tensorInfo)
22*89c4ff92SAndroid Build Coastguard Worker : m_OutputSlot(outputSlot)
23*89c4ff92SAndroid Build Coastguard Worker , m_Valid(valid)
24*89c4ff92SAndroid Build Coastguard Worker , m_TensorInfo(tensorInfo)
25*89c4ff92SAndroid Build Coastguard Worker {}
26*89c4ff92SAndroid Build Coastguard Worker
IsValid() const27*89c4ff92SAndroid Build Coastguard Worker bool LayerInputHandle::IsValid() const
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker return m_Valid;
30*89c4ff92SAndroid Build Coastguard Worker }
31*89c4ff92SAndroid Build Coastguard Worker
Connect(armnn::IInputSlot & inputSlot)32*89c4ff92SAndroid Build Coastguard Worker void LayerInputHandle::Connect(armnn::IInputSlot& inputSlot)
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(IsValid());
35*89c4ff92SAndroid Build Coastguard Worker if (m_OutputSlot)
36*89c4ff92SAndroid Build Coastguard Worker {
37*89c4ff92SAndroid Build Coastguard Worker m_OutputSlot->Connect(inputSlot);
38*89c4ff92SAndroid Build Coastguard Worker }
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker
Disconnect(armnn::IInputSlot & inputSlot)41*89c4ff92SAndroid Build Coastguard Worker void LayerInputHandle::Disconnect(armnn::IInputSlot& inputSlot)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(IsValid());
44*89c4ff92SAndroid Build Coastguard Worker if (m_OutputSlot)
45*89c4ff92SAndroid Build Coastguard Worker {
46*89c4ff92SAndroid Build Coastguard Worker m_OutputSlot->Disconnect(inputSlot);
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker
GetTensorInfo() const50*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& LayerInputHandle::GetTensorInfo() const
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker return m_TensorInfo;
53*89c4ff92SAndroid Build Coastguard Worker }
54*89c4ff92SAndroid Build Coastguard Worker
SanitizeQuantizationScale(LayerInputHandle & weight,LayerInputHandle & input)55*89c4ff92SAndroid Build Coastguard Worker void LayerInputHandle::SanitizeQuantizationScale(LayerInputHandle& weight, LayerInputHandle& input)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker if (m_OutputSlot)
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightInfo = weight.GetTensorInfo();
60*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputInfo = input.GetTensorInfo();
61*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasInfo = GetTensorInfo();
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker SanitizeBiasQuantizationScale(biasInfo, weightInfo, inputInfo);
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker m_TensorInfo = biasInfo;
66*89c4ff92SAndroid Build Coastguard Worker m_OutputSlot->SetTensorInfo(biasInfo);
67*89c4ff92SAndroid Build Coastguard Worker }
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker
GetOutputSlot() const70*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot* LayerInputHandle::GetOutputSlot() const
71*89c4ff92SAndroid Build Coastguard Worker {
72*89c4ff92SAndroid Build Coastguard Worker return m_OutputSlot;
73*89c4ff92SAndroid Build Coastguard Worker }
74*89c4ff92SAndroid Build Coastguard Worker
ConstTensorPin(bool optional)75*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin::ConstTensorPin(bool optional)
76*89c4ff92SAndroid Build Coastguard Worker : m_Optional(optional)
77*89c4ff92SAndroid Build Coastguard Worker {}
78*89c4ff92SAndroid Build Coastguard Worker
ConstTensorPin(armnn::TensorInfo & tensorInfo,const void * valueStart,uint32_t numBytes,const armnn::PermutationVector & mappings)79*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin::ConstTensorPin(armnn::TensorInfo& tensorInfo,
80*89c4ff92SAndroid Build Coastguard Worker const void* valueStart,
81*89c4ff92SAndroid Build Coastguard Worker uint32_t numBytes,
82*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& mappings)
83*89c4ff92SAndroid Build Coastguard Worker : m_Optional(false)
84*89c4ff92SAndroid Build Coastguard Worker {
85*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(numBytes);
86*89c4ff92SAndroid Build Coastguard Worker if (tensorInfo.GetNumBytes() != numBytes)
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "The size of ConstTensor does not match its TensorInfo.";
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker const bool needsSwizzling = (mappings.GetSize() > 0);
92*89c4ff92SAndroid Build Coastguard Worker if (needsSwizzling)
93*89c4ff92SAndroid Build Coastguard Worker {
94*89c4ff92SAndroid Build Coastguard Worker m_SwizzledTensorData.resize(tensorInfo.GetNumBytes());
95*89c4ff92SAndroid Build Coastguard Worker SwizzleAndroidNn4dTensorToArmNn(tensorInfo, valueStart, m_SwizzledTensorData.data(), mappings);
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker m_ConstTensor = armnn::ConstTensor(tensorInfo, m_SwizzledTensorData.data());
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker else
100*89c4ff92SAndroid Build Coastguard Worker {
101*89c4ff92SAndroid Build Coastguard Worker m_ConstTensor = armnn::ConstTensor(tensorInfo, valueStart);
102*89c4ff92SAndroid Build Coastguard Worker }
103*89c4ff92SAndroid Build Coastguard Worker }
104*89c4ff92SAndroid Build Coastguard Worker
IsValid() const105*89c4ff92SAndroid Build Coastguard Worker bool ConstTensorPin::IsValid() const
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker return m_ConstTensor.GetMemoryArea() != nullptr;
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker
IsOptional() const110*89c4ff92SAndroid Build Coastguard Worker bool ConstTensorPin::IsOptional() const
111*89c4ff92SAndroid Build Coastguard Worker {
112*89c4ff92SAndroid Build Coastguard Worker return m_Optional;
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker
GetConstTensor() const115*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor& ConstTensorPin::GetConstTensor() const
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker return m_ConstTensor;
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker
GetConstTensorPtr() const120*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor* ConstTensorPin::GetConstTensorPtr() const
121*89c4ff92SAndroid Build Coastguard Worker {
122*89c4ff92SAndroid Build Coastguard Worker if (IsValid() && m_ConstTensor.GetNumElements() > 0)
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker return &m_ConstTensor;
125*89c4ff92SAndroid Build Coastguard Worker }
126*89c4ff92SAndroid Build Coastguard Worker // tensor is either invalid, or has no elements (indicating an optional tensor that was not provided)
127*89c4ff92SAndroid Build Coastguard Worker return nullptr;
128*89c4ff92SAndroid Build Coastguard Worker }
129*89c4ff92SAndroid Build Coastguard Worker
130*89c4ff92SAndroid Build Coastguard Worker ///
131*89c4ff92SAndroid Build Coastguard Worker /// Utility functions
132*89c4ff92SAndroid Build Coastguard Worker ///
133*89c4ff92SAndroid Build Coastguard Worker
IsWeightsValid(const Operation & operation,uint32_t inputIndex,const Model & model)134*89c4ff92SAndroid Build Coastguard Worker bool IsWeightsValid(const Operation& operation,
135*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
136*89c4ff92SAndroid Build Coastguard Worker const Model& model)
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker const Operand* operand = GetInputOperand(operation, inputIndex, model);
139*89c4ff92SAndroid Build Coastguard Worker if (!operand)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker Fail("%s: failed to get input operand %i", __func__, inputIndex);
142*89c4ff92SAndroid Build Coastguard Worker return false;
143*89c4ff92SAndroid Build Coastguard Worker }
144*89c4ff92SAndroid Build Coastguard Worker
145*89c4ff92SAndroid Build Coastguard Worker if (operand->lifetime != OperandLifeTime::CONSTANT_COPY
146*89c4ff92SAndroid Build Coastguard Worker && operand->lifetime != OperandLifeTime::CONSTANT_REFERENCE
147*89c4ff92SAndroid Build Coastguard Worker && operand->lifetime != OperandLifeTime::NO_VALUE)
148*89c4ff92SAndroid Build Coastguard Worker {
149*89c4ff92SAndroid Build Coastguard Worker return false;
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker return true;
152*89c4ff92SAndroid Build Coastguard Worker }
153*89c4ff92SAndroid Build Coastguard Worker
ConvertOperandToConstTensorPin(const Operand & operand,const Model & model,const ConversionData & data,const armnn::PermutationVector & dimensionMappings,const armnn::TensorShape * overrideTensorShape,bool optional,const armnn::DataType * overrideDataType)154*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin ConvertOperandToConstTensorPin(const Operand& operand,
155*89c4ff92SAndroid Build Coastguard Worker const Model& model,
156*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data,
157*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& dimensionMappings,
158*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape* overrideTensorShape,
159*89c4ff92SAndroid Build Coastguard Worker bool optional,
160*89c4ff92SAndroid Build Coastguard Worker const armnn::DataType* overrideDataType)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker if (!IsOperandTypeSupportedForTensors(operand.type))
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << __func__ << ": unsupported operand type for tensor" << operand.type;
165*89c4ff92SAndroid Build Coastguard Worker return ConstTensorPin();
166*89c4ff92SAndroid Build Coastguard Worker }
167*89c4ff92SAndroid Build Coastguard Worker
168*89c4ff92SAndroid Build Coastguard Worker if (!optional && !IsOperandConstant(operand))
169*89c4ff92SAndroid Build Coastguard Worker {
170*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << __func__ << ": lifetime for input tensor: r" << operand.lifetime;
171*89c4ff92SAndroid Build Coastguard Worker return ConstTensorPin();
172*89c4ff92SAndroid Build Coastguard Worker }
173*89c4ff92SAndroid Build Coastguard Worker
174*89c4ff92SAndroid Build Coastguard Worker const void* const valueStart = GetOperandValueReadOnlyAddress(operand, model, data, optional);
175*89c4ff92SAndroid Build Coastguard Worker if (!valueStart)
176*89c4ff92SAndroid Build Coastguard Worker {
177*89c4ff92SAndroid Build Coastguard Worker if (optional)
178*89c4ff92SAndroid Build Coastguard Worker {
179*89c4ff92SAndroid Build Coastguard Worker // optional tensor with no values is not really an error; return it as invalid, but marked as optional
180*89c4ff92SAndroid Build Coastguard Worker return ConstTensorPin(true);
181*89c4ff92SAndroid Build Coastguard Worker }
182*89c4ff92SAndroid Build Coastguard Worker // mandatory tensor with no values
183*89c4ff92SAndroid Build Coastguard Worker Fail("%s: failed to get operand address", __func__);
184*89c4ff92SAndroid Build Coastguard Worker return ConstTensorPin();
185*89c4ff92SAndroid Build Coastguard Worker }
186*89c4ff92SAndroid Build Coastguard Worker
187*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo = GetTensorInfoForOperand(operand);
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Worker if (overrideTensorShape)
190*89c4ff92SAndroid Build Coastguard Worker {
191*89c4ff92SAndroid Build Coastguard Worker tensorInfo.SetShape(*overrideTensorShape);
192*89c4ff92SAndroid Build Coastguard Worker }
193*89c4ff92SAndroid Build Coastguard Worker
194*89c4ff92SAndroid Build Coastguard Worker if (overrideDataType)
195*89c4ff92SAndroid Build Coastguard Worker {
196*89c4ff92SAndroid Build Coastguard Worker tensorInfo.SetDataType(*overrideDataType);
197*89c4ff92SAndroid Build Coastguard Worker }
198*89c4ff92SAndroid Build Coastguard Worker
199*89c4ff92SAndroid Build Coastguard Worker // Make sure isConstant flag is set.
200*89c4ff92SAndroid Build Coastguard Worker tensorInfo.SetConstant();
201*89c4ff92SAndroid Build Coastguard Worker return ConstTensorPin(tensorInfo, valueStart, operand.location.length, dimensionMappings);
202*89c4ff92SAndroid Build Coastguard Worker }
203*89c4ff92SAndroid Build Coastguard Worker
ConvertToLayerInputHandle(const Operation & operation,uint32_t inputIndex,const Model & model,ConversionData & data,const armnn::PermutationVector & dimensionMappings,const LayerInputHandle * inputHandle)204*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle ConvertToLayerInputHandle(const Operation& operation,
205*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
206*89c4ff92SAndroid Build Coastguard Worker const Model& model,
207*89c4ff92SAndroid Build Coastguard Worker ConversionData& data,
208*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& dimensionMappings,
209*89c4ff92SAndroid Build Coastguard Worker const LayerInputHandle* inputHandle)
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker
212*89c4ff92SAndroid Build Coastguard Worker const Operand* operand = GetInputOperand(operation, inputIndex, model);
213*89c4ff92SAndroid Build Coastguard Worker if (!operand)
214*89c4ff92SAndroid Build Coastguard Worker {
215*89c4ff92SAndroid Build Coastguard Worker Fail("%s: failed to get input operand %i", __func__, inputIndex);
216*89c4ff92SAndroid Build Coastguard Worker return LayerInputHandle();
217*89c4ff92SAndroid Build Coastguard Worker }
218*89c4ff92SAndroid Build Coastguard Worker
219*89c4ff92SAndroid Build Coastguard Worker if (!IsOperandTypeSupportedForTensors(operand->type))
220*89c4ff92SAndroid Build Coastguard Worker {
221*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << __func__ << ": unsupported operand type for tensor: " << operand->type;
222*89c4ff92SAndroid Build Coastguard Worker return LayerInputHandle();
223*89c4ff92SAndroid Build Coastguard Worker }
224*89c4ff92SAndroid Build Coastguard Worker
225*89c4ff92SAndroid Build Coastguard Worker try
226*89c4ff92SAndroid Build Coastguard Worker {
227*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo operandTensorInfo = GetTensorInfoForOperand(*operand);
228*89c4ff92SAndroid Build Coastguard Worker
229*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(operandTensorInfo))
230*89c4ff92SAndroid Build Coastguard Worker {
231*89c4ff92SAndroid Build Coastguard Worker data.m_DynamicInputsEncountered = true;
232*89c4ff92SAndroid Build Coastguard Worker
233*89c4ff92SAndroid Build Coastguard Worker const uint32_t operandIndex = operation.inputs[inputIndex];
234*89c4ff92SAndroid Build Coastguard Worker
235*89c4ff92SAndroid Build Coastguard Worker // Check if the dynamic input tensors have been inferred by one of the previous layers
236*89c4ff92SAndroid Build Coastguard Worker // If not we can't support them
237*89c4ff92SAndroid Build Coastguard Worker if (data.m_OutputSlotForOperand.size() >= operandIndex && data.m_OutputSlotForOperand[operandIndex])
238*89c4ff92SAndroid Build Coastguard Worker {
239*89c4ff92SAndroid Build Coastguard Worker operandTensorInfo = data.m_OutputSlotForOperand[operandIndex]->GetTensorInfo();
240*89c4ff92SAndroid Build Coastguard Worker }
241*89c4ff92SAndroid Build Coastguard Worker else
242*89c4ff92SAndroid Build Coastguard Worker {
243*89c4ff92SAndroid Build Coastguard Worker Fail("%s: Type 2 dynamic input tensors are not supported", __func__);
244*89c4ff92SAndroid Build Coastguard Worker return LayerInputHandle();
245*89c4ff92SAndroid Build Coastguard Worker }
246*89c4ff92SAndroid Build Coastguard Worker }
247*89c4ff92SAndroid Build Coastguard Worker
248*89c4ff92SAndroid Build Coastguard Worker switch (operand->lifetime)
249*89c4ff92SAndroid Build Coastguard Worker {
250*89c4ff92SAndroid Build Coastguard Worker case OperandLifeTime::SUBGRAPH_INPUT:
251*89c4ff92SAndroid Build Coastguard Worker {
252*89c4ff92SAndroid Build Coastguard Worker // NOTE: We must check whether we can support the input tensor on at least one
253*89c4ff92SAndroid Build Coastguard Worker // of the provided backends; otherwise we cannot convert the operation
254*89c4ff92SAndroid Build Coastguard Worker bool isInputSupported = false;
255*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC(__func__,
256*89c4ff92SAndroid Build Coastguard Worker IsInputSupported,
257*89c4ff92SAndroid Build Coastguard Worker data.m_Backends,
258*89c4ff92SAndroid Build Coastguard Worker isInputSupported,
259*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId(),
260*89c4ff92SAndroid Build Coastguard Worker operandTensorInfo);
261*89c4ff92SAndroid Build Coastguard Worker
262*89c4ff92SAndroid Build Coastguard Worker if (!isInputSupported)
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker Fail("%s: unsupported input tensor", __func__);
265*89c4ff92SAndroid Build Coastguard Worker return LayerInputHandle();
266*89c4ff92SAndroid Build Coastguard Worker }
267*89c4ff92SAndroid Build Coastguard Worker
268*89c4ff92SAndroid Build Coastguard Worker [[clang::fallthrough]]; // intentional fallthrough
269*89c4ff92SAndroid Build Coastguard Worker }
270*89c4ff92SAndroid Build Coastguard Worker case OperandLifeTime::TEMPORARY_VARIABLE: // intentional fallthrough
271*89c4ff92SAndroid Build Coastguard Worker case OperandLifeTime::SUBGRAPH_OUTPUT:
272*89c4ff92SAndroid Build Coastguard Worker {
273*89c4ff92SAndroid Build Coastguard Worker // The tensor is either an operand internal to the model, or a model input.
274*89c4ff92SAndroid Build Coastguard Worker // It can be associated with an ArmNN output slot for an existing layer.
275*89c4ff92SAndroid Build Coastguard Worker
276*89c4ff92SAndroid Build Coastguard Worker // m_OutputSlotForOperand[...] can be nullptr if the previous layer could not be converted
277*89c4ff92SAndroid Build Coastguard Worker const uint32_t operandIndex = operation.inputs[inputIndex];
278*89c4ff92SAndroid Build Coastguard Worker return LayerInputHandle(true, data.m_OutputSlotForOperand[operandIndex], operandTensorInfo);
279*89c4ff92SAndroid Build Coastguard Worker }
280*89c4ff92SAndroid Build Coastguard Worker case OperandLifeTime::CONSTANT_COPY: // intentional fallthrough
281*89c4ff92SAndroid Build Coastguard Worker case OperandLifeTime::POINTER:
282*89c4ff92SAndroid Build Coastguard Worker case OperandLifeTime::CONSTANT_REFERENCE:
283*89c4ff92SAndroid Build Coastguard Worker {
284*89c4ff92SAndroid Build Coastguard Worker auto constantTensorDataType = operandTensorInfo.GetDataType();
285*89c4ff92SAndroid Build Coastguard Worker // The tensor has an already known constant value, and can be converted into an ArmNN Constant layer.
286*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin tensorPin = ConvertOperandToConstTensorPin(*operand,
287*89c4ff92SAndroid Build Coastguard Worker model,
288*89c4ff92SAndroid Build Coastguard Worker data,
289*89c4ff92SAndroid Build Coastguard Worker dimensionMappings,
290*89c4ff92SAndroid Build Coastguard Worker nullptr,
291*89c4ff92SAndroid Build Coastguard Worker false,
292*89c4ff92SAndroid Build Coastguard Worker &constantTensorDataType);
293*89c4ff92SAndroid Build Coastguard Worker if (tensorPin.IsValid())
294*89c4ff92SAndroid Build Coastguard Worker {
295*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
296*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
297*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC(__func__,
298*89c4ff92SAndroid Build Coastguard Worker IsConstantSupported,
299*89c4ff92SAndroid Build Coastguard Worker data.m_Backends,
300*89c4ff92SAndroid Build Coastguard Worker isSupported,
301*89c4ff92SAndroid Build Coastguard Worker setBackend,
302*89c4ff92SAndroid Build Coastguard Worker tensorPin.GetConstTensor().GetInfo());
303*89c4ff92SAndroid Build Coastguard Worker if (!isSupported)
304*89c4ff92SAndroid Build Coastguard Worker {
305*89c4ff92SAndroid Build Coastguard Worker return LayerInputHandle();
306*89c4ff92SAndroid Build Coastguard Worker }
307*89c4ff92SAndroid Build Coastguard Worker
308*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* constantLayer =
309*89c4ff92SAndroid Build Coastguard Worker data.m_Network->AddConstantLayer(tensorPin.GetConstTensor());
310*89c4ff92SAndroid Build Coastguard Worker constantLayer->SetBackendId(setBackend);
311*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = constantLayer->GetOutputSlot(0);
312*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo constantTensorInfo = tensorPin.GetConstTensor().GetInfo();
313*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(constantTensorInfo);
314*89c4ff92SAndroid Build Coastguard Worker
315*89c4ff92SAndroid Build Coastguard Worker return LayerInputHandle(true, &outputSlot, constantTensorInfo);
316*89c4ff92SAndroid Build Coastguard Worker }
317*89c4ff92SAndroid Build Coastguard Worker else
318*89c4ff92SAndroid Build Coastguard Worker {
319*89c4ff92SAndroid Build Coastguard Worker Fail("%s: invalid operand tensor", __func__);
320*89c4ff92SAndroid Build Coastguard Worker return LayerInputHandle();
321*89c4ff92SAndroid Build Coastguard Worker }
322*89c4ff92SAndroid Build Coastguard Worker break;
323*89c4ff92SAndroid Build Coastguard Worker }
324*89c4ff92SAndroid Build Coastguard Worker default:
325*89c4ff92SAndroid Build Coastguard Worker {
326*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << __func__ << ": unsupported lifetime for input tensor: " << operand->lifetime;
327*89c4ff92SAndroid Build Coastguard Worker return LayerInputHandle();
328*89c4ff92SAndroid Build Coastguard Worker }
329*89c4ff92SAndroid Build Coastguard Worker }
330*89c4ff92SAndroid Build Coastguard Worker }
331*89c4ff92SAndroid Build Coastguard Worker catch (UnsupportedOperand<OperandType>& e)
332*89c4ff92SAndroid Build Coastguard Worker {
333*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << __func__ << ": Operand type: " << e.m_type << " not supported in ArmnnDriver";
334*89c4ff92SAndroid Build Coastguard Worker return LayerInputHandle();
335*89c4ff92SAndroid Build Coastguard Worker }
336*89c4ff92SAndroid Build Coastguard Worker }
337*89c4ff92SAndroid Build Coastguard Worker
ConvertPaddings(const Operation & operation,const Model & model,ConversionData & data,unsigned int rank,armnn::PadDescriptor & padDescriptor)338*89c4ff92SAndroid Build Coastguard Worker bool ConvertPaddings(const Operation& operation,
339*89c4ff92SAndroid Build Coastguard Worker const Model& model,
340*89c4ff92SAndroid Build Coastguard Worker ConversionData& data,
341*89c4ff92SAndroid Build Coastguard Worker unsigned int rank,
342*89c4ff92SAndroid Build Coastguard Worker armnn::PadDescriptor& padDescriptor)
343*89c4ff92SAndroid Build Coastguard Worker {
344*89c4ff92SAndroid Build Coastguard Worker const Operand* paddingsOperand = GetInputOperand(operation, 1, model);
345*89c4ff92SAndroid Build Coastguard Worker if (!paddingsOperand)
346*89c4ff92SAndroid Build Coastguard Worker {
347*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Could not read paddings operand", __func__);
348*89c4ff92SAndroid Build Coastguard Worker }
349*89c4ff92SAndroid Build Coastguard Worker
350*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape paddingsOperandShape = GetTensorShapeForOperand(*paddingsOperand);
351*89c4ff92SAndroid Build Coastguard Worker if (paddingsOperandShape.GetNumDimensions() != 2 || paddingsOperandShape.GetNumElements() != rank * 2)
352*89c4ff92SAndroid Build Coastguard Worker {
353*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Operation has invalid paddings operand: expected shape [%d, 2]", __func__, rank);
354*89c4ff92SAndroid Build Coastguard Worker }
355*89c4ff92SAndroid Build Coastguard Worker
356*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> paddings;
357*89c4ff92SAndroid Build Coastguard Worker if (!GetTensorInt32Values(*paddingsOperand, paddings, model, data))
358*89c4ff92SAndroid Build Coastguard Worker {
359*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Operation has invalid or unsupported paddings operand", __func__);
360*89c4ff92SAndroid Build Coastguard Worker }
361*89c4ff92SAndroid Build Coastguard Worker
362*89c4ff92SAndroid Build Coastguard Worker // add padding for each dimension of input tensor.
363*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < paddings.size() - 1; i += 2)
364*89c4ff92SAndroid Build Coastguard Worker {
365*89c4ff92SAndroid Build Coastguard Worker int paddingBeforeInput = paddings[i];
366*89c4ff92SAndroid Build Coastguard Worker int paddingAfterInput = paddings[i + 1];
367*89c4ff92SAndroid Build Coastguard Worker
368*89c4ff92SAndroid Build Coastguard Worker if (paddingBeforeInput < 0 || paddingAfterInput < 0)
369*89c4ff92SAndroid Build Coastguard Worker {
370*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Operation has invalid paddings operand, invalid padding values.", __func__);
371*89c4ff92SAndroid Build Coastguard Worker }
372*89c4ff92SAndroid Build Coastguard Worker
373*89c4ff92SAndroid Build Coastguard Worker padDescriptor.m_PadList.emplace_back((unsigned int) paddingBeforeInput, (unsigned int) paddingAfterInput);
374*89c4ff92SAndroid Build Coastguard Worker }
375*89c4ff92SAndroid Build Coastguard Worker
376*89c4ff92SAndroid Build Coastguard Worker return true;
377*89c4ff92SAndroid Build Coastguard Worker }
378*89c4ff92SAndroid Build Coastguard Worker
379*89c4ff92SAndroid Build Coastguard Worker
ConvertPooling2d(const Operation & operation,const char * operationName,armnn::PoolingAlgorithm poolType,const Model & model,ConversionData & data)380*89c4ff92SAndroid Build Coastguard Worker bool ConvertPooling2d(const Operation& operation,
381*89c4ff92SAndroid Build Coastguard Worker const char* operationName,
382*89c4ff92SAndroid Build Coastguard Worker armnn::PoolingAlgorithm poolType,
383*89c4ff92SAndroid Build Coastguard Worker const Model& model,
384*89c4ff92SAndroid Build Coastguard Worker ConversionData& data)
385*89c4ff92SAndroid Build Coastguard Worker {
386*89c4ff92SAndroid Build Coastguard Worker
387*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "Converter::ConvertL2Pool2d()";
388*89c4ff92SAndroid Build Coastguard Worker
389*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data);
390*89c4ff92SAndroid Build Coastguard Worker if (!input.IsValid())
391*89c4ff92SAndroid Build Coastguard Worker {
392*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Operation Could not read input 0", operationName);
393*89c4ff92SAndroid Build Coastguard Worker }
394*89c4ff92SAndroid Build Coastguard Worker
395*89c4ff92SAndroid Build Coastguard Worker const Operand* output = GetOutputOperand(operation, 0, model);
396*89c4ff92SAndroid Build Coastguard Worker if (!output)
397*89c4ff92SAndroid Build Coastguard Worker {
398*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Could not read output 0", __func__);
399*89c4ff92SAndroid Build Coastguard Worker }
400*89c4ff92SAndroid Build Coastguard Worker
401*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
402*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
403*89c4ff92SAndroid Build Coastguard Worker
404*89c4ff92SAndroid Build Coastguard Worker armnn::Pooling2dDescriptor desc;
405*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolType = poolType;
406*89c4ff92SAndroid Build Coastguard Worker desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Floor;
407*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = armnn::DataLayout::NHWC;
408*89c4ff92SAndroid Build Coastguard Worker
409*89c4ff92SAndroid Build Coastguard Worker ActivationFn activation;
410*89c4ff92SAndroid Build Coastguard Worker
411*89c4ff92SAndroid Build Coastguard Worker auto inputSize = operation.inputs.size();
412*89c4ff92SAndroid Build Coastguard Worker
413*89c4ff92SAndroid Build Coastguard Worker if (inputSize >= 10)
414*89c4ff92SAndroid Build Coastguard Worker {
415*89c4ff92SAndroid Build Coastguard Worker // one input, 9 parameters (padding l r t b, stridex, stridey, width, height, activation type)
416*89c4ff92SAndroid Build Coastguard Worker if (!GetInputScalar(operation, 1, OperandType::INT32, desc.m_PadLeft, model, data) ||
417*89c4ff92SAndroid Build Coastguard Worker !GetInputScalar(operation, 2, OperandType::INT32, desc.m_PadRight, model, data) ||
418*89c4ff92SAndroid Build Coastguard Worker !GetInputScalar(operation, 3, OperandType::INT32, desc.m_PadTop, model, data) ||
419*89c4ff92SAndroid Build Coastguard Worker !GetInputScalar(operation, 4, OperandType::INT32, desc.m_PadBottom, model, data) ||
420*89c4ff92SAndroid Build Coastguard Worker !GetInputScalar(operation, 5, OperandType::INT32, desc.m_StrideX, model, data) ||
421*89c4ff92SAndroid Build Coastguard Worker !GetInputScalar(operation, 6, OperandType::INT32, desc.m_StrideY, model, data) ||
422*89c4ff92SAndroid Build Coastguard Worker !GetInputScalar(operation, 7, OperandType::INT32, desc.m_PoolWidth, model, data) ||
423*89c4ff92SAndroid Build Coastguard Worker !GetInputScalar(operation, 8, OperandType::INT32, desc.m_PoolHeight, model, data) ||
424*89c4ff92SAndroid Build Coastguard Worker !GetInputActivationFunction(operation, 9, activation, model, data))
425*89c4ff92SAndroid Build Coastguard Worker {
426*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Operation has invalid inputs", operationName);
427*89c4ff92SAndroid Build Coastguard Worker }
428*89c4ff92SAndroid Build Coastguard Worker
429*89c4ff92SAndroid Build Coastguard Worker if (Is12OrLaterOperand(*output))
430*89c4ff92SAndroid Build Coastguard Worker {
431*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = OptionalDataLayout(operation, 10, model, data);
432*89c4ff92SAndroid Build Coastguard Worker }
433*89c4ff92SAndroid Build Coastguard Worker }
434*89c4ff92SAndroid Build Coastguard Worker else
435*89c4ff92SAndroid Build Coastguard Worker {
436*89c4ff92SAndroid Build Coastguard Worker // one input, 6 parameters (padding, stridex, stridey, width, height, activation type)
437*89c4ff92SAndroid Build Coastguard Worker ::android::nn::PaddingScheme scheme;
438*89c4ff92SAndroid Build Coastguard Worker if (!GetInputPaddingScheme(operation, 1, scheme, model, data) ||
439*89c4ff92SAndroid Build Coastguard Worker !GetInputScalar(operation, 2, OperandType::INT32, desc.m_StrideX, model, data) ||
440*89c4ff92SAndroid Build Coastguard Worker !GetInputScalar(operation, 3, OperandType::INT32, desc.m_StrideY, model, data) ||
441*89c4ff92SAndroid Build Coastguard Worker !GetInputScalar(operation, 4, OperandType::INT32, desc.m_PoolWidth, model, data) ||
442*89c4ff92SAndroid Build Coastguard Worker !GetInputScalar(operation, 5, OperandType::INT32, desc.m_PoolHeight, model, data) ||
443*89c4ff92SAndroid Build Coastguard Worker !GetInputActivationFunction(operation, 6, activation, model, data))
444*89c4ff92SAndroid Build Coastguard Worker {
445*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Operation has invalid inputs", operationName);
446*89c4ff92SAndroid Build Coastguard Worker }
447*89c4ff92SAndroid Build Coastguard Worker
448*89c4ff92SAndroid Build Coastguard Worker if (Is12OrLaterOperand(*output))
449*89c4ff92SAndroid Build Coastguard Worker {
450*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = OptionalDataLayout(operation, 7, model, data);
451*89c4ff92SAndroid Build Coastguard Worker }
452*89c4ff92SAndroid Build Coastguard Worker
453*89c4ff92SAndroid Build Coastguard Worker const armnnUtils::DataLayoutIndexed dataLayout(desc.m_DataLayout);
454*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputWidth = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
455*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputHeight = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
456*89c4ff92SAndroid Build Coastguard Worker
457*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputWidth, desc.m_PoolWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, scheme);
458*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputHeight, desc.m_PoolHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, scheme);
459*89c4ff92SAndroid Build Coastguard Worker }
460*89c4ff92SAndroid Build Coastguard Worker
461*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
462*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
463*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
464*89c4ff92SAndroid Build Coastguard Worker {
465*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC(__func__,
466*89c4ff92SAndroid Build Coastguard Worker IsPooling2dSupported,
467*89c4ff92SAndroid Build Coastguard Worker data.m_Backends,
468*89c4ff92SAndroid Build Coastguard Worker isSupported,
469*89c4ff92SAndroid Build Coastguard Worker setBackend,
470*89c4ff92SAndroid Build Coastguard Worker inputInfo,
471*89c4ff92SAndroid Build Coastguard Worker outputInfo,
472*89c4ff92SAndroid Build Coastguard Worker desc);
473*89c4ff92SAndroid Build Coastguard Worker
474*89c4ff92SAndroid Build Coastguard Worker };
475*89c4ff92SAndroid Build Coastguard Worker
476*89c4ff92SAndroid Build Coastguard Worker if(IsDynamicTensor(outputInfo))
477*89c4ff92SAndroid Build Coastguard Worker {
478*89c4ff92SAndroid Build Coastguard Worker isSupported = AreDynamicTensorsSupported();
479*89c4ff92SAndroid Build Coastguard Worker }
480*89c4ff92SAndroid Build Coastguard Worker else
481*89c4ff92SAndroid Build Coastguard Worker {
482*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputInfo, isSupported);
483*89c4ff92SAndroid Build Coastguard Worker }
484*89c4ff92SAndroid Build Coastguard Worker
485*89c4ff92SAndroid Build Coastguard Worker if (!isSupported)
486*89c4ff92SAndroid Build Coastguard Worker {
487*89c4ff92SAndroid Build Coastguard Worker return false;
488*89c4ff92SAndroid Build Coastguard Worker }
489*89c4ff92SAndroid Build Coastguard Worker
490*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* pooling2dLayer = data.m_Network->AddPooling2dLayer(desc);
491*89c4ff92SAndroid Build Coastguard Worker pooling2dLayer->SetBackendId(setBackend);
492*89c4ff92SAndroid Build Coastguard Worker if (!pooling2dLayer)
493*89c4ff92SAndroid Build Coastguard Worker {
494*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: AddPooling2dLayer failed", __func__);
495*89c4ff92SAndroid Build Coastguard Worker }
496*89c4ff92SAndroid Build Coastguard Worker
497*89c4ff92SAndroid Build Coastguard Worker input.Connect(pooling2dLayer->GetInputSlot(0));
498*89c4ff92SAndroid Build Coastguard Worker
499*89c4ff92SAndroid Build Coastguard Worker if (!isSupported)
500*89c4ff92SAndroid Build Coastguard Worker {
501*89c4ff92SAndroid Build Coastguard Worker return false;
502*89c4ff92SAndroid Build Coastguard Worker }
503*89c4ff92SAndroid Build Coastguard Worker
504*89c4ff92SAndroid Build Coastguard Worker return SetupAndTrackLayerOutputSlot(operation, 0, *pooling2dLayer, model,
505*89c4ff92SAndroid Build Coastguard Worker data, nullptr, validateFunc, activation);
506*89c4ff92SAndroid Build Coastguard Worker }
507*89c4ff92SAndroid Build Coastguard Worker
ConvertReduce(const Operation & operation,const Model & model,ConversionData & data,armnn::ReduceOperation reduceOperation)508*89c4ff92SAndroid Build Coastguard Worker bool ConvertReduce(const Operation& operation,
509*89c4ff92SAndroid Build Coastguard Worker const Model& model,
510*89c4ff92SAndroid Build Coastguard Worker ConversionData& data,
511*89c4ff92SAndroid Build Coastguard Worker armnn::ReduceOperation reduceOperation)
512*89c4ff92SAndroid Build Coastguard Worker {
513*89c4ff92SAndroid Build Coastguard Worker armnn::ReduceDescriptor descriptor;
514*89c4ff92SAndroid Build Coastguard Worker descriptor.m_ReduceOperation = reduceOperation;
515*89c4ff92SAndroid Build Coastguard Worker
516*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data);
517*89c4ff92SAndroid Build Coastguard Worker if (!input.IsValid())
518*89c4ff92SAndroid Build Coastguard Worker {
519*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Operation has invalid inputs", __func__);
520*89c4ff92SAndroid Build Coastguard Worker }
521*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
522*89c4ff92SAndroid Build Coastguard Worker
523*89c4ff92SAndroid Build Coastguard Worker const Operand* output = GetOutputOperand(operation, 0, model);
524*89c4ff92SAndroid Build Coastguard Worker if (!output)
525*89c4ff92SAndroid Build Coastguard Worker {
526*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Could not read output 0", __func__);
527*89c4ff92SAndroid Build Coastguard Worker }
528*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
529*89c4ff92SAndroid Build Coastguard Worker
530*89c4ff92SAndroid Build Coastguard Worker const Operand* axisOperand = GetInputOperand(operation, 1, model);
531*89c4ff92SAndroid Build Coastguard Worker if (!axisOperand)
532*89c4ff92SAndroid Build Coastguard Worker {
533*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Could not read input 1", __func__);
534*89c4ff92SAndroid Build Coastguard Worker }
535*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> axis;
536*89c4ff92SAndroid Build Coastguard Worker if (!GetTensorInt32Values(*axisOperand, axis, model, data))
537*89c4ff92SAndroid Build Coastguard Worker {
538*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Input 1 has invalid values", __func__);
539*89c4ff92SAndroid Build Coastguard Worker }
540*89c4ff92SAndroid Build Coastguard Worker
541*89c4ff92SAndroid Build Coastguard Worker // Convert the axis to unsigned int and remove duplicates.
542*89c4ff92SAndroid Build Coastguard Worker unsigned int rank = inputInfo.GetNumDimensions();
543*89c4ff92SAndroid Build Coastguard Worker std::set<unsigned int> uniqueAxis;
544*89c4ff92SAndroid Build Coastguard Worker std::transform(axis.begin(), axis.end(),
545*89c4ff92SAndroid Build Coastguard Worker std::inserter(uniqueAxis, uniqueAxis.begin()),
546*89c4ff92SAndroid Build Coastguard Worker [rank](int i) -> unsigned int { return (i + rank) % rank; });
547*89c4ff92SAndroid Build Coastguard Worker descriptor.m_vAxis.assign(uniqueAxis.begin(), uniqueAxis.end());
548*89c4ff92SAndroid Build Coastguard Worker
549*89c4ff92SAndroid Build Coastguard Worker // Get the "keep dims" flag.
550*89c4ff92SAndroid Build Coastguard Worker if (!GetInputScalar(operation, 2, OperandType::BOOL, descriptor.m_KeepDims, model, data))
551*89c4ff92SAndroid Build Coastguard Worker {
552*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Could not read input 2", __func__);
553*89c4ff92SAndroid Build Coastguard Worker }
554*89c4ff92SAndroid Build Coastguard Worker
555*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
556*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
557*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
558*89c4ff92SAndroid Build Coastguard Worker {
559*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC(__func__,
560*89c4ff92SAndroid Build Coastguard Worker IsReduceSupported,
561*89c4ff92SAndroid Build Coastguard Worker data.m_Backends,
562*89c4ff92SAndroid Build Coastguard Worker isSupported,
563*89c4ff92SAndroid Build Coastguard Worker setBackend,
564*89c4ff92SAndroid Build Coastguard Worker inputInfo,
565*89c4ff92SAndroid Build Coastguard Worker outputInfo,
566*89c4ff92SAndroid Build Coastguard Worker descriptor);
567*89c4ff92SAndroid Build Coastguard Worker };
568*89c4ff92SAndroid Build Coastguard Worker
569*89c4ff92SAndroid Build Coastguard Worker if(!IsDynamicTensor(outputInfo))
570*89c4ff92SAndroid Build Coastguard Worker {
571*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputInfo, isSupported);
572*89c4ff92SAndroid Build Coastguard Worker }
573*89c4ff92SAndroid Build Coastguard Worker else
574*89c4ff92SAndroid Build Coastguard Worker {
575*89c4ff92SAndroid Build Coastguard Worker isSupported = AreDynamicTensorsSupported();
576*89c4ff92SAndroid Build Coastguard Worker }
577*89c4ff92SAndroid Build Coastguard Worker
578*89c4ff92SAndroid Build Coastguard Worker if (!isSupported)
579*89c4ff92SAndroid Build Coastguard Worker {
580*89c4ff92SAndroid Build Coastguard Worker return false;
581*89c4ff92SAndroid Build Coastguard Worker }
582*89c4ff92SAndroid Build Coastguard Worker
583*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const layer = data.m_Network->AddReduceLayer(descriptor);
584*89c4ff92SAndroid Build Coastguard Worker layer->SetBackendId(setBackend);
585*89c4ff92SAndroid Build Coastguard Worker assert(layer != nullptr);
586*89c4ff92SAndroid Build Coastguard Worker input.Connect(layer->GetInputSlot(0));
587*89c4ff92SAndroid Build Coastguard Worker
588*89c4ff92SAndroid Build Coastguard Worker return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data, nullptr, validateFunc);
589*89c4ff92SAndroid Build Coastguard Worker }
590*89c4ff92SAndroid Build Coastguard Worker
591*89c4ff92SAndroid Build Coastguard Worker
ConvertToActivation(const Operation & operation,const char * operationName,const armnn::ActivationDescriptor & activationDesc,const Model & model,ConversionData & data)592*89c4ff92SAndroid Build Coastguard Worker bool ConvertToActivation(const Operation& operation,
593*89c4ff92SAndroid Build Coastguard Worker const char* operationName,
594*89c4ff92SAndroid Build Coastguard Worker const armnn::ActivationDescriptor& activationDesc,
595*89c4ff92SAndroid Build Coastguard Worker const Model& model,
596*89c4ff92SAndroid Build Coastguard Worker ConversionData& data)
597*89c4ff92SAndroid Build Coastguard Worker {
598*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data);
599*89c4ff92SAndroid Build Coastguard Worker if (!input.IsValid())
600*89c4ff92SAndroid Build Coastguard Worker {
601*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Input 0 is invalid", operationName);
602*89c4ff92SAndroid Build Coastguard Worker }
603*89c4ff92SAndroid Build Coastguard Worker
604*89c4ff92SAndroid Build Coastguard Worker const Operand* outputOperand = GetOutputOperand(operation, 0, model);
605*89c4ff92SAndroid Build Coastguard Worker if (!outputOperand)
606*89c4ff92SAndroid Build Coastguard Worker {
607*89c4ff92SAndroid Build Coastguard Worker return false;
608*89c4ff92SAndroid Build Coastguard Worker }
609*89c4ff92SAndroid Build Coastguard Worker
610*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outInfo = GetTensorInfoForOperand(*outputOperand);
611*89c4ff92SAndroid Build Coastguard Worker
612*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
613*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
614*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
615*89c4ff92SAndroid Build Coastguard Worker {
616*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC(__func__,
617*89c4ff92SAndroid Build Coastguard Worker IsActivationSupported,
618*89c4ff92SAndroid Build Coastguard Worker data.m_Backends,
619*89c4ff92SAndroid Build Coastguard Worker isSupported,
620*89c4ff92SAndroid Build Coastguard Worker setBackend,
621*89c4ff92SAndroid Build Coastguard Worker input.GetTensorInfo(),
622*89c4ff92SAndroid Build Coastguard Worker outInfo,
623*89c4ff92SAndroid Build Coastguard Worker activationDesc);
624*89c4ff92SAndroid Build Coastguard Worker };
625*89c4ff92SAndroid Build Coastguard Worker
626*89c4ff92SAndroid Build Coastguard Worker if(IsDynamicTensor(outInfo))
627*89c4ff92SAndroid Build Coastguard Worker {
628*89c4ff92SAndroid Build Coastguard Worker isSupported = AreDynamicTensorsSupported();
629*89c4ff92SAndroid Build Coastguard Worker }
630*89c4ff92SAndroid Build Coastguard Worker else
631*89c4ff92SAndroid Build Coastguard Worker {
632*89c4ff92SAndroid Build Coastguard Worker validateFunc(outInfo, isSupported);
633*89c4ff92SAndroid Build Coastguard Worker }
634*89c4ff92SAndroid Build Coastguard Worker
635*89c4ff92SAndroid Build Coastguard Worker if (!isSupported)
636*89c4ff92SAndroid Build Coastguard Worker {
637*89c4ff92SAndroid Build Coastguard Worker return false;
638*89c4ff92SAndroid Build Coastguard Worker }
639*89c4ff92SAndroid Build Coastguard Worker
640*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = data.m_Network->AddActivationLayer(activationDesc);
641*89c4ff92SAndroid Build Coastguard Worker layer->SetBackendId(setBackend);
642*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer != nullptr);
643*89c4ff92SAndroid Build Coastguard Worker input.Connect(layer->GetInputSlot(0));
644*89c4ff92SAndroid Build Coastguard Worker
645*89c4ff92SAndroid Build Coastguard Worker return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data, nullptr, validateFunc);
646*89c4ff92SAndroid Build Coastguard Worker }
647*89c4ff92SAndroid Build Coastguard Worker
DequantizeIfRequired(size_t operand_index,const Operation & operation,const Model & model,const ConversionData & data)648*89c4ff92SAndroid Build Coastguard Worker DequantizeResult DequantizeIfRequired(size_t operand_index,
649*89c4ff92SAndroid Build Coastguard Worker const Operation& operation,
650*89c4ff92SAndroid Build Coastguard Worker const Model& model,
651*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data)
652*89c4ff92SAndroid Build Coastguard Worker {
653*89c4ff92SAndroid Build Coastguard Worker const Operand* weightsOperand = GetInputOperand(operation, operand_index, model);
654*89c4ff92SAndroid Build Coastguard Worker if (!weightsOperand)
655*89c4ff92SAndroid Build Coastguard Worker {
656*89c4ff92SAndroid Build Coastguard Worker return { nullptr, 0, armnn::TensorInfo(), DequantizeStatus::INVALID_OPERAND };
657*89c4ff92SAndroid Build Coastguard Worker }
658*89c4ff92SAndroid Build Coastguard Worker
659*89c4ff92SAndroid Build Coastguard Worker if (IsOperandConstant(*weightsOperand))
660*89c4ff92SAndroid Build Coastguard Worker {
661*89c4ff92SAndroid Build Coastguard Worker // Weights are already constant
662*89c4ff92SAndroid Build Coastguard Worker return { nullptr, 0, armnn::TensorInfo(), DequantizeStatus::NOT_REQUIRED };
663*89c4ff92SAndroid Build Coastguard Worker }
664*89c4ff92SAndroid Build Coastguard Worker
665*89c4ff92SAndroid Build Coastguard Worker const size_t weightsInputIndex = operation.inputs[operand_index];
666*89c4ff92SAndroid Build Coastguard Worker
667*89c4ff92SAndroid Build Coastguard Worker // The weights are a non const tensor, this indicates they might be the output of a dequantize op.
668*89c4ff92SAndroid Build Coastguard Worker // Iterate over the nodes and find the previous operation which should be DEQUANTIZE
669*89c4ff92SAndroid Build Coastguard Worker for (uint32_t operationIdx = 0; operationIdx < getMainModel(model).operations.size(); ++operationIdx)
670*89c4ff92SAndroid Build Coastguard Worker {
671*89c4ff92SAndroid Build Coastguard Worker // Search for the DEQUANTIZE op which has the operand with index equal to operandIndex
672*89c4ff92SAndroid Build Coastguard Worker const auto& operationIt = getMainModel(model).operations[operationIdx];
673*89c4ff92SAndroid Build Coastguard Worker if (operationIt.type != OperationType::DEQUANTIZE)
674*89c4ff92SAndroid Build Coastguard Worker {
675*89c4ff92SAndroid Build Coastguard Worker continue;
676*89c4ff92SAndroid Build Coastguard Worker }
677*89c4ff92SAndroid Build Coastguard Worker
678*89c4ff92SAndroid Build Coastguard Worker size_t outOpIndex = weightsInputIndex + 1;
679*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; outOpIndex != weightsInputIndex && i < operationIt.outputs.size(); ++i)
680*89c4ff92SAndroid Build Coastguard Worker {
681*89c4ff92SAndroid Build Coastguard Worker outOpIndex = operationIt.outputs[i];
682*89c4ff92SAndroid Build Coastguard Worker }
683*89c4ff92SAndroid Build Coastguard Worker
684*89c4ff92SAndroid Build Coastguard Worker if (outOpIndex != weightsInputIndex)
685*89c4ff92SAndroid Build Coastguard Worker {
686*89c4ff92SAndroid Build Coastguard Worker continue;
687*89c4ff92SAndroid Build Coastguard Worker }
688*89c4ff92SAndroid Build Coastguard Worker
689*89c4ff92SAndroid Build Coastguard Worker const Operand* operand = GetInputOperand(operationIt, 0, model);
690*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(operand);
691*89c4ff92SAndroid Build Coastguard Worker
692*89c4ff92SAndroid Build Coastguard Worker if (!IsQSymm8(*operand))
693*89c4ff92SAndroid Build Coastguard Worker {
694*89c4ff92SAndroid Build Coastguard Worker // Only supporting dequantize from QSYMM8 to FLOAT
695*89c4ff92SAndroid Build Coastguard Worker break;
696*89c4ff92SAndroid Build Coastguard Worker }
697*89c4ff92SAndroid Build Coastguard Worker
698*89c4ff92SAndroid Build Coastguard Worker // Allocate a new buffer for the dequantized data and manually dequantize
699*89c4ff92SAndroid Build Coastguard Worker const void* startValue = GetOperandValueReadOnlyAddress(*operand, model, data);
700*89c4ff92SAndroid Build Coastguard Worker if (!startValue)
701*89c4ff92SAndroid Build Coastguard Worker {
702*89c4ff92SAndroid Build Coastguard Worker // Failed to get the operand address
703*89c4ff92SAndroid Build Coastguard Worker break;
704*89c4ff92SAndroid Build Coastguard Worker }
705*89c4ff92SAndroid Build Coastguard Worker
706*89c4ff92SAndroid Build Coastguard Worker const uint8_t* quantizedBuffer = reinterpret_cast<const uint8_t*>(startValue);
707*89c4ff92SAndroid Build Coastguard Worker size_t dequantizedBufferLength = operand->location.length;
708*89c4ff92SAndroid Build Coastguard Worker const float quantizationScale = operand->scale;
709*89c4ff92SAndroid Build Coastguard Worker
710*89c4ff92SAndroid Build Coastguard Worker auto dequantizedBuffer = std::make_unique<float[]>(dequantizedBufferLength + 1);
711*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < dequantizedBufferLength; ++i)
712*89c4ff92SAndroid Build Coastguard Worker {
713*89c4ff92SAndroid Build Coastguard Worker float* dstPtr = dequantizedBuffer.get();
714*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(dstPtr);
715*89c4ff92SAndroid Build Coastguard Worker *dstPtr++ = quantizedBuffer[i] * quantizationScale;
716*89c4ff92SAndroid Build Coastguard Worker }
717*89c4ff92SAndroid Build Coastguard Worker
718*89c4ff92SAndroid Build Coastguard Worker // Construct tensor info for dequantized ConstTensor
719*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo(operand->dimensions.size(),
720*89c4ff92SAndroid Build Coastguard Worker operand->dimensions.data(),
721*89c4ff92SAndroid Build Coastguard Worker armnn::DataType::Float32);
722*89c4ff92SAndroid Build Coastguard Worker
723*89c4ff92SAndroid Build Coastguard Worker return { std::move(dequantizedBuffer), dequantizedBufferLength * sizeof(float),
724*89c4ff92SAndroid Build Coastguard Worker std::move(tensorInfo),
725*89c4ff92SAndroid Build Coastguard Worker DequantizeStatus::SUCCESS };
726*89c4ff92SAndroid Build Coastguard Worker }
727*89c4ff92SAndroid Build Coastguard Worker
728*89c4ff92SAndroid Build Coastguard Worker return { nullptr, 0, armnn::TensorInfo() , DequantizeStatus::NOT_REQUIRED};
729*89c4ff92SAndroid Build Coastguard Worker }
730*89c4ff92SAndroid Build Coastguard Worker
DequantizeAndMakeConstTensorPin(const Operation & operation,const Model & model,const ConversionData & data,size_t operandIndex,bool optional)731*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin DequantizeAndMakeConstTensorPin(const Operation& operation,
732*89c4ff92SAndroid Build Coastguard Worker const Model& model,
733*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data,
734*89c4ff92SAndroid Build Coastguard Worker size_t operandIndex,
735*89c4ff92SAndroid Build Coastguard Worker bool optional)
736*89c4ff92SAndroid Build Coastguard Worker {
737*89c4ff92SAndroid Build Coastguard Worker DequantizeResult dequantized = DequantizeIfRequired(operandIndex,operation, model, data);
738*89c4ff92SAndroid Build Coastguard Worker
739*89c4ff92SAndroid Build Coastguard Worker DequantizeStatus status = std::get<3>(dequantized);
740*89c4ff92SAndroid Build Coastguard Worker switch (status)
741*89c4ff92SAndroid Build Coastguard Worker {
742*89c4ff92SAndroid Build Coastguard Worker case DequantizeStatus::INVALID_OPERAND:
743*89c4ff92SAndroid Build Coastguard Worker {
744*89c4ff92SAndroid Build Coastguard Worker // return invalid const tensor pin
745*89c4ff92SAndroid Build Coastguard Worker return ConstTensorPin();
746*89c4ff92SAndroid Build Coastguard Worker }
747*89c4ff92SAndroid Build Coastguard Worker case DequantizeStatus::NOT_REQUIRED:
748*89c4ff92SAndroid Build Coastguard Worker {
749*89c4ff92SAndroid Build Coastguard Worker return ConvertOperationInputToConstTensorPin(
750*89c4ff92SAndroid Build Coastguard Worker operation, operandIndex, model, data, g_DontPermute, nullptr, optional);
751*89c4ff92SAndroid Build Coastguard Worker }
752*89c4ff92SAndroid Build Coastguard Worker case DequantizeStatus::SUCCESS:
753*89c4ff92SAndroid Build Coastguard Worker default:
754*89c4ff92SAndroid Build Coastguard Worker {
755*89c4ff92SAndroid Build Coastguard Worker return ConstTensorPin(
756*89c4ff92SAndroid Build Coastguard Worker std::get<2>(dequantized), std::get<0>(dequantized).get(), std::get<1>(dequantized), g_DontPermute);
757*89c4ff92SAndroid Build Coastguard Worker }
758*89c4ff92SAndroid Build Coastguard Worker }
759*89c4ff92SAndroid Build Coastguard Worker }
760*89c4ff92SAndroid Build Coastguard Worker
GetInputPaddingScheme(const Operation & operation,uint32_t inputIndex,PaddingScheme & outPaddingScheme,const Model & model,const ConversionData & data)761*89c4ff92SAndroid Build Coastguard Worker bool GetInputPaddingScheme(const Operation& operation,
762*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
763*89c4ff92SAndroid Build Coastguard Worker PaddingScheme& outPaddingScheme,
764*89c4ff92SAndroid Build Coastguard Worker const Model& model,
765*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data)
766*89c4ff92SAndroid Build Coastguard Worker {
767*89c4ff92SAndroid Build Coastguard Worker int32_t paddingSchemeAsInt;
768*89c4ff92SAndroid Build Coastguard Worker if (!GetInputInt32(operation, inputIndex, paddingSchemeAsInt, model, data))
769*89c4ff92SAndroid Build Coastguard Worker {
770*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: failed to get padding scheme input value", __func__);
771*89c4ff92SAndroid Build Coastguard Worker }
772*89c4ff92SAndroid Build Coastguard Worker
773*89c4ff92SAndroid Build Coastguard Worker outPaddingScheme = static_cast<::android::nn::PaddingScheme>(paddingSchemeAsInt);
774*89c4ff92SAndroid Build Coastguard Worker return true;
775*89c4ff92SAndroid Build Coastguard Worker }
776*89c4ff92SAndroid Build Coastguard Worker
GetOperandValueReadOnlyAddress(const Operand & operand,const Model & model,const ConversionData & data,bool optional)777*89c4ff92SAndroid Build Coastguard Worker const void* GetOperandValueReadOnlyAddress(const Operand& operand,
778*89c4ff92SAndroid Build Coastguard Worker const Model& model,
779*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data,
780*89c4ff92SAndroid Build Coastguard Worker bool optional)
781*89c4ff92SAndroid Build Coastguard Worker {
782*89c4ff92SAndroid Build Coastguard Worker const void* valueStart = nullptr;
783*89c4ff92SAndroid Build Coastguard Worker switch (operand.lifetime)
784*89c4ff92SAndroid Build Coastguard Worker {
785*89c4ff92SAndroid Build Coastguard Worker case OperandLifeTime::CONSTANT_COPY:
786*89c4ff92SAndroid Build Coastguard Worker {
787*89c4ff92SAndroid Build Coastguard Worker valueStart = model.operandValues.data() + operand.location.offset;
788*89c4ff92SAndroid Build Coastguard Worker break;
789*89c4ff92SAndroid Build Coastguard Worker }
790*89c4ff92SAndroid Build Coastguard Worker case OperandLifeTime::POINTER:
791*89c4ff92SAndroid Build Coastguard Worker {
792*89c4ff92SAndroid Build Coastguard Worker // Pointer specified in the model
793*89c4ff92SAndroid Build Coastguard Worker valueStart = std::get<const void*>(operand.location.pointer);
794*89c4ff92SAndroid Build Coastguard Worker break;
795*89c4ff92SAndroid Build Coastguard Worker }
796*89c4ff92SAndroid Build Coastguard Worker case OperandLifeTime::CONSTANT_REFERENCE:
797*89c4ff92SAndroid Build Coastguard Worker {
798*89c4ff92SAndroid Build Coastguard Worker // Constant specified via a Memory object
799*89c4ff92SAndroid Build Coastguard Worker valueStart = GetMemoryFromPool(operand.location, data.m_MemPools);
800*89c4ff92SAndroid Build Coastguard Worker break;
801*89c4ff92SAndroid Build Coastguard Worker }
802*89c4ff92SAndroid Build Coastguard Worker case OperandLifeTime::NO_VALUE:
803*89c4ff92SAndroid Build Coastguard Worker {
804*89c4ff92SAndroid Build Coastguard Worker // An optional input tensor with no values is not an error so should not register as a fail
805*89c4ff92SAndroid Build Coastguard Worker if (optional)
806*89c4ff92SAndroid Build Coastguard Worker {
807*89c4ff92SAndroid Build Coastguard Worker valueStart = nullptr;
808*89c4ff92SAndroid Build Coastguard Worker break;
809*89c4ff92SAndroid Build Coastguard Worker }
810*89c4ff92SAndroid Build Coastguard Worker [[fallthrough]];
811*89c4ff92SAndroid Build Coastguard Worker }
812*89c4ff92SAndroid Build Coastguard Worker default:
813*89c4ff92SAndroid Build Coastguard Worker {
814*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << __func__ << ": unsupported/invalid operand lifetime:: " << operand.lifetime;
815*89c4ff92SAndroid Build Coastguard Worker valueStart = nullptr;
816*89c4ff92SAndroid Build Coastguard Worker }
817*89c4ff92SAndroid Build Coastguard Worker }
818*89c4ff92SAndroid Build Coastguard Worker
819*89c4ff92SAndroid Build Coastguard Worker return valueStart;
820*89c4ff92SAndroid Build Coastguard Worker }
821*89c4ff92SAndroid Build Coastguard Worker
GetTensorInt32Values(const Operand & operand,std::vector<int32_t> & outValues,const Model & model,const ConversionData & data)822*89c4ff92SAndroid Build Coastguard Worker bool GetTensorInt32Values(const Operand& operand,
823*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t>& outValues,
824*89c4ff92SAndroid Build Coastguard Worker const Model& model,
825*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data)
826*89c4ff92SAndroid Build Coastguard Worker {
827*89c4ff92SAndroid Build Coastguard Worker if (operand.type != OperandType::TENSOR_INT32)
828*89c4ff92SAndroid Build Coastguard Worker {
829*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << __func__ << ": invalid operand type: " << operand.type;
830*89c4ff92SAndroid Build Coastguard Worker return false;
831*89c4ff92SAndroid Build Coastguard Worker }
832*89c4ff92SAndroid Build Coastguard Worker
833*89c4ff92SAndroid Build Coastguard Worker const void* startAddress = GetOperandValueReadOnlyAddress(operand, model, data);
834*89c4ff92SAndroid Build Coastguard Worker if (!startAddress)
835*89c4ff92SAndroid Build Coastguard Worker {
836*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << __func__ << ": failed to get operand address " << operand.type;
837*89c4ff92SAndroid Build Coastguard Worker return false;
838*89c4ff92SAndroid Build Coastguard Worker }
839*89c4ff92SAndroid Build Coastguard Worker
840*89c4ff92SAndroid Build Coastguard Worker // Check number of bytes is sensible
841*89c4ff92SAndroid Build Coastguard Worker const uint32_t numBytes = operand.location.length;
842*89c4ff92SAndroid Build Coastguard Worker if (numBytes % sizeof(int32_t) != 0)
843*89c4ff92SAndroid Build Coastguard Worker {
844*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: invalid number of bytes: %i, expected to be a multiple of %i",
845*89c4ff92SAndroid Build Coastguard Worker __func__, numBytes, sizeof(int32_t));
846*89c4ff92SAndroid Build Coastguard Worker }
847*89c4ff92SAndroid Build Coastguard Worker
848*89c4ff92SAndroid Build Coastguard Worker outValues.resize(numBytes / sizeof(int32_t));
849*89c4ff92SAndroid Build Coastguard Worker memcpy(outValues.data(), startAddress, numBytes);
850*89c4ff92SAndroid Build Coastguard Worker return true;
851*89c4ff92SAndroid Build Coastguard Worker }
852*89c4ff92SAndroid Build Coastguard Worker
OptionalDataLayout(const Operation & operation,uint32_t inputIndex,const Model & model,ConversionData & data)853*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout OptionalDataLayout(const Operation& operation,
854*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
855*89c4ff92SAndroid Build Coastguard Worker const Model& model,
856*89c4ff92SAndroid Build Coastguard Worker ConversionData& data)
857*89c4ff92SAndroid Build Coastguard Worker {
858*89c4ff92SAndroid Build Coastguard Worker const Operand* operand = GetInputOperand(operation, inputIndex, model);
859*89c4ff92SAndroid Build Coastguard Worker if (!operand)
860*89c4ff92SAndroid Build Coastguard Worker {
861*89c4ff92SAndroid Build Coastguard Worker return armnn::DataLayout::NHWC;
862*89c4ff92SAndroid Build Coastguard Worker }
863*89c4ff92SAndroid Build Coastguard Worker
864*89c4ff92SAndroid Build Coastguard Worker if (!IsBool(*operand))
865*89c4ff92SAndroid Build Coastguard Worker {
866*89c4ff92SAndroid Build Coastguard Worker return armnn::DataLayout::NHWC;
867*89c4ff92SAndroid Build Coastguard Worker }
868*89c4ff92SAndroid Build Coastguard Worker
869*89c4ff92SAndroid Build Coastguard Worker const void* valueAddress = GetOperandValueReadOnlyAddress(*operand, model, data);
870*89c4ff92SAndroid Build Coastguard Worker if (!valueAddress)
871*89c4ff92SAndroid Build Coastguard Worker {
872*89c4ff92SAndroid Build Coastguard Worker return armnn::DataLayout::NHWC;
873*89c4ff92SAndroid Build Coastguard Worker }
874*89c4ff92SAndroid Build Coastguard Worker
875*89c4ff92SAndroid Build Coastguard Worker if (*(static_cast<const bool*>(valueAddress)))
876*89c4ff92SAndroid Build Coastguard Worker {
877*89c4ff92SAndroid Build Coastguard Worker return armnn::DataLayout::NCHW;
878*89c4ff92SAndroid Build Coastguard Worker }
879*89c4ff92SAndroid Build Coastguard Worker else
880*89c4ff92SAndroid Build Coastguard Worker {
881*89c4ff92SAndroid Build Coastguard Worker return armnn::DataLayout::NHWC;
882*89c4ff92SAndroid Build Coastguard Worker }
883*89c4ff92SAndroid Build Coastguard Worker }
884*89c4ff92SAndroid Build Coastguard Worker
ProcessActivation(const armnn::TensorInfo & tensorInfo,ActivationFn activation,armnn::IConnectableLayer * prevLayer,ConversionData & data)885*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* ProcessActivation(const armnn::TensorInfo& tensorInfo,
886*89c4ff92SAndroid Build Coastguard Worker ActivationFn activation,
887*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* prevLayer,
888*89c4ff92SAndroid Build Coastguard Worker ConversionData& data)
889*89c4ff92SAndroid Build Coastguard Worker {
890*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(prevLayer->GetNumOutputSlots() == 1);
891*89c4ff92SAndroid Build Coastguard Worker
892*89c4ff92SAndroid Build Coastguard Worker prevLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
893*89c4ff92SAndroid Build Coastguard Worker
894*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* activationLayer = prevLayer;
895*89c4ff92SAndroid Build Coastguard Worker
896*89c4ff92SAndroid Build Coastguard Worker if (activation != ActivationFn::kActivationNone)
897*89c4ff92SAndroid Build Coastguard Worker {
898*89c4ff92SAndroid Build Coastguard Worker armnn::ActivationDescriptor activationDesc;
899*89c4ff92SAndroid Build Coastguard Worker switch (activation)
900*89c4ff92SAndroid Build Coastguard Worker {
901*89c4ff92SAndroid Build Coastguard Worker case ActivationFn::kActivationRelu:
902*89c4ff92SAndroid Build Coastguard Worker {
903*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_Function = armnn::ActivationFunction::ReLu;
904*89c4ff92SAndroid Build Coastguard Worker break;
905*89c4ff92SAndroid Build Coastguard Worker }
906*89c4ff92SAndroid Build Coastguard Worker case ActivationFn::kActivationRelu1:
907*89c4ff92SAndroid Build Coastguard Worker {
908*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
909*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_A = 1.0f;
910*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_B = -1.0f;
911*89c4ff92SAndroid Build Coastguard Worker break;
912*89c4ff92SAndroid Build Coastguard Worker }
913*89c4ff92SAndroid Build Coastguard Worker case ActivationFn::kActivationRelu6:
914*89c4ff92SAndroid Build Coastguard Worker {
915*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
916*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_A = 6.0f;
917*89c4ff92SAndroid Build Coastguard Worker break;
918*89c4ff92SAndroid Build Coastguard Worker }
919*89c4ff92SAndroid Build Coastguard Worker case ActivationFn::kActivationSigmoid:
920*89c4ff92SAndroid Build Coastguard Worker {
921*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_Function = armnn::ActivationFunction::Sigmoid;
922*89c4ff92SAndroid Build Coastguard Worker break;
923*89c4ff92SAndroid Build Coastguard Worker }
924*89c4ff92SAndroid Build Coastguard Worker case ActivationFn::kActivationTanh:
925*89c4ff92SAndroid Build Coastguard Worker {
926*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_Function = armnn::ActivationFunction::TanH;
927*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_A = 1.0f;
928*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_B = 1.0f;
929*89c4ff92SAndroid Build Coastguard Worker break;
930*89c4ff92SAndroid Build Coastguard Worker }
931*89c4ff92SAndroid Build Coastguard Worker default:
932*89c4ff92SAndroid Build Coastguard Worker {
933*89c4ff92SAndroid Build Coastguard Worker Fail("%s: Invalid activation enum value %i", __func__, activation);
934*89c4ff92SAndroid Build Coastguard Worker return nullptr;
935*89c4ff92SAndroid Build Coastguard Worker }
936*89c4ff92SAndroid Build Coastguard Worker }
937*89c4ff92SAndroid Build Coastguard Worker
938*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
939*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
940*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC(__func__,
941*89c4ff92SAndroid Build Coastguard Worker IsActivationSupported,
942*89c4ff92SAndroid Build Coastguard Worker data.m_Backends,
943*89c4ff92SAndroid Build Coastguard Worker isSupported,
944*89c4ff92SAndroid Build Coastguard Worker setBackend,
945*89c4ff92SAndroid Build Coastguard Worker prevLayer->GetOutputSlot(0).GetTensorInfo(),
946*89c4ff92SAndroid Build Coastguard Worker tensorInfo,
947*89c4ff92SAndroid Build Coastguard Worker activationDesc);
948*89c4ff92SAndroid Build Coastguard Worker if (!isSupported)
949*89c4ff92SAndroid Build Coastguard Worker {
950*89c4ff92SAndroid Build Coastguard Worker return nullptr;
951*89c4ff92SAndroid Build Coastguard Worker }
952*89c4ff92SAndroid Build Coastguard Worker
953*89c4ff92SAndroid Build Coastguard Worker activationLayer = data.m_Network->AddActivationLayer(activationDesc);
954*89c4ff92SAndroid Build Coastguard Worker activationLayer->SetBackendId(setBackend);
955*89c4ff92SAndroid Build Coastguard Worker
956*89c4ff92SAndroid Build Coastguard Worker prevLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
957*89c4ff92SAndroid Build Coastguard Worker activationLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
958*89c4ff92SAndroid Build Coastguard Worker }
959*89c4ff92SAndroid Build Coastguard Worker
960*89c4ff92SAndroid Build Coastguard Worker return activationLayer;
961*89c4ff92SAndroid Build Coastguard Worker }
962*89c4ff92SAndroid Build Coastguard Worker
SetupAndTrackLayerOutputSlot(const Operation & operation,uint32_t operationOutputIndex,armnn::IConnectableLayer & layer,uint32_t layerOutputIndex,const Model & model,ConversionData & data,const armnn::TensorInfo * overrideOutputInfo,const std::function<void (const armnn::TensorInfo &,bool &)> & validateFunc,const ActivationFn & activationFunction,bool inferOutputShapes)963*89c4ff92SAndroid Build Coastguard Worker bool SetupAndTrackLayerOutputSlot(const Operation& operation,
964*89c4ff92SAndroid Build Coastguard Worker uint32_t operationOutputIndex,
965*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer& layer,
966*89c4ff92SAndroid Build Coastguard Worker uint32_t layerOutputIndex,
967*89c4ff92SAndroid Build Coastguard Worker const Model& model,
968*89c4ff92SAndroid Build Coastguard Worker ConversionData& data,
969*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo* overrideOutputInfo,
970*89c4ff92SAndroid Build Coastguard Worker const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc,
971*89c4ff92SAndroid Build Coastguard Worker const ActivationFn& activationFunction,
972*89c4ff92SAndroid Build Coastguard Worker bool inferOutputShapes)
973*89c4ff92SAndroid Build Coastguard Worker {
974*89c4ff92SAndroid Build Coastguard Worker const Operand* outputOperand = GetOutputOperand(operation, operationOutputIndex, model);
975*89c4ff92SAndroid Build Coastguard Worker if ((outputOperand == nullptr) || (operationOutputIndex >= layer.GetNumOutputSlots()))
976*89c4ff92SAndroid Build Coastguard Worker {
977*89c4ff92SAndroid Build Coastguard Worker return false;
978*89c4ff92SAndroid Build Coastguard Worker }
979*89c4ff92SAndroid Build Coastguard Worker
980*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = layer.GetOutputSlot(layerOutputIndex);
981*89c4ff92SAndroid Build Coastguard Worker if (overrideOutputInfo == nullptr)
982*89c4ff92SAndroid Build Coastguard Worker {
983*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand));
984*89c4ff92SAndroid Build Coastguard Worker }
985*89c4ff92SAndroid Build Coastguard Worker else
986*89c4ff92SAndroid Build Coastguard Worker {
987*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(*overrideOutputInfo);
988*89c4ff92SAndroid Build Coastguard Worker }
989*89c4ff92SAndroid Build Coastguard Worker
990*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
991*89c4ff92SAndroid Build Coastguard Worker if (validateFunc && (IsDynamicTensor(outputSlot.GetTensorInfo()) || inferOutputShapes))
992*89c4ff92SAndroid Build Coastguard Worker {
993*89c4ff92SAndroid Build Coastguard Worker // Type one dynamic tensors require the previous layer's output shape for inference
994*89c4ff92SAndroid Build Coastguard Worker for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex)
995*89c4ff92SAndroid Build Coastguard Worker {
996*89c4ff92SAndroid Build Coastguard Worker if(!layer.GetInputSlot(inputSlotIndex).GetConnection())
997*89c4ff92SAndroid Build Coastguard Worker {
998*89c4ff92SAndroid Build Coastguard Worker return false;
999*89c4ff92SAndroid Build Coastguard Worker }
1000*89c4ff92SAndroid Build Coastguard Worker }
1001*89c4ff92SAndroid Build Coastguard Worker // IsTensorInfoSet will infer the dynamic output shape
1002*89c4ff92SAndroid Build Coastguard Worker outputSlot.IsTensorInfoSet();
1003*89c4ff92SAndroid Build Coastguard Worker // Once the shape is inferred we can validate it
1004*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputSlot.GetTensorInfo(), isSupported);
1005*89c4ff92SAndroid Build Coastguard Worker
1006*89c4ff92SAndroid Build Coastguard Worker if(!isSupported)
1007*89c4ff92SAndroid Build Coastguard Worker {
1008*89c4ff92SAndroid Build Coastguard Worker for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex)
1009*89c4ff92SAndroid Build Coastguard Worker {
1010*89c4ff92SAndroid Build Coastguard Worker layer.GetInputSlot(inputSlotIndex).GetConnection()->Disconnect(layer.GetInputSlot(inputSlotIndex));
1011*89c4ff92SAndroid Build Coastguard Worker }
1012*89c4ff92SAndroid Build Coastguard Worker return false;
1013*89c4ff92SAndroid Build Coastguard Worker }
1014*89c4ff92SAndroid Build Coastguard Worker }
1015*89c4ff92SAndroid Build Coastguard Worker
1016*89c4ff92SAndroid Build Coastguard Worker const uint32_t operandIndex = operation.outputs[operationOutputIndex];
1017*89c4ff92SAndroid Build Coastguard Worker
1018*89c4ff92SAndroid Build Coastguard Worker if (activationFunction != ActivationFn::kActivationNone)
1019*89c4ff92SAndroid Build Coastguard Worker {
1020*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& activationOutputInfo = outputSlot.GetTensorInfo();
1021*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const endLayer = ProcessActivation(activationOutputInfo, activationFunction,
1022*89c4ff92SAndroid Build Coastguard Worker &layer, data);
1023*89c4ff92SAndroid Build Coastguard Worker
1024*89c4ff92SAndroid Build Coastguard Worker if (!endLayer)
1025*89c4ff92SAndroid Build Coastguard Worker {
1026*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: ProcessActivation failed", __func__);
1027*89c4ff92SAndroid Build Coastguard Worker }
1028*89c4ff92SAndroid Build Coastguard Worker
1029*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& activationOutputSlot = endLayer->GetOutputSlot(layerOutputIndex);
1030*89c4ff92SAndroid Build Coastguard Worker data.m_OutputSlotForOperand[operandIndex] = &activationOutputSlot;
1031*89c4ff92SAndroid Build Coastguard Worker }
1032*89c4ff92SAndroid Build Coastguard Worker else
1033*89c4ff92SAndroid Build Coastguard Worker {
1034*89c4ff92SAndroid Build Coastguard Worker data.m_OutputSlotForOperand[operandIndex] = &outputSlot;
1035*89c4ff92SAndroid Build Coastguard Worker }
1036*89c4ff92SAndroid Build Coastguard Worker
1037*89c4ff92SAndroid Build Coastguard Worker return true;
1038*89c4ff92SAndroid Build Coastguard Worker }
1039*89c4ff92SAndroid Build Coastguard Worker
IsConnectedToDequantize(armnn::IOutputSlot * ioutputSlot)1040*89c4ff92SAndroid Build Coastguard Worker bool IsConnectedToDequantize(armnn::IOutputSlot* ioutputSlot)
1041*89c4ff92SAndroid Build Coastguard Worker {
1042*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ConversionUtils::IsConnectedToDequantize()";
1043*89c4ff92SAndroid Build Coastguard Worker if (!ioutputSlot)
1044*89c4ff92SAndroid Build Coastguard Worker {
1045*89c4ff92SAndroid Build Coastguard Worker return false;
1046*89c4ff92SAndroid Build Coastguard Worker }
1047*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ConversionUtils::IsConnectedToDequantize() ioutputSlot is valid.";
1048*89c4ff92SAndroid Build Coastguard Worker // Find the connections and layers..
1049*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer& owningLayer = ioutputSlot->GetOwningIConnectableLayer();
1050*89c4ff92SAndroid Build Coastguard Worker if (owningLayer.GetType() == armnn::LayerType::Dequantize)
1051*89c4ff92SAndroid Build Coastguard Worker {
1052*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ConversionUtils::IsConnectedToDequantize() connected to Dequantize Layer.";
1053*89c4ff92SAndroid Build Coastguard Worker armnn::IInputSlot& inputSlot = owningLayer.GetInputSlot(0);
1054*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot* connection = inputSlot.GetConnection();
1055*89c4ff92SAndroid Build Coastguard Worker if (connection)
1056*89c4ff92SAndroid Build Coastguard Worker {
1057*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ConversionUtils::IsConnectedToDequantize() Dequantize Layer has a connection.";
1058*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer& connectedLayer =
1059*89c4ff92SAndroid Build Coastguard Worker connection->GetOwningIConnectableLayer();
1060*89c4ff92SAndroid Build Coastguard Worker if (connectedLayer.GetType() == armnn::LayerType::Constant)
1061*89c4ff92SAndroid Build Coastguard Worker {
1062*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ConversionUtils::IsConnectedToDequantize() Dequantize Layer connected to Constant";
1063*89c4ff92SAndroid Build Coastguard Worker return true;
1064*89c4ff92SAndroid Build Coastguard Worker }
1065*89c4ff92SAndroid Build Coastguard Worker }
1066*89c4ff92SAndroid Build Coastguard Worker }
1067*89c4ff92SAndroid Build Coastguard Worker return false;
1068*89c4ff92SAndroid Build Coastguard Worker }
1069*89c4ff92SAndroid Build Coastguard Worker
1070*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn_driver
1071