1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <ClassicDelegateUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <SharedFunctions.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/internal/tensor.h>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker
VisitConv2dOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)20*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitConv2dOperator(DelegateData& delegateData,
21*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
22*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
23*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
24*89c4ff92SAndroid Build Coastguard Worker int32_t operatorCode)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker auto numInputs = tfLiteNode->inputs->size;
27*89c4ff92SAndroid Build Coastguard Worker if (numInputs < 2)
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
30*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
31*89c4ff92SAndroid Build Coastguard Worker 2, numInputs, nodeIndex);
32*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
33*89c4ff92SAndroid Build Coastguard Worker }
34*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker armnn::Convolution2dDescriptor descriptor;
37*89c4ff92SAndroid Build Coastguard Worker const auto params = reinterpret_cast<TfLiteConvParams*>(tfLiteNode->builtin_data);
38*89c4ff92SAndroid Build Coastguard Worker
39*89c4ff92SAndroid Build Coastguard Worker bool biasEnabled = IsOptionalOperandPresent(tfLiteNode, 2);
40*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = biasEnabled;
41*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = NonNegative(params->stride_width, nodeIndex);
42*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = NonNegative(params->stride_height, nodeIndex);
43*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = armnn::DataLayout::NHWC;
44*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationX = NonNegative(params->dilation_width_factor, nodeIndex);
45*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationY = NonNegative(params->dilation_height_factor, nodeIndex);
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
48*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
49*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteTensors[tfLiteNode->inputs->data[0]]))
50*89c4ff92SAndroid Build Coastguard Worker {
51*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
52*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
53*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid input tensor in operator #%d node #%d: ",
54*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
55*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
56*89c4ff92SAndroid Build Coastguard Worker }
57*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteInputTensor))
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
60*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
61*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
62*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
63*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
64*89c4ff92SAndroid Build Coastguard Worker }
65*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
66*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteOutputTensor))
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
69*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
70*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid output tensor in operator #%d node #%d: ",
71*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
72*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
73*89c4ff92SAndroid Build Coastguard Worker }
74*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteOutputTensor))
75*89c4ff92SAndroid Build Coastguard Worker {
76*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
77*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
78*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
79*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
80*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker
83*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteFilterTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
84*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteFilterTensor))
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
87*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
88*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid filter tensor in operator #%d node #%d: ",
89*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
90*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
91*89c4ff92SAndroid Build Coastguard Worker }
92*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteFilterTensor))
93*89c4ff92SAndroid Build Coastguard Worker {
94*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
95*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
96*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic filter tensors are not supported in node #%d: ",
97*89c4ff92SAndroid Build Coastguard Worker nodeIndex);
98*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker
101*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
102*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
103*89c4ff92SAndroid Build Coastguard Worker
104*89c4ff92SAndroid Build Coastguard Worker auto* tfLiteNodeParameters = reinterpret_cast<TfLiteConvParams*>(tfLiteNode->builtin_data);
105*89c4ff92SAndroid Build Coastguard Worker TfLiteFusedActivation activationType=kTfLiteActNone;
106*89c4ff92SAndroid Build Coastguard Worker if (tfLiteNodeParameters)
107*89c4ff92SAndroid Build Coastguard Worker {
108*89c4ff92SAndroid Build Coastguard Worker activationType = tfLiteNodeParameters->activation;
109*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
110*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo, activationType);
111*89c4ff92SAndroid Build Coastguard Worker if(activationStatus != kTfLiteOk)
112*89c4ff92SAndroid Build Coastguard Worker {
113*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker
118*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& filterTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteFilterTensor);
119*89c4ff92SAndroid Build Coastguard Worker
120*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasTensorInfo;
121*89c4ff92SAndroid Build Coastguard Worker if(biasEnabled)
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
124*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteBiasTensor))
125*89c4ff92SAndroid Build Coastguard Worker {
126*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
127*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
128*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid bias tensor in operator #%d node #%d: ",
129*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
130*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteBiasTensor))
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
135*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
136*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic bias tensors are not supported in node #%d: ",
137*89c4ff92SAndroid Build Coastguard Worker nodeIndex);
138*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
139*89c4ff92SAndroid Build Coastguard Worker }
140*89c4ff92SAndroid Build Coastguard Worker biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
141*89c4ff92SAndroid Build Coastguard Worker }
142*89c4ff92SAndroid Build Coastguard Worker else
143*89c4ff92SAndroid Build Coastguard Worker {
144*89c4ff92SAndroid Build Coastguard Worker biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
145*89c4ff92SAndroid Build Coastguard Worker }
146*89c4ff92SAndroid Build Coastguard Worker
147*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<armnn::TensorInfo> optionalBiasInfo(biasTensorInfo);
148*89c4ff92SAndroid Build Coastguard Worker
149*89c4ff92SAndroid Build Coastguard Worker // TfLite uses NHWC tensors
150*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputHeight = inputTensorInfo.GetShape()[1];
151*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputWidth = inputTensorInfo.GetShape()[2];
152*89c4ff92SAndroid Build Coastguard Worker
153*89c4ff92SAndroid Build Coastguard Worker const unsigned int filterHeight = filterTensorInfo.GetShape()[1];
154*89c4ff92SAndroid Build Coastguard Worker const unsigned int filterWidth = filterTensorInfo.GetShape()[2];
155*89c4ff92SAndroid Build Coastguard Worker
156*89c4ff92SAndroid Build Coastguard Worker // Calculate padding
157*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputHeight, filterHeight, descriptor.m_StrideY, descriptor.m_DilationY,
158*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop, descriptor.m_PadBottom, params->padding);
159*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputWidth, filterWidth, descriptor.m_StrideX, descriptor.m_DilationX,
160*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft, descriptor.m_PadRight, params->padding);
161*89c4ff92SAndroid Build Coastguard Worker
162*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
163*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
166*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("CONV2D",
167*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
168*89c4ff92SAndroid Build Coastguard Worker IsConvolution2dSupported,
169*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
170*89c4ff92SAndroid Build Coastguard Worker isSupported,
171*89c4ff92SAndroid Build Coastguard Worker setBackend,
172*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
173*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
174*89c4ff92SAndroid Build Coastguard Worker descriptor,
175*89c4ff92SAndroid Build Coastguard Worker filterTensorInfo,
176*89c4ff92SAndroid Build Coastguard Worker optionalBiasInfo);
177*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker
180*89c4ff92SAndroid Build Coastguard Worker // Set up filter and biases
181*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = delegateData.m_Network->AddConvolution2dLayer(descriptor);
182*89c4ff92SAndroid Build Coastguard Worker layer->SetBackendId(setBackend);
183*89c4ff92SAndroid Build Coastguard Worker
184*89c4ff92SAndroid Build Coastguard Worker if(filterTensorInfo.IsConstant())
185*89c4ff92SAndroid Build Coastguard Worker {
186*89c4ff92SAndroid Build Coastguard Worker auto filter =
187*89c4ff92SAndroid Build Coastguard Worker CreateConstTensor(&tfLiteContext->tensors[tfLiteNode->inputs->data[1]],
188*89c4ff92SAndroid Build Coastguard Worker filterTensorInfo);
189*89c4ff92SAndroid Build Coastguard Worker
190*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer *weightsLayer = delegateData.m_Network->AddConstantLayer(filter);
191*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
192*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(filterTensorInfo);
193*89c4ff92SAndroid Build Coastguard Worker }
194*89c4ff92SAndroid Build Coastguard Worker
195*89c4ff92SAndroid Build Coastguard Worker if (biasEnabled)
196*89c4ff92SAndroid Build Coastguard Worker {
197*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
198*89c4ff92SAndroid Build Coastguard Worker if(biasTensorInfo.IsConstant())
199*89c4ff92SAndroid Build Coastguard Worker {
200*89c4ff92SAndroid Build Coastguard Worker auto biasTensor = CreateConstTensor(&tfLiteBiasTensor, biasTensorInfo);
201*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor);
202*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(biasLayer != nullptr);
203*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
204*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
205*89c4ff92SAndroid Build Coastguard Worker }
206*89c4ff92SAndroid Build Coastguard Worker }
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker // The data input can also be constant, so we must check that this is also allocated to an input slot
209*89c4ff92SAndroid Build Coastguard Worker if(inputTensorInfo.IsConstant())
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker auto input =
212*89c4ff92SAndroid Build Coastguard Worker CreateConstTensor(&tfLiteContext->tensors[tfLiteNode->inputs->data[0]],
213*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo);
214*89c4ff92SAndroid Build Coastguard Worker
215*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer *inputLayer = delegateData.m_Network->AddConstantLayer(input);
216*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
217*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker
220*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer != nullptr);
221*89c4ff92SAndroid Build Coastguard Worker
222*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
223*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(outputTensorInfo);
224*89c4ff92SAndroid Build Coastguard Worker
225*89c4ff92SAndroid Build Coastguard Worker if(Connect(layer, tfLiteNode, delegateData) != kTfLiteOk)
226*89c4ff92SAndroid Build Coastguard Worker {
227*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
228*89c4ff92SAndroid Build Coastguard Worker }
229*89c4ff92SAndroid Build Coastguard Worker
230*89c4ff92SAndroid Build Coastguard Worker if (!tfLiteNodeParameters)
231*89c4ff92SAndroid Build Coastguard Worker {
232*89c4ff92SAndroid Build Coastguard Worker // No Activation
233*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
234*89c4ff92SAndroid Build Coastguard Worker }
235*89c4ff92SAndroid Build Coastguard Worker // Check and Create activation
236*89c4ff92SAndroid Build Coastguard Worker return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
237*89c4ff92SAndroid Build Coastguard Worker
238*89c4ff92SAndroid Build Coastguard Worker }
239*89c4ff92SAndroid Build Coastguard Worker
240*89c4ff92SAndroid Build Coastguard Worker // Conv3d is only correctly supported for external delegates from TF Lite v2.6, as there was a breaking bug in v2.5.
241*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_POST_TFLITE_2_5)
VisitConv3dOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)242*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitConv3dOperator(DelegateData& delegateData,
243*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
244*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
245*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
246*89c4ff92SAndroid Build Coastguard Worker int32_t operatorCode)
247*89c4ff92SAndroid Build Coastguard Worker {
248*89c4ff92SAndroid Build Coastguard Worker auto numInputs = tfLiteNode->inputs->size;
249*89c4ff92SAndroid Build Coastguard Worker if (numInputs < 2)
250*89c4ff92SAndroid Build Coastguard Worker {
251*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
252*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
253*89c4ff92SAndroid Build Coastguard Worker 2, numInputs, nodeIndex);
254*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
255*89c4ff92SAndroid Build Coastguard Worker }
256*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
257*89c4ff92SAndroid Build Coastguard Worker
258*89c4ff92SAndroid Build Coastguard Worker armnn::Convolution3dDescriptor descriptor;
259*89c4ff92SAndroid Build Coastguard Worker const auto params = reinterpret_cast<TfLiteConv3DParams*>(tfLiteNode->builtin_data);
260*89c4ff92SAndroid Build Coastguard Worker
261*89c4ff92SAndroid Build Coastguard Worker bool biasEnabled = IsOptionalOperandPresent(tfLiteNode, 2);
262*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = biasEnabled;
263*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = armnn::DataLayout::NDHWC;
264*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = NonNegative(params->stride_width, nodeIndex);
265*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = NonNegative(params->stride_height, nodeIndex);
266*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideZ = NonNegative(params->stride_depth, nodeIndex);
267*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationX = NonNegative(params->dilation_width_factor, nodeIndex);
268*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationY = NonNegative(params->dilation_height_factor, nodeIndex);
269*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationZ = NonNegative(params->dilation_depth_factor, nodeIndex);
270*89c4ff92SAndroid Build Coastguard Worker
271*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
272*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
273*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
274*89c4ff92SAndroid Build Coastguard Worker {
275*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
276*89c4ff92SAndroid Build Coastguard Worker }
277*89c4ff92SAndroid Build Coastguard Worker
278*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
279*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
280*89c4ff92SAndroid Build Coastguard Worker {
281*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
282*89c4ff92SAndroid Build Coastguard Worker }
283*89c4ff92SAndroid Build Coastguard Worker
284*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteFilterTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
285*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteFilterTensor, operatorCode, nodeIndex))
286*89c4ff92SAndroid Build Coastguard Worker {
287*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
288*89c4ff92SAndroid Build Coastguard Worker }
289*89c4ff92SAndroid Build Coastguard Worker
290*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
291*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
292*89c4ff92SAndroid Build Coastguard Worker
293*89c4ff92SAndroid Build Coastguard Worker auto* tfLiteNodeParameters = reinterpret_cast<TfLiteConv3DParams*>(tfLiteNode->builtin_data);
294*89c4ff92SAndroid Build Coastguard Worker TfLiteFusedActivation activationType=kTfLiteActNone;
295*89c4ff92SAndroid Build Coastguard Worker if (tfLiteNodeParameters)
296*89c4ff92SAndroid Build Coastguard Worker {
297*89c4ff92SAndroid Build Coastguard Worker activationType = tfLiteNodeParameters->activation;
298*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
299*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo, activationType);
300*89c4ff92SAndroid Build Coastguard Worker if(activationStatus != kTfLiteOk)
301*89c4ff92SAndroid Build Coastguard Worker {
302*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
303*89c4ff92SAndroid Build Coastguard Worker }
304*89c4ff92SAndroid Build Coastguard Worker
305*89c4ff92SAndroid Build Coastguard Worker }
306*89c4ff92SAndroid Build Coastguard Worker
307*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& filterTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteFilterTensor);
308*89c4ff92SAndroid Build Coastguard Worker
309*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasTensorInfo;
310*89c4ff92SAndroid Build Coastguard Worker if(biasEnabled)
311*89c4ff92SAndroid Build Coastguard Worker {
312*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
313*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteBiasTensor, operatorCode, nodeIndex))
314*89c4ff92SAndroid Build Coastguard Worker {
315*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
316*89c4ff92SAndroid Build Coastguard Worker }
317*89c4ff92SAndroid Build Coastguard Worker biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
318*89c4ff92SAndroid Build Coastguard Worker }
319*89c4ff92SAndroid Build Coastguard Worker else
320*89c4ff92SAndroid Build Coastguard Worker {
321*89c4ff92SAndroid Build Coastguard Worker biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
322*89c4ff92SAndroid Build Coastguard Worker }
323*89c4ff92SAndroid Build Coastguard Worker
324*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<armnn::TensorInfo> optionalBiasInfo(biasTensorInfo);
325*89c4ff92SAndroid Build Coastguard Worker
326*89c4ff92SAndroid Build Coastguard Worker // TfLite uses NDHWC tensors
327*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputDepth = inputTensorInfo.GetShape()[1];
328*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputHeight = inputTensorInfo.GetShape()[2];
329*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputWidth = inputTensorInfo.GetShape()[3];
330*89c4ff92SAndroid Build Coastguard Worker
331*89c4ff92SAndroid Build Coastguard Worker // Assuming the filter is DHWIO : Depth, Height, Width, OutputChannels, InputChannels
332*89c4ff92SAndroid Build Coastguard Worker const unsigned int filterDepth = filterTensorInfo.GetShape()[0];
333*89c4ff92SAndroid Build Coastguard Worker const unsigned int filterHeight = filterTensorInfo.GetShape()[1];
334*89c4ff92SAndroid Build Coastguard Worker const unsigned int filterWidth = filterTensorInfo.GetShape()[2];
335*89c4ff92SAndroid Build Coastguard Worker
336*89c4ff92SAndroid Build Coastguard Worker // Calculate padding
337*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputDepth, filterDepth, descriptor.m_StrideZ, descriptor.m_DilationZ,
338*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadFront, descriptor.m_PadBack, params->padding);
339*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputHeight, filterHeight, descriptor.m_StrideY, descriptor.m_DilationY,
340*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop, descriptor.m_PadBottom, params->padding);
341*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputWidth, filterWidth, descriptor.m_StrideX, descriptor.m_DilationX,
342*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft, descriptor.m_PadRight, params->padding);
343*89c4ff92SAndroid Build Coastguard Worker
344*89c4ff92SAndroid Build Coastguard Worker // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
345*89c4ff92SAndroid Build Coastguard Worker // support for the operator
346*89c4ff92SAndroid Build Coastguard Worker // If supported, VisitConvolutionOperator will be called again to add the layer to the network as seen below.
347*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
348*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
349*89c4ff92SAndroid Build Coastguard Worker {
350*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
351*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("CONV3D",
352*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
353*89c4ff92SAndroid Build Coastguard Worker IsConvolution3dSupported,
354*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
355*89c4ff92SAndroid Build Coastguard Worker isSupported,
356*89c4ff92SAndroid Build Coastguard Worker setBackend,
357*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
358*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
359*89c4ff92SAndroid Build Coastguard Worker descriptor,
360*89c4ff92SAndroid Build Coastguard Worker filterTensorInfo,
361*89c4ff92SAndroid Build Coastguard Worker optionalBiasInfo);
362*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
363*89c4ff92SAndroid Build Coastguard Worker }
364*89c4ff92SAndroid Build Coastguard Worker
365*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = delegateData.m_Network->AddConvolution3dLayer(descriptor);
366*89c4ff92SAndroid Build Coastguard Worker layer->SetBackendId(setBackend);
367*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer != nullptr);
368*89c4ff92SAndroid Build Coastguard Worker
369*89c4ff92SAndroid Build Coastguard Worker // Add a constant layer for weights and biases if inputs are constant,
370*89c4ff92SAndroid Build Coastguard Worker // which are connected to the Convolution3d layer as inputs.
371*89c4ff92SAndroid Build Coastguard Worker if (filterTensorInfo.IsConstant())
372*89c4ff92SAndroid Build Coastguard Worker {
373*89c4ff92SAndroid Build Coastguard Worker auto filter = CreateConstTensor(&tfLiteFilterTensor,
374*89c4ff92SAndroid Build Coastguard Worker filterTensorInfo);
375*89c4ff92SAndroid Build Coastguard Worker
376*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(filter);
377*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(weightsLayer != nullptr);
378*89c4ff92SAndroid Build Coastguard Worker
379*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
380*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(filterTensorInfo);
381*89c4ff92SAndroid Build Coastguard Worker }
382*89c4ff92SAndroid Build Coastguard Worker
383*89c4ff92SAndroid Build Coastguard Worker if(biasEnabled)
384*89c4ff92SAndroid Build Coastguard Worker {
385*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
386*89c4ff92SAndroid Build Coastguard Worker if(biasTensorInfo.IsConstant())
387*89c4ff92SAndroid Build Coastguard Worker {
388*89c4ff92SAndroid Build Coastguard Worker auto biases = CreateConstTensor(&tfLiteBiasTensor,
389*89c4ff92SAndroid Build Coastguard Worker biasTensorInfo);
390*89c4ff92SAndroid Build Coastguard Worker
391*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biases);
392*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(biasLayer != nullptr);
393*89c4ff92SAndroid Build Coastguard Worker
394*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
395*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
396*89c4ff92SAndroid Build Coastguard Worker }
397*89c4ff92SAndroid Build Coastguard Worker }
398*89c4ff92SAndroid Build Coastguard Worker
399*89c4ff92SAndroid Build Coastguard Worker // The data input can also be constant, so we must check that this is also allocated to an input slot
400*89c4ff92SAndroid Build Coastguard Worker if(inputTensorInfo.IsConstant())
401*89c4ff92SAndroid Build Coastguard Worker {
402*89c4ff92SAndroid Build Coastguard Worker auto input =
403*89c4ff92SAndroid Build Coastguard Worker CreateConstTensor(&tfLiteContext->tensors[tfLiteNode->inputs->data[0]],
404*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo);
405*89c4ff92SAndroid Build Coastguard Worker
406*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer *inputLayer = delegateData.m_Network->AddConstantLayer(input);
407*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
408*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
409*89c4ff92SAndroid Build Coastguard Worker }
410*89c4ff92SAndroid Build Coastguard Worker
411*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
412*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(outputTensorInfo);
413*89c4ff92SAndroid Build Coastguard Worker
414*89c4ff92SAndroid Build Coastguard Worker if(Connect(layer, tfLiteNode, delegateData) != kTfLiteOk)
415*89c4ff92SAndroid Build Coastguard Worker {
416*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
417*89c4ff92SAndroid Build Coastguard Worker }
418*89c4ff92SAndroid Build Coastguard Worker
419*89c4ff92SAndroid Build Coastguard Worker if (!tfLiteNodeParameters)
420*89c4ff92SAndroid Build Coastguard Worker {
421*89c4ff92SAndroid Build Coastguard Worker // No Activation
422*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
423*89c4ff92SAndroid Build Coastguard Worker }
424*89c4ff92SAndroid Build Coastguard Worker
425*89c4ff92SAndroid Build Coastguard Worker // Check and create activation
426*89c4ff92SAndroid Build Coastguard Worker return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
427*89c4ff92SAndroid Build Coastguard Worker }
428*89c4ff92SAndroid Build Coastguard Worker #endif
429*89c4ff92SAndroid Build Coastguard Worker
VisitDepthwiseConv2dOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)430*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitDepthwiseConv2dOperator(DelegateData& delegateData,
431*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
432*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
433*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
434*89c4ff92SAndroid Build Coastguard Worker int32_t operatorCode)
435*89c4ff92SAndroid Build Coastguard Worker {
436*89c4ff92SAndroid Build Coastguard Worker auto numInputs = tfLiteNode->inputs->size;
437*89c4ff92SAndroid Build Coastguard Worker if (numInputs < 2)
438*89c4ff92SAndroid Build Coastguard Worker {
439*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
440*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
441*89c4ff92SAndroid Build Coastguard Worker 2, numInputs, nodeIndex);
442*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
443*89c4ff92SAndroid Build Coastguard Worker }
444*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
445*89c4ff92SAndroid Build Coastguard Worker
446*89c4ff92SAndroid Build Coastguard Worker bool biasEnabled = IsOptionalOperandPresent(tfLiteNode, 2);
447*89c4ff92SAndroid Build Coastguard Worker
448*89c4ff92SAndroid Build Coastguard Worker armnn::DepthwiseConvolution2dDescriptor descriptor;
449*89c4ff92SAndroid Build Coastguard Worker const auto params = reinterpret_cast<TfLiteDepthwiseConvParams*>(tfLiteNode->builtin_data);
450*89c4ff92SAndroid Build Coastguard Worker
451*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = biasEnabled;
452*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = NonNegative(params->stride_width, nodeIndex);
453*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = NonNegative(params->stride_height, nodeIndex);
454*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = armnn::DataLayout::NHWC;
455*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationX = NonNegative(params->dilation_width_factor, nodeIndex);
456*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationY = NonNegative(params->dilation_height_factor, nodeIndex);
457*89c4ff92SAndroid Build Coastguard Worker
458*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
459*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
460*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteInputTensor))
461*89c4ff92SAndroid Build Coastguard Worker {
462*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
463*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
464*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid input tensor in operator #%d node #%d: ",
465*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
466*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
467*89c4ff92SAndroid Build Coastguard Worker }
468*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteInputTensor))
469*89c4ff92SAndroid Build Coastguard Worker {
470*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
471*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
472*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
473*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
474*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
475*89c4ff92SAndroid Build Coastguard Worker }
476*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
477*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteOutputTensor))
478*89c4ff92SAndroid Build Coastguard Worker {
479*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
480*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
481*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid output tensor in operator #%d node #%d: ",
482*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
483*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
484*89c4ff92SAndroid Build Coastguard Worker }
485*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteOutputTensor))
486*89c4ff92SAndroid Build Coastguard Worker {
487*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
488*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
489*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
490*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
491*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
492*89c4ff92SAndroid Build Coastguard Worker }
493*89c4ff92SAndroid Build Coastguard Worker
494*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteFilterTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
495*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteFilterTensor))
496*89c4ff92SAndroid Build Coastguard Worker {
497*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
498*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
499*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid filter tensor in operator #%d node #%d: ",
500*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
501*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
502*89c4ff92SAndroid Build Coastguard Worker }
503*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteFilterTensor))
504*89c4ff92SAndroid Build Coastguard Worker {
505*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
506*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
507*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic filter tensors are not supported in node #%d: ",
508*89c4ff92SAndroid Build Coastguard Worker nodeIndex);
509*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
510*89c4ff92SAndroid Build Coastguard Worker }
511*89c4ff92SAndroid Build Coastguard Worker
512*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
513*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
514*89c4ff92SAndroid Build Coastguard Worker
515*89c4ff92SAndroid Build Coastguard Worker auto* tfLiteNodeParameters = reinterpret_cast<TfLiteDepthwiseConvParams *>(tfLiteNode->builtin_data);
516*89c4ff92SAndroid Build Coastguard Worker TfLiteFusedActivation activationType = kTfLiteActNone;
517*89c4ff92SAndroid Build Coastguard Worker if (tfLiteNodeParameters)
518*89c4ff92SAndroid Build Coastguard Worker {
519*89c4ff92SAndroid Build Coastguard Worker activationType = tfLiteNodeParameters->activation;
520*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
521*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo, activationType);
522*89c4ff92SAndroid Build Coastguard Worker if(activationStatus != kTfLiteOk)
523*89c4ff92SAndroid Build Coastguard Worker {
524*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
525*89c4ff92SAndroid Build Coastguard Worker }
526*89c4ff92SAndroid Build Coastguard Worker
527*89c4ff92SAndroid Build Coastguard Worker }
528*89c4ff92SAndroid Build Coastguard Worker
529*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& filterTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteFilterTensor);
530*89c4ff92SAndroid Build Coastguard Worker
531*89c4ff92SAndroid Build Coastguard Worker // Assuming input is NHWC
532*89c4ff92SAndroid Build Coastguard Worker unsigned int inputHeight = inputTensorInfo.GetShape()[1];
533*89c4ff92SAndroid Build Coastguard Worker unsigned int inputWidth = inputTensorInfo.GetShape()[2];
534*89c4ff92SAndroid Build Coastguard Worker
535*89c4ff92SAndroid Build Coastguard Worker // TensorflowLite weights come in the format [1, H, W, I * M]
536*89c4ff92SAndroid Build Coastguard Worker unsigned int filterHeight = filterTensorInfo.GetShape()[1];
537*89c4ff92SAndroid Build Coastguard Worker unsigned int filterWidth = filterTensorInfo.GetShape()[2];
538*89c4ff92SAndroid Build Coastguard Worker
539*89c4ff92SAndroid Build Coastguard Worker // Calculate padding
540*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputHeight, filterHeight, descriptor.m_StrideY, descriptor.m_DilationY,
541*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop, descriptor.m_PadBottom, params->padding);
542*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputWidth, filterWidth, descriptor.m_StrideX, descriptor.m_DilationX,
543*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft, descriptor.m_PadRight, params->padding);
544*89c4ff92SAndroid Build Coastguard Worker
545*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo biasTensorInfo;
546*89c4ff92SAndroid Build Coastguard Worker if(biasEnabled)
547*89c4ff92SAndroid Build Coastguard Worker {
548*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
549*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteBiasTensor))
550*89c4ff92SAndroid Build Coastguard Worker {
551*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
552*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
553*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid bias tensor in operator #%d node #%d: ",
554*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
555*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
556*89c4ff92SAndroid Build Coastguard Worker }
557*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteBiasTensor))
558*89c4ff92SAndroid Build Coastguard Worker {
559*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
560*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
561*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic bias tensors are not supported in node #%d: ",
562*89c4ff92SAndroid Build Coastguard Worker nodeIndex);
563*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
564*89c4ff92SAndroid Build Coastguard Worker }
565*89c4ff92SAndroid Build Coastguard Worker biasTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteBiasTensor);
566*89c4ff92SAndroid Build Coastguard Worker }
567*89c4ff92SAndroid Build Coastguard Worker else
568*89c4ff92SAndroid Build Coastguard Worker {
569*89c4ff92SAndroid Build Coastguard Worker biasTensorInfo = armnn::TensorInfo(armnn::TensorShape({1}), GetDataType(tfLiteInputTensor));
570*89c4ff92SAndroid Build Coastguard Worker }
571*89c4ff92SAndroid Build Coastguard Worker
572*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
573*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
574*89c4ff92SAndroid Build Coastguard Worker {
575*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
576*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("DEPTHWISE_CONV2D",
577*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
578*89c4ff92SAndroid Build Coastguard Worker IsDepthwiseConvolutionSupported,
579*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
580*89c4ff92SAndroid Build Coastguard Worker isSupported,
581*89c4ff92SAndroid Build Coastguard Worker setBackend,
582*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
583*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
584*89c4ff92SAndroid Build Coastguard Worker descriptor,
585*89c4ff92SAndroid Build Coastguard Worker filterTensorInfo,
586*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<armnn::TensorInfo>(biasTensorInfo));
587*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
588*89c4ff92SAndroid Build Coastguard Worker }
589*89c4ff92SAndroid Build Coastguard Worker
590*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = delegateData.m_Network->AddDepthwiseConvolution2dLayer(descriptor);
591*89c4ff92SAndroid Build Coastguard Worker layer->SetBackendId(setBackend);
592*89c4ff92SAndroid Build Coastguard Worker
593*89c4ff92SAndroid Build Coastguard Worker if(filterTensorInfo.IsConstant())
594*89c4ff92SAndroid Build Coastguard Worker {
595*89c4ff92SAndroid Build Coastguard Worker // For depthwise the weights layout is the same as for tflite [1, H, W, I*M]. No permutation required.
596*89c4ff92SAndroid Build Coastguard Worker auto filter = CreateConstTensor(&tfLiteFilterTensor, filterTensorInfo);
597*89c4ff92SAndroid Build Coastguard Worker
598*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* weightsLayer = delegateData.m_Network->AddConstantLayer(filter);
599*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
600*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(filterTensorInfo);
601*89c4ff92SAndroid Build Coastguard Worker }
602*89c4ff92SAndroid Build Coastguard Worker
603*89c4ff92SAndroid Build Coastguard Worker if (biasEnabled)
604*89c4ff92SAndroid Build Coastguard Worker {
605*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteBiasTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
606*89c4ff92SAndroid Build Coastguard Worker if(biasTensorInfo.IsConstant())
607*89c4ff92SAndroid Build Coastguard Worker {
608*89c4ff92SAndroid Build Coastguard Worker auto biasTensor = CreateConstTensor(&tfLiteBiasTensor, biasTensorInfo);
609*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* biasLayer = delegateData.m_Network->AddConstantLayer(biasTensor);
610*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(biasLayer != nullptr);
611*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
612*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensorInfo);
613*89c4ff92SAndroid Build Coastguard Worker }
614*89c4ff92SAndroid Build Coastguard Worker }
615*89c4ff92SAndroid Build Coastguard Worker
616*89c4ff92SAndroid Build Coastguard Worker // The data input can also be constant, so we must check that this is also allocated to an input slot
617*89c4ff92SAndroid Build Coastguard Worker if(inputTensorInfo.IsConstant())
618*89c4ff92SAndroid Build Coastguard Worker {
619*89c4ff92SAndroid Build Coastguard Worker auto input =
620*89c4ff92SAndroid Build Coastguard Worker CreateConstTensor(&tfLiteContext->tensors[tfLiteNode->inputs->data[0]],
621*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo);
622*89c4ff92SAndroid Build Coastguard Worker
623*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer *inputLayer = delegateData.m_Network->AddConstantLayer(input);
624*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
625*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
626*89c4ff92SAndroid Build Coastguard Worker }
627*89c4ff92SAndroid Build Coastguard Worker
628*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer != nullptr);
629*89c4ff92SAndroid Build Coastguard Worker
630*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
631*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(outputTensorInfo);
632*89c4ff92SAndroid Build Coastguard Worker
633*89c4ff92SAndroid Build Coastguard Worker if(Connect(layer, tfLiteNode, delegateData) != kTfLiteOk)
634*89c4ff92SAndroid Build Coastguard Worker {
635*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
636*89c4ff92SAndroid Build Coastguard Worker }
637*89c4ff92SAndroid Build Coastguard Worker
638*89c4ff92SAndroid Build Coastguard Worker if (!tfLiteNodeParameters)
639*89c4ff92SAndroid Build Coastguard Worker {
640*89c4ff92SAndroid Build Coastguard Worker // No Activation
641*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
642*89c4ff92SAndroid Build Coastguard Worker }
643*89c4ff92SAndroid Build Coastguard Worker // Check and create activation
644*89c4ff92SAndroid Build Coastguard Worker return FusedActivation(tfLiteContext, tfLiteNode, activationType, layer, 0, delegateData);
645*89c4ff92SAndroid Build Coastguard Worker }
646*89c4ff92SAndroid Build Coastguard Worker
VisitTransposeConv2dOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)647*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitTransposeConv2dOperator(DelegateData& delegateData,
648*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
649*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
650*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
651*89c4ff92SAndroid Build Coastguard Worker int32_t operatorCode)
652*89c4ff92SAndroid Build Coastguard Worker {
653*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
654*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
655*89c4ff92SAndroid Build Coastguard Worker
656*89c4ff92SAndroid Build Coastguard Worker armnn::TransposeConvolution2dDescriptor descriptor;
657*89c4ff92SAndroid Build Coastguard Worker auto* parameters = reinterpret_cast<TfLiteTransposeConvParams*>(tfLiteNode->builtin_data);
658*89c4ff92SAndroid Build Coastguard Worker descriptor.m_BiasEnabled = false;
659*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = NonNegative(parameters->stride_width, nodeIndex);
660*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = NonNegative(parameters->stride_height, nodeIndex);
661*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = armnn::DataLayout::NHWC;
662*89c4ff92SAndroid Build Coastguard Worker
663*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
664*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputShapeTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
665*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteOutputShapeTensor))
666*89c4ff92SAndroid Build Coastguard Worker {
667*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
668*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
669*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid input tensor in operator #%d node #%d: ",
670*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
671*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
672*89c4ff92SAndroid Build Coastguard Worker }
673*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteOutputShapeTensor))
674*89c4ff92SAndroid Build Coastguard Worker {
675*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
676*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
677*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
678*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
679*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
680*89c4ff92SAndroid Build Coastguard Worker }
681*89c4ff92SAndroid Build Coastguard Worker
682*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo outputShapeTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputShapeTensor);
683*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape(outputShapeTensorInfo.GetNumElements());
684*89c4ff92SAndroid Build Coastguard Worker if (outputShapeTensorInfo.GetDataType() == armnn::DataType::Signed32)
685*89c4ff92SAndroid Build Coastguard Worker {
686*89c4ff92SAndroid Build Coastguard Worker for(unsigned int i=0; i < outputShapeTensorInfo.GetNumElements(); i++)
687*89c4ff92SAndroid Build Coastguard Worker {
688*89c4ff92SAndroid Build Coastguard Worker outputShape[i] = ::tflite::GetTensorData<int32_t>(&tfLiteOutputShapeTensor)[i];
689*89c4ff92SAndroid Build Coastguard Worker }
690*89c4ff92SAndroid Build Coastguard Worker }
691*89c4ff92SAndroid Build Coastguard Worker
692*89c4ff92SAndroid Build Coastguard Worker if (outputShapeTensorInfo.GetDataType() == armnn::DataType::QAsymmU8)
693*89c4ff92SAndroid Build Coastguard Worker {
694*89c4ff92SAndroid Build Coastguard Worker for(unsigned int i=0; i < outputShapeTensorInfo.GetNumElements(); i++)
695*89c4ff92SAndroid Build Coastguard Worker {
696*89c4ff92SAndroid Build Coastguard Worker outputShape[i] = ::tflite::GetTensorData<uint8_t>(&tfLiteOutputShapeTensor)[i];
697*89c4ff92SAndroid Build Coastguard Worker }
698*89c4ff92SAndroid Build Coastguard Worker }
699*89c4ff92SAndroid Build Coastguard Worker // Change from signed to unsigned int to store in TransposeConvolution2dDescriptor.
700*89c4ff92SAndroid Build Coastguard Worker for (int dimension : outputShape)
701*89c4ff92SAndroid Build Coastguard Worker {
702*89c4ff92SAndroid Build Coastguard Worker descriptor.m_OutputShape.push_back(static_cast<unsigned int>(dimension));
703*89c4ff92SAndroid Build Coastguard Worker }
704*89c4ff92SAndroid Build Coastguard Worker descriptor.m_OutputShapeEnabled = true;
705*89c4ff92SAndroid Build Coastguard Worker
706*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
707*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteInputTensor))
708*89c4ff92SAndroid Build Coastguard Worker {
709*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
710*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
711*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid input tensor in operator #%d node #%d: ",
712*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
713*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
714*89c4ff92SAndroid Build Coastguard Worker }
715*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteInputTensor))
716*89c4ff92SAndroid Build Coastguard Worker {
717*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
718*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
719*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
720*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
721*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
722*89c4ff92SAndroid Build Coastguard Worker }
723*89c4ff92SAndroid Build Coastguard Worker
724*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
725*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteOutputTensor))
726*89c4ff92SAndroid Build Coastguard Worker {
727*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
728*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
729*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid output tensor in operator #%d node #%d: ",
730*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
731*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
732*89c4ff92SAndroid Build Coastguard Worker }
733*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteOutputTensor))
734*89c4ff92SAndroid Build Coastguard Worker {
735*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
736*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
737*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
738*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
739*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
740*89c4ff92SAndroid Build Coastguard Worker }
741*89c4ff92SAndroid Build Coastguard Worker
742*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteFilterTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
743*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteFilterTensor))
744*89c4ff92SAndroid Build Coastguard Worker {
745*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
746*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
747*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid filter tensor in operator #%d node #%d: ",
748*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
749*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
750*89c4ff92SAndroid Build Coastguard Worker }
751*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteFilterTensor))
752*89c4ff92SAndroid Build Coastguard Worker {
753*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
754*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
755*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic filter tensors are not supported in operator #%d node #%d: ",
756*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
757*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
758*89c4ff92SAndroid Build Coastguard Worker }
759*89c4ff92SAndroid Build Coastguard Worker
760*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
761*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
762*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& filterTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteFilterTensor);
763*89c4ff92SAndroid Build Coastguard Worker
764*89c4ff92SAndroid Build Coastguard Worker // TfLite uses NHWC tensors
765*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputHeight = inputTensorInfo.GetShape()[1];
766*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputWidth = inputTensorInfo.GetShape()[2];
767*89c4ff92SAndroid Build Coastguard Worker
768*89c4ff92SAndroid Build Coastguard Worker const unsigned int filterHeight = filterTensorInfo.GetShape()[1];
769*89c4ff92SAndroid Build Coastguard Worker const unsigned int filterWidth = filterTensorInfo.GetShape()[2];
770*89c4ff92SAndroid Build Coastguard Worker
771*89c4ff92SAndroid Build Coastguard Worker // Calculate padding
772*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputHeight,
773*89c4ff92SAndroid Build Coastguard Worker filterHeight,
774*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY,
775*89c4ff92SAndroid Build Coastguard Worker 1, // dilation y
776*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop,
777*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadBottom,
778*89c4ff92SAndroid Build Coastguard Worker parameters->padding);
779*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputWidth,
780*89c4ff92SAndroid Build Coastguard Worker filterWidth,
781*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX,
782*89c4ff92SAndroid Build Coastguard Worker 1, // dilation x
783*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft,
784*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadRight,
785*89c4ff92SAndroid Build Coastguard Worker parameters->padding);
786*89c4ff92SAndroid Build Coastguard Worker
787*89c4ff92SAndroid Build Coastguard Worker // Set up filter
788*89c4ff92SAndroid Build Coastguard Worker auto filterTensor = CreateConstTensor(&tfLiteFilterTensor,
789*89c4ff92SAndroid Build Coastguard Worker filterTensorInfo);
790*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
791*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
792*89c4ff92SAndroid Build Coastguard Worker {
793*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
794*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("TRANSPOSE_CONV2D",
795*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
796*89c4ff92SAndroid Build Coastguard Worker IsTransposeConvolution2dSupported,
797*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
798*89c4ff92SAndroid Build Coastguard Worker isSupported,
799*89c4ff92SAndroid Build Coastguard Worker setBackend,
800*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
801*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
802*89c4ff92SAndroid Build Coastguard Worker descriptor,
803*89c4ff92SAndroid Build Coastguard Worker filterTensorInfo,
804*89c4ff92SAndroid Build Coastguard Worker armnn::EmptyOptional());
805*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
806*89c4ff92SAndroid Build Coastguard Worker }
807*89c4ff92SAndroid Build Coastguard Worker
808*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = delegateData.m_Network->AddTransposeConvolution2dLayer(descriptor,
809*89c4ff92SAndroid Build Coastguard Worker filterTensor,
810*89c4ff92SAndroid Build Coastguard Worker armnn::EmptyOptional());
811*89c4ff92SAndroid Build Coastguard Worker layer->SetBackendId(setBackend);
812*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer != nullptr);
813*89c4ff92SAndroid Build Coastguard Worker
814*89c4ff92SAndroid Build Coastguard Worker // The data input can be constant, so we must check that this is allocated to an input slot
815*89c4ff92SAndroid Build Coastguard Worker if(inputTensorInfo.IsConstant())
816*89c4ff92SAndroid Build Coastguard Worker {
817*89c4ff92SAndroid Build Coastguard Worker auto input =
818*89c4ff92SAndroid Build Coastguard Worker CreateConstTensor(&tfLiteContext->tensors[tfLiteNode->inputs->data[2]],
819*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo);
820*89c4ff92SAndroid Build Coastguard Worker
821*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer *inputLayer = delegateData.m_Network->AddConstantLayer(input);
822*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
823*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
824*89c4ff92SAndroid Build Coastguard Worker }
825*89c4ff92SAndroid Build Coastguard Worker
826*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
827*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(outputTensorInfo);
828*89c4ff92SAndroid Build Coastguard Worker
829*89c4ff92SAndroid Build Coastguard Worker // Connect
830*89c4ff92SAndroid Build Coastguard Worker if (delegateData.m_OutputSlotForNode[static_cast<unsigned int>(tfLiteNode->inputs->data[2])] != nullptr)
831*89c4ff92SAndroid Build Coastguard Worker {
832*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode[static_cast<unsigned int>(tfLiteNode->inputs->data[2])]->
833*89c4ff92SAndroid Build Coastguard Worker Connect(layer->GetInputSlot(0));
834*89c4ff92SAndroid Build Coastguard Worker }
835*89c4ff92SAndroid Build Coastguard Worker
836*89c4ff92SAndroid Build Coastguard Worker // Prepare output slots
837*89c4ff92SAndroid Build Coastguard Worker for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
838*89c4ff92SAndroid Build Coastguard Worker {
839*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
840*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode[static_cast<unsigned int>(tfLiteNode->outputs->data[outputIndex])] =
841*89c4ff92SAndroid Build Coastguard Worker &outputSlot;
842*89c4ff92SAndroid Build Coastguard Worker }
843*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
844*89c4ff92SAndroid Build Coastguard Worker }
845*89c4ff92SAndroid Build Coastguard Worker
VisitConvolutionOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)846*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitConvolutionOperator(DelegateData& delegateData,
847*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
848*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
849*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
850*89c4ff92SAndroid Build Coastguard Worker int32_t operatorCode)
851*89c4ff92SAndroid Build Coastguard Worker {
852*89c4ff92SAndroid Build Coastguard Worker switch(operatorCode)
853*89c4ff92SAndroid Build Coastguard Worker {
854*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinConv2d:
855*89c4ff92SAndroid Build Coastguard Worker return VisitConv2dOperator(delegateData, tfLiteContext, tfLiteNode, nodeIndex, operatorCode);
856*89c4ff92SAndroid Build Coastguard Worker // Conv3d is only correctly supported for external delegates from TF Lite v2.6, as there was a breaking bug in v2.5.
857*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_POST_TFLITE_2_5)
858*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinConv3d:
859*89c4ff92SAndroid Build Coastguard Worker return VisitConv3dOperator(delegateData, tfLiteContext, tfLiteNode, nodeIndex, operatorCode);
860*89c4ff92SAndroid Build Coastguard Worker #endif
861*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinDepthwiseConv2d:
862*89c4ff92SAndroid Build Coastguard Worker return VisitDepthwiseConv2dOperator(delegateData, tfLiteContext, tfLiteNode, nodeIndex, operatorCode);
863*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinTransposeConv:
864*89c4ff92SAndroid Build Coastguard Worker return VisitTransposeConv2dOperator(delegateData, tfLiteContext, tfLiteNode, nodeIndex, operatorCode);
865*89c4ff92SAndroid Build Coastguard Worker default:
866*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
867*89c4ff92SAndroid Build Coastguard Worker }
868*89c4ff92SAndroid Build Coastguard Worker }
869*89c4ff92SAndroid Build Coastguard Worker
870*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
871