xref: /aosp_15_r20/external/armnn/delegate/classic/src/Control.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/internal/tensor_ctypes.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
17*89c4ff92SAndroid Build Coastguard Worker #include <iterator>
18*89c4ff92SAndroid Build Coastguard Worker #include <string>
19*89c4ff92SAndroid Build Coastguard Worker #include <vector>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker 
SetupConcatViewOrigin(const armnn::TensorInfo & inputTensorInfo,armnn::OriginsDescriptor & concatDescriptor,const unsigned int concatAxis,unsigned int inputIndex,unsigned int & mergeDimOrigin)24*89c4ff92SAndroid Build Coastguard Worker void SetupConcatViewOrigin(const armnn::TensorInfo& inputTensorInfo,
25*89c4ff92SAndroid Build Coastguard Worker                            armnn::OriginsDescriptor& concatDescriptor,
26*89c4ff92SAndroid Build Coastguard Worker                            const unsigned int concatAxis,
27*89c4ff92SAndroid Build Coastguard Worker                            unsigned int inputIndex,
28*89c4ff92SAndroid Build Coastguard Worker                            unsigned int& mergeDimOrigin)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker     const uint32_t inputRank = concatDescriptor.GetNumDimensions();
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker     // double check dimensions of the tensors
33*89c4ff92SAndroid Build Coastguard Worker     if (inputTensorInfo.GetNumDimensions() != inputRank)
34*89c4ff92SAndroid Build Coastguard Worker     {
35*89c4ff92SAndroid Build Coastguard Worker         throw armnn::ParseException("The number of dimensions for input tensors "
36*89c4ff92SAndroid Build Coastguard Worker                                     "of the concatenation operator should be: " + std::to_string(inputRank));
37*89c4ff92SAndroid Build Coastguard Worker     }
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int j = 0; j < concatAxis; ++j)
40*89c4ff92SAndroid Build Coastguard Worker     {
41*89c4ff92SAndroid Build Coastguard Worker         concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
42*89c4ff92SAndroid Build Coastguard Worker     }
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
45*89c4ff92SAndroid Build Coastguard Worker     mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
48*89c4ff92SAndroid Build Coastguard Worker     {
49*89c4ff92SAndroid Build Coastguard Worker         concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
50*89c4ff92SAndroid Build Coastguard Worker     }
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker 
VisitConcatenationOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteConcatOperatorCode)53*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitConcatenationOperator(DelegateData& delegateData,
54*89c4ff92SAndroid Build Coastguard Worker                                         TfLiteContext* tfLiteContext,
55*89c4ff92SAndroid Build Coastguard Worker                                         TfLiteNode* tfLiteNode,
56*89c4ff92SAndroid Build Coastguard Worker                                         int nodeIndex,
57*89c4ff92SAndroid Build Coastguard Worker                                         int32_t tfLiteConcatOperatorCode)
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker     unsigned int numInputs = tfLiteNode->inputs->size;
60*89c4ff92SAndroid Build Coastguard Worker     if (numInputs < 2)
61*89c4ff92SAndroid Build Coastguard Worker     {
62*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
63*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
64*89c4ff92SAndroid Build Coastguard Worker             2, numInputs, nodeIndex);
65*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
66*89c4ff92SAndroid Build Coastguard Worker     }
67*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorInfo> inputTensorInfos;
72*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < numInputs; ++i)
73*89c4ff92SAndroid Build Coastguard Worker     {
74*89c4ff92SAndroid Build Coastguard Worker         const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[i]];
75*89c4ff92SAndroid Build Coastguard Worker         if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteConcatOperatorCode, nodeIndex))
76*89c4ff92SAndroid Build Coastguard Worker         {
77*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
78*89c4ff92SAndroid Build Coastguard Worker         }
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
81*89c4ff92SAndroid Build Coastguard Worker         inputTensorInfos.emplace_back(inputTensorInfo);
82*89c4ff92SAndroid Build Coastguard Worker     }
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker     // Convert input tensors to const armnn::TensorInfo* type for FORWARD_LAYER_SUPPORT_FUNC.
85*89c4ff92SAndroid Build Coastguard Worker     std::vector<const armnn::TensorInfo*> inputConstTensorInfos;
86*89c4ff92SAndroid Build Coastguard Worker     std::transform(inputTensorInfos.begin(),
87*89c4ff92SAndroid Build Coastguard Worker                    inputTensorInfos.end(),
88*89c4ff92SAndroid Build Coastguard Worker                    std::back_inserter(inputConstTensorInfos),
89*89c4ff92SAndroid Build Coastguard Worker                    [](armnn::TensorInfo& t)->const armnn::TensorInfo*{ return &t; });
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
92*89c4ff92SAndroid Build Coastguard Worker     if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteConcatOperatorCode, nodeIndex))
93*89c4ff92SAndroid Build Coastguard Worker     {
94*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
95*89c4ff92SAndroid Build Coastguard Worker     }
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker     // Setup OriginsDescriptor, axis and view origin
98*89c4ff92SAndroid Build Coastguard Worker     unsigned int numConcatView = static_cast<unsigned int>(numInputs);
99*89c4ff92SAndroid Build Coastguard Worker     uint32_t inputRank = tfLiteTensors[tfLiteNode->inputs->data[0]].dims->size;
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker     auto* concatenationParameters = reinterpret_cast<TfLiteConcatenationParams*>(tfLiteNode->builtin_data);
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     if(!concatenationParameters)
104*89c4ff92SAndroid Build Coastguard Worker     {
105*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception(&"TfLiteArmnnDelegate: Concat parameters are null in: " [ nodeIndex]);
106*89c4ff92SAndroid Build Coastguard Worker     }
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker     const unsigned int concatDimInput = static_cast<unsigned int>(
109*89c4ff92SAndroid Build Coastguard Worker             (static_cast<int>(inputRank) + concatenationParameters->axis) % static_cast<int>(inputRank));
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker     armnn::OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), inputRank);
112*89c4ff92SAndroid Build Coastguard Worker     concatDescriptor.SetConcatAxis(concatDimInput);
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker     unsigned int mergeDimOrigin = 0;
115*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
116*89c4ff92SAndroid Build Coastguard Worker     {
117*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo inputTensorInfo = GetTensorInfoForTfLiteTensor(
118*89c4ff92SAndroid Build Coastguard Worker                 tfLiteTensors[tfLiteNode->inputs->data[viewIndex]]);
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker         // Sets up concatDescriptor view origin
121*89c4ff92SAndroid Build Coastguard Worker         SetupConcatViewOrigin(inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin);
122*89c4ff92SAndroid Build Coastguard Worker     }
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker     // Verify we support the fused activation before attempting to create a layer
127*89c4ff92SAndroid Build Coastguard Worker     TfLiteFusedActivation activationType = concatenationParameters->activation;
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
130*89c4ff92SAndroid Build Coastguard Worker                                                                     outputTensorInfo, activationType);
131*89c4ff92SAndroid Build Coastguard Worker     if(activationStatus != kTfLiteOk)
132*89c4ff92SAndroid Build Coastguard Worker     {
133*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
134*89c4ff92SAndroid Build Coastguard Worker     }
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker     // Check if supported
137*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
138*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
139*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
140*89c4ff92SAndroid Build Coastguard Worker     {
141*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("CONCATENATION",
142*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
143*89c4ff92SAndroid Build Coastguard Worker                                    IsConcatSupported,
144*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
145*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
146*89c4ff92SAndroid Build Coastguard Worker                                    setBackend,
147*89c4ff92SAndroid Build Coastguard Worker                                    inputConstTensorInfos,
148*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfo,
149*89c4ff92SAndroid Build Coastguard Worker                                    concatDescriptor);
150*89c4ff92SAndroid Build Coastguard Worker     };
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker     if (!delegateData.m_Network)
153*89c4ff92SAndroid Build Coastguard Worker     {
154*89c4ff92SAndroid Build Coastguard Worker         validateFunc(outputTensorInfo, isSupported);
155*89c4ff92SAndroid Build Coastguard Worker         return isSupported ? kTfLiteOk : kTfLiteError;
156*89c4ff92SAndroid Build Coastguard Worker     }
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker     // Setup layer and connect.
159*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* concatenationLayer = delegateData.m_Network->AddConcatLayer(concatDescriptor);
160*89c4ff92SAndroid Build Coastguard Worker     concatenationLayer->SetBackendId(setBackend);
161*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(concatenationLayer != nullptr);
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker     // Connect the Constant Inputs
164*89c4ff92SAndroid Build Coastguard Worker     auto inputsTensorsProcess = ProcessInputs(concatenationLayer,
165*89c4ff92SAndroid Build Coastguard Worker                                               delegateData,
166*89c4ff92SAndroid Build Coastguard Worker                                               tfLiteContext,
167*89c4ff92SAndroid Build Coastguard Worker                                               tfLiteNode);
168*89c4ff92SAndroid Build Coastguard Worker     if (inputsTensorsProcess == kTfLiteError)
169*89c4ff92SAndroid Build Coastguard Worker     {
170*89c4ff92SAndroid Build Coastguard Worker         return inputsTensorsProcess;
171*89c4ff92SAndroid Build Coastguard Worker     }
172*89c4ff92SAndroid Build Coastguard Worker 
173*89c4ff92SAndroid Build Coastguard Worker     armnn::IOutputSlot& outputSlot = concatenationLayer->GetOutputSlot(0);
174*89c4ff92SAndroid Build Coastguard Worker     outputSlot.SetTensorInfo(outputTensorInfo);
175*89c4ff92SAndroid Build Coastguard Worker     if(Connect(concatenationLayer, tfLiteNode, delegateData) != kTfLiteOk)
176*89c4ff92SAndroid Build Coastguard Worker     {
177*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
178*89c4ff92SAndroid Build Coastguard Worker     }
179*89c4ff92SAndroid Build Coastguard Worker 
180*89c4ff92SAndroid Build Coastguard Worker     if (activationType == kTfLiteActNone)
181*89c4ff92SAndroid Build Coastguard Worker     {
182*89c4ff92SAndroid Build Coastguard Worker         // No Activation
183*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteOk;
184*89c4ff92SAndroid Build Coastguard Worker     }
185*89c4ff92SAndroid Build Coastguard Worker 
186*89c4ff92SAndroid Build Coastguard Worker     // Check and Create activation
187*89c4ff92SAndroid Build Coastguard Worker     return FusedActivation(tfLiteContext, tfLiteNode, activationType, concatenationLayer, 0, delegateData);
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker 
VisitMeanOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteMeanOperatorCode)190*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitMeanOperator(DelegateData& delegateData,
191*89c4ff92SAndroid Build Coastguard Worker                                TfLiteContext* tfLiteContext,
192*89c4ff92SAndroid Build Coastguard Worker                                TfLiteNode* tfLiteNode,
193*89c4ff92SAndroid Build Coastguard Worker                                int nodeIndex,
194*89c4ff92SAndroid Build Coastguard Worker                                int32_t tfLiteMeanOperatorCode)
195*89c4ff92SAndroid Build Coastguard Worker {
196*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
197*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
198*89c4ff92SAndroid Build Coastguard Worker 
199*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
200*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
201*89c4ff92SAndroid Build Coastguard Worker     if(!IsValid(&tfLiteInputTensor))
202*89c4ff92SAndroid Build Coastguard Worker     {
203*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
204*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
205*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Invalid input tensor in operator #%d node #%d: ",
206*89c4ff92SAndroid Build Coastguard Worker             tfLiteMeanOperatorCode, nodeIndex);
207*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
208*89c4ff92SAndroid Build Coastguard Worker     }
209*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamicTensor(tfLiteInputTensor))
210*89c4ff92SAndroid Build Coastguard Worker     {
211*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
212*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
213*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
214*89c4ff92SAndroid Build Coastguard Worker             tfLiteMeanOperatorCode, nodeIndex);
215*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
216*89c4ff92SAndroid Build Coastguard Worker     }
217*89c4ff92SAndroid Build Coastguard Worker 
218*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
219*89c4ff92SAndroid Build Coastguard Worker     if(!IsValid(&tfLiteAxisTensor))
220*89c4ff92SAndroid Build Coastguard Worker     {
221*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
222*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
223*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Invalid axis tensor in operator #%d node #%d: ",
224*89c4ff92SAndroid Build Coastguard Worker             tfLiteMeanOperatorCode, nodeIndex);
225*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
226*89c4ff92SAndroid Build Coastguard Worker     }
227*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamicTensor(tfLiteAxisTensor))
228*89c4ff92SAndroid Build Coastguard Worker     {
229*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
230*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
231*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Dynamic axis tensors are not supported in operator #%d node #%d: ",
232*89c4ff92SAndroid Build Coastguard Worker             tfLiteMeanOperatorCode, nodeIndex);
233*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
234*89c4ff92SAndroid Build Coastguard Worker     }
235*89c4ff92SAndroid Build Coastguard Worker 
236*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
237*89c4ff92SAndroid Build Coastguard Worker     if(!IsValid(&tfLiteOutputTensor))
238*89c4ff92SAndroid Build Coastguard Worker     {
239*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
240*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
241*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Invalid output tensor in operator #%d node #%d: ",
242*89c4ff92SAndroid Build Coastguard Worker             tfLiteAxisTensor, nodeIndex);
243*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
244*89c4ff92SAndroid Build Coastguard Worker     }
245*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamicTensor(tfLiteOutputTensor))
246*89c4ff92SAndroid Build Coastguard Worker     {
247*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
248*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
249*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
250*89c4ff92SAndroid Build Coastguard Worker             tfLiteMeanOperatorCode, nodeIndex);
251*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
252*89c4ff92SAndroid Build Coastguard Worker     }
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& inputTensorInfo =  GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
255*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& axisTensorInfo =   GetTensorInfoForTfLiteTensor(tfLiteAxisTensor);
256*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
257*89c4ff92SAndroid Build Coastguard Worker 
258*89c4ff92SAndroid Build Coastguard Worker     auto* axisTensorData = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
259*89c4ff92SAndroid Build Coastguard Worker 
260*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> axis;
261*89c4ff92SAndroid Build Coastguard Worker     // Add axis data to vector to be converter to unsigned int and assigned to descriptor axis.
262*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < axisTensorInfo.GetNumElements(); ++i)
263*89c4ff92SAndroid Build Coastguard Worker     {
264*89c4ff92SAndroid Build Coastguard Worker         axis.emplace_back(axisTensorData[i]);
265*89c4ff92SAndroid Build Coastguard Worker     }
266*89c4ff92SAndroid Build Coastguard Worker 
267*89c4ff92SAndroid Build Coastguard Worker     // Convert the axis to unsigned int and remove duplicates.
268*89c4ff92SAndroid Build Coastguard Worker     unsigned int rank = inputTensorInfo.GetNumDimensions();
269*89c4ff92SAndroid Build Coastguard Worker     std::set<unsigned int> uniqueAxis;
270*89c4ff92SAndroid Build Coastguard Worker     std::transform(axis.begin(),
271*89c4ff92SAndroid Build Coastguard Worker                    axis.end(),
272*89c4ff92SAndroid Build Coastguard Worker                    std::inserter(uniqueAxis, uniqueAxis.begin()),
273*89c4ff92SAndroid Build Coastguard Worker                    [rank](int i)->unsigned int{ return (i + rank) % rank; });
274*89c4ff92SAndroid Build Coastguard Worker 
275*89c4ff92SAndroid Build Coastguard Worker     // Setup MeanDescriptor and assign axis and keepDims
276*89c4ff92SAndroid Build Coastguard Worker     armnn::MeanDescriptor desc;
277*89c4ff92SAndroid Build Coastguard Worker     desc.m_Axis.assign(uniqueAxis.begin(), uniqueAxis.end());
278*89c4ff92SAndroid Build Coastguard Worker     desc.m_KeepDims = inputTensorInfo.GetNumDimensions() == outputTensorInfo.GetNumDimensions() ? true : false;
279*89c4ff92SAndroid Build Coastguard Worker 
280*89c4ff92SAndroid Build Coastguard Worker     // Check if supported
281*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
282*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
283*89c4ff92SAndroid Build Coastguard Worker     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
284*89c4ff92SAndroid Build Coastguard Worker     {
285*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("MEAN",
286*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
287*89c4ff92SAndroid Build Coastguard Worker                                    IsMeanSupported,
288*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
289*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
290*89c4ff92SAndroid Build Coastguard Worker                                    setBackend,
291*89c4ff92SAndroid Build Coastguard Worker                                    inputTensorInfo,
292*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfo,
293*89c4ff92SAndroid Build Coastguard Worker                                    desc);
294*89c4ff92SAndroid Build Coastguard Worker     };
295*89c4ff92SAndroid Build Coastguard Worker 
296*89c4ff92SAndroid Build Coastguard Worker     if (!delegateData.m_Network)
297*89c4ff92SAndroid Build Coastguard Worker     {
298*89c4ff92SAndroid Build Coastguard Worker         validateFunc(outputTensorInfo, isSupported);
299*89c4ff92SAndroid Build Coastguard Worker         return isSupported ? kTfLiteOk : kTfLiteError;
300*89c4ff92SAndroid Build Coastguard Worker     }
301*89c4ff92SAndroid Build Coastguard Worker 
302*89c4ff92SAndroid Build Coastguard Worker     // Setup layer and connect.
303*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* meanLayer = delegateData.m_Network->AddMeanLayer(desc);
304*89c4ff92SAndroid Build Coastguard Worker     meanLayer->SetBackendId(setBackend);
305*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(meanLayer != nullptr);
306*89c4ff92SAndroid Build Coastguard Worker 
307*89c4ff92SAndroid Build Coastguard Worker     armnn::IOutputSlot& outputSlot = meanLayer->GetOutputSlot(0);
308*89c4ff92SAndroid Build Coastguard Worker     outputSlot.SetTensorInfo(outputTensorInfo);
309*89c4ff92SAndroid Build Coastguard Worker 
310*89c4ff92SAndroid Build Coastguard Worker     // try to connect the Constant Inputs if there are any
311*89c4ff92SAndroid Build Coastguard Worker     if(ProcessInputs(meanLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
312*89c4ff92SAndroid Build Coastguard Worker     {
313*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
314*89c4ff92SAndroid Build Coastguard Worker     }
315*89c4ff92SAndroid Build Coastguard Worker 
316*89c4ff92SAndroid Build Coastguard Worker     return Connect(meanLayer, tfLiteNode, delegateData);
317*89c4ff92SAndroid Build Coastguard Worker }
318*89c4ff92SAndroid Build Coastguard Worker 
VisitControlOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)319*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitControlOperator(DelegateData& delegateData,
320*89c4ff92SAndroid Build Coastguard Worker                                   TfLiteContext* tfLiteContext,
321*89c4ff92SAndroid Build Coastguard Worker                                   TfLiteNode* tfLiteNode,
322*89c4ff92SAndroid Build Coastguard Worker                                   int nodeIndex,
323*89c4ff92SAndroid Build Coastguard Worker                                   int32_t operatorCode)
324*89c4ff92SAndroid Build Coastguard Worker {
325*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(delegateData,
326*89c4ff92SAndroid Build Coastguard Worker                         tfLiteContext,
327*89c4ff92SAndroid Build Coastguard Worker                         tfLiteNode,
328*89c4ff92SAndroid Build Coastguard Worker                         nodeIndex,
329*89c4ff92SAndroid Build Coastguard Worker                         operatorCode);
330*89c4ff92SAndroid Build Coastguard Worker 
331*89c4ff92SAndroid Build Coastguard Worker     switch(operatorCode)
332*89c4ff92SAndroid Build Coastguard Worker     {
333*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinConcatenation:
334*89c4ff92SAndroid Build Coastguard Worker             return VisitConcatenationOperator(delegateData, tfLiteContext, tfLiteNode, nodeIndex, operatorCode);
335*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinMean:
336*89c4ff92SAndroid Build Coastguard Worker             return VisitMeanOperator(delegateData, tfLiteContext, tfLiteNode, nodeIndex, operatorCode);
337*89c4ff92SAndroid Build Coastguard Worker         default:
338*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
339*89c4ff92SAndroid Build Coastguard Worker     }
340*89c4ff92SAndroid Build Coastguard Worker }
341*89c4ff92SAndroid Build Coastguard Worker 
342*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
343