xref: /aosp_15_r20/external/armnn/delegate/classic/src/Unpack.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <ClassicDelegateUtils.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
16*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker 
VisitUnpackOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)21*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitUnpackOperator(DelegateData& delegateData,
22*89c4ff92SAndroid Build Coastguard Worker                                  TfLiteContext* tfLiteContext,
23*89c4ff92SAndroid Build Coastguard Worker                                  TfLiteNode* tfLiteNode,
24*89c4ff92SAndroid Build Coastguard Worker                                  int nodeIndex,
25*89c4ff92SAndroid Build Coastguard Worker                                  int32_t operatorCode)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
30*89c4ff92SAndroid Build Coastguard Worker     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
33*89c4ff92SAndroid Build Coastguard Worker     {
34*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
35*89c4ff92SAndroid Build Coastguard Worker     }
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker     // Get Unpack Axis
38*89c4ff92SAndroid Build Coastguard Worker     const auto params = reinterpret_cast<TfLiteUnpackParams*>(tfLiteNode->builtin_data);
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker     const unsigned int unpackAxis = NonNegative(params->axis, nodeIndex);
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     if (unpackAxis >= inputTensorInfo.GetNumDimensions())
45*89c4ff92SAndroid Build Coastguard Worker     {
46*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
47*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
48*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: The unpack axis #%d cannot be greater than or equal to "
49*89c4ff92SAndroid Build Coastguard Worker             "the number of input dimensions #%d in operator #%d node #%d",
50*89c4ff92SAndroid Build Coastguard Worker             unpackAxis, inputTensorInfo.GetNumDimensions(), operatorCode, nodeIndex);
51*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
52*89c4ff92SAndroid Build Coastguard Worker     }
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     // Get Unpack Num
55*89c4ff92SAndroid Build Coastguard Worker     unsigned int unpackNum = NonNegative(params->num, nodeIndex);
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     // If num is not defined, automatically infer from the length of the dimension axis.
58*89c4ff92SAndroid Build Coastguard Worker     if(unpackNum == 0)
59*89c4ff92SAndroid Build Coastguard Worker     {
60*89c4ff92SAndroid Build Coastguard Worker         unpackNum = inputTensorInfo.GetShape()[unpackAxis];
61*89c4ff92SAndroid Build Coastguard Worker     }
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     // If unpack number cannot be inferred and is still zero, return kTfLiteError.
64*89c4ff92SAndroid Build Coastguard Worker     if(unpackNum == 0)
65*89c4ff92SAndroid Build Coastguard Worker     {
66*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
67*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
68*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Number to unpack must greater than zero in operator #%d node #%d: ",
69*89c4ff92SAndroid Build Coastguard Worker             operatorCode, nodeIndex);
70*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
71*89c4ff92SAndroid Build Coastguard Worker     }
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker     // Check outputs
74*89c4ff92SAndroid Build Coastguard Worker     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, unpackNum, nodeIndex));
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     auto inputDimSize = inputTensorInfo.GetNumDimensions();
78*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> unpackDimSizes(inputDimSize);
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     // Add current input shape to unpackDimSizes
81*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < inputDimSize; ++i)
82*89c4ff92SAndroid Build Coastguard Worker     {
83*89c4ff92SAndroid Build Coastguard Worker         unpackDimSizes[i] = inputTensorInfo.GetShape()[i];
84*89c4ff92SAndroid Build Coastguard Worker     }
85*89c4ff92SAndroid Build Coastguard Worker 
86*89c4ff92SAndroid Build Coastguard Worker     if (unpackDimSizes[unpackAxis] != unpackNum)
87*89c4ff92SAndroid Build Coastguard Worker     {
88*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_MAYBE_KERNEL_LOG(
89*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext,
90*89c4ff92SAndroid Build Coastguard Worker             "TfLiteArmnnDelegate: Number to unpack must be the same as length "
91*89c4ff92SAndroid Build Coastguard Worker             "of the dimension to unpack along in operator #%d node #%d: ",
92*89c4ff92SAndroid Build Coastguard Worker             operatorCode, nodeIndex);
93*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
94*89c4ff92SAndroid Build Coastguard Worker     }
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     unpackDimSizes[unpackAxis] /= unpackNum;
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker     armnn::SplitterDescriptor splitDesc(unpackNum, static_cast<unsigned int>(unpackDimSizes.size()));
99*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int j = 0; j < unpackNum; ++j)
100*89c4ff92SAndroid Build Coastguard Worker     {
101*89c4ff92SAndroid Build Coastguard Worker         // Set the size of the views.
102*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int dimIdx = 0; dimIdx < unpackDimSizes.size(); ++dimIdx)
103*89c4ff92SAndroid Build Coastguard Worker         {
104*89c4ff92SAndroid Build Coastguard Worker             splitDesc.SetViewSize(j, dimIdx, unpackDimSizes[dimIdx]);
105*89c4ff92SAndroid Build Coastguard Worker         }
106*89c4ff92SAndroid Build Coastguard Worker         splitDesc.SetViewOriginCoord(j, unpackAxis, unpackDimSizes[unpackAxis] * j);
107*89c4ff92SAndroid Build Coastguard Worker     }
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorInfo> outputs;
110*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < unpackNum; ++i)
111*89c4ff92SAndroid Build Coastguard Worker     {
112*89c4ff92SAndroid Build Coastguard Worker         const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
113*89c4ff92SAndroid Build Coastguard Worker         if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
114*89c4ff92SAndroid Build Coastguard Worker         {
115*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
116*89c4ff92SAndroid Build Coastguard Worker         }
117*89c4ff92SAndroid Build Coastguard Worker         outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
118*89c4ff92SAndroid Build Coastguard Worker     }
119*89c4ff92SAndroid Build Coastguard Worker     const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     // Determine the shape of the Splitter layer outputs for validation
122*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape splitOutShape = armnn::TensorShape(static_cast<unsigned int>(unpackDimSizes.size()),
123*89c4ff92SAndroid Build Coastguard Worker                                                           unpackDimSizes.data());
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorInfo> splitterOutputs;
126*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int outputIndex = 0; outputIndex < outputTensorInfos.size(); ++outputIndex)
127*89c4ff92SAndroid Build Coastguard Worker     {
128*89c4ff92SAndroid Build Coastguard Worker         splitterOutputs.push_back(armnn::TensorInfo(splitOutShape,
129*89c4ff92SAndroid Build Coastguard Worker                                                     outputTensorInfos[outputIndex].get().GetDataType(),
130*89c4ff92SAndroid Build Coastguard Worker                                                     outputTensorInfos[outputIndex].get().GetQuantizationScale(),
131*89c4ff92SAndroid Build Coastguard Worker                                                     outputTensorInfos[outputIndex].get().GetQuantizationOffset()));
132*89c4ff92SAndroid Build Coastguard Worker     }
133*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::reference_wrapper<armnn::TensorInfo>> splitterOutputTensorInfos(splitterOutputs.begin(),
134*89c4ff92SAndroid Build Coastguard Worker                                                                                      splitterOutputs.end());
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackendSplit;
137*89c4ff92SAndroid Build Coastguard Worker     if (!delegateData.m_Network)
138*89c4ff92SAndroid Build Coastguard Worker     {
139*89c4ff92SAndroid Build Coastguard Worker         // Check if splitter is supported
140*89c4ff92SAndroid Build Coastguard Worker         bool isSupported = false;
141*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("UNPACK",
142*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
143*89c4ff92SAndroid Build Coastguard Worker                                    IsSplitterSupported,
144*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
145*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
146*89c4ff92SAndroid Build Coastguard Worker                                    setBackendSplit,
147*89c4ff92SAndroid Build Coastguard Worker                                    inputTensorInfo,
148*89c4ff92SAndroid Build Coastguard Worker                                    splitterOutputTensorInfos,
149*89c4ff92SAndroid Build Coastguard Worker                                    splitDesc);
150*89c4ff92SAndroid Build Coastguard Worker         return isSupported ? kTfLiteOk : kTfLiteError;
151*89c4ff92SAndroid Build Coastguard Worker     }
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker     // Create Reshape descriptor from the first outputTensorInfo to validate a single Reshape layer
154*89c4ff92SAndroid Build Coastguard Worker     // Use this descriptor later when creating every ReshapeLayer as all Reshape Layers should be the same
155*89c4ff92SAndroid Build Coastguard Worker     armnn::ReshapeDescriptor reshapeDescriptor;
156*89c4ff92SAndroid Build Coastguard Worker     reshapeDescriptor.m_TargetShape = outputTensorInfos[0].get().GetShape();
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackendReshape;
159*89c4ff92SAndroid Build Coastguard Worker     if (!delegateData.m_Network)
160*89c4ff92SAndroid Build Coastguard Worker     {
161*89c4ff92SAndroid Build Coastguard Worker         bool isSupported = false;
162*89c4ff92SAndroid Build Coastguard Worker         FORWARD_LAYER_SUPPORT_FUNC("RESHAPE",
163*89c4ff92SAndroid Build Coastguard Worker                                    tfLiteContext,
164*89c4ff92SAndroid Build Coastguard Worker                                    IsReshapeSupported,
165*89c4ff92SAndroid Build Coastguard Worker                                    delegateData.m_Backends,
166*89c4ff92SAndroid Build Coastguard Worker                                    isSupported,
167*89c4ff92SAndroid Build Coastguard Worker                                    setBackendReshape,
168*89c4ff92SAndroid Build Coastguard Worker                                    splitterOutputTensorInfos[0],
169*89c4ff92SAndroid Build Coastguard Worker                                    outputTensorInfos[0],
170*89c4ff92SAndroid Build Coastguard Worker                                    reshapeDescriptor);
171*89c4ff92SAndroid Build Coastguard Worker         return isSupported ? kTfLiteOk : kTfLiteError;
172*89c4ff92SAndroid Build Coastguard Worker     };
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker     std::string splitterLayerName("Unpack Splitter");
175*89c4ff92SAndroid Build Coastguard Worker 
176*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* splitterLayer = delegateData.m_Network->AddSplitterLayer(splitDesc,
177*89c4ff92SAndroid Build Coastguard Worker                                                                                        splitterLayerName.c_str());
178*89c4ff92SAndroid Build Coastguard Worker     splitterLayer->SetBackendId(setBackendSplit);
179*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(splitterLayer != nullptr);
180*89c4ff92SAndroid Build Coastguard Worker 
181*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int k = 0; k < splitterLayer->GetNumOutputSlots(); ++k)
182*89c4ff92SAndroid Build Coastguard Worker     {
183*89c4ff92SAndroid Build Coastguard Worker         splitterLayer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
184*89c4ff92SAndroid Build Coastguard Worker     }
185*89c4ff92SAndroid Build Coastguard Worker 
186*89c4ff92SAndroid Build Coastguard Worker     // Connect the input slots
187*89c4ff92SAndroid Build Coastguard Worker     delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(splitterLayer->GetInputSlot(0));
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker     // Create reshape to remove the unpacked dimension for unpack operator of each output from Splitter.
190*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int outputIndex = 0; outputIndex < splitterLayer->GetNumOutputSlots(); ++outputIndex)
191*89c4ff92SAndroid Build Coastguard Worker     {
192*89c4ff92SAndroid Build Coastguard Worker         std::string reshapeLayerName("Unpack Reshape");
193*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor,
194*89c4ff92SAndroid Build Coastguard Worker                                                                                          reshapeLayerName.c_str());
195*89c4ff92SAndroid Build Coastguard Worker         reshapeLayer->SetBackendId(setBackendReshape);
196*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(reshapeLayer != nullptr);
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker         splitterLayer->GetOutputSlot(outputIndex).SetTensorInfo(splitterOutputTensorInfos[outputIndex]);
199*89c4ff92SAndroid Build Coastguard Worker         splitterLayer->GetOutputSlot(outputIndex).Connect(reshapeLayer->GetInputSlot(0));
200*89c4ff92SAndroid Build Coastguard Worker 
201*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo outputTensorInfo  = outputTensorInfos[outputIndex];
202*89c4ff92SAndroid Build Coastguard Worker         reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
203*89c4ff92SAndroid Build Coastguard Worker 
204*89c4ff92SAndroid Build Coastguard Worker         armnn::IOutputSlot& slot = reshapeLayer->GetOutputSlot(0);
205*89c4ff92SAndroid Build Coastguard Worker 
206*89c4ff92SAndroid Build Coastguard Worker         delegateData.m_OutputSlotForNode[
207*89c4ff92SAndroid Build Coastguard Worker             static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &slot;
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker     }
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker     return kTfLiteOk;
212*89c4ff92SAndroid Build Coastguard Worker }
213*89c4ff92SAndroid Build Coastguard Worker 
214*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
215