xref: /aosp_15_r20/external/armnn/delegate/classic/src/Pooling.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <ClassicDelegateUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flexbuffers.h>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker 
VisitPooling2dOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLitePoolingOperatorCode)19*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitPooling2dOperator(DelegateData& delegateData,
20*89c4ff92SAndroid Build Coastguard Worker                                     TfLiteContext* tfLiteContext,
21*89c4ff92SAndroid Build Coastguard Worker                                     TfLiteNode* tfLiteNode,
22*89c4ff92SAndroid Build Coastguard Worker                                     int nodeIndex,
23*89c4ff92SAndroid Build Coastguard Worker                                     int32_t tfLitePoolingOperatorCode)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
29*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
30*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamicTensor(tfLiteInputTensor))
31*89c4ff92SAndroid Build Coastguard Worker     {
32*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
33*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
34*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
35*89c4ff92SAndroid Build Coastguard Worker             tfLitePoolingOperatorCode, nodeIndex);
36*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
37*89c4ff92SAndroid Build Coastguard Worker     }
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
40*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamicTensor(tfLiteOutputTensor))
41*89c4ff92SAndroid Build Coastguard Worker     {
42*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
43*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
44*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
45*89c4ff92SAndroid Build Coastguard Worker             tfLitePoolingOperatorCode, nodeIndex);
46*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
47*89c4ff92SAndroid Build Coastguard Worker     }
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
50*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     auto* tfLiteNodeParameters = reinterpret_cast<TfLitePoolParams*>(tfLiteNode->builtin_data);
53*89c4ff92SAndroid Build Coastguard Worker     TfLiteFusedActivation activationType = kTfLiteActNone;
54*89c4ff92SAndroid Build Coastguard Worker     if (tfLiteNodeParameters)
55*89c4ff92SAndroid Build Coastguard Worker     {
56*89c4ff92SAndroid Build Coastguard Worker         activationType = tfLiteNodeParameters->activation;
57*89c4ff92SAndroid Build Coastguard Worker         TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
58*89c4ff92SAndroid Build Coastguard Worker                                                                         outputTensorInfo, activationType);
59*89c4ff92SAndroid Build Coastguard Worker         if(activationStatus != kTfLiteOk)
60*89c4ff92SAndroid Build Coastguard Worker         {
61*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
62*89c4ff92SAndroid Build Coastguard Worker         }
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     }
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     armnn::PoolingAlgorithm poolingAlgorithm;
67*89c4ff92SAndroid Build Coastguard Worker     switch(tfLitePoolingOperatorCode)
68*89c4ff92SAndroid Build Coastguard Worker     {
69*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinAveragePool2d:
70*89c4ff92SAndroid Build Coastguard Worker             poolingAlgorithm = armnn::PoolingAlgorithm::Average;
71*89c4ff92SAndroid Build Coastguard Worker             break;
72*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinL2Pool2d:
73*89c4ff92SAndroid Build Coastguard Worker             poolingAlgorithm = armnn::PoolingAlgorithm::L2;
74*89c4ff92SAndroid Build Coastguard Worker             break;
75*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinMaxPool2d:
76*89c4ff92SAndroid Build Coastguard Worker             poolingAlgorithm = armnn::PoolingAlgorithm::Max;
77*89c4ff92SAndroid Build Coastguard Worker             break;
78*89c4ff92SAndroid Build Coastguard Worker         default:
79*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
80*89c4ff92SAndroid Build Coastguard Worker     }
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker     armnn::Pooling2dDescriptor descriptor;
83*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PoolType = poolingAlgorithm;
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PoolWidth = tfLiteNodeParameters->filter_width;
86*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PoolHeight = tfLiteNodeParameters->filter_height;
87*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX = tfLiteNodeParameters->stride_width;
88*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY = tfLiteNodeParameters->stride_height;
89*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout = armnn::DataLayout::NHWC;
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputHeight = inputTensorInfo.GetShape()[1];
92*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputWidth  = inputTensorInfo.GetShape()[2];
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputHeight, descriptor.m_PoolHeight, descriptor.m_StrideY, 1u,
95*89c4ff92SAndroid Build Coastguard Worker                 descriptor.m_PadTop, descriptor.m_PadBottom, tfLiteNodeParameters->padding);
96*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputWidth, descriptor.m_PoolWidth, descriptor.m_StrideX, 1u,
97*89c4ff92SAndroid Build Coastguard Worker                 descriptor.m_PadLeft, descriptor.m_PadRight, tfLiteNodeParameters->padding);
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
100*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
101*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
102*89c4ff92SAndroid Build Coastguard Worker     {
103*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("POOLING_2D",
104*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
105*89c4ff92SAndroid Build Coastguard Worker                                    IsPooling2dSupported,
106*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
107*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
108*89c4ff92SAndroid Build Coastguard Worker                                    setBackend,
109*89c4ff92SAndroid Build Coastguard Worker                                    inputTensorInfo,
110*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfo,
111*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
112*89c4ff92SAndroid Build Coastguard Worker     };
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker     if (!delegateData.m_Network)
115*89c4ff92SAndroid Build Coastguard Worker     {
116*89c4ff92SAndroid Build Coastguard Worker         validateFunc(outputTensorInfo, isSupported);
117*89c4ff92SAndroid Build Coastguard Worker         return isSupported ? kTfLiteOk : kTfLiteError;
118*89c4ff92SAndroid Build Coastguard Worker     }
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* poolingLayer = delegateData.m_Network->AddPooling2dLayer(descriptor);
121*89c4ff92SAndroid Build Coastguard Worker     poolingLayer->SetBackendId(setBackend);
122*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(poolingLayer != nullptr);
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker     armnn::IOutputSlot& outputSlot = poolingLayer->GetOutputSlot(0);
125*89c4ff92SAndroid Build Coastguard Worker     outputSlot.SetTensorInfo(outputTensorInfo);
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker     // try to connect the Constant Inputs if there are any
128*89c4ff92SAndroid Build Coastguard Worker     if(ProcessInputs(poolingLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
129*89c4ff92SAndroid Build Coastguard Worker     {
130*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
131*89c4ff92SAndroid Build Coastguard Worker     }
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker     if(Connect(poolingLayer, tfLiteNode, delegateData) != kTfLiteOk)
134*89c4ff92SAndroid Build Coastguard Worker     {
135*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
136*89c4ff92SAndroid Build Coastguard Worker     }
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker     // Check and create activation
139*89c4ff92SAndroid Build Coastguard Worker     return FusedActivation(tfLiteContext, tfLiteNode, activationType, poolingLayer, 0, delegateData);
140*89c4ff92SAndroid Build Coastguard Worker }
141*89c4ff92SAndroid Build Coastguard Worker 
VisitPooling3dOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,std::string customOperatorName)142*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitPooling3dOperator(DelegateData& delegateData,
143*89c4ff92SAndroid Build Coastguard Worker                                     TfLiteContext* tfLiteContext,
144*89c4ff92SAndroid Build Coastguard Worker                                     TfLiteNode* tfLiteNode,
145*89c4ff92SAndroid Build Coastguard Worker                                     int nodeIndex,
146*89c4ff92SAndroid Build Coastguard Worker                                     std::string customOperatorName)
147*89c4ff92SAndroid Build Coastguard Worker {
148*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
149*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
152*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
153*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamicTensor(tfLiteInputTensor))
154*89c4ff92SAndroid Build Coastguard Worker     {
155*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
156*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
157*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
158*89c4ff92SAndroid Build Coastguard Worker             customOperatorName.c_str(), nodeIndex);
159*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
160*89c4ff92SAndroid Build Coastguard Worker     }
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
163*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamicTensor(tfLiteOutputTensor))
164*89c4ff92SAndroid Build Coastguard Worker     {
165*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
166*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
167*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
168*89c4ff92SAndroid Build Coastguard Worker             customOperatorName.c_str(), nodeIndex);
169*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
170*89c4ff92SAndroid Build Coastguard Worker     }
171*89c4ff92SAndroid Build Coastguard Worker     // Set the input and output info
172*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
173*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
174*89c4ff92SAndroid Build Coastguard Worker 
175*89c4ff92SAndroid Build Coastguard Worker     // Custom Operators are defined by the name string associated to the operator. Use this to determine
176*89c4ff92SAndroid Build Coastguard Worker     // which pooling algorithm to create the armnn operator with. L2 Pooling3D is unsupported in TfLite.
177*89c4ff92SAndroid Build Coastguard Worker     armnn::PoolingAlgorithm poolingAlgorithm;
178*89c4ff92SAndroid Build Coastguard Worker     if (customOperatorName == "MaxPool3D")
179*89c4ff92SAndroid Build Coastguard Worker     {
180*89c4ff92SAndroid Build Coastguard Worker         poolingAlgorithm = armnn::PoolingAlgorithm::Max;
181*89c4ff92SAndroid Build Coastguard Worker     }
182*89c4ff92SAndroid Build Coastguard Worker     else if (customOperatorName == "AveragePool3D")
183*89c4ff92SAndroid Build Coastguard Worker     {
184*89c4ff92SAndroid Build Coastguard Worker         poolingAlgorithm = armnn::PoolingAlgorithm::Average;
185*89c4ff92SAndroid Build Coastguard Worker     }
186*89c4ff92SAndroid Build Coastguard Worker     else
187*89c4ff92SAndroid Build Coastguard Worker     {
188*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
189*89c4ff92SAndroid Build Coastguard Worker     }
190*89c4ff92SAndroid Build Coastguard Worker     // Create the armnn pool3d descriptor and set the algorithm parsed above.
191*89c4ff92SAndroid Build Coastguard Worker     armnn::Pooling3dDescriptor descriptor;
192*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PoolType = poolingAlgorithm;
193*89c4ff92SAndroid Build Coastguard Worker 
194*89c4ff92SAndroid Build Coastguard Worker     // custom_initial_data and custom_initial_data_size are void* variables defined in the tflite registration
195*89c4ff92SAndroid Build Coastguard Worker     // used to access the custom option buffer for the operator.
196*89c4ff92SAndroid Build Coastguard Worker     auto custom_data = tfLiteNode->custom_initial_data;
197*89c4ff92SAndroid Build Coastguard Worker     auto custom_data_size = tfLiteNode->custom_initial_data_size;
198*89c4ff92SAndroid Build Coastguard Worker     // Reinterpret the void* to a byte buffer to access the options data in the flexbuffers map.
199*89c4ff92SAndroid Build Coastguard Worker     const flexbuffers::Map& m = flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(custom_data),
200*89c4ff92SAndroid Build Coastguard Worker                                                      custom_data_size).AsMap();
201*89c4ff92SAndroid Build Coastguard Worker     // poolDims is a vector of [ 1, Depth, Height, Width, 1 ]
202*89c4ff92SAndroid Build Coastguard Worker     const auto poolDims = m["ksize"].AsTypedVector();
203*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PoolWidth = poolDims[3].AsInt32();
204*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PoolHeight = poolDims[2].AsInt32();
205*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_PoolDepth = poolDims[1].AsInt32();
206*89c4ff92SAndroid Build Coastguard Worker 
207*89c4ff92SAndroid Build Coastguard Worker     // strideDimes is a vector of [ 1, Z, Y, X, 1]
208*89c4ff92SAndroid Build Coastguard Worker     const auto strideDims = m["strides"].AsTypedVector();
209*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideX = strideDims[3].AsInt32();
210*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideY = strideDims[2].AsInt32();
211*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_StrideZ = strideDims[1].AsInt32();
212*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout = armnn::DataLayout::NDHWC;
213*89c4ff92SAndroid Build Coastguard Worker 
214*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputDepth = inputTensorInfo.GetShape()[1];
215*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputHeight = inputTensorInfo.GetShape()[2];
216*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputWidth = inputTensorInfo.GetShape()[3];
217*89c4ff92SAndroid Build Coastguard Worker 
218*89c4ff92SAndroid Build Coastguard Worker     // CalcPadding expects a TfLitePadding type. Parse flexbuffers to extract padding string and create TfLitePadding.
219*89c4ff92SAndroid Build Coastguard Worker     std::string paddingStr = m["padding"].AsString().str();
220*89c4ff92SAndroid Build Coastguard Worker     TfLitePadding padding;
221*89c4ff92SAndroid Build Coastguard Worker     if (paddingStr == "VALID")
222*89c4ff92SAndroid Build Coastguard Worker     {
223*89c4ff92SAndroid Build Coastguard Worker         padding = kTfLitePaddingValid;
224*89c4ff92SAndroid Build Coastguard Worker     }
225*89c4ff92SAndroid Build Coastguard Worker     else if (paddingStr == "SAME")
226*89c4ff92SAndroid Build Coastguard Worker     {
227*89c4ff92SAndroid Build Coastguard Worker         padding = kTfLitePaddingSame;
228*89c4ff92SAndroid Build Coastguard Worker     }
229*89c4ff92SAndroid Build Coastguard Worker     else
230*89c4ff92SAndroid Build Coastguard Worker     {
231*89c4ff92SAndroid Build Coastguard Worker         padding = kTfLitePaddingUnknown;
232*89c4ff92SAndroid Build Coastguard Worker     }
233*89c4ff92SAndroid Build Coastguard Worker     // Calculates padding for each pooling dimension separately
234*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputHeight, descriptor.m_PoolHeight, descriptor.m_StrideY, 1u,
235*89c4ff92SAndroid Build Coastguard Worker                 descriptor.m_PadTop, descriptor.m_PadBottom, padding);
236*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputWidth, descriptor.m_PoolWidth, descriptor.m_StrideX, 1u,
237*89c4ff92SAndroid Build Coastguard Worker                 descriptor.m_PadLeft, descriptor.m_PadRight, padding);
238*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputDepth, descriptor.m_PoolDepth, descriptor.m_StrideZ, 1u,
239*89c4ff92SAndroid Build Coastguard Worker                 descriptor.m_PadFront, descriptor.m_PadBack, padding);
240*89c4ff92SAndroid Build Coastguard Worker 
241*89c4ff92SAndroid Build Coastguard Worker 
242*89c4ff92SAndroid Build Coastguard Worker     // Check activation by parsing the string from the flexbuffer map
243*89c4ff92SAndroid Build Coastguard Worker     std::string activationTypeStr = m["activation"].AsString().str();
244*89c4ff92SAndroid Build Coastguard Worker     TfLiteFusedActivation activationType = kTfLiteActNone;
245*89c4ff92SAndroid Build Coastguard Worker 
246*89c4ff92SAndroid Build Coastguard Worker     if (activationTypeStr == "kTfLiteActRelu")
247*89c4ff92SAndroid Build Coastguard Worker     {
248*89c4ff92SAndroid Build Coastguard Worker         activationType = kTfLiteActRelu;
249*89c4ff92SAndroid Build Coastguard Worker     }
250*89c4ff92SAndroid Build Coastguard Worker     else if (activationTypeStr == "kTfLiteActReluN1To1")
251*89c4ff92SAndroid Build Coastguard Worker     {
252*89c4ff92SAndroid Build Coastguard Worker         activationType = kTfLiteActReluN1To1;
253*89c4ff92SAndroid Build Coastguard Worker     }
254*89c4ff92SAndroid Build Coastguard Worker     else if (activationTypeStr == "kTfLiteActRelu6")
255*89c4ff92SAndroid Build Coastguard Worker     {
256*89c4ff92SAndroid Build Coastguard Worker         activationType = kTfLiteActRelu6;
257*89c4ff92SAndroid Build Coastguard Worker     }
258*89c4ff92SAndroid Build Coastguard Worker     else if (activationTypeStr == "kTfLiteActTanh")
259*89c4ff92SAndroid Build Coastguard Worker     {
260*89c4ff92SAndroid Build Coastguard Worker         activationType = kTfLiteActTanh;
261*89c4ff92SAndroid Build Coastguard Worker     }
262*89c4ff92SAndroid Build Coastguard Worker     else if (activationTypeStr == "kTfLiteActSignBit")
263*89c4ff92SAndroid Build Coastguard Worker     {
264*89c4ff92SAndroid Build Coastguard Worker         activationType = kTfLiteActSignBit;
265*89c4ff92SAndroid Build Coastguard Worker     }
266*89c4ff92SAndroid Build Coastguard Worker     else if (activationTypeStr == "kTfLiteActSigmoid")
267*89c4ff92SAndroid Build Coastguard Worker     {
268*89c4ff92SAndroid Build Coastguard Worker         activationType = kTfLiteActSigmoid;
269*89c4ff92SAndroid Build Coastguard Worker     }
270*89c4ff92SAndroid Build Coastguard Worker     else
271*89c4ff92SAndroid Build Coastguard Worker     {
272*89c4ff92SAndroid Build Coastguard Worker         activationType = kTfLiteActNone;
273*89c4ff92SAndroid Build Coastguard Worker     }
274*89c4ff92SAndroid Build Coastguard Worker 
275*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
276*89c4ff92SAndroid Build Coastguard Worker                                                                     outputTensorInfo, activationType);
277*89c4ff92SAndroid Build Coastguard Worker     if(activationStatus != kTfLiteOk)
278*89c4ff92SAndroid Build Coastguard Worker     {
279*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
280*89c4ff92SAndroid Build Coastguard Worker     }
281*89c4ff92SAndroid Build Coastguard Worker 
282*89c4ff92SAndroid Build Coastguard Worker 
283*89c4ff92SAndroid Build Coastguard Worker     // Validate the output info.
284*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
285*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
286*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported) {
287*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("POOLING_3D",
288*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
289*89c4ff92SAndroid Build Coastguard Worker                                    IsPooling3dSupported,
290*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
291*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
292*89c4ff92SAndroid Build Coastguard Worker                                    setBackend,
293*89c4ff92SAndroid Build Coastguard Worker                                    inputTensorInfo,
294*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfo,
295*89c4ff92SAndroid Build Coastguard Worker                                    descriptor);
296*89c4ff92SAndroid Build Coastguard Worker     };
297*89c4ff92SAndroid Build Coastguard Worker 
298*89c4ff92SAndroid Build Coastguard Worker     if (!delegateData.m_Network)
299*89c4ff92SAndroid Build Coastguard Worker     {
300*89c4ff92SAndroid Build Coastguard Worker         validateFunc(outputTensorInfo, isSupported);
301*89c4ff92SAndroid Build Coastguard Worker         return isSupported ? kTfLiteOk : kTfLiteError;
302*89c4ff92SAndroid Build Coastguard Worker     }
303*89c4ff92SAndroid Build Coastguard Worker 
304*89c4ff92SAndroid Build Coastguard Worker     // Create the Layer
305*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* poolingLayer = delegateData.m_Network->AddPooling3dLayer(descriptor);
306*89c4ff92SAndroid Build Coastguard Worker     poolingLayer->SetBackendId(setBackend);
307*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(poolingLayer != nullptr);
308*89c4ff92SAndroid Build Coastguard Worker 
309*89c4ff92SAndroid Build Coastguard Worker     // Create and set output slots
310*89c4ff92SAndroid Build Coastguard Worker     armnn::IOutputSlot& outputSlot = poolingLayer->GetOutputSlot(0);
311*89c4ff92SAndroid Build Coastguard Worker     outputSlot.SetTensorInfo(outputTensorInfo);
312*89c4ff92SAndroid Build Coastguard Worker 
313*89c4ff92SAndroid Build Coastguard Worker     // try to connect the Constant Inputs if there are any
314*89c4ff92SAndroid Build Coastguard Worker     if(ProcessInputs(poolingLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
315*89c4ff92SAndroid Build Coastguard Worker     {
316*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
317*89c4ff92SAndroid Build Coastguard Worker     }
318*89c4ff92SAndroid Build Coastguard Worker 
319*89c4ff92SAndroid Build Coastguard Worker     if(Connect(poolingLayer, tfLiteNode, delegateData) != kTfLiteOk)
320*89c4ff92SAndroid Build Coastguard Worker     {
321*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
322*89c4ff92SAndroid Build Coastguard Worker     }
323*89c4ff92SAndroid Build Coastguard Worker 
324*89c4ff92SAndroid Build Coastguard Worker     return FusedActivation(tfLiteContext, tfLiteNode, activationType, poolingLayer, 0, delegateData);
325*89c4ff92SAndroid Build Coastguard Worker }
326*89c4ff92SAndroid Build Coastguard Worker 
327*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
328