xref: /aosp_15_r20/external/armnn/delegate/classic/src/Split.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020,2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 
10 #include <algorithm>
11 #include <iterator>
12 #include <vector>
13 
14 namespace armnnDelegate
15 {
16 
17 constexpr unsigned int MaxNumOfTensorDimensions = 5U;
18 
VisitSplitOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteSplitOperatorCode)19 TfLiteStatus VisitSplitOperator(DelegateData& delegateData,
20                                 TfLiteContext* tfLiteContext,
21                                 TfLiteNode* tfLiteNode,
22                                 int nodeIndex,
23                                 int32_t tfLiteSplitOperatorCode)
24 {
25     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
26 
27     auto* splitParameters = reinterpret_cast<TfLiteSplitParams*>(tfLiteNode->builtin_data);
28     const unsigned int numSplits =  NonNegative(splitParameters->num_splits, nodeIndex);
29 
30     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
31 
32     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
33     const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
34     if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitOperatorCode, nodeIndex))
35     {
36         return kTfLiteError;
37     }
38 
39     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
40     if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitOperatorCode, nodeIndex))
41     {
42         return kTfLiteError;
43     }
44 
45     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
46 
47     ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
48     auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
49     std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
50     int32_t axis = axisTensorData[0];
51 
52     auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
53     if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
54     {
55         // Square bracket denotes inclusive n while parenthesis denotes exclusive n
56         // E.g. Rank 4 tensor can have axis in range [-4, 3)
57         // -1 == 3, -2 == 2, -3 == 1, -4 == 0
58         TF_LITE_MAYBE_KERNEL_LOG(
59                 tfLiteContext,
60                 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
61                 axis, nodeIndex);
62     }
63     const unsigned int splitDim = ComputeWrappedIndex(axis, inputTensorInfo.GetNumDimensions());
64 
65     std::vector<armnn::TensorInfo> outputs;
66     for (unsigned int i = 0; i < numSplits; ++i)
67     {
68         const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
69         if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitOperatorCode, nodeIndex))
70         {
71             return kTfLiteError;
72         }
73         outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
74     }
75     const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
76 
77     auto inputDimSize = inputTensorInfo.GetNumDimensions();
78     if (inputDimSize > MaxNumOfTensorDimensions)
79     {
80         TF_LITE_MAYBE_KERNEL_LOG(
81             tfLiteContext,
82             "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
83             "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
84         return kTfLiteError;
85     }
86 
87     std::vector<unsigned int> splitterDimSizes(inputDimSize);
88 
89     // Add current input shape to splitterDimSizes
90     for (unsigned int i = 0; i < inputDimSize; ++i)
91     {
92         splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
93     }
94 
95     if (splitterDimSizes[splitDim] % numSplits != 0)
96     {
97         TF_LITE_MAYBE_KERNEL_LOG(
98             tfLiteContext,
99             "TfLiteArmnnDelegate: Number of splits #%d must evenly divide the dimension #%d in node #%d: ",
100             numSplits, splitterDimSizes[splitDim], nodeIndex);
101         return kTfLiteError;
102     }
103     splitterDimSizes[splitDim] /= numSplits;
104 
105     armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
106     for (unsigned int j = 0; j < numSplits; ++j)
107     {
108         // Set the size of the views.
109         for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
110         {
111             splitDescriptor.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]);
112         }
113         splitDescriptor.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j);
114     }
115 
116     armnn::BackendId setBackend;
117     if (!delegateData.m_Network)
118     {
119         // Check if supported
120         bool isSupported = false;
121         FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
122                                    tfLiteContext,
123                                    IsSplitterSupported,
124                                    delegateData.m_Backends,
125                                    isSupported,
126                                    setBackend,
127                                    inputTensorInfo,
128                                    outputTensorInfos,
129                                    splitDescriptor);
130         return isSupported ? kTfLiteOk : kTfLiteError;
131     }
132 
133     armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
134     layer->SetBackendId(setBackend);
135     ARMNN_ASSERT(layer != nullptr);
136 
137     for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
138     {
139         layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
140     }
141 
142     // Connect the input slots
143     delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[1]]->Connect(layer->GetInputSlot(0));
144 
145     // Prepare output slots
146     for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
147     {
148         armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
149         delegateData.m_OutputSlotForNode[
150             static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &outputSlot;
151     }
152 
153     return kTfLiteOk;
154 }
155 
VisitSplitVOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteSplitVOperatorCode)156 TfLiteStatus VisitSplitVOperator(DelegateData& delegateData,
157                                  TfLiteContext* tfLiteContext,
158                                  TfLiteNode* tfLiteNode,
159                                  int nodeIndex,
160                                  int32_t tfLiteSplitVOperatorCode)
161 {
162     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 3, nodeIndex));
163 
164     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
165     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
166     if (!IsValid(tfLiteContext, tfLiteInputTensor, tfLiteSplitVOperatorCode, nodeIndex))
167     {
168         return kTfLiteError;
169     }
170 
171     const TfLiteTensor& tfLiteSplitsTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
172     if (!IsValid(tfLiteContext, tfLiteSplitsTensor, tfLiteSplitVOperatorCode, nodeIndex))
173     {
174         return kTfLiteError;
175     }
176 
177     const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[2]];
178     if (!IsValid(tfLiteContext, tfLiteAxisTensor, tfLiteSplitVOperatorCode, nodeIndex))
179     {
180         return kTfLiteError;
181     }
182 
183     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
184     const armnn::TensorInfo& splitsTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteSplitsTensor);
185     ARMNN_ASSERT(splitsTensorInfo.GetNumDimensions() == 1);
186     ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
187 
188     auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
189     std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
190     int32_t axis = axisTensorData[0];
191 
192     auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
193     if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
194     {
195         TF_LITE_MAYBE_KERNEL_LOG(
196                 tfLiteContext,
197                 "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
198                 axis, nodeIndex);
199     }
200     const unsigned int splitDim = ComputeWrappedIndex(axisTensorData[0], inputTensorInfo.GetNumDimensions());
201 
202     auto* splitVParameters = reinterpret_cast<TfLiteSplitVParams*>(tfLiteNode->builtin_data);
203     unsigned int numSplits = 0;
204     if (splitVParameters)
205     {
206         numSplits = NonNegative(splitVParameters->num_splits, nodeIndex);
207     }
208     else
209     {
210         numSplits = splitsTensorInfo.GetNumElements();
211     }
212 
213     if (numSplits <= 0)
214     {
215         TF_LITE_MAYBE_KERNEL_LOG(
216             tfLiteContext, "TfLiteArmnnDelegate: Invalid number of splits %d  in node #%d",
217             numSplits, nodeIndex);
218         return kTfLiteError;
219     }
220 
221     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, numSplits, nodeIndex));
222     std::vector<armnn::TensorInfo> outputs;
223     for (unsigned int i = 0; i < numSplits; ++i)
224     {
225         const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
226         if (!IsValid(tfLiteContext, tfLiteOutputTensor, tfLiteSplitVOperatorCode, nodeIndex))
227         {
228             return kTfLiteError;
229         }
230         outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
231     }
232     const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
233 
234     auto inputDimSize = inputTensorInfo.GetNumDimensions();
235     if (inputDimSize > MaxNumOfTensorDimensions)
236     {
237         TF_LITE_MAYBE_KERNEL_LOG(
238             tfLiteContext,
239             "TfLiteArmnnDelegate: The number of dimensions: #%d for input tensors of the split op cannot be greater "
240             "than #%d in node #%d: ", inputDimSize, MaxNumOfTensorDimensions, nodeIndex);
241         return kTfLiteError;
242     }
243 
244     std::vector<int32_t> splitsTensorData(numSplits);
245     std::memcpy(splitsTensorData.data(), tfLiteSplitsTensor.data.data, splitsTensorInfo.GetNumBytes());
246 
247 
248     unsigned int index         = 0;
249     unsigned int inferredIndex = 0;
250     int numberOfInferred       = 0;
251     int splitSum = 0;
252 
253     for (auto splitData : splitsTensorData)
254     {
255         if (splitData < 0)
256         {
257             ++numberOfInferred;
258             inferredIndex = index;
259         }
260         else
261         {
262             splitSum += splitData;
263         }
264         ++index;
265     }
266 
267     // Check for inferred axis
268     if (numberOfInferred == 0)
269     {
270         if (splitSum != armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
271         {
272             TF_LITE_MAYBE_KERNEL_LOG(
273                 tfLiteContext, "TfLiteArmnnDelegate: SplitV split_sizes does not sum to the dimension of value along"
274                                " split_dim in node #%d", nodeIndex);
275             return kTfLiteError;
276         }
277     }
278     else if (numberOfInferred == 1)
279     {
280         splitsTensorData[inferredIndex] = armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
281     }
282     else
283     {
284         TF_LITE_MAYBE_KERNEL_LOG(
285             tfLiteContext, "TfLiteArmnnDelegate: SplitV cannot infer split size for more than one split in node #%d",
286             nodeIndex);
287         return kTfLiteError;
288     }
289 
290     armnn::SplitterDescriptor splitDescriptor(numSplits, inputDimSize);
291     unsigned int accumSplit = 0;
292     for (unsigned int j = 0; j < numSplits; ++j)
293     {
294         unsigned int splitSize = armnn::numeric_cast<unsigned int>(splitsTensorData[j]);
295 
296         // Set the size of the views.
297         for (unsigned int dimIdx = 0; dimIdx < inputTensorInfo.GetNumDimensions(); ++dimIdx)
298         {
299             unsigned int dimSize = inputTensorInfo.GetShape()[dimIdx];
300             if (dimIdx == splitDim)
301             {
302                 dimSize = splitSize;
303             }
304             splitDescriptor.SetViewSize(j, dimIdx, dimSize);
305         }
306 
307         splitDescriptor.SetViewOriginCoord(j, splitDim, accumSplit);
308         accumSplit += splitSize;
309     }
310 
311     armnn::BackendId setBackend;
312     if (!delegateData.m_Network)
313     {
314         // Check if supported
315         bool isSupported = false;
316         FORWARD_LAYER_SUPPORT_FUNC("SPLIT",
317                                    tfLiteContext,
318                                    IsSplitterSupported,
319                                    delegateData.m_Backends,
320                                    isSupported,
321                                    setBackend,
322                                    inputTensorInfo,
323                                    outputTensorInfos,
324                                    splitDescriptor);
325         return isSupported ? kTfLiteOk : kTfLiteError;
326     }
327 
328     armnn::IConnectableLayer* layer = delegateData.m_Network->AddSplitterLayer(splitDescriptor);
329     layer->SetBackendId(setBackend);
330     ARMNN_ASSERT(layer != nullptr);
331 
332     for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
333     {
334         layer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
335     }
336 
337     // try to connect the Constant Inputs if there are any
338     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
339     {
340         return kTfLiteError;
341     }
342 
343     // Connect
344     return Connect(layer, tfLiteNode, delegateData);
345 }
346 
347 } // namespace armnnDelegate