1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <ClassicDelegateUtils.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
16*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker
VisitUnpackOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)21*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus VisitUnpackOperator(DelegateData& delegateData,
22*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
23*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
24*89c4ff92SAndroid Build Coastguard Worker int nodeIndex,
25*89c4ff92SAndroid Build Coastguard Worker int32_t operatorCode)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
30*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker
37*89c4ff92SAndroid Build Coastguard Worker // Get Unpack Axis
38*89c4ff92SAndroid Build Coastguard Worker const auto params = reinterpret_cast<TfLiteUnpackParams*>(tfLiteNode->builtin_data);
39*89c4ff92SAndroid Build Coastguard Worker
40*89c4ff92SAndroid Build Coastguard Worker const unsigned int unpackAxis = NonNegative(params->axis, nodeIndex);
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker if (unpackAxis >= inputTensorInfo.GetNumDimensions())
45*89c4ff92SAndroid Build Coastguard Worker {
46*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
47*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
48*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: The unpack axis #%d cannot be greater than or equal to "
49*89c4ff92SAndroid Build Coastguard Worker "the number of input dimensions #%d in operator #%d node #%d",
50*89c4ff92SAndroid Build Coastguard Worker unpackAxis, inputTensorInfo.GetNumDimensions(), operatorCode, nodeIndex);
51*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
52*89c4ff92SAndroid Build Coastguard Worker }
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker // Get Unpack Num
55*89c4ff92SAndroid Build Coastguard Worker unsigned int unpackNum = NonNegative(params->num, nodeIndex);
56*89c4ff92SAndroid Build Coastguard Worker
57*89c4ff92SAndroid Build Coastguard Worker // If num is not defined, automatically infer from the length of the dimension axis.
58*89c4ff92SAndroid Build Coastguard Worker if(unpackNum == 0)
59*89c4ff92SAndroid Build Coastguard Worker {
60*89c4ff92SAndroid Build Coastguard Worker unpackNum = inputTensorInfo.GetShape()[unpackAxis];
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker // If unpack number cannot be inferred and is still zero, return kTfLiteError.
64*89c4ff92SAndroid Build Coastguard Worker if(unpackNum == 0)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
67*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
68*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Number to unpack must greater than zero in operator #%d node #%d: ",
69*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
70*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
71*89c4ff92SAndroid Build Coastguard Worker }
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker // Check outputs
74*89c4ff92SAndroid Build Coastguard Worker TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, unpackNum, nodeIndex));
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker auto inputDimSize = inputTensorInfo.GetNumDimensions();
78*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> unpackDimSizes(inputDimSize);
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker // Add current input shape to unpackDimSizes
81*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < inputDimSize; ++i)
82*89c4ff92SAndroid Build Coastguard Worker {
83*89c4ff92SAndroid Build Coastguard Worker unpackDimSizes[i] = inputTensorInfo.GetShape()[i];
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Worker if (unpackDimSizes[unpackAxis] != unpackNum)
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker TF_LITE_MAYBE_KERNEL_LOG(
89*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
90*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Number to unpack must be the same as length "
91*89c4ff92SAndroid Build Coastguard Worker "of the dimension to unpack along in operator #%d node #%d: ",
92*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
93*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Worker unpackDimSizes[unpackAxis] /= unpackNum;
97*89c4ff92SAndroid Build Coastguard Worker
98*89c4ff92SAndroid Build Coastguard Worker armnn::SplitterDescriptor splitDesc(unpackNum, static_cast<unsigned int>(unpackDimSizes.size()));
99*89c4ff92SAndroid Build Coastguard Worker for (unsigned int j = 0; j < unpackNum; ++j)
100*89c4ff92SAndroid Build Coastguard Worker {
101*89c4ff92SAndroid Build Coastguard Worker // Set the size of the views.
102*89c4ff92SAndroid Build Coastguard Worker for (unsigned int dimIdx = 0; dimIdx < unpackDimSizes.size(); ++dimIdx)
103*89c4ff92SAndroid Build Coastguard Worker {
104*89c4ff92SAndroid Build Coastguard Worker splitDesc.SetViewSize(j, dimIdx, unpackDimSizes[dimIdx]);
105*89c4ff92SAndroid Build Coastguard Worker }
106*89c4ff92SAndroid Build Coastguard Worker splitDesc.SetViewOriginCoord(j, unpackAxis, unpackDimSizes[unpackAxis] * j);
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorInfo> outputs;
110*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < unpackNum; ++i)
111*89c4ff92SAndroid Build Coastguard Worker {
112*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[i]];
113*89c4ff92SAndroid Build Coastguard Worker if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
114*89c4ff92SAndroid Build Coastguard Worker {
115*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker outputs.push_back(GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true));
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker const std::vector<std::reference_wrapper<armnn::TensorInfo>> outputTensorInfos(outputs.begin(), outputs.end());
120*89c4ff92SAndroid Build Coastguard Worker
121*89c4ff92SAndroid Build Coastguard Worker // Determine the shape of the Splitter layer outputs for validation
122*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape splitOutShape = armnn::TensorShape(static_cast<unsigned int>(unpackDimSizes.size()),
123*89c4ff92SAndroid Build Coastguard Worker unpackDimSizes.data());
124*89c4ff92SAndroid Build Coastguard Worker
125*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorInfo> splitterOutputs;
126*89c4ff92SAndroid Build Coastguard Worker for (unsigned int outputIndex = 0; outputIndex < outputTensorInfos.size(); ++outputIndex)
127*89c4ff92SAndroid Build Coastguard Worker {
128*89c4ff92SAndroid Build Coastguard Worker splitterOutputs.push_back(armnn::TensorInfo(splitOutShape,
129*89c4ff92SAndroid Build Coastguard Worker outputTensorInfos[outputIndex].get().GetDataType(),
130*89c4ff92SAndroid Build Coastguard Worker outputTensorInfos[outputIndex].get().GetQuantizationScale(),
131*89c4ff92SAndroid Build Coastguard Worker outputTensorInfos[outputIndex].get().GetQuantizationOffset()));
132*89c4ff92SAndroid Build Coastguard Worker }
133*89c4ff92SAndroid Build Coastguard Worker std::vector<std::reference_wrapper<armnn::TensorInfo>> splitterOutputTensorInfos(splitterOutputs.begin(),
134*89c4ff92SAndroid Build Coastguard Worker splitterOutputs.end());
135*89c4ff92SAndroid Build Coastguard Worker
136*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackendSplit;
137*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker // Check if splitter is supported
140*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
141*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("UNPACK",
142*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
143*89c4ff92SAndroid Build Coastguard Worker IsSplitterSupported,
144*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
145*89c4ff92SAndroid Build Coastguard Worker isSupported,
146*89c4ff92SAndroid Build Coastguard Worker setBackendSplit,
147*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo,
148*89c4ff92SAndroid Build Coastguard Worker splitterOutputTensorInfos,
149*89c4ff92SAndroid Build Coastguard Worker splitDesc);
150*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker
153*89c4ff92SAndroid Build Coastguard Worker // Create Reshape descriptor from the first outputTensorInfo to validate a single Reshape layer
154*89c4ff92SAndroid Build Coastguard Worker // Use this descriptor later when creating every ReshapeLayer as all Reshape Layers should be the same
155*89c4ff92SAndroid Build Coastguard Worker armnn::ReshapeDescriptor reshapeDescriptor;
156*89c4ff92SAndroid Build Coastguard Worker reshapeDescriptor.m_TargetShape = outputTensorInfos[0].get().GetShape();
157*89c4ff92SAndroid Build Coastguard Worker
158*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackendReshape;
159*89c4ff92SAndroid Build Coastguard Worker if (!delegateData.m_Network)
160*89c4ff92SAndroid Build Coastguard Worker {
161*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
162*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC("RESHAPE",
163*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
164*89c4ff92SAndroid Build Coastguard Worker IsReshapeSupported,
165*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
166*89c4ff92SAndroid Build Coastguard Worker isSupported,
167*89c4ff92SAndroid Build Coastguard Worker setBackendReshape,
168*89c4ff92SAndroid Build Coastguard Worker splitterOutputTensorInfos[0],
169*89c4ff92SAndroid Build Coastguard Worker outputTensorInfos[0],
170*89c4ff92SAndroid Build Coastguard Worker reshapeDescriptor);
171*89c4ff92SAndroid Build Coastguard Worker return isSupported ? kTfLiteOk : kTfLiteError;
172*89c4ff92SAndroid Build Coastguard Worker };
173*89c4ff92SAndroid Build Coastguard Worker
174*89c4ff92SAndroid Build Coastguard Worker std::string splitterLayerName("Unpack Splitter");
175*89c4ff92SAndroid Build Coastguard Worker
176*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* splitterLayer = delegateData.m_Network->AddSplitterLayer(splitDesc,
177*89c4ff92SAndroid Build Coastguard Worker splitterLayerName.c_str());
178*89c4ff92SAndroid Build Coastguard Worker splitterLayer->SetBackendId(setBackendSplit);
179*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(splitterLayer != nullptr);
180*89c4ff92SAndroid Build Coastguard Worker
181*89c4ff92SAndroid Build Coastguard Worker for (unsigned int k = 0; k < splitterLayer->GetNumOutputSlots(); ++k)
182*89c4ff92SAndroid Build Coastguard Worker {
183*89c4ff92SAndroid Build Coastguard Worker splitterLayer->GetOutputSlot(k).SetTensorInfo(outputs[k]);
184*89c4ff92SAndroid Build Coastguard Worker }
185*89c4ff92SAndroid Build Coastguard Worker
186*89c4ff92SAndroid Build Coastguard Worker // Connect the input slots
187*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[0]]->Connect(splitterLayer->GetInputSlot(0));
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Worker // Create reshape to remove the unpacked dimension for unpack operator of each output from Splitter.
190*89c4ff92SAndroid Build Coastguard Worker for (unsigned int outputIndex = 0; outputIndex < splitterLayer->GetNumOutputSlots(); ++outputIndex)
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker std::string reshapeLayerName("Unpack Reshape");
193*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* reshapeLayer = delegateData.m_Network->AddReshapeLayer(reshapeDescriptor,
194*89c4ff92SAndroid Build Coastguard Worker reshapeLayerName.c_str());
195*89c4ff92SAndroid Build Coastguard Worker reshapeLayer->SetBackendId(setBackendReshape);
196*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(reshapeLayer != nullptr);
197*89c4ff92SAndroid Build Coastguard Worker
198*89c4ff92SAndroid Build Coastguard Worker splitterLayer->GetOutputSlot(outputIndex).SetTensorInfo(splitterOutputTensorInfos[outputIndex]);
199*89c4ff92SAndroid Build Coastguard Worker splitterLayer->GetOutputSlot(outputIndex).Connect(reshapeLayer->GetInputSlot(0));
200*89c4ff92SAndroid Build Coastguard Worker
201*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = outputTensorInfos[outputIndex];
202*89c4ff92SAndroid Build Coastguard Worker reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
203*89c4ff92SAndroid Build Coastguard Worker
204*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& slot = reshapeLayer->GetOutputSlot(0);
205*89c4ff92SAndroid Build Coastguard Worker
206*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode[
207*89c4ff92SAndroid Build Coastguard Worker static_cast<unsigned long>(tfLiteNode->outputs->data[outputIndex])] = &slot;
208*89c4ff92SAndroid Build Coastguard Worker
209*89c4ff92SAndroid Build Coastguard Worker }
210*89c4ff92SAndroid Build Coastguard Worker
211*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
212*89c4ff92SAndroid Build Coastguard Worker }
213*89c4ff92SAndroid Build Coastguard Worker
214*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
215