xref: /aosp_15_r20/external/armnn/delegate/classic/src/Split.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020,2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <ClassicDelegateUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
11*89c4ff92SAndroid Build Coastguard Worker #include <iterator>
12*89c4ff92SAndroid Build Coastguard Worker #include <vector>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int MaxNumOfTensorDimensions = 5U;
18*89c4ff92SAndroid Build Coastguard Worker 
VisitSplitOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteSplitOperatorCode)19*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitSplitOperator(DelegateData& delegateData,
20*89c4ff92SAndroid Build Coastguard Worker                                 TfLiteContext* tfLiteContext,
21*89c4ff92SAndroid Build Coastguard Worker                                 TfLiteNode* tfLiteNode,
22*89c4ff92SAndroid Build Coastguard Worker                                 int nodeIndex,
23*89c4ff92SAndroid Build Coastguard Worker                                 int32_t tfLiteSplitOperatorCode)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker     auto* splitParameters = reinterpret_cast<TfLiteSplitParams*>(tfLiteNode->builtin_data);
28*89c4ff92SAndroid Build Coastguard Worker     const unsigned int numSplits =  NonNegative(splitParameters->num_splits, nodeIndex);
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
33*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
34*89c4ff92SAndroid Build Coastguard Worker     if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitOperatorCode, nodeIndex))
35*89c4ff92SAndroid Build Coastguard Worker     {
36*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
37*89c4ff92SAndroid Build Coastguard Worker     }
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
40*89c4ff92SAndroid Build Coastguard Worker     if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitOperatorCode, nodeIndex))
41*89c4ff92SAndroid Build Coastguard Worker     {
42*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
43*89c4ff92SAndroid Build Coastguard Worker     }
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
48*89c4ff92SAndroid Build Coastguard Worker     auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
49*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
50*89c4ff92SAndroid Build Coastguard Worker     int32_t axis = axisTensorData[0];
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
53*89c4ff92SAndroid Build Coastguard Worker     if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
54*89c4ff92SAndroid Build Coastguard Worker     {
55*89c4ff92SAndroid Build Coastguard Worker         // Square bracket denotes inclusive n while parenthesis denotes exclusive n
56*89c4ff92SAndroid Build Coastguard Worker         // E.g. Rank 4 tensor can have axis in range [-4, 3)
57*89c4ff92SAndroid Build Coastguard Worker         // -1 == 3, -2 == 2, -3 == 1, -4 == 0
58*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
59*89c4ff92SAndroid Build Coastguard Worker                 tfLiteContext,
60*89c4ff92SAndroid Build Coastguard Worker                 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
61*89c4ff92SAndroid Build Coastguard Worker                 axis, nodeIndex);
62*89c4ff92SAndroid Build Coastguard Worker     }
63*89c4ff92SAndroid Build Coastguard Worker     const unsigned int splitDim = ComputeWrappedIndex(axis, inputTensorInfo.GetNumDimensions());
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorInfo> outputs;
66*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < numSplits; ++i)
67*89c4ff92SAndroid Build Coastguard Worker     {
68*89c4ff92SAndroid Build Coastguard Worker         const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
69*89c4ff92SAndroid Build Coastguard Worker         if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitOperatorCode, nodeIndex))
70*89c4ff92SAndroid Build Coastguard Worker         {
71*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
72*89c4ff92SAndroid Build Coastguard Worker         }
73*89c4ff92SAndroid Build Coastguard Worker         outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
74*89c4ff92SAndroid Build Coastguard Worker     }
75*89c4ff92SAndroid Build Coastguard Worker     const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     auto inputDimSize = inputTensorInfo.GetNumDimensions();
78*89c4ff92SAndroid Build Coastguard Worker     if (inputDimSize > MaxNumOfTensorDimensions)
79*89c4ff92SAndroid Build Coastguard Worker     {
80*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
81*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
82*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
83*89c4ff92SAndroid Build Coastguard Worker             "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
84*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
85*89c4ff92SAndroid Build Coastguard Worker     }
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> splitterDimSizes(inputDimSize);
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     // Add current input shape to splitterDimSizes
90*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < inputDimSize; ++i)
91*89c4ff92SAndroid Build Coastguard Worker     {
92*89c4ff92SAndroid Build Coastguard Worker         splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
93*89c4ff92SAndroid Build Coastguard Worker     }
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker     if (splitterDimSizes[splitDim] % numSplits != 0)
96*89c4ff92SAndroid Build Coastguard Worker     {
97*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
98*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
99*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Number of splits #%d must evenly divide the dimension #%d in node #%d: ",
100*89c4ff92SAndroid Build Coastguard Worker             numSplits, splitterDimSizes[splitDim], nodeIndex);
101*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
102*89c4ff92SAndroid Build Coastguard Worker     }
103*89c4ff92SAndroid Build Coastguard Worker     splitterDimSizes[splitDim] /= numSplits;
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker     armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
106*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int j = 0; j < numSplits; ++j)
107*89c4ff92SAndroid Build Coastguard Worker     {
108*89c4ff92SAndroid Build Coastguard Worker         // Set the size of the views.
109*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
110*89c4ff92SAndroid Build Coastguard Worker         {
111*89c4ff92SAndroid Build Coastguard Worker             splitDescriptor.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]);
112*89c4ff92SAndroid Build Coastguard Worker         }
113*89c4ff92SAndroid Build Coastguard Worker         splitDescriptor.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j);
114*89c4ff92SAndroid Build Coastguard Worker     }
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
117*89c4ff92SAndroid Build Coastguard Worker     if (!delegateData.m_Network)
118*89c4ff92SAndroid Build Coastguard Worker     {
119*89c4ff92SAndroid Build Coastguard Worker         // Check if supported
120*89c4ff92SAndroid Build Coastguard Worker         bool isSupported = false;
121*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
122*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
123*89c4ff92SAndroid Build Coastguard Worker                                    IsSplitterSupported,
124*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
125*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
126*89c4ff92SAndroid Build Coastguard Worker                                    setBackend,
127*89c4ff92SAndroid Build Coastguard Worker                                    inputTensorInfo,
128*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfos,
129*89c4ff92SAndroid Build Coastguard Worker                                    splitDescriptor);
130*89c4ff92SAndroid Build Coastguard Worker         return isSupported ? kTfLiteOk : kTfLiteError;
131*89c4ff92SAndroid Build Coastguard Worker     }
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
134*89c4ff92SAndroid Build Coastguard Worker     layer->SetBackendId(setBackend);
135*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
138*89c4ff92SAndroid Build Coastguard Worker     {
139*89c4ff92SAndroid Build Coastguard Worker         layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
140*89c4ff92SAndroid Build Coastguard Worker     }
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker     // Connect the input slots
143*89c4ff92SAndroid Build Coastguard Worker     delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(0));
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker     // Prepare output slots
146*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
147*89c4ff92SAndroid Build Coastguard Worker     {
148*89c4ff92SAndroid Build Coastguard Worker         armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
149*89c4ff92SAndroid Build Coastguard Worker         delegateData.m_OutputSlotForNode[
150*89c4ff92SAndroid Build Coastguard Worker             static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &outputSlot;
151*89c4ff92SAndroid Build Coastguard Worker     }
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker     return kTfLiteOk;
154*89c4ff92SAndroid Build Coastguard Worker }
155*89c4ff92SAndroid Build Coastguard Worker 
VisitSplitVOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteSplitVOperatorCode)156*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitSplitVOperator(DelegateData& delegateData,
157*89c4ff92SAndroid Build Coastguard Worker                                  TfLiteContext* tfLiteContext,
158*89c4ff92SAndroid Build Coastguard Worker                                  TfLiteNode* tfLiteNode,
159*89c4ff92SAndroid Build Coastguard Worker                                  int nodeIndex,
160*89c4ff92SAndroid Build Coastguard Worker                                  int32_t tfLiteSplitVOperatorCode)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
165*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
166*89c4ff92SAndroid Build Coastguard Worker     if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitVOperatorCode, nodeIndex))
167*89c4ff92SAndroid Build Coastguard Worker     {
168*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
169*89c4ff92SAndroid Build Coastguard Worker     }
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteSplitsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
172*89c4ff92SAndroid Build Coastguard Worker     if (!IsValid(tfLiteContext, tfLiteSplitsTensor, tfLiteSplitVOperatorCode, nodeIndex))
173*89c4ff92SAndroid Build Coastguard Worker     {
174*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
175*89c4ff92SAndroid Build Coastguard Worker     }
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
178*89c4ff92SAndroid Build Coastguard Worker     if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitVOperatorCode, nodeIndex))
179*89c4ff92SAndroid Build Coastguard Worker     {
180*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
181*89c4ff92SAndroid Build Coastguard Worker     }
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
184*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& splitsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteSplitsTensor);
185*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(splitsTensorInfo.GetNumDimensions() == 1);
186*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker     auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
189*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
190*89c4ff92SAndroid Build Coastguard Worker     int32_t axis = axisTensorData[0];
191*89c4ff92SAndroid Build Coastguard Worker 
192*89c4ff92SAndroid Build Coastguard Worker     auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
193*89c4ff92SAndroid Build Coastguard Worker     if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
194*89c4ff92SAndroid Build Coastguard Worker     {
195*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
196*89c4ff92SAndroid Build Coastguard Worker                 tfLiteContext,
197*89c4ff92SAndroid Build Coastguard Worker                 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
198*89c4ff92SAndroid Build Coastguard Worker                 axis, nodeIndex);
199*89c4ff92SAndroid Build Coastguard Worker     }
200*89c4ff92SAndroid Build Coastguard Worker     const unsigned int splitDim = ComputeWrappedIndex(axisTensorData[0], inputTensorInfo.GetNumDimensions());
201*89c4ff92SAndroid Build Coastguard Worker 
202*89c4ff92SAndroid Build Coastguard Worker     auto* splitVParameters = reinterpret_cast<TfLiteSplitVParams*>(tfLiteNode->builtin_data);
203*89c4ff92SAndroid Build Coastguard Worker     unsigned int numSplits = 0;
204*89c4ff92SAndroid Build Coastguard Worker     if (splitVParameters)
205*89c4ff92SAndroid Build Coastguard Worker     {
206*89c4ff92SAndroid Build Coastguard Worker         numSplits = NonNegative(splitVParameters->num_splits, nodeIndex);
207*89c4ff92SAndroid Build Coastguard Worker     }
208*89c4ff92SAndroid Build Coastguard Worker     else
209*89c4ff92SAndroid Build Coastguard Worker     {
210*89c4ff92SAndroid Build Coastguard Worker         numSplits = splitsTensorInfo.GetNumElements();
211*89c4ff92SAndroid Build Coastguard Worker     }
212*89c4ff92SAndroid Build Coastguard Worker 
213*89c4ff92SAndroid Build Coastguard Worker     if (numSplits <= 0)
214*89c4ff92SAndroid Build Coastguard Worker     {
215*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
216*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext, "TfLiteArmnnDelegate: Invalid number of splits %d  in node #%d",
217*89c4ff92SAndroid Build Coastguard Worker             numSplits, nodeIndex);
218*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
219*89c4ff92SAndroid Build Coastguard Worker     }
220*89c4ff92SAndroid Build Coastguard Worker 
221*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
222*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorInfo> outputs;
223*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < numSplits; ++i)
224*89c4ff92SAndroid Build Coastguard Worker     {
225*89c4ff92SAndroid Build Coastguard Worker         const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
226*89c4ff92SAndroid Build Coastguard Worker         if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitVOperatorCode, nodeIndex))
227*89c4ff92SAndroid Build Coastguard Worker         {
228*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
229*89c4ff92SAndroid Build Coastguard Worker         }
230*89c4ff92SAndroid Build Coastguard Worker         outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
231*89c4ff92SAndroid Build Coastguard Worker     }
232*89c4ff92SAndroid Build Coastguard Worker     const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
233*89c4ff92SAndroid Build Coastguard Worker 
234*89c4ff92SAndroid Build Coastguard Worker     auto inputDimSize = inputTensorInfo.GetNumDimensions();
235*89c4ff92SAndroid Build Coastguard Worker     if (inputDimSize > MaxNumOfTensorDimensions)
236*89c4ff92SAndroid Build Coastguard Worker     {
237*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
238*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
239*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
240*89c4ff92SAndroid Build Coastguard Worker             "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
241*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
242*89c4ff92SAndroid Build Coastguard Worker     }
243*89c4ff92SAndroid Build Coastguard Worker 
244*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> splitsTensorData(numSplits);
245*89c4ff92SAndroid Build Coastguard Worker     std::memcpy(splitsTensorData.data(), tfLiteSplitsTensor.data.data, splitsTensorInfo.GetNumBytes());
246*89c4ff92SAndroid Build Coastguard Worker 
247*89c4ff92SAndroid Build Coastguard Worker 
248*89c4ff92SAndroid Build Coastguard Worker     unsigned int index         = 0;
249*89c4ff92SAndroid Build Coastguard Worker     unsigned int inferredIndex = 0;
250*89c4ff92SAndroid Build Coastguard Worker     int numberOfInferred       = 0;
251*89c4ff92SAndroid Build Coastguard Worker     int splitSum = 0;
252*89c4ff92SAndroid Build Coastguard Worker 
253*89c4ff92SAndroid Build Coastguard Worker     for (auto splitData : splitsTensorData)
254*89c4ff92SAndroid Build Coastguard Worker     {
255*89c4ff92SAndroid Build Coastguard Worker         if (splitData < 0)
256*89c4ff92SAndroid Build Coastguard Worker         {
257*89c4ff92SAndroid Build Coastguard Worker             ++numberOfInferred;
258*89c4ff92SAndroid Build Coastguard Worker             inferredIndex = index;
259*89c4ff92SAndroid Build Coastguard Worker         }
260*89c4ff92SAndroid Build Coastguard Worker         else
261*89c4ff92SAndroid Build Coastguard Worker         {
262*89c4ff92SAndroid Build Coastguard Worker             splitSum += splitData;
263*89c4ff92SAndroid Build Coastguard Worker         }
264*89c4ff92SAndroid Build Coastguard Worker         ++index;
265*89c4ff92SAndroid Build Coastguard Worker     }
266*89c4ff92SAndroid Build Coastguard Worker 
267*89c4ff92SAndroid Build Coastguard Worker     // Check for inferred axis
268*89c4ff92SAndroid Build Coastguard Worker     if (numberOfInferred == 0)
269*89c4ff92SAndroid Build Coastguard Worker     {
270*89c4ff92SAndroid Build Coastguard Worker         if (splitSum != armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
271*89c4ff92SAndroid Build Coastguard Worker         {
272*89c4ff92SAndroid Build Coastguard Worker             TF_LITE_MAYBE_KERNEL_LOG(
273*89c4ff92SAndroid Build Coastguard Worker                 tfLiteContext, "TfLiteArmnnDelegate: SplitV split_sizes does not sum to the dimension of value along"
274*89c4ff92SAndroid Build Coastguard Worker                                " split_dim in node #%d", nodeIndex);
275*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
276*89c4ff92SAndroid Build Coastguard Worker         }
277*89c4ff92SAndroid Build Coastguard Worker     }
278*89c4ff92SAndroid Build Coastguard Worker     else if (numberOfInferred == 1)
279*89c4ff92SAndroid Build Coastguard Worker     {
280*89c4ff92SAndroid Build Coastguard Worker         splitsTensorData[inferredIndex] = armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
281*89c4ff92SAndroid Build Coastguard Worker     }
282*89c4ff92SAndroid Build Coastguard Worker     else
283*89c4ff92SAndroid Build Coastguard Worker     {
284*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
285*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext, "TfLiteArmnnDelegate: SplitV cannot infer split size for more than one split in node #%d",
286*89c4ff92SAndroid Build Coastguard Worker             nodeIndex);
287*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
288*89c4ff92SAndroid Build Coastguard Worker     }
289*89c4ff92SAndroid Build Coastguard Worker 
290*89c4ff92SAndroid Build Coastguard Worker     armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
291*89c4ff92SAndroid Build Coastguard Worker     unsigned int accumSplit = 0;
292*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int j = 0; j < numSplits; ++j)
293*89c4ff92SAndroid Build Coastguard Worker     {
294*89c4ff92SAndroid Build Coastguard Worker         unsigned int splitSize = armnn::numeric_cast<unsigned int>(splitsTensorData[j]);
295*89c4ff92SAndroid Build Coastguard Worker 
296*89c4ff92SAndroid Build Coastguard Worker         // Set the size of the views.
297*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int dimIdx = 0; dimIdx < inputTensorInfo.GetNumDimensions(); ++dimIdx)
298*89c4ff92SAndroid Build Coastguard Worker         {
299*89c4ff92SAndroid Build Coastguard Worker             unsigned int dimSize = inputTensorInfo.GetShape()[dimIdx];
300*89c4ff92SAndroid Build Coastguard Worker             if (dimIdx == splitDim)
301*89c4ff92SAndroid Build Coastguard Worker             {
302*89c4ff92SAndroid Build Coastguard Worker                 dimSize = splitSize;
303*89c4ff92SAndroid Build Coastguard Worker             }
304*89c4ff92SAndroid Build Coastguard Worker             splitDescriptor.SetViewSize(j, dimIdx, dimSize);
305*89c4ff92SAndroid Build Coastguard Worker         }
306*89c4ff92SAndroid Build Coastguard Worker 
307*89c4ff92SAndroid Build Coastguard Worker         splitDescriptor.SetViewOriginCoord(j, splitDim, accumSplit);
308*89c4ff92SAndroid Build Coastguard Worker         accumSplit += splitSize;
309*89c4ff92SAndroid Build Coastguard Worker     }
310*89c4ff92SAndroid Build Coastguard Worker 
311*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
312*89c4ff92SAndroid Build Coastguard Worker     if (!delegateData.m_Network)
313*89c4ff92SAndroid Build Coastguard Worker     {
314*89c4ff92SAndroid Build Coastguard Worker         // Check if supported
315*89c4ff92SAndroid Build Coastguard Worker         bool isSupported = false;
316*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
317*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
318*89c4ff92SAndroid Build Coastguard Worker                                    IsSplitterSupported,
319*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
320*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
321*89c4ff92SAndroid Build Coastguard Worker                                    setBackend,
322*89c4ff92SAndroid Build Coastguard Worker                                    inputTensorInfo,
323*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfos,
324*89c4ff92SAndroid Build Coastguard Worker                                    splitDescriptor);
325*89c4ff92SAndroid Build Coastguard Worker         return isSupported ? kTfLiteOk : kTfLiteError;
326*89c4ff92SAndroid Build Coastguard Worker     }
327*89c4ff92SAndroid Build Coastguard Worker 
328*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
329*89c4ff92SAndroid Build Coastguard Worker     layer->SetBackendId(setBackend);
330*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
331*89c4ff92SAndroid Build Coastguard Worker 
332*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
333*89c4ff92SAndroid Build Coastguard Worker     {
334*89c4ff92SAndroid Build Coastguard Worker         layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
335*89c4ff92SAndroid Build Coastguard Worker     }
336*89c4ff92SAndroid Build Coastguard Worker 
337*89c4ff92SAndroid Build Coastguard Worker     // try to connect the Constant Inputs if there are any
338*89c4ff92SAndroid Build Coastguard Worker     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
339*89c4ff92SAndroid Build Coastguard Worker     {
340*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
341*89c4ff92SAndroid Build Coastguard Worker     }
342*89c4ff92SAndroid Build Coastguard Worker 
343*89c4ff92SAndroid Build Coastguard Worker     // Connect
344*89c4ff92SAndroid Build Coastguard Worker     return Connect(layer, tfLiteNode, delegateData);
345*89c4ff92SAndroid Build Coastguard Worker }
346*89c4ff92SAndroid Build Coastguard Worker 
347*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate