1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020,2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <ClassicDelegateUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
11*89c4ff92SAndroid Build Coastguard Worker #include <iterator>
12*89c4ff92SAndroid Build Coastguard Worker #include <vector>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker constexpr unsigned int MaxNumOfTensorDimensions = 5U;
18*89c4ff92SAndroid Build Coastguard Worker
VisitSplitOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteSplitOperatorCode)19*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitSplitOperator(DelegateData& delegateData,
20*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
21*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
22*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
23*89c4ff92SAndroid Build Coastguard Worker int32_t tfLiteSplitOperatorCode)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
26*89c4ff92SAndroid Build Coastguard Worker
27*89c4ff92SAndroid Build Coastguard Worker auto* splitParameters = reinterpret_cast<TfLiteSplitParams*>(tfLiteNode->builtin_data);
28*89c4ff92SAndroid Build Coastguard Worker const unsigned int numSplits = NonNegative(splitParameters->num_splits, nodeIndex);
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
33*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
34*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitOperatorCode, nodeIndex))
35*89c4ff92SAndroid Build Coastguard Worker {
36*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker
39*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
40*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitOperatorCode, nodeIndex))
41*89c4ff92SAndroid Build Coastguard Worker {
42*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
43*89c4ff92SAndroid Build Coastguard Worker }
44*89c4ff92SAndroid Build Coastguard Worker
45*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
48*89c4ff92SAndroid Build Coastguard Worker auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
49*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
50*89c4ff92SAndroid Build Coastguard Worker int32_t axis = axisTensorData[0];
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
53*89c4ff92SAndroid Build Coastguard Worker if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker // Square bracket denotes inclusive n while parenthesis denotes exclusive n
56*89c4ff92SAndroid Build Coastguard Worker // E.g. Rank 4 tensor can have axis in range [-4, 3)
57*89c4ff92SAndroid Build Coastguard Worker // -1 == 3, -2 == 2, -3 == 1, -4 == 0
58*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
59*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
60*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
61*89c4ff92SAndroid Build Coastguard Worker axis, nodeIndex);
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker const unsigned int splitDim = ComputeWrappedIndex(axis, inputTensorInfo.GetNumDimensions());
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorInfo> outputs;
66*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numSplits; ++i)
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
69*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitOperatorCode, nodeIndex))
70*89c4ff92SAndroid Build Coastguard Worker {
71*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
74*89c4ff92SAndroid Build Coastguard Worker }
75*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker auto inputDimSize = inputTensorInfo.GetNumDimensions();
78*89c4ff92SAndroid Build Coastguard Worker if (inputDimSize > MaxNumOfTensorDimensions)
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
81*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
82*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
83*89c4ff92SAndroid Build Coastguard Worker "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
84*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> splitterDimSizes(inputDimSize);
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker // Add current input shape to splitterDimSizes
90*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < inputDimSize; ++i)
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
93*89c4ff92SAndroid Build Coastguard Worker }
94*89c4ff92SAndroid Build Coastguard Worker
95*89c4ff92SAndroid Build Coastguard Worker if (splitterDimSizes[splitDim] % numSplits != 0)
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
98*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
99*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Number of splits #%d must evenly divide the dimension #%d in node #%d: ",
100*89c4ff92SAndroid Build Coastguard Worker numSplits, splitterDimSizes[splitDim], nodeIndex);
101*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
102*89c4ff92SAndroid Build Coastguard Worker }
103*89c4ff92SAndroid Build Coastguard Worker splitterDimSizes[splitDim] /= numSplits;
104*89c4ff92SAndroid Build Coastguard Worker
105*89c4ff92SAndroid Build Coastguard Worker armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
106*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = 0; j < numSplits; ++j)
107*89c4ff92SAndroid Build Coastguard Worker {
108*89c4ff92SAndroid Build Coastguard Worker // Set the size of the views.
109*89c4ff92SAndroid Build Coastguard Worker for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker splitDescriptor.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]);
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker splitDescriptor.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j);
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
117*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
118*89c4ff92SAndroid Build Coastguard Worker {
119*89c4ff92SAndroid Build Coastguard Worker // Check if supported
120*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
121*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
122*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
123*89c4ff92SAndroid Build Coastguard Worker IsSplitterSupported,
124*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
125*89c4ff92SAndroid Build Coastguard Worker isSupported,
126*89c4ff92SAndroid Build Coastguard Worker setBackend,
127*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
128*89c4ff92SAndroid Build Coastguard Worker outputTensorInfos,
129*89c4ff92SAndroid Build Coastguard Worker splitDescriptor);
130*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker
133*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
134*89c4ff92SAndroid Build Coastguard Worker layer->SetBackendId(setBackend);
135*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer != nullptr);
136*89c4ff92SAndroid Build Coastguard Worker
137*89c4ff92SAndroid Build Coastguard Worker for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
140*89c4ff92SAndroid Build Coastguard Worker }
141*89c4ff92SAndroid Build Coastguard Worker
142*89c4ff92SAndroid Build Coastguard Worker // Connect the input slots
143*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(0));
144*89c4ff92SAndroid Build Coastguard Worker
145*89c4ff92SAndroid Build Coastguard Worker // Prepare output slots
146*89c4ff92SAndroid Build Coastguard Worker for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
147*89c4ff92SAndroid Build Coastguard Worker {
148*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
149*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode[
150*89c4ff92SAndroid Build Coastguard Worker static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &outputSlot;
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker
153*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
154*89c4ff92SAndroid Build Coastguard Worker }
155*89c4ff92SAndroid Build Coastguard Worker
VisitSplitVOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteSplitVOperatorCode)156*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitSplitVOperator(DelegateData& delegateData,
157*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
158*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
159*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
160*89c4ff92SAndroid Build Coastguard Worker int32_t tfLiteSplitVOperatorCode)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
165*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
166*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitVOperatorCode, nodeIndex))
167*89c4ff92SAndroid Build Coastguard Worker {
168*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
169*89c4ff92SAndroid Build Coastguard Worker }
170*89c4ff92SAndroid Build Coastguard Worker
171*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteSplitsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
172*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteSplitsTensor, tfLiteSplitVOperatorCode, nodeIndex))
173*89c4ff92SAndroid Build Coastguard Worker {
174*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
175*89c4ff92SAndroid Build Coastguard Worker }
176*89c4ff92SAndroid Build Coastguard Worker
177*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
178*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitVOperatorCode, nodeIndex))
179*89c4ff92SAndroid Build Coastguard Worker {
180*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
181*89c4ff92SAndroid Build Coastguard Worker }
182*89c4ff92SAndroid Build Coastguard Worker
183*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
184*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& splitsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteSplitsTensor);
185*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(splitsTensorInfo.GetNumDimensions() == 1);
186*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
187*89c4ff92SAndroid Build Coastguard Worker
188*89c4ff92SAndroid Build Coastguard Worker auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
189*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
190*89c4ff92SAndroid Build Coastguard Worker int32_t axis = axisTensorData[0];
191*89c4ff92SAndroid Build Coastguard Worker
192*89c4ff92SAndroid Build Coastguard Worker auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
193*89c4ff92SAndroid Build Coastguard Worker if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
194*89c4ff92SAndroid Build Coastguard Worker {
195*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
196*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
197*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
198*89c4ff92SAndroid Build Coastguard Worker axis, nodeIndex);
199*89c4ff92SAndroid Build Coastguard Worker }
200*89c4ff92SAndroid Build Coastguard Worker const unsigned int splitDim = ComputeWrappedIndex(axisTensorData[0], inputTensorInfo.GetNumDimensions());
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker auto* splitVParameters = reinterpret_cast<TfLiteSplitVParams*>(tfLiteNode->builtin_data);
203*89c4ff92SAndroid Build Coastguard Worker unsigned int numSplits = 0;
204*89c4ff92SAndroid Build Coastguard Worker if (splitVParameters)
205*89c4ff92SAndroid Build Coastguard Worker {
206*89c4ff92SAndroid Build Coastguard Worker numSplits = NonNegative(splitVParameters->num_splits, nodeIndex);
207*89c4ff92SAndroid Build Coastguard Worker }
208*89c4ff92SAndroid Build Coastguard Worker else
209*89c4ff92SAndroid Build Coastguard Worker {
210*89c4ff92SAndroid Build Coastguard Worker numSplits = splitsTensorInfo.GetNumElements();
211*89c4ff92SAndroid Build Coastguard Worker }
212*89c4ff92SAndroid Build Coastguard Worker
213*89c4ff92SAndroid Build Coastguard Worker if (numSplits <= 0)
214*89c4ff92SAndroid Build Coastguard Worker {
215*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
216*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, "TfLiteArmnnDelegate: Invalid number of splits %d in node #%d",
217*89c4ff92SAndroid Build Coastguard Worker numSplits, nodeIndex);
218*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
219*89c4ff92SAndroid Build Coastguard Worker }
220*89c4ff92SAndroid Build Coastguard Worker
221*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
222*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorInfo> outputs;
223*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numSplits; ++i)
224*89c4ff92SAndroid Build Coastguard Worker {
225*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
226*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitVOperatorCode, nodeIndex))
227*89c4ff92SAndroid Build Coastguard Worker {
228*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
229*89c4ff92SAndroid Build Coastguard Worker }
230*89c4ff92SAndroid Build Coastguard Worker outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
231*89c4ff92SAndroid Build Coastguard Worker }
232*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
233*89c4ff92SAndroid Build Coastguard Worker
234*89c4ff92SAndroid Build Coastguard Worker auto inputDimSize = inputTensorInfo.GetNumDimensions();
235*89c4ff92SAndroid Build Coastguard Worker if (inputDimSize > MaxNumOfTensorDimensions)
236*89c4ff92SAndroid Build Coastguard Worker {
237*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
238*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
239*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
240*89c4ff92SAndroid Build Coastguard Worker "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
241*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
242*89c4ff92SAndroid Build Coastguard Worker }
243*89c4ff92SAndroid Build Coastguard Worker
244*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> splitsTensorData(numSplits);
245*89c4ff92SAndroid Build Coastguard Worker std::memcpy(splitsTensorData.data(), tfLiteSplitsTensor.data.data, splitsTensorInfo.GetNumBytes());
246*89c4ff92SAndroid Build Coastguard Worker
247*89c4ff92SAndroid Build Coastguard Worker
248*89c4ff92SAndroid Build Coastguard Worker unsigned int index = 0;
249*89c4ff92SAndroid Build Coastguard Worker unsigned int inferredIndex = 0;
250*89c4ff92SAndroid Build Coastguard Worker int numberOfInferred = 0;
251*89c4ff92SAndroid Build Coastguard Worker int splitSum = 0;
252*89c4ff92SAndroid Build Coastguard Worker
253*89c4ff92SAndroid Build Coastguard Worker for (auto splitData : splitsTensorData)
254*89c4ff92SAndroid Build Coastguard Worker {
255*89c4ff92SAndroid Build Coastguard Worker if (splitData < 0)
256*89c4ff92SAndroid Build Coastguard Worker {
257*89c4ff92SAndroid Build Coastguard Worker ++numberOfInferred;
258*89c4ff92SAndroid Build Coastguard Worker inferredIndex = index;
259*89c4ff92SAndroid Build Coastguard Worker }
260*89c4ff92SAndroid Build Coastguard Worker else
261*89c4ff92SAndroid Build Coastguard Worker {
262*89c4ff92SAndroid Build Coastguard Worker splitSum += splitData;
263*89c4ff92SAndroid Build Coastguard Worker }
264*89c4ff92SAndroid Build Coastguard Worker ++index;
265*89c4ff92SAndroid Build Coastguard Worker }
266*89c4ff92SAndroid Build Coastguard Worker
267*89c4ff92SAndroid Build Coastguard Worker // Check for inferred axis
268*89c4ff92SAndroid Build Coastguard Worker if (numberOfInferred == 0)
269*89c4ff92SAndroid Build Coastguard Worker {
270*89c4ff92SAndroid Build Coastguard Worker if (splitSum != armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
271*89c4ff92SAndroid Build Coastguard Worker {
272*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
273*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, "TfLiteArmnnDelegate: SplitV split_sizes does not sum to the dimension of value along"
274*89c4ff92SAndroid Build Coastguard Worker " split_dim in node #%d", nodeIndex);
275*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
276*89c4ff92SAndroid Build Coastguard Worker }
277*89c4ff92SAndroid Build Coastguard Worker }
278*89c4ff92SAndroid Build Coastguard Worker else if (numberOfInferred == 1)
279*89c4ff92SAndroid Build Coastguard Worker {
280*89c4ff92SAndroid Build Coastguard Worker splitsTensorData[inferredIndex] = armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
281*89c4ff92SAndroid Build Coastguard Worker }
282*89c4ff92SAndroid Build Coastguard Worker else
283*89c4ff92SAndroid Build Coastguard Worker {
284*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
285*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, "TfLiteArmnnDelegate: SplitV cannot infer split size for more than one split in node #%d",
286*89c4ff92SAndroid Build Coastguard Worker nodeIndex);
287*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
288*89c4ff92SAndroid Build Coastguard Worker }
289*89c4ff92SAndroid Build Coastguard Worker
290*89c4ff92SAndroid Build Coastguard Worker armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
291*89c4ff92SAndroid Build Coastguard Worker unsigned int accumSplit = 0;
292*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = 0; j < numSplits; ++j)
293*89c4ff92SAndroid Build Coastguard Worker {
294*89c4ff92SAndroid Build Coastguard Worker unsigned int splitSize = armnn::numeric_cast<unsigned int>(splitsTensorData[j]);
295*89c4ff92SAndroid Build Coastguard Worker
296*89c4ff92SAndroid Build Coastguard Worker // Set the size of the views.
297*89c4ff92SAndroid Build Coastguard Worker for (unsigned int dimIdx = 0; dimIdx < inputTensorInfo.GetNumDimensions(); ++dimIdx)
298*89c4ff92SAndroid Build Coastguard Worker {
299*89c4ff92SAndroid Build Coastguard Worker unsigned int dimSize = inputTensorInfo.GetShape()[dimIdx];
300*89c4ff92SAndroid Build Coastguard Worker if (dimIdx == splitDim)
301*89c4ff92SAndroid Build Coastguard Worker {
302*89c4ff92SAndroid Build Coastguard Worker dimSize = splitSize;
303*89c4ff92SAndroid Build Coastguard Worker }
304*89c4ff92SAndroid Build Coastguard Worker splitDescriptor.SetViewSize(j, dimIdx, dimSize);
305*89c4ff92SAndroid Build Coastguard Worker }
306*89c4ff92SAndroid Build Coastguard Worker
307*89c4ff92SAndroid Build Coastguard Worker splitDescriptor.SetViewOriginCoord(j, splitDim, accumSplit);
308*89c4ff92SAndroid Build Coastguard Worker accumSplit += splitSize;
309*89c4ff92SAndroid Build Coastguard Worker }
310*89c4ff92SAndroid Build Coastguard Worker
311*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
312*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
313*89c4ff92SAndroid Build Coastguard Worker {
314*89c4ff92SAndroid Build Coastguard Worker // Check if supported
315*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
316*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
317*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
318*89c4ff92SAndroid Build Coastguard Worker IsSplitterSupported,
319*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
320*89c4ff92SAndroid Build Coastguard Worker isSupported,
321*89c4ff92SAndroid Build Coastguard Worker setBackend,
322*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
323*89c4ff92SAndroid Build Coastguard Worker outputTensorInfos,
324*89c4ff92SAndroid Build Coastguard Worker splitDescriptor);
325*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
326*89c4ff92SAndroid Build Coastguard Worker }
327*89c4ff92SAndroid Build Coastguard Worker
328*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
329*89c4ff92SAndroid Build Coastguard Worker layer->SetBackendId(setBackend);
330*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer != nullptr);
331*89c4ff92SAndroid Build Coastguard Worker
332*89c4ff92SAndroid Build Coastguard Worker for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
333*89c4ff92SAndroid Build Coastguard Worker {
334*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
335*89c4ff92SAndroid Build Coastguard Worker }
336*89c4ff92SAndroid Build Coastguard Worker
337*89c4ff92SAndroid Build Coastguard Worker // try to connect the Constant Inputs if there are any
338*89c4ff92SAndroid Build Coastguard Worker if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
339*89c4ff92SAndroid Build Coastguard Worker {
340*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
341*89c4ff92SAndroid Build Coastguard Worker }
342*89c4ff92SAndroid Build Coastguard Worker
343*89c4ff92SAndroid Build Coastguard Worker // Connect
344*89c4ff92SAndroid Build Coastguard Worker return Connect(layer, tfLiteNode, delegateData);
345*89c4ff92SAndroid Build Coastguard Worker }
346*89c4ff92SAndroid Build Coastguard Worker
347*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate