1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <ClassicDelegateUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flexbuffers.h>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker
VisitPooling2dOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLitePoolingOperatorCode)19*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitPooling2dOperator(DelegateData& delegateData,
20*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
21*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
22*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
23*89c4ff92SAndroid Build Coastguard Worker int32_t tfLitePoolingOperatorCode)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
27*89c4ff92SAndroid Build Coastguard Worker
28*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
29*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
30*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteInputTensor))
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
33*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
34*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
35*89c4ff92SAndroid Build Coastguard Worker tfLitePoolingOperatorCode, nodeIndex);
36*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker
39*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
40*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteOutputTensor))
41*89c4ff92SAndroid Build Coastguard Worker {
42*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
43*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
44*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
45*89c4ff92SAndroid Build Coastguard Worker tfLitePoolingOperatorCode, nodeIndex);
46*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
50*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker auto* tfLiteNodeParameters = reinterpret_cast<TfLitePoolParams*>(tfLiteNode->builtin_data);
53*89c4ff92SAndroid Build Coastguard Worker TfLiteFusedActivation activationType = kTfLiteActNone;
54*89c4ff92SAndroid Build Coastguard Worker if (tfLiteNodeParameters)
55*89c4ff92SAndroid Build Coastguard Worker {
56*89c4ff92SAndroid Build Coastguard Worker activationType = tfLiteNodeParameters->activation;
57*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
58*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo, activationType);
59*89c4ff92SAndroid Build Coastguard Worker if(activationStatus != kTfLiteOk)
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker
64*89c4ff92SAndroid Build Coastguard Worker }
65*89c4ff92SAndroid Build Coastguard Worker
66*89c4ff92SAndroid Build Coastguard Worker armnn::PoolingAlgorithm poolingAlgorithm;
67*89c4ff92SAndroid Build Coastguard Worker switch(tfLitePoolingOperatorCode)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinAveragePool2d:
70*89c4ff92SAndroid Build Coastguard Worker poolingAlgorithm = armnn::PoolingAlgorithm::Average;
71*89c4ff92SAndroid Build Coastguard Worker break;
72*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinL2Pool2d:
73*89c4ff92SAndroid Build Coastguard Worker poolingAlgorithm = armnn::PoolingAlgorithm::L2;
74*89c4ff92SAndroid Build Coastguard Worker break;
75*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMaxPool2d:
76*89c4ff92SAndroid Build Coastguard Worker poolingAlgorithm = armnn::PoolingAlgorithm::Max;
77*89c4ff92SAndroid Build Coastguard Worker break;
78*89c4ff92SAndroid Build Coastguard Worker default:
79*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
80*89c4ff92SAndroid Build Coastguard Worker }
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker armnn::Pooling2dDescriptor descriptor;
83*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PoolType = poolingAlgorithm;
84*89c4ff92SAndroid Build Coastguard Worker
85*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PoolWidth = tfLiteNodeParameters->filter_width;
86*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PoolHeight = tfLiteNodeParameters->filter_height;
87*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = tfLiteNodeParameters->stride_width;
88*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = tfLiteNodeParameters->stride_height;
89*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = armnn::DataLayout::NHWC;
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker unsigned int inputHeight = inputTensorInfo.GetShape()[1];
92*89c4ff92SAndroid Build Coastguard Worker unsigned int inputWidth = inputTensorInfo.GetShape()[2];
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputHeight, descriptor.m_PoolHeight, descriptor.m_StrideY, 1u,
95*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop, descriptor.m_PadBottom, tfLiteNodeParameters->padding);
96*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputWidth, descriptor.m_PoolWidth, descriptor.m_StrideX, 1u,
97*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft, descriptor.m_PadRight, tfLiteNodeParameters->padding);
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
100*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
101*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
102*89c4ff92SAndroid Build Coastguard Worker {
103*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("POOLING_2D",
104*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
105*89c4ff92SAndroid Build Coastguard Worker IsPooling2dSupported,
106*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
107*89c4ff92SAndroid Build Coastguard Worker isSupported,
108*89c4ff92SAndroid Build Coastguard Worker setBackend,
109*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
110*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
111*89c4ff92SAndroid Build Coastguard Worker descriptor);
112*89c4ff92SAndroid Build Coastguard Worker };
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputTensorInfo, isSupported);
117*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker
120*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* poolingLayer = delegateData.m_Network->AddPooling2dLayer(descriptor);
121*89c4ff92SAndroid Build Coastguard Worker poolingLayer->SetBackendId(setBackend);
122*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(poolingLayer != nullptr);
123*89c4ff92SAndroid Build Coastguard Worker
124*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = poolingLayer->GetOutputSlot(0);
125*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(outputTensorInfo);
126*89c4ff92SAndroid Build Coastguard Worker
127*89c4ff92SAndroid Build Coastguard Worker // try to connect the Constant Inputs if there are any
128*89c4ff92SAndroid Build Coastguard Worker if(ProcessInputs(poolingLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
129*89c4ff92SAndroid Build Coastguard Worker {
130*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker
133*89c4ff92SAndroid Build Coastguard Worker if(Connect(poolingLayer, tfLiteNode, delegateData) != kTfLiteOk)
134*89c4ff92SAndroid Build Coastguard Worker {
135*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Worker // Check and create activation
139*89c4ff92SAndroid Build Coastguard Worker return FusedActivation(tfLiteContext, tfLiteNode, activationType, poolingLayer, 0, delegateData);
140*89c4ff92SAndroid Build Coastguard Worker }
141*89c4ff92SAndroid Build Coastguard Worker
VisitPooling3dOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,std::string customOperatorName)142*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitPooling3dOperator(DelegateData& delegateData,
143*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
144*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
145*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
146*89c4ff92SAndroid Build Coastguard Worker std::string customOperatorName)
147*89c4ff92SAndroid Build Coastguard Worker {
148*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
149*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
150*89c4ff92SAndroid Build Coastguard Worker
151*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
152*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
153*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteInputTensor))
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
156*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
157*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
158*89c4ff92SAndroid Build Coastguard Worker customOperatorName.c_str(), nodeIndex);
159*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker
162*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
163*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteOutputTensor))
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
166*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
167*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
168*89c4ff92SAndroid Build Coastguard Worker customOperatorName.c_str(), nodeIndex);
169*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
170*89c4ff92SAndroid Build Coastguard Worker }
171*89c4ff92SAndroid Build Coastguard Worker // Set the input and output info
172*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
173*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Worker // Custom Operators are defined by the name string associated to the operator. Use this to determine
176*89c4ff92SAndroid Build Coastguard Worker // which pooling algorithm to create the armnn operator with. L2 Pooling3D is unsupported in TfLite.
177*89c4ff92SAndroid Build Coastguard Worker armnn::PoolingAlgorithm poolingAlgorithm;
178*89c4ff92SAndroid Build Coastguard Worker if (customOperatorName == "MaxPool3D")
179*89c4ff92SAndroid Build Coastguard Worker {
180*89c4ff92SAndroid Build Coastguard Worker poolingAlgorithm = armnn::PoolingAlgorithm::Max;
181*89c4ff92SAndroid Build Coastguard Worker }
182*89c4ff92SAndroid Build Coastguard Worker else if (customOperatorName == "AveragePool3D")
183*89c4ff92SAndroid Build Coastguard Worker {
184*89c4ff92SAndroid Build Coastguard Worker poolingAlgorithm = armnn::PoolingAlgorithm::Average;
185*89c4ff92SAndroid Build Coastguard Worker }
186*89c4ff92SAndroid Build Coastguard Worker else
187*89c4ff92SAndroid Build Coastguard Worker {
188*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
189*89c4ff92SAndroid Build Coastguard Worker }
190*89c4ff92SAndroid Build Coastguard Worker // Create the armnn pool3d descriptor and set the algorithm parsed above.
191*89c4ff92SAndroid Build Coastguard Worker armnn::Pooling3dDescriptor descriptor;
192*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PoolType = poolingAlgorithm;
193*89c4ff92SAndroid Build Coastguard Worker
194*89c4ff92SAndroid Build Coastguard Worker // custom_initial_data and custom_initial_data_size are void* variables defined in the tflite registration
195*89c4ff92SAndroid Build Coastguard Worker // used to access the custom option buffer for the operator.
196*89c4ff92SAndroid Build Coastguard Worker auto custom_data = tfLiteNode->custom_initial_data;
197*89c4ff92SAndroid Build Coastguard Worker auto custom_data_size = tfLiteNode->custom_initial_data_size;
198*89c4ff92SAndroid Build Coastguard Worker // Reinterpret the void* to a byte buffer to access the options data in the flexbuffers map.
199*89c4ff92SAndroid Build Coastguard Worker const flexbuffers::Map& m = flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(custom_data),
200*89c4ff92SAndroid Build Coastguard Worker custom_data_size).AsMap();
201*89c4ff92SAndroid Build Coastguard Worker // poolDims is a vector of [ 1, Depth, Height, Width, 1 ]
202*89c4ff92SAndroid Build Coastguard Worker const auto poolDims = m["ksize"].AsTypedVector();
203*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PoolWidth = poolDims[3].AsInt32();
204*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PoolHeight = poolDims[2].AsInt32();
205*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PoolDepth = poolDims[1].AsInt32();
206*89c4ff92SAndroid Build Coastguard Worker
207*89c4ff92SAndroid Build Coastguard Worker // strideDimes is a vector of [ 1, Z, Y, X, 1]
208*89c4ff92SAndroid Build Coastguard Worker const auto strideDims = m["strides"].AsTypedVector();
209*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideX = strideDims[3].AsInt32();
210*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideY = strideDims[2].AsInt32();
211*89c4ff92SAndroid Build Coastguard Worker descriptor.m_StrideZ = strideDims[1].AsInt32();
212*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DataLayout = armnn::DataLayout::NDHWC;
213*89c4ff92SAndroid Build Coastguard Worker
214*89c4ff92SAndroid Build Coastguard Worker unsigned int inputDepth = inputTensorInfo.GetShape()[1];
215*89c4ff92SAndroid Build Coastguard Worker unsigned int inputHeight = inputTensorInfo.GetShape()[2];
216*89c4ff92SAndroid Build Coastguard Worker unsigned int inputWidth = inputTensorInfo.GetShape()[3];
217*89c4ff92SAndroid Build Coastguard Worker
218*89c4ff92SAndroid Build Coastguard Worker // CalcPadding expects a TfLitePadding type. Parse flexbuffers to extract padding string and create TfLitePadding.
219*89c4ff92SAndroid Build Coastguard Worker std::string paddingStr = m["padding"].AsString().str();
220*89c4ff92SAndroid Build Coastguard Worker TfLitePadding padding;
221*89c4ff92SAndroid Build Coastguard Worker if (paddingStr == "VALID")
222*89c4ff92SAndroid Build Coastguard Worker {
223*89c4ff92SAndroid Build Coastguard Worker padding = kTfLitePaddingValid;
224*89c4ff92SAndroid Build Coastguard Worker }
225*89c4ff92SAndroid Build Coastguard Worker else if (paddingStr == "SAME")
226*89c4ff92SAndroid Build Coastguard Worker {
227*89c4ff92SAndroid Build Coastguard Worker padding = kTfLitePaddingSame;
228*89c4ff92SAndroid Build Coastguard Worker }
229*89c4ff92SAndroid Build Coastguard Worker else
230*89c4ff92SAndroid Build Coastguard Worker {
231*89c4ff92SAndroid Build Coastguard Worker padding = kTfLitePaddingUnknown;
232*89c4ff92SAndroid Build Coastguard Worker }
233*89c4ff92SAndroid Build Coastguard Worker // Calculates padding for each pooling dimension separately
234*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputHeight, descriptor.m_PoolHeight, descriptor.m_StrideY, 1u,
235*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadTop, descriptor.m_PadBottom, padding);
236*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputWidth, descriptor.m_PoolWidth, descriptor.m_StrideX, 1u,
237*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadLeft, descriptor.m_PadRight, padding);
238*89c4ff92SAndroid Build Coastguard Worker CalcPadding(inputDepth, descriptor.m_PoolDepth, descriptor.m_StrideZ, 1u,
239*89c4ff92SAndroid Build Coastguard Worker descriptor.m_PadFront, descriptor.m_PadBack, padding);
240*89c4ff92SAndroid Build Coastguard Worker
241*89c4ff92SAndroid Build Coastguard Worker
242*89c4ff92SAndroid Build Coastguard Worker // Check activation by parsing the string from the flexbuffer map
243*89c4ff92SAndroid Build Coastguard Worker std::string activationTypeStr = m["activation"].AsString().str();
244*89c4ff92SAndroid Build Coastguard Worker TfLiteFusedActivation activationType = kTfLiteActNone;
245*89c4ff92SAndroid Build Coastguard Worker
246*89c4ff92SAndroid Build Coastguard Worker if (activationTypeStr == "kTfLiteActRelu")
247*89c4ff92SAndroid Build Coastguard Worker {
248*89c4ff92SAndroid Build Coastguard Worker activationType = kTfLiteActRelu;
249*89c4ff92SAndroid Build Coastguard Worker }
250*89c4ff92SAndroid Build Coastguard Worker else if (activationTypeStr == "kTfLiteActReluN1To1")
251*89c4ff92SAndroid Build Coastguard Worker {
252*89c4ff92SAndroid Build Coastguard Worker activationType = kTfLiteActReluN1To1;
253*89c4ff92SAndroid Build Coastguard Worker }
254*89c4ff92SAndroid Build Coastguard Worker else if (activationTypeStr == "kTfLiteActRelu6")
255*89c4ff92SAndroid Build Coastguard Worker {
256*89c4ff92SAndroid Build Coastguard Worker activationType = kTfLiteActRelu6;
257*89c4ff92SAndroid Build Coastguard Worker }
258*89c4ff92SAndroid Build Coastguard Worker else if (activationTypeStr == "kTfLiteActTanh")
259*89c4ff92SAndroid Build Coastguard Worker {
260*89c4ff92SAndroid Build Coastguard Worker activationType = kTfLiteActTanh;
261*89c4ff92SAndroid Build Coastguard Worker }
262*89c4ff92SAndroid Build Coastguard Worker else if (activationTypeStr == "kTfLiteActSignBit")
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker activationType = kTfLiteActSignBit;
265*89c4ff92SAndroid Build Coastguard Worker }
266*89c4ff92SAndroid Build Coastguard Worker else if (activationTypeStr == "kTfLiteActSigmoid")
267*89c4ff92SAndroid Build Coastguard Worker {
268*89c4ff92SAndroid Build Coastguard Worker activationType = kTfLiteActSigmoid;
269*89c4ff92SAndroid Build Coastguard Worker }
270*89c4ff92SAndroid Build Coastguard Worker else
271*89c4ff92SAndroid Build Coastguard Worker {
272*89c4ff92SAndroid Build Coastguard Worker activationType = kTfLiteActNone;
273*89c4ff92SAndroid Build Coastguard Worker }
274*89c4ff92SAndroid Build Coastguard Worker
275*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
276*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo, activationType);
277*89c4ff92SAndroid Build Coastguard Worker if(activationStatus != kTfLiteOk)
278*89c4ff92SAndroid Build Coastguard Worker {
279*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
280*89c4ff92SAndroid Build Coastguard Worker }
281*89c4ff92SAndroid Build Coastguard Worker
282*89c4ff92SAndroid Build Coastguard Worker
283*89c4ff92SAndroid Build Coastguard Worker // Validate the output info.
284*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
285*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
286*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) {
287*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("POOLING_3D",
288*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
289*89c4ff92SAndroid Build Coastguard Worker IsPooling3dSupported,
290*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
291*89c4ff92SAndroid Build Coastguard Worker isSupported,
292*89c4ff92SAndroid Build Coastguard Worker setBackend,
293*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
294*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
295*89c4ff92SAndroid Build Coastguard Worker descriptor);
296*89c4ff92SAndroid Build Coastguard Worker };
297*89c4ff92SAndroid Build Coastguard Worker
298*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
299*89c4ff92SAndroid Build Coastguard Worker {
300*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputTensorInfo, isSupported);
301*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
302*89c4ff92SAndroid Build Coastguard Worker }
303*89c4ff92SAndroid Build Coastguard Worker
304*89c4ff92SAndroid Build Coastguard Worker // Create the Layer
305*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* poolingLayer = delegateData.m_Network->AddPooling3dLayer(descriptor);
306*89c4ff92SAndroid Build Coastguard Worker poolingLayer->SetBackendId(setBackend);
307*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(poolingLayer != nullptr);
308*89c4ff92SAndroid Build Coastguard Worker
309*89c4ff92SAndroid Build Coastguard Worker // Create and set output slots
310*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = poolingLayer->GetOutputSlot(0);
311*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(outputTensorInfo);
312*89c4ff92SAndroid Build Coastguard Worker
313*89c4ff92SAndroid Build Coastguard Worker // try to connect the Constant Inputs if there are any
314*89c4ff92SAndroid Build Coastguard Worker if(ProcessInputs(poolingLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
315*89c4ff92SAndroid Build Coastguard Worker {
316*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
317*89c4ff92SAndroid Build Coastguard Worker }
318*89c4ff92SAndroid Build Coastguard Worker
319*89c4ff92SAndroid Build Coastguard Worker if(Connect(poolingLayer, tfLiteNode, delegateData) != kTfLiteOk)
320*89c4ff92SAndroid Build Coastguard Worker {
321*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
322*89c4ff92SAndroid Build Coastguard Worker }
323*89c4ff92SAndroid Build Coastguard Worker
324*89c4ff92SAndroid Build Coastguard Worker return FusedActivation(tfLiteContext, tfLiteNode, activationType, poolingLayer, 0, delegateData);
325*89c4ff92SAndroid Build Coastguard Worker }
326*89c4ff92SAndroid Build Coastguard Worker
327*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
328