1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/internal/tensor_ctypes.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
17*89c4ff92SAndroid Build Coastguard Worker #include <iterator>
18*89c4ff92SAndroid Build Coastguard Worker #include <string>
19*89c4ff92SAndroid Build Coastguard Worker #include <vector>
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker
SetupConcatViewOrigin(const armnn::TensorInfo & inputTensorInfo,armnn::OriginsDescriptor & concatDescriptor,const unsigned int concatAxis,unsigned int inputIndex,unsigned int & mergeDimOrigin)24*89c4ff92SAndroid Build Coastguard Worker void SetupConcatViewOrigin(const armnn::TensorInfo& inputTensorInfo,
25*89c4ff92SAndroid Build Coastguard Worker armnn::OriginsDescriptor& concatDescriptor,
26*89c4ff92SAndroid Build Coastguard Worker const unsigned int concatAxis,
27*89c4ff92SAndroid Build Coastguard Worker unsigned int inputIndex,
28*89c4ff92SAndroid Build Coastguard Worker unsigned int& mergeDimOrigin)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker const uint32_t inputRank = concatDescriptor.GetNumDimensions();
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker // double check dimensions of the tensors
33*89c4ff92SAndroid Build Coastguard Worker if (inputTensorInfo.GetNumDimensions() != inputRank)
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker throw armnn::ParseException("The number of dimensions for input tensors "
36*89c4ff92SAndroid Build Coastguard Worker "of the concatenation operator should be: " + std::to_string(inputRank));
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker
39*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = 0; j < concatAxis; ++j)
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
42*89c4ff92SAndroid Build Coastguard Worker }
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
45*89c4ff92SAndroid Build Coastguard Worker mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker
VisitConcatenationOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteConcatOperatorCode)53*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitConcatenationOperator(DelegateData& delegateData,
54*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
55*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
56*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
57*89c4ff92SAndroid Build Coastguard Worker int32_t tfLiteConcatOperatorCode)
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker unsigned int numInputs = tfLiteNode->inputs->size;
60*89c4ff92SAndroid Build Coastguard Worker if (numInputs < 2)
61*89c4ff92SAndroid Build Coastguard Worker {
62*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
63*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, "TfLiteArmnnDelegate: Minimum number of inputs (%d != %d) in node #%d",
64*89c4ff92SAndroid Build Coastguard Worker 2, numInputs, nodeIndex);
65*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
66*89c4ff92SAndroid Build Coastguard Worker }
67*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
68*89c4ff92SAndroid Build Coastguard Worker
69*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorInfo> inputTensorInfos;
72*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numInputs; ++i)
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[i]];
75*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteConcatOperatorCode, nodeIndex))
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
78*89c4ff92SAndroid Build Coastguard Worker }
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
81*89c4ff92SAndroid Build Coastguard Worker inputTensorInfos.emplace_back(inputTensorInfo);
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker
84*89c4ff92SAndroid Build Coastguard Worker // Convert input tensors to const armnn::TensorInfo* type for FORWARD_LAYER_SUPPORT_FUNC.
85*89c4ff92SAndroid Build Coastguard Worker std::vector<const armnn::TensorInfo*> inputConstTensorInfos;
86*89c4ff92SAndroid Build Coastguard Worker std::transform(inputTensorInfos.begin(),
87*89c4ff92SAndroid Build Coastguard Worker inputTensorInfos.end(),
88*89c4ff92SAndroid Build Coastguard Worker std::back_inserter(inputConstTensorInfos),
89*89c4ff92SAndroid Build Coastguard Worker [](armnn::TensorInfo& t)->const armnn::TensorInfo*{ return &t; });
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
92*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteConcatOperatorCode, nodeIndex))
93*89c4ff92SAndroid Build Coastguard Worker {
94*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
95*89c4ff92SAndroid Build Coastguard Worker }
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker // Setup OriginsDescriptor, axis and view origin
98*89c4ff92SAndroid Build Coastguard Worker unsigned int numConcatView = static_cast<unsigned int>(numInputs);
99*89c4ff92SAndroid Build Coastguard Worker uint32_t inputRank = tfLiteTensors[tfLiteNode->inputs->data[0]].dims->size;
100*89c4ff92SAndroid Build Coastguard Worker
101*89c4ff92SAndroid Build Coastguard Worker auto* concatenationParameters = reinterpret_cast<TfLiteConcatenationParams*>(tfLiteNode->builtin_data);
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker if(!concatenationParameters)
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(&"TfLiteArmnnDelegate: Concat parameters are null in: " [ nodeIndex]);
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker
108*89c4ff92SAndroid Build Coastguard Worker const unsigned int concatDimInput = static_cast<unsigned int>(
109*89c4ff92SAndroid Build Coastguard Worker (static_cast<int>(inputRank) + concatenationParameters->axis) % static_cast<int>(inputRank));
110*89c4ff92SAndroid Build Coastguard Worker
111*89c4ff92SAndroid Build Coastguard Worker armnn::OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), inputRank);
112*89c4ff92SAndroid Build Coastguard Worker concatDescriptor.SetConcatAxis(concatDimInput);
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker unsigned int mergeDimOrigin = 0;
115*89c4ff92SAndroid Build Coastguard Worker for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = GetTensorInfoForTfLiteTensor(
118*89c4ff92SAndroid Build Coastguard Worker tfLiteTensors[tfLiteNode->inputs->data[viewIndex]]);
119*89c4ff92SAndroid Build Coastguard Worker
120*89c4ff92SAndroid Build Coastguard Worker // Sets up concatDescriptor view origin
121*89c4ff92SAndroid Build Coastguard Worker SetupConcatViewOrigin(inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin);
122*89c4ff92SAndroid Build Coastguard Worker }
123*89c4ff92SAndroid Build Coastguard Worker
124*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker // Verify we support the fused activation before attempting to create a layer
127*89c4ff92SAndroid Build Coastguard Worker TfLiteFusedActivation activationType = concatenationParameters->activation;
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus activationStatus = ValidateFusedActivationOperator(delegateData, tfLiteContext, outputTensorInfo,
130*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo, activationType);
131*89c4ff92SAndroid Build Coastguard Worker if(activationStatus != kTfLiteOk)
132*89c4ff92SAndroid Build Coastguard Worker {
133*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
134*89c4ff92SAndroid Build Coastguard Worker }
135*89c4ff92SAndroid Build Coastguard Worker
136*89c4ff92SAndroid Build Coastguard Worker // Check if supported
137*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
138*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
139*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("CONCATENATION",
142*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
143*89c4ff92SAndroid Build Coastguard Worker IsConcatSupported,
144*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
145*89c4ff92SAndroid Build Coastguard Worker isSupported,
146*89c4ff92SAndroid Build Coastguard Worker setBackend,
147*89c4ff92SAndroid Build Coastguard Worker inputConstTensorInfos,
148*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
149*89c4ff92SAndroid Build Coastguard Worker concatDescriptor);
150*89c4ff92SAndroid Build Coastguard Worker };
151*89c4ff92SAndroid Build Coastguard Worker
152*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
153*89c4ff92SAndroid Build Coastguard Worker {
154*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputTensorInfo, isSupported);
155*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
156*89c4ff92SAndroid Build Coastguard Worker }
157*89c4ff92SAndroid Build Coastguard Worker
158*89c4ff92SAndroid Build Coastguard Worker // Setup layer and connect.
159*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* concatenationLayer = delegateData.m_Network->AddConcatLayer(concatDescriptor);
160*89c4ff92SAndroid Build Coastguard Worker concatenationLayer->SetBackendId(setBackend);
161*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(concatenationLayer != nullptr);
162*89c4ff92SAndroid Build Coastguard Worker
163*89c4ff92SAndroid Build Coastguard Worker // Connect the Constant Inputs
164*89c4ff92SAndroid Build Coastguard Worker auto inputsTensorsProcess = ProcessInputs(concatenationLayer,
165*89c4ff92SAndroid Build Coastguard Worker delegateData,
166*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
167*89c4ff92SAndroid Build Coastguard Worker tfLiteNode);
168*89c4ff92SAndroid Build Coastguard Worker if (inputsTensorsProcess == kTfLiteError)
169*89c4ff92SAndroid Build Coastguard Worker {
170*89c4ff92SAndroid Build Coastguard Worker return inputsTensorsProcess;
171*89c4ff92SAndroid Build Coastguard Worker }
172*89c4ff92SAndroid Build Coastguard Worker
173*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = concatenationLayer->GetOutputSlot(0);
174*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(outputTensorInfo);
175*89c4ff92SAndroid Build Coastguard Worker if(Connect(concatenationLayer, tfLiteNode, delegateData) != kTfLiteOk)
176*89c4ff92SAndroid Build Coastguard Worker {
177*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker
180*89c4ff92SAndroid Build Coastguard Worker if (activationType == kTfLiteActNone)
181*89c4ff92SAndroid Build Coastguard Worker {
182*89c4ff92SAndroid Build Coastguard Worker // No Activation
183*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
184*89c4ff92SAndroid Build Coastguard Worker }
185*89c4ff92SAndroid Build Coastguard Worker
186*89c4ff92SAndroid Build Coastguard Worker // Check and Create activation
187*89c4ff92SAndroid Build Coastguard Worker return FusedActivation(tfLiteContext, tfLiteNode, activationType, concatenationLayer, 0, delegateData);
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker
VisitMeanOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteMeanOperatorCode)190*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitMeanOperator(DelegateData& delegateData,
191*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
192*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
193*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
194*89c4ff92SAndroid Build Coastguard Worker int32_t tfLiteMeanOperatorCode)
195*89c4ff92SAndroid Build Coastguard Worker {
196*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
197*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
198*89c4ff92SAndroid Build Coastguard Worker
199*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
200*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
201*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteInputTensor))
202*89c4ff92SAndroid Build Coastguard Worker {
203*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
204*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
205*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid input tensor in operator #%d node #%d: ",
206*89c4ff92SAndroid Build Coastguard Worker tfLiteMeanOperatorCode, nodeIndex);
207*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
208*89c4ff92SAndroid Build Coastguard Worker }
209*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteInputTensor))
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
212*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
213*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
214*89c4ff92SAndroid Build Coastguard Worker tfLiteMeanOperatorCode, nodeIndex);
215*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
216*89c4ff92SAndroid Build Coastguard Worker }
217*89c4ff92SAndroid Build Coastguard Worker
218*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
219*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteAxisTensor))
220*89c4ff92SAndroid Build Coastguard Worker {
221*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
222*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
223*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid axis tensor in operator #%d node #%d: ",
224*89c4ff92SAndroid Build Coastguard Worker tfLiteMeanOperatorCode, nodeIndex);
225*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
226*89c4ff92SAndroid Build Coastguard Worker }
227*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteAxisTensor))
228*89c4ff92SAndroid Build Coastguard Worker {
229*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
230*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
231*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic axis tensors are not supported in operator #%d node #%d: ",
232*89c4ff92SAndroid Build Coastguard Worker tfLiteMeanOperatorCode, nodeIndex);
233*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
234*89c4ff92SAndroid Build Coastguard Worker }
235*89c4ff92SAndroid Build Coastguard Worker
236*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
237*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(&tfLiteOutputTensor))
238*89c4ff92SAndroid Build Coastguard Worker {
239*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
240*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
241*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid output tensor in operator #%d node #%d: ",
242*89c4ff92SAndroid Build Coastguard Worker tfLiteAxisTensor, nodeIndex);
243*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
244*89c4ff92SAndroid Build Coastguard Worker }
245*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteOutputTensor))
246*89c4ff92SAndroid Build Coastguard Worker {
247*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
248*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
249*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
250*89c4ff92SAndroid Build Coastguard Worker tfLiteMeanOperatorCode, nodeIndex);
251*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
252*89c4ff92SAndroid Build Coastguard Worker }
253*89c4ff92SAndroid Build Coastguard Worker
254*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
255*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& axisTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteAxisTensor);
256*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
257*89c4ff92SAndroid Build Coastguard Worker
258*89c4ff92SAndroid Build Coastguard Worker auto* axisTensorData = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
259*89c4ff92SAndroid Build Coastguard Worker
260*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> axis;
261*89c4ff92SAndroid Build Coastguard Worker // Add axis data to vector to be converter to unsigned int and assigned to descriptor axis.
262*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < axisTensorInfo.GetNumElements(); ++i)
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker axis.emplace_back(axisTensorData[i]);
265*89c4ff92SAndroid Build Coastguard Worker }
266*89c4ff92SAndroid Build Coastguard Worker
267*89c4ff92SAndroid Build Coastguard Worker // Convert the axis to unsigned int and remove duplicates.
268*89c4ff92SAndroid Build Coastguard Worker unsigned int rank = inputTensorInfo.GetNumDimensions();
269*89c4ff92SAndroid Build Coastguard Worker std::set<unsigned int> uniqueAxis;
270*89c4ff92SAndroid Build Coastguard Worker std::transform(axis.begin(),
271*89c4ff92SAndroid Build Coastguard Worker axis.end(),
272*89c4ff92SAndroid Build Coastguard Worker std::inserter(uniqueAxis, uniqueAxis.begin()),
273*89c4ff92SAndroid Build Coastguard Worker [rank](int i)->unsigned int{ return (i + rank) % rank; });
274*89c4ff92SAndroid Build Coastguard Worker
275*89c4ff92SAndroid Build Coastguard Worker // Setup MeanDescriptor and assign axis and keepDims
276*89c4ff92SAndroid Build Coastguard Worker armnn::MeanDescriptor desc;
277*89c4ff92SAndroid Build Coastguard Worker desc.m_Axis.assign(uniqueAxis.begin(), uniqueAxis.end());
278*89c4ff92SAndroid Build Coastguard Worker desc.m_KeepDims = inputTensorInfo.GetNumDimensions() == outputTensorInfo.GetNumDimensions() ? true : false;
279*89c4ff92SAndroid Build Coastguard Worker
280*89c4ff92SAndroid Build Coastguard Worker // Check if supported
281*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
282*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
283*89c4ff92SAndroid Build Coastguard Worker auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
284*89c4ff92SAndroid Build Coastguard Worker {
285*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("MEAN",
286*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
287*89c4ff92SAndroid Build Coastguard Worker IsMeanSupported,
288*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
289*89c4ff92SAndroid Build Coastguard Worker isSupported,
290*89c4ff92SAndroid Build Coastguard Worker setBackend,
291*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
292*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
293*89c4ff92SAndroid Build Coastguard Worker desc);
294*89c4ff92SAndroid Build Coastguard Worker };
295*89c4ff92SAndroid Build Coastguard Worker
296*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
297*89c4ff92SAndroid Build Coastguard Worker {
298*89c4ff92SAndroid Build Coastguard Worker validateFunc(outputTensorInfo, isSupported);
299*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
300*89c4ff92SAndroid Build Coastguard Worker }
301*89c4ff92SAndroid Build Coastguard Worker
302*89c4ff92SAndroid Build Coastguard Worker // Setup layer and connect.
303*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* meanLayer = delegateData.m_Network->AddMeanLayer(desc);
304*89c4ff92SAndroid Build Coastguard Worker meanLayer->SetBackendId(setBackend);
305*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(meanLayer != nullptr);
306*89c4ff92SAndroid Build Coastguard Worker
307*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = meanLayer->GetOutputSlot(0);
308*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(outputTensorInfo);
309*89c4ff92SAndroid Build Coastguard Worker
310*89c4ff92SAndroid Build Coastguard Worker // try to connect the Constant Inputs if there are any
311*89c4ff92SAndroid Build Coastguard Worker if(ProcessInputs(meanLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
312*89c4ff92SAndroid Build Coastguard Worker {
313*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
314*89c4ff92SAndroid Build Coastguard Worker }
315*89c4ff92SAndroid Build Coastguard Worker
316*89c4ff92SAndroid Build Coastguard Worker return Connect(meanLayer, tfLiteNode, delegateData);
317*89c4ff92SAndroid Build Coastguard Worker }
318*89c4ff92SAndroid Build Coastguard Worker
VisitControlOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)319*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitControlOperator(DelegateData& delegateData,
320*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
321*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
322*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
323*89c4ff92SAndroid Build Coastguard Worker int32_t operatorCode)
324*89c4ff92SAndroid Build Coastguard Worker {
325*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(delegateData,
326*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
327*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
328*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
329*89c4ff92SAndroid Build Coastguard Worker operatorCode);
330*89c4ff92SAndroid Build Coastguard Worker
331*89c4ff92SAndroid Build Coastguard Worker switch(operatorCode)
332*89c4ff92SAndroid Build Coastguard Worker {
333*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinConcatenation:
334*89c4ff92SAndroid Build Coastguard Worker return VisitConcatenationOperator(delegateData, tfLiteContext, tfLiteNode, nodeIndex, operatorCode);
335*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMean:
336*89c4ff92SAndroid Build Coastguard Worker return VisitMeanOperator(delegateData, tfLiteContext, tfLiteNode, nodeIndex, operatorCode);
337*89c4ff92SAndroid Build Coastguard Worker default:
338*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
339*89c4ff92SAndroid Build Coastguard Worker }
340*89c4ff92SAndroid Build Coastguard Worker }
341*89c4ff92SAndroid Build Coastguard Worker
342*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
343