xref: /aosp_15_r20/external/armnn/delegate/opaque/src/OpaqueDelegateUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <DelegateUtils.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendHelper.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TensorUtils.hpp>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
20*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
21*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
22*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/c_api_opaque.h>
23*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
24*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/kernel_util.h>
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker namespace
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker // Macro to call an Is<layer_name>Supported function and log caller name together with reason for lack of support
30*89c4ff92SAndroid Build Coastguard Worker #define FORWARD_LAYER_OPAQUE_SUPPORT_FUNC(opName, tfLiteContext, func, backends, supported, setBackend, ...) \
31*89c4ff92SAndroid Build Coastguard Worker try \
32*89c4ff92SAndroid Build Coastguard Worker { \
33*89c4ff92SAndroid Build Coastguard Worker     for (auto&& backendId : backends) \
34*89c4ff92SAndroid Build Coastguard Worker     { \
35*89c4ff92SAndroid Build Coastguard Worker         auto layerSupportObject = armnn::GetILayerSupportByBackendId(backendId); \
36*89c4ff92SAndroid Build Coastguard Worker         if (layerSupportObject.IsBackendRegistered()) \
37*89c4ff92SAndroid Build Coastguard Worker         { \
38*89c4ff92SAndroid Build Coastguard Worker             std::string reasonIfUnsupported; \
39*89c4ff92SAndroid Build Coastguard Worker             supported = \
40*89c4ff92SAndroid Build Coastguard Worker                 layerSupportObject.func(__VA_ARGS__, armnn::Optional<std::string&>(reasonIfUnsupported)); \
41*89c4ff92SAndroid Build Coastguard Worker             if (supported) \
42*89c4ff92SAndroid Build Coastguard Worker             { \
43*89c4ff92SAndroid Build Coastguard Worker                 setBackend = backendId; \
44*89c4ff92SAndroid Build Coastguard Worker                 break; \
45*89c4ff92SAndroid Build Coastguard Worker             } \
46*89c4ff92SAndroid Build Coastguard Worker             else \
47*89c4ff92SAndroid Build Coastguard Worker             { \
48*89c4ff92SAndroid Build Coastguard Worker                 if (reasonIfUnsupported.size() > 0) \
49*89c4ff92SAndroid Build Coastguard Worker                 { \
50*89c4ff92SAndroid Build Coastguard Worker                     TFLITE_LOG_PROD(tflite::TFLITE_LOG_WARNING, \
51*89c4ff92SAndroid Build Coastguard Worker                                     "%s: not supported by armnn: %s", opName, reasonIfUnsupported.c_str()); \
52*89c4ff92SAndroid Build Coastguard Worker                 } \
53*89c4ff92SAndroid Build Coastguard Worker                 else \
54*89c4ff92SAndroid Build Coastguard Worker                 { \
55*89c4ff92SAndroid Build Coastguard Worker                     TFLITE_LOG_PROD(tflite::TFLITE_LOG_WARNING, \
56*89c4ff92SAndroid Build Coastguard Worker                                     "%s: not supported by armnn", opName); \
57*89c4ff92SAndroid Build Coastguard Worker                 } \
58*89c4ff92SAndroid Build Coastguard Worker             } \
59*89c4ff92SAndroid Build Coastguard Worker         } \
60*89c4ff92SAndroid Build Coastguard Worker         else \
61*89c4ff92SAndroid Build Coastguard Worker         { \
62*89c4ff92SAndroid Build Coastguard Worker             TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext, "%s: backend not registered: %s", \
63*89c4ff92SAndroid Build Coastguard Worker                                       opName, backendId.Get().c_str()); \
64*89c4ff92SAndroid Build Coastguard Worker         } \
65*89c4ff92SAndroid Build Coastguard Worker     } \
66*89c4ff92SAndroid Build Coastguard Worker     if (!supported) \
67*89c4ff92SAndroid Build Coastguard Worker     { \
68*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext, "%s: not supported by any specified backend", opName); \
69*89c4ff92SAndroid Build Coastguard Worker     } \
70*89c4ff92SAndroid Build Coastguard Worker } \
71*89c4ff92SAndroid Build Coastguard Worker catch (const armnn::InvalidArgumentException &e) \
72*89c4ff92SAndroid Build Coastguard Worker { \
73*89c4ff92SAndroid Build Coastguard Worker     throw armnn::InvalidArgumentException(e, "Failed to check layer support", CHECK_LOCATION()); \
74*89c4ff92SAndroid Build Coastguard Worker }
75*89c4ff92SAndroid Build Coastguard Worker 
ValidateNumInputs(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,const unsigned int expectedSize,int nodeIndex)76*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateNumInputs(TfLiteOpaqueContext* tfLiteContext,
77*89c4ff92SAndroid Build Coastguard Worker                                TfLiteOpaqueNode* tfLiteNode,
78*89c4ff92SAndroid Build Coastguard Worker                                const unsigned int expectedSize,
79*89c4ff92SAndroid Build Coastguard Worker                                int nodeIndex)
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker     int numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
82*89c4ff92SAndroid Build Coastguard Worker     if (static_cast<unsigned int>(numInputs) != expectedSize)
83*89c4ff92SAndroid Build Coastguard Worker     {
84*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
85*89c4ff92SAndroid Build Coastguard Worker                 tfLiteContext, "TfLiteArmnnOpaqueDelegate: Unexpected number of inputs (%d != %d) in node #%d",
86*89c4ff92SAndroid Build Coastguard Worker                 numInputs, expectedSize, nodeIndex);
87*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
88*89c4ff92SAndroid Build Coastguard Worker     }
89*89c4ff92SAndroid Build Coastguard Worker     return kTfLiteOk;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker 
ValidateNumOutputs(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,const unsigned int expectedSize,int nodeIndex)92*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateNumOutputs(TfLiteOpaqueContext* tfLiteContext,
93*89c4ff92SAndroid Build Coastguard Worker                                 TfLiteOpaqueNode* tfLiteNode,
94*89c4ff92SAndroid Build Coastguard Worker                                 const unsigned int expectedSize,
95*89c4ff92SAndroid Build Coastguard Worker                                 int nodeIndex)
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker     auto numOutputs = TfLiteOpaqueNodeNumberOfOutputs(tfLiteNode);
98*89c4ff92SAndroid Build Coastguard Worker     if (static_cast<unsigned int>(numOutputs) != expectedSize)
99*89c4ff92SAndroid Build Coastguard Worker     {
100*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
101*89c4ff92SAndroid Build Coastguard Worker                 tfLiteContext, "TfLiteArmnnOpaqueDelegate: Unexpected number of outputs (%d != %d) in node #%d",
102*89c4ff92SAndroid Build Coastguard Worker                 numOutputs, expectedSize, nodeIndex);
103*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
104*89c4ff92SAndroid Build Coastguard Worker     }
105*89c4ff92SAndroid Build Coastguard Worker     return kTfLiteOk;
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker 
IsConstantTensor(const TfLiteOpaqueTensor * tfLiteTensor)108*89c4ff92SAndroid Build Coastguard Worker bool IsConstantTensor(const TfLiteOpaqueTensor* tfLiteTensor)
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker     auto tensorAllocationType = TfLiteOpaqueTensorGetAllocationType(tfLiteTensor);
111*89c4ff92SAndroid Build Coastguard Worker     if (tensorAllocationType == kTfLiteMmapRo)
112*89c4ff92SAndroid Build Coastguard Worker     {
113*89c4ff92SAndroid Build Coastguard Worker         return true;
114*89c4ff92SAndroid Build Coastguard Worker     }
115*89c4ff92SAndroid Build Coastguard Worker     return false;
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker 
IsDynamicTensor(const TfLiteOpaqueTensor * tfLiteTensor)118*89c4ff92SAndroid Build Coastguard Worker bool IsDynamicTensor(const TfLiteOpaqueTensor* tfLiteTensor)
119*89c4ff92SAndroid Build Coastguard Worker {
120*89c4ff92SAndroid Build Coastguard Worker     auto tensorAllocationType = TfLiteOpaqueTensorGetAllocationType(tfLiteTensor);
121*89c4ff92SAndroid Build Coastguard Worker     if (tensorAllocationType == kTfLiteDynamic)
122*89c4ff92SAndroid Build Coastguard Worker     {
123*89c4ff92SAndroid Build Coastguard Worker         return true;
124*89c4ff92SAndroid Build Coastguard Worker     }
125*89c4ff92SAndroid Build Coastguard Worker     return false;
126*89c4ff92SAndroid Build Coastguard Worker }
127*89c4ff92SAndroid Build Coastguard Worker 
IsValid(const TfLiteOpaqueTensor * tfLiteTensor)128*89c4ff92SAndroid Build Coastguard Worker bool IsValid(const TfLiteOpaqueTensor* tfLiteTensor)
129*89c4ff92SAndroid Build Coastguard Worker {
130*89c4ff92SAndroid Build Coastguard Worker     return tfLiteTensor == nullptr ? false : true;
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker 
IsValid(TfLiteOpaqueContext * tfLiteContext,const TfLiteOpaqueTensor * tfLiteTensor,int32_t operatorCode,int32_t nodeIndex)133*89c4ff92SAndroid Build Coastguard Worker bool IsValid(TfLiteOpaqueContext* tfLiteContext,
134*89c4ff92SAndroid Build Coastguard Worker              const TfLiteOpaqueTensor* tfLiteTensor,
135*89c4ff92SAndroid Build Coastguard Worker              int32_t operatorCode,
136*89c4ff92SAndroid Build Coastguard Worker              int32_t nodeIndex)
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker     if(!IsValid(tfLiteTensor))
139*89c4ff92SAndroid Build Coastguard Worker     {
140*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
141*89c4ff92SAndroid Build Coastguard Worker                 tfLiteContext,
142*89c4ff92SAndroid Build Coastguard Worker                 "TfLiteArmnnDelegate: Invalid TfLite tensor in operator #%d node #%d: ",
143*89c4ff92SAndroid Build Coastguard Worker                 operatorCode, nodeIndex);
144*89c4ff92SAndroid Build Coastguard Worker         return false;
145*89c4ff92SAndroid Build Coastguard Worker     }
146*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamicTensor(tfLiteTensor))
147*89c4ff92SAndroid Build Coastguard Worker     {
148*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
149*89c4ff92SAndroid Build Coastguard Worker                 tfLiteContext,
150*89c4ff92SAndroid Build Coastguard Worker                 "TfLiteArmnnDelegate: Dynamic tensors are not supported in operator #%d node #%d: ",
151*89c4ff92SAndroid Build Coastguard Worker                 operatorCode, nodeIndex);
152*89c4ff92SAndroid Build Coastguard Worker         return false;
153*89c4ff92SAndroid Build Coastguard Worker     }
154*89c4ff92SAndroid Build Coastguard Worker     return true;
155*89c4ff92SAndroid Build Coastguard Worker }
156*89c4ff92SAndroid Build Coastguard Worker 
IsAffineQuantization(const TfLiteOpaqueTensor & tfLiteTensor)157*89c4ff92SAndroid Build Coastguard Worker bool IsAffineQuantization(const TfLiteOpaqueTensor& tfLiteTensor)
158*89c4ff92SAndroid Build Coastguard Worker {
159*89c4ff92SAndroid Build Coastguard Worker     auto quantizationInfo = TfLiteOpaqueTensorGetQuantization(&tfLiteTensor);
160*89c4ff92SAndroid Build Coastguard Worker     if (quantizationInfo.type == kTfLiteAffineQuantization)
161*89c4ff92SAndroid Build Coastguard Worker     {
162*89c4ff92SAndroid Build Coastguard Worker         return true;
163*89c4ff92SAndroid Build Coastguard Worker     }
164*89c4ff92SAndroid Build Coastguard Worker     return false;
165*89c4ff92SAndroid Build Coastguard Worker }
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker // Connects the layer to the graph
Connect(armnn::IConnectableLayer * layer,TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,armnnOpaqueDelegate::DelegateData & data)168*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus Connect(armnn::IConnectableLayer* layer,
169*89c4ff92SAndroid Build Coastguard Worker                      TfLiteOpaqueContext* tfLiteContext,
170*89c4ff92SAndroid Build Coastguard Worker                      TfLiteOpaqueNode* tfLiteNode,
171*89c4ff92SAndroid Build Coastguard Worker                      armnnOpaqueDelegate::DelegateData& data)
172*89c4ff92SAndroid Build Coastguard Worker {
173*89c4ff92SAndroid Build Coastguard Worker     // Get array of input indices, inputIndexArray is set from the TfLiteOpaqueNodeInputs function
174*89c4ff92SAndroid Build Coastguard Worker     // This function turns inputIndexArray into an int array of indices. These indices point to the index of the
175*89c4ff92SAndroid Build Coastguard Worker     // tensors for each input slot in the node.
176*89c4ff92SAndroid Build Coastguard Worker     const int* inputIndexArray;
177*89c4ff92SAndroid Build Coastguard Worker     int numInputs;
178*89c4ff92SAndroid Build Coastguard Worker     if(TfLiteOpaqueNodeInputs(tfLiteNode, &inputIndexArray, &numInputs) != kTfLiteOk)
179*89c4ff92SAndroid Build Coastguard Worker     {
180*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
181*89c4ff92SAndroid Build Coastguard Worker     }
182*89c4ff92SAndroid Build Coastguard Worker     // numInputs is set from TfLiteOpaqueNodeInputs.
183*89c4ff92SAndroid Build Coastguard Worker     if(numInputs != static_cast<int>(layer->GetNumInputSlots()))
184*89c4ff92SAndroid Build Coastguard Worker     {
185*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << "Layer: " << layer->GetName() << ": Expected number of input slots does not match actual "
186*89c4ff92SAndroid Build Coastguard Worker                                                           "number of input slots.";
187*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
188*89c4ff92SAndroid Build Coastguard Worker     }
189*89c4ff92SAndroid Build Coastguard Worker     // Connect the input slots.
190*89c4ff92SAndroid Build Coastguard Worker     // For each input slot, get the index of the opaque tensor that was allocated for it.
191*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int inputIndex = 0; inputIndex < layer->GetNumInputSlots(); ++inputIndex)
192*89c4ff92SAndroid Build Coastguard Worker     {
193*89c4ff92SAndroid Build Coastguard Worker         if (data.m_OutputSlotForNode[inputIndexArray[inputIndex]] != nullptr)
194*89c4ff92SAndroid Build Coastguard Worker         {
195*89c4ff92SAndroid Build Coastguard Worker             data.m_OutputSlotForNode[inputIndexArray[inputIndex]]->Connect(layer->GetInputSlot(inputIndex));
196*89c4ff92SAndroid Build Coastguard Worker         }
197*89c4ff92SAndroid Build Coastguard Worker     }
198*89c4ff92SAndroid Build Coastguard Worker 
199*89c4ff92SAndroid Build Coastguard Worker     // Get array of output indices, outputIndexArray is set from the TfLiteOpaqueNodeOutputs function
200*89c4ff92SAndroid Build Coastguard Worker     // This function turns outputIndexArray into an int array of indices. These indices point to the tensors for
201*89c4ff92SAndroid Build Coastguard Worker     // each output slot in the node.
202*89c4ff92SAndroid Build Coastguard Worker     const int* outputIndexArray;
203*89c4ff92SAndroid Build Coastguard Worker     int numOutputs;
204*89c4ff92SAndroid Build Coastguard Worker     if(TfLiteOpaqueNodeOutputs(tfLiteNode, &outputIndexArray, &numOutputs) != kTfLiteOk)
205*89c4ff92SAndroid Build Coastguard Worker     {
206*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
207*89c4ff92SAndroid Build Coastguard Worker     }
208*89c4ff92SAndroid Build Coastguard Worker     // numOutputs is set from TfLiteOpaqueNodeOutputs.
209*89c4ff92SAndroid Build Coastguard Worker     if(numOutputs != static_cast<int>(layer->GetNumOutputSlots()))
210*89c4ff92SAndroid Build Coastguard Worker     {
211*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << "Layer: " << layer->GetName() << ": Expected number of output slots does not match actual "
212*89c4ff92SAndroid Build Coastguard Worker                                                              "number of output slots.";
213*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
214*89c4ff92SAndroid Build Coastguard Worker     }
215*89c4ff92SAndroid Build Coastguard Worker 
216*89c4ff92SAndroid Build Coastguard Worker     // Prepare output slots
217*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
218*89c4ff92SAndroid Build Coastguard Worker     {
219*89c4ff92SAndroid Build Coastguard Worker         armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
220*89c4ff92SAndroid Build Coastguard Worker         data.m_OutputSlotForNode[static_cast<unsigned long>(outputIndexArray[outputIndex])] = &outputSlot;
221*89c4ff92SAndroid Build Coastguard Worker     }
222*89c4ff92SAndroid Build Coastguard Worker 
223*89c4ff92SAndroid Build Coastguard Worker     return kTfLiteOk;
224*89c4ff92SAndroid Build Coastguard Worker }
225*89c4ff92SAndroid Build Coastguard Worker 
FusedActivation(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,TfLiteFusedActivation activationType,armnn::IConnectableLayer * prevLayer,unsigned int outputSlotIndex,armnnOpaqueDelegate::DelegateData & data)226*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus FusedActivation(TfLiteOpaqueContext* tfLiteContext,
227*89c4ff92SAndroid Build Coastguard Worker                              TfLiteOpaqueNode* tfLiteNode,
228*89c4ff92SAndroid Build Coastguard Worker                              TfLiteFusedActivation activationType,
229*89c4ff92SAndroid Build Coastguard Worker                              armnn::IConnectableLayer* prevLayer,
230*89c4ff92SAndroid Build Coastguard Worker                              unsigned int outputSlotIndex,
231*89c4ff92SAndroid Build Coastguard Worker                              armnnOpaqueDelegate::DelegateData& data)
232*89c4ff92SAndroid Build Coastguard Worker {
233*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& activationOutputInfo = prevLayer->GetOutputSlot(outputSlotIndex).GetTensorInfo();
234*89c4ff92SAndroid Build Coastguard Worker 
235*89c4ff92SAndroid Build Coastguard Worker     armnn::ActivationDescriptor activationDesc;
236*89c4ff92SAndroid Build Coastguard Worker 
237*89c4ff92SAndroid Build Coastguard Worker     switch (activationType)
238*89c4ff92SAndroid Build Coastguard Worker     {
239*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteActNone:
240*89c4ff92SAndroid Build Coastguard Worker         {
241*89c4ff92SAndroid Build Coastguard Worker             // No Activation
242*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteOk;
243*89c4ff92SAndroid Build Coastguard Worker         }
244*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteActRelu:
245*89c4ff92SAndroid Build Coastguard Worker         {
246*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_Function = armnn::ActivationFunction::ReLu;
247*89c4ff92SAndroid Build Coastguard Worker             break;
248*89c4ff92SAndroid Build Coastguard Worker         }
249*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteActReluN1To1:
250*89c4ff92SAndroid Build Coastguard Worker         {
251*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
252*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_A = 1.0f;
253*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_B = -1.0f;
254*89c4ff92SAndroid Build Coastguard Worker             break;
255*89c4ff92SAndroid Build Coastguard Worker         }
256*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteActRelu6:
257*89c4ff92SAndroid Build Coastguard Worker         {
258*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
259*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_A = 6.0f;
260*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_B = 0.0f;
261*89c4ff92SAndroid Build Coastguard Worker             break;
262*89c4ff92SAndroid Build Coastguard Worker         }
263*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteActSigmoid:
264*89c4ff92SAndroid Build Coastguard Worker         {
265*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_Function = armnn::ActivationFunction::Sigmoid;
266*89c4ff92SAndroid Build Coastguard Worker             break;
267*89c4ff92SAndroid Build Coastguard Worker         }
268*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteActTanh:
269*89c4ff92SAndroid Build Coastguard Worker         {
270*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_Function = armnn::ActivationFunction::TanH;
271*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_A = 1.0f;
272*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_B = 1.0f;
273*89c4ff92SAndroid Build Coastguard Worker             break;
274*89c4ff92SAndroid Build Coastguard Worker         }
275*89c4ff92SAndroid Build Coastguard Worker         default:
276*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
277*89c4ff92SAndroid Build Coastguard Worker     }
278*89c4ff92SAndroid Build Coastguard Worker 
279*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
280*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
281*89c4ff92SAndroid Build Coastguard Worker     FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("ACTIVATION",
282*89c4ff92SAndroid Build Coastguard Worker                                       tfLiteContext,
283*89c4ff92SAndroid Build Coastguard Worker                                       IsActivationSupported,
284*89c4ff92SAndroid Build Coastguard Worker                                       data.m_Backends,
285*89c4ff92SAndroid Build Coastguard Worker                                       isSupported,
286*89c4ff92SAndroid Build Coastguard Worker                                       setBackend,
287*89c4ff92SAndroid Build Coastguard Worker                                       activationOutputInfo,
288*89c4ff92SAndroid Build Coastguard Worker                                       activationOutputInfo,
289*89c4ff92SAndroid Build Coastguard Worker                                       activationDesc);
290*89c4ff92SAndroid Build Coastguard Worker     if (!isSupported)
291*89c4ff92SAndroid Build Coastguard Worker     {
292*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
293*89c4ff92SAndroid Build Coastguard Worker     }
294*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* activationLayer = data.m_Network->AddActivationLayer(activationDesc);
295*89c4ff92SAndroid Build Coastguard Worker     activationLayer->SetBackendId(setBackend);
296*89c4ff92SAndroid Build Coastguard Worker 
297*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(activationLayer != nullptr);
298*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).SetTensorInfo(activationOutputInfo);
299*89c4ff92SAndroid Build Coastguard Worker 
300*89c4ff92SAndroid Build Coastguard Worker     // Get array of output indices, outputIndexArray is set from the TfLiteOpaqueNodeOutputs function
301*89c4ff92SAndroid Build Coastguard Worker     // This function turns outputIndexArray into an int array of indices. These indices point to the tensors for
302*89c4ff92SAndroid Build Coastguard Worker     // each output slot in the node.
303*89c4ff92SAndroid Build Coastguard Worker     const int* outputIndexArray;
304*89c4ff92SAndroid Build Coastguard Worker     int numOutputs;
305*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus outputStatus = TfLiteOpaqueNodeOutputs(tfLiteNode, &outputIndexArray, &numOutputs);
306*89c4ff92SAndroid Build Coastguard Worker     if(outputStatus != kTfLiteOk)
307*89c4ff92SAndroid Build Coastguard Worker     {
308*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
309*89c4ff92SAndroid Build Coastguard Worker     }
310*89c4ff92SAndroid Build Coastguard Worker 
311*89c4ff92SAndroid Build Coastguard Worker     // Connect and prepare output slots
312*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int outputIndex = 0; outputIndex < activationLayer->GetNumOutputSlots(); ++outputIndex)
313*89c4ff92SAndroid Build Coastguard Worker     {
314*89c4ff92SAndroid Build Coastguard Worker         data.m_OutputSlotForNode[static_cast<unsigned long>(
315*89c4ff92SAndroid Build Coastguard Worker                 outputIndexArray[outputIndex])]->Connect(activationLayer->GetInputSlot(0));
316*89c4ff92SAndroid Build Coastguard Worker 
317*89c4ff92SAndroid Build Coastguard Worker         armnn::IOutputSlot& outputSlot = activationLayer->GetOutputSlot(outputIndex);
318*89c4ff92SAndroid Build Coastguard Worker         data.m_OutputSlotForNode[static_cast<unsigned long>(outputIndexArray[outputIndex])] = &outputSlot;
319*89c4ff92SAndroid Build Coastguard Worker     }
320*89c4ff92SAndroid Build Coastguard Worker     return kTfLiteOk;
321*89c4ff92SAndroid Build Coastguard Worker }
322*89c4ff92SAndroid Build Coastguard Worker 
AddReshapeLayer(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,armnn::IConnectableLayer * prevLayer,armnn::TensorInfo reshapedOutputTensorInfo,armnn::TensorInfo outputTensorInfo,armnnOpaqueDelegate::DelegateData & data)323*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* AddReshapeLayer(TfLiteOpaqueContext* tfLiteContext,
324*89c4ff92SAndroid Build Coastguard Worker                                           TfLiteOpaqueNode* tfLiteNode,
325*89c4ff92SAndroid Build Coastguard Worker                                           armnn::IConnectableLayer* prevLayer,
326*89c4ff92SAndroid Build Coastguard Worker                                           armnn::TensorInfo reshapedOutputTensorInfo,
327*89c4ff92SAndroid Build Coastguard Worker                                           armnn::TensorInfo outputTensorInfo,
328*89c4ff92SAndroid Build Coastguard Worker                                           armnnOpaqueDelegate::DelegateData& data)
329*89c4ff92SAndroid Build Coastguard Worker {
330*89c4ff92SAndroid Build Coastguard Worker     armnn::ReshapeDescriptor desc;
331*89c4ff92SAndroid Build Coastguard Worker     desc.m_TargetShape = outputTensorInfo.GetShape();
332*89c4ff92SAndroid Build Coastguard Worker 
333*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
334*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
335*89c4ff92SAndroid Build Coastguard Worker     FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("RESHAPE",
336*89c4ff92SAndroid Build Coastguard Worker                                       tfLiteContext,
337*89c4ff92SAndroid Build Coastguard Worker                                       IsReshapeSupported,
338*89c4ff92SAndroid Build Coastguard Worker                                       data.m_Backends,
339*89c4ff92SAndroid Build Coastguard Worker                                       isSupported,
340*89c4ff92SAndroid Build Coastguard Worker                                       setBackend,
341*89c4ff92SAndroid Build Coastguard Worker                                       reshapedOutputTensorInfo,
342*89c4ff92SAndroid Build Coastguard Worker                                       outputTensorInfo,
343*89c4ff92SAndroid Build Coastguard Worker                                       desc);
344*89c4ff92SAndroid Build Coastguard Worker 
345*89c4ff92SAndroid Build Coastguard Worker     if (!isSupported)
346*89c4ff92SAndroid Build Coastguard Worker     {
347*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
348*89c4ff92SAndroid Build Coastguard Worker     }
349*89c4ff92SAndroid Build Coastguard Worker 
350*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* reshapeLayer = data.m_Network->AddReshapeLayer(desc);
351*89c4ff92SAndroid Build Coastguard Worker     reshapeLayer->SetBackendId(setBackend);
352*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(reshapeLayer != nullptr);
353*89c4ff92SAndroid Build Coastguard Worker 
354*89c4ff92SAndroid Build Coastguard Worker     prevLayer->GetOutputSlot(0).SetTensorInfo(reshapedOutputTensorInfo);
355*89c4ff92SAndroid Build Coastguard Worker     reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
356*89c4ff92SAndroid Build Coastguard Worker 
357*89c4ff92SAndroid Build Coastguard Worker     // Gather array of indices and it's length, replaces node->outputs->data[i]
358*89c4ff92SAndroid Build Coastguard Worker     const int* outputIndices = nullptr;
359*89c4ff92SAndroid Build Coastguard Worker     int numOutputs = 0;
360*89c4ff92SAndroid Build Coastguard Worker 
361*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus status = TfLiteOpaqueNodeOutputs(tfLiteNode, &outputIndices, &numOutputs);
362*89c4ff92SAndroid Build Coastguard Worker     if(status != kTfLiteOk)
363*89c4ff92SAndroid Build Coastguard Worker     {
364*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to gather output information from node.");
365*89c4ff92SAndroid Build Coastguard Worker     }
366*89c4ff92SAndroid Build Coastguard Worker 
367*89c4ff92SAndroid Build Coastguard Worker     if (static_cast<unsigned int>(numOutputs) != reshapeLayer->GetNumOutputSlots())
368*89c4ff92SAndroid Build Coastguard Worker     {
369*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unexpected number of outputs (" +
370*89c4ff92SAndroid Build Coastguard Worker                                std::to_string(numOutputs) +
371*89c4ff92SAndroid Build Coastguard Worker                                "!= " +
372*89c4ff92SAndroid Build Coastguard Worker                                std::to_string(reshapeLayer->GetNumOutputSlots()) +
373*89c4ff92SAndroid Build Coastguard Worker                                ") in node.");
374*89c4ff92SAndroid Build Coastguard Worker     }
375*89c4ff92SAndroid Build Coastguard Worker 
376*89c4ff92SAndroid Build Coastguard Worker     // Connect and prepare output slots
377*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int outputIndex = 0; outputIndex < reshapeLayer->GetNumOutputSlots(); ++outputIndex)
378*89c4ff92SAndroid Build Coastguard Worker     {
379*89c4ff92SAndroid Build Coastguard Worker         data.m_OutputSlotForNode[static_cast<unsigned long>(
380*89c4ff92SAndroid Build Coastguard Worker                 outputIndices[outputIndex])]->Connect(reshapeLayer->GetInputSlot(0));
381*89c4ff92SAndroid Build Coastguard Worker 
382*89c4ff92SAndroid Build Coastguard Worker         armnn::IOutputSlot& outputSlot = reshapeLayer->GetOutputSlot(outputIndex);
383*89c4ff92SAndroid Build Coastguard Worker         data.m_OutputSlotForNode[static_cast<unsigned long>(outputIndices[outputIndex])] = &outputSlot;
384*89c4ff92SAndroid Build Coastguard Worker     }
385*89c4ff92SAndroid Build Coastguard Worker     return reshapeLayer;
386*89c4ff92SAndroid Build Coastguard Worker }
387*89c4ff92SAndroid Build Coastguard Worker 
GetDataType(const TfLiteOpaqueTensor * tfLiteTensor)388*89c4ff92SAndroid Build Coastguard Worker armnn::DataType GetDataType(const TfLiteOpaqueTensor* tfLiteTensor)
389*89c4ff92SAndroid Build Coastguard Worker {
390*89c4ff92SAndroid Build Coastguard Worker     switch (TfLiteOpaqueTensorType(tfLiteTensor))
391*89c4ff92SAndroid Build Coastguard Worker     {
392*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBool:
393*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Boolean;
394*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteFloat32:
395*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Float32;
396*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteFloat16:
397*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Float16;
398*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteUInt8:
399*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::QAsymmU8;
400*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteInt8:
401*89c4ff92SAndroid Build Coastguard Worker         {
402*89c4ff92SAndroid Build Coastguard Worker             auto quantizationInfo = TfLiteOpaqueTensorGetQuantization(tfLiteTensor);
403*89c4ff92SAndroid Build Coastguard Worker             if (quantizationInfo.type == kTfLiteAffineQuantization)
404*89c4ff92SAndroid Build Coastguard Worker             {
405*89c4ff92SAndroid Build Coastguard Worker                 auto* quantization =
406*89c4ff92SAndroid Build Coastguard Worker                         reinterpret_cast<TfLiteAffineQuantization*>(quantizationInfo.params);
407*89c4ff92SAndroid Build Coastguard Worker 
408*89c4ff92SAndroid Build Coastguard Worker                 if (quantization->zero_point != nullptr && quantization->zero_point->size == 1)
409*89c4ff92SAndroid Build Coastguard Worker                 {
410*89c4ff92SAndroid Build Coastguard Worker                     return armnn::DataType::QAsymmS8;
411*89c4ff92SAndroid Build Coastguard Worker                 }
412*89c4ff92SAndroid Build Coastguard Worker                 else
413*89c4ff92SAndroid Build Coastguard Worker                 {
414*89c4ff92SAndroid Build Coastguard Worker                     return armnn::DataType::QSymmS8;
415*89c4ff92SAndroid Build Coastguard Worker                 }
416*89c4ff92SAndroid Build Coastguard Worker             }
417*89c4ff92SAndroid Build Coastguard Worker             else
418*89c4ff92SAndroid Build Coastguard Worker             {
419*89c4ff92SAndroid Build Coastguard Worker                 return armnn::DataType::QAsymmS8;
420*89c4ff92SAndroid Build Coastguard Worker             }
421*89c4ff92SAndroid Build Coastguard Worker         }
422*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteInt16:
423*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::QSymmS16;
424*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteInt32:
425*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Signed32;
426*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteInt64:
427*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Signed64;
428*89c4ff92SAndroid Build Coastguard Worker         default:
429*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception(
430*89c4ff92SAndroid Build Coastguard Worker                     &"TfLiteArmnnDelegate: Unsupported data type: " [ TfLiteOpaqueTensorType(tfLiteTensor) ]);
431*89c4ff92SAndroid Build Coastguard Worker     }
432*89c4ff92SAndroid Build Coastguard Worker }
433*89c4ff92SAndroid Build Coastguard Worker 
GetTensorInfoForTfLiteOpaqueTensor(const TfLiteOpaqueTensor * tfLiteTensor,bool isOutput=false)434*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo GetTensorInfoForTfLiteOpaqueTensor(const TfLiteOpaqueTensor* tfLiteTensor, bool isOutput = false)
435*89c4ff92SAndroid Build Coastguard Worker {
436*89c4ff92SAndroid Build Coastguard Worker     armnn::DataType type = GetDataType(tfLiteTensor);
437*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo ret;
438*89c4ff92SAndroid Build Coastguard Worker 
439*89c4ff92SAndroid Build Coastguard Worker     auto tensorDimensionSize = TfLiteOpaqueTensorNumDims(tfLiteTensor);
440*89c4ff92SAndroid Build Coastguard Worker     if (tensorDimensionSize == 0)
441*89c4ff92SAndroid Build Coastguard Worker     {
442*89c4ff92SAndroid Build Coastguard Worker         // If input tensor does not have a shape
443*89c4ff92SAndroid Build Coastguard Worker         // assuming that it has 1D tensor
444*89c4ff92SAndroid Build Coastguard Worker         if (!isOutput)
445*89c4ff92SAndroid Build Coastguard Worker         {
446*89c4ff92SAndroid Build Coastguard Worker             std::vector<unsigned int> safeShape = { 1 };
447*89c4ff92SAndroid Build Coastguard Worker             bool dimensionsSpecificity[1] = { true };
448*89c4ff92SAndroid Build Coastguard Worker 
449*89c4ff92SAndroid Build Coastguard Worker             armnn::TensorShape tensorShape(armnn::numeric_cast<unsigned int>(safeShape.size()),
450*89c4ff92SAndroid Build Coastguard Worker                                            safeShape.data(),
451*89c4ff92SAndroid Build Coastguard Worker                                            dimensionsSpecificity);
452*89c4ff92SAndroid Build Coastguard Worker             ret = armnn::TensorInfo(tensorShape, type);
453*89c4ff92SAndroid Build Coastguard Worker 
454*89c4ff92SAndroid Build Coastguard Worker             if(IsConstantTensor(tfLiteTensor))
455*89c4ff92SAndroid Build Coastguard Worker             {
456*89c4ff92SAndroid Build Coastguard Worker                 ret.SetConstant(true);
457*89c4ff92SAndroid Build Coastguard Worker             }
458*89c4ff92SAndroid Build Coastguard Worker         }
459*89c4ff92SAndroid Build Coastguard Worker         else
460*89c4ff92SAndroid Build Coastguard Worker         {
461*89c4ff92SAndroid Build Coastguard Worker             armnn::TensorShape tensorShape(armnn::Dimensionality::NotSpecified);
462*89c4ff92SAndroid Build Coastguard Worker             ret = armnn::TensorInfo(tensorShape, type);
463*89c4ff92SAndroid Build Coastguard Worker         }
464*89c4ff92SAndroid Build Coastguard Worker     }
465*89c4ff92SAndroid Build Coastguard Worker     else
466*89c4ff92SAndroid Build Coastguard Worker     {
467*89c4ff92SAndroid Build Coastguard Worker         std::vector<unsigned int> tensorDims(static_cast<unsigned int>(tensorDimensionSize));
468*89c4ff92SAndroid Build Coastguard Worker         bool dimensionsSpecificity[5] = { true, true, true, true, true };
469*89c4ff92SAndroid Build Coastguard Worker 
470*89c4ff92SAndroid Build Coastguard Worker         for (int32_t i = 0; i < tensorDimensionSize; ++i)
471*89c4ff92SAndroid Build Coastguard Worker         {
472*89c4ff92SAndroid Build Coastguard Worker             int32_t dim = TfLiteOpaqueTensorDim(tfLiteTensor, i);
473*89c4ff92SAndroid Build Coastguard Worker 
474*89c4ff92SAndroid Build Coastguard Worker             if (dim == 0)
475*89c4ff92SAndroid Build Coastguard Worker             {
476*89c4ff92SAndroid Build Coastguard Worker                 dimensionsSpecificity[i] = false;
477*89c4ff92SAndroid Build Coastguard Worker             }
478*89c4ff92SAndroid Build Coastguard Worker             tensorDims[i] = static_cast<unsigned int>(dim);
479*89c4ff92SAndroid Build Coastguard Worker         }
480*89c4ff92SAndroid Build Coastguard Worker 
481*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorShape tensorShape(static_cast<unsigned int>(tensorDimensionSize),
482*89c4ff92SAndroid Build Coastguard Worker                                        tensorDims.data(),
483*89c4ff92SAndroid Build Coastguard Worker                                        dimensionsSpecificity);
484*89c4ff92SAndroid Build Coastguard Worker 
485*89c4ff92SAndroid Build Coastguard Worker         if(IsConstantTensor(tfLiteTensor))
486*89c4ff92SAndroid Build Coastguard Worker         {
487*89c4ff92SAndroid Build Coastguard Worker             ret = armnn::TensorInfo(tensorShape, type);
488*89c4ff92SAndroid Build Coastguard Worker             ret.SetConstant(true);
489*89c4ff92SAndroid Build Coastguard Worker         }
490*89c4ff92SAndroid Build Coastguard Worker         else
491*89c4ff92SAndroid Build Coastguard Worker         {
492*89c4ff92SAndroid Build Coastguard Worker             ret = armnn::TensorInfo(tensorShape, type);
493*89c4ff92SAndroid Build Coastguard Worker         }
494*89c4ff92SAndroid Build Coastguard Worker     }
495*89c4ff92SAndroid Build Coastguard Worker 
496*89c4ff92SAndroid Build Coastguard Worker     auto quantizationInfo = TfLiteOpaqueTensorGetQuantization(tfLiteTensor);
497*89c4ff92SAndroid Build Coastguard Worker     if (quantizationInfo.type == kTfLiteAffineQuantization)
498*89c4ff92SAndroid Build Coastguard Worker     {
499*89c4ff92SAndroid Build Coastguard Worker         // get per-channel quantization parameters
500*89c4ff92SAndroid Build Coastguard Worker         const auto* affineQuantization =
501*89c4ff92SAndroid Build Coastguard Worker                 reinterpret_cast<TfLiteAffineQuantization*>(quantizationInfo.params);
502*89c4ff92SAndroid Build Coastguard Worker         if (affineQuantization->scale->size > 1)
503*89c4ff92SAndroid Build Coastguard Worker         {
504*89c4ff92SAndroid Build Coastguard Worker             std::vector<float> quantizationScales;
505*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int i = 0; i < static_cast<unsigned int>(affineQuantization->scale->size); ++i)
506*89c4ff92SAndroid Build Coastguard Worker             {
507*89c4ff92SAndroid Build Coastguard Worker                 quantizationScales.push_back(affineQuantization->scale->data[i]);
508*89c4ff92SAndroid Build Coastguard Worker             }
509*89c4ff92SAndroid Build Coastguard Worker             ret.SetQuantizationScales(quantizationScales);
510*89c4ff92SAndroid Build Coastguard Worker             ret.SetQuantizationDim(armnn::numeric_cast<unsigned int>(affineQuantization->quantized_dimension));
511*89c4ff92SAndroid Build Coastguard Worker         }
512*89c4ff92SAndroid Build Coastguard Worker         else
513*89c4ff92SAndroid Build Coastguard Worker         {
514*89c4ff92SAndroid Build Coastguard Worker             ret.SetQuantizationScale(affineQuantization->scale->data[0]);
515*89c4ff92SAndroid Build Coastguard Worker             ret.SetQuantizationOffset(affineQuantization->zero_point->data[0]);
516*89c4ff92SAndroid Build Coastguard Worker         }
517*89c4ff92SAndroid Build Coastguard Worker     }
518*89c4ff92SAndroid Build Coastguard Worker     else
519*89c4ff92SAndroid Build Coastguard Worker     {
520*89c4ff92SAndroid Build Coastguard Worker         auto quantizationParameters = TfLiteOpaqueTensorGetQuantizationParams(tfLiteTensor);
521*89c4ff92SAndroid Build Coastguard Worker         ret.SetQuantizationScale(quantizationParameters.scale);
522*89c4ff92SAndroid Build Coastguard Worker         ret.SetQuantizationOffset(quantizationParameters.zero_point);
523*89c4ff92SAndroid Build Coastguard Worker     }
524*89c4ff92SAndroid Build Coastguard Worker 
525*89c4ff92SAndroid Build Coastguard Worker     return ret;
526*89c4ff92SAndroid Build Coastguard Worker }
527*89c4ff92SAndroid Build Coastguard Worker 
CreateConstTensor(const TfLiteOpaqueTensor * tfLiteTensor,const armnn::TensorInfo & tensorInfo)528*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor CreateConstTensor(const TfLiteOpaqueTensor* tfLiteTensor,
529*89c4ff92SAndroid Build Coastguard Worker                                      const armnn::TensorInfo& tensorInfo)
530*89c4ff92SAndroid Build Coastguard Worker {
531*89c4ff92SAndroid Build Coastguard Worker     auto allocType = TfLiteOpaqueTensorGetAllocationType(tfLiteTensor);
532*89c4ff92SAndroid Build Coastguard Worker     if (allocType != kTfLiteMmapRo)
533*89c4ff92SAndroid Build Coastguard Worker     {
534*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("TfLiteArmnnDelegate: Not constant allocation type: " + std::to_string(allocType));
535*89c4ff92SAndroid Build Coastguard Worker     }
536*89c4ff92SAndroid Build Coastguard Worker 
537*89c4ff92SAndroid Build Coastguard Worker     return armnn::ConstTensor(tensorInfo, TfLiteOpaqueTensorData(tfLiteTensor));
538*89c4ff92SAndroid Build Coastguard Worker }
539*89c4ff92SAndroid Build Coastguard Worker 
GetConstTensorForTfLiteTensor(const TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,int index)540*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor* GetConstTensorForTfLiteTensor(const TfLiteOpaqueContext* tfLiteContext,
541*89c4ff92SAndroid Build Coastguard Worker                                                   TfLiteOpaqueNode* tfLiteNode,
542*89c4ff92SAndroid Build Coastguard Worker                                                   int index)
543*89c4ff92SAndroid Build Coastguard Worker {
544*89c4ff92SAndroid Build Coastguard Worker     const TfLiteOpaqueTensor* tfLiteTensor = TfLiteOpaqueNodeGetInput(tfLiteContext, tfLiteNode, index);
545*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo tensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteTensor);
546*89c4ff92SAndroid Build Coastguard Worker 
547*89c4ff92SAndroid Build Coastguard Worker     return new armnn::ConstTensor(tensorInfo, TfLiteOpaqueTensorData(tfLiteTensor));
548*89c4ff92SAndroid Build Coastguard Worker }
549*89c4ff92SAndroid Build Coastguard Worker 
IsOptionalOperandPresent(TfLiteOpaqueNode * tfLiteNode,const int operandIndex)550*89c4ff92SAndroid Build Coastguard Worker bool IsOptionalOperandPresent(TfLiteOpaqueNode* tfLiteNode, const int operandIndex)
551*89c4ff92SAndroid Build Coastguard Worker {
552*89c4ff92SAndroid Build Coastguard Worker     // Get array of input indices, inputIndexArray is set from the TfLiteOpaqueNodeInputs function
553*89c4ff92SAndroid Build Coastguard Worker     // This function turns inputIndexArray into an int array of indices. These indices point to the index of the
554*89c4ff92SAndroid Build Coastguard Worker     // tensors for each input slot in the node.
555*89c4ff92SAndroid Build Coastguard Worker     const int* inputIndexArray;
556*89c4ff92SAndroid Build Coastguard Worker     int numInputs = 0;
557*89c4ff92SAndroid Build Coastguard Worker 
558*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus status = TfLiteOpaqueNodeInputs(tfLiteNode, &inputIndexArray, &numInputs);
559*89c4ff92SAndroid Build Coastguard Worker     if(status != kTfLiteOk)
560*89c4ff92SAndroid Build Coastguard Worker     {
561*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to gather input information from node.");
562*89c4ff92SAndroid Build Coastguard Worker     }
563*89c4ff92SAndroid Build Coastguard Worker 
564*89c4ff92SAndroid Build Coastguard Worker     // If the inputs array has fewer than operandIndex entries or if the entry at operandIndex has a value of -1 or
565*89c4ff92SAndroid Build Coastguard Worker     // less then the input is not present.
566*89c4ff92SAndroid Build Coastguard Worker     if (numInputs > operandIndex && inputIndexArray[operandIndex] >= 0)
567*89c4ff92SAndroid Build Coastguard Worker     {
568*89c4ff92SAndroid Build Coastguard Worker         return true;
569*89c4ff92SAndroid Build Coastguard Worker     }
570*89c4ff92SAndroid Build Coastguard Worker     return false;
571*89c4ff92SAndroid Build Coastguard Worker }
572*89c4ff92SAndroid Build Coastguard Worker 
ProcessInputs(armnn::IConnectableLayer * layer,armnnOpaqueDelegate::DelegateData & delegateData,TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode)573*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ProcessInputs(armnn::IConnectableLayer* layer,
574*89c4ff92SAndroid Build Coastguard Worker                            armnnOpaqueDelegate::DelegateData& delegateData,
575*89c4ff92SAndroid Build Coastguard Worker                            TfLiteOpaqueContext* tfLiteContext,
576*89c4ff92SAndroid Build Coastguard Worker                            TfLiteOpaqueNode* tfLiteNode)
577*89c4ff92SAndroid Build Coastguard Worker {
578*89c4ff92SAndroid Build Coastguard Worker     // Get array of input indices, inputIndexArray is set from the TfLiteOpaqueNodeInputs function
579*89c4ff92SAndroid Build Coastguard Worker     // This function turns inputIndexArray into an int array of indices. These indices point to the index of the
580*89c4ff92SAndroid Build Coastguard Worker     // tensors for each input slot in the node.
581*89c4ff92SAndroid Build Coastguard Worker     const int* inputIndexArray;
582*89c4ff92SAndroid Build Coastguard Worker     int numInputs = 0;
583*89c4ff92SAndroid Build Coastguard Worker 
584*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus status = TfLiteOpaqueNodeInputs(tfLiteNode, &inputIndexArray, &numInputs);
585*89c4ff92SAndroid Build Coastguard Worker     if(status != kTfLiteOk)
586*89c4ff92SAndroid Build Coastguard Worker     {
587*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to gather input information from node.");
588*89c4ff92SAndroid Build Coastguard Worker     }
589*89c4ff92SAndroid Build Coastguard Worker 
590*89c4ff92SAndroid Build Coastguard Worker     // Process input tensors
591*89c4ff92SAndroid Build Coastguard Worker     // If input tensor is a Constant tensor create a constant layer and connect it to the network
592*89c4ff92SAndroid Build Coastguard Worker     for (int32_t inputIndex = 0; inputIndex < static_cast<int32_t>(layer->GetNumInputSlots()); ++inputIndex)
593*89c4ff92SAndroid Build Coastguard Worker     {
594*89c4ff92SAndroid Build Coastguard Worker         const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueNodeGetInput(tfLiteContext, tfLiteNode, inputIndex);
595*89c4ff92SAndroid Build Coastguard Worker 
596*89c4ff92SAndroid Build Coastguard Worker         if (IsConstantTensor(tfLiteInputTensor))
597*89c4ff92SAndroid Build Coastguard Worker         {
598*89c4ff92SAndroid Build Coastguard Worker             armnn::TensorInfo inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
599*89c4ff92SAndroid Build Coastguard Worker 
600*89c4ff92SAndroid Build Coastguard Worker             bool isSupported = false;
601*89c4ff92SAndroid Build Coastguard Worker             armnn::BackendId setBackend;
602*89c4ff92SAndroid Build Coastguard Worker             FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("CONSTANT",
603*89c4ff92SAndroid Build Coastguard Worker                                               tfLiteContext,
604*89c4ff92SAndroid Build Coastguard Worker                                               IsConstantSupported,
605*89c4ff92SAndroid Build Coastguard Worker                                               delegateData.m_Backends,
606*89c4ff92SAndroid Build Coastguard Worker                                               isSupported,
607*89c4ff92SAndroid Build Coastguard Worker                                               setBackend,
608*89c4ff92SAndroid Build Coastguard Worker                                               inputTensorInfo);
609*89c4ff92SAndroid Build Coastguard Worker             if (!isSupported)
610*89c4ff92SAndroid Build Coastguard Worker             {
611*89c4ff92SAndroid Build Coastguard Worker                 return kTfLiteError;
612*89c4ff92SAndroid Build Coastguard Worker             }
613*89c4ff92SAndroid Build Coastguard Worker 
614*89c4ff92SAndroid Build Coastguard Worker             auto constantInput = CreateConstTensor(tfLiteInputTensor, inputTensorInfo);
615*89c4ff92SAndroid Build Coastguard Worker 
616*89c4ff92SAndroid Build Coastguard Worker             armnn::IConnectableLayer* constantLayer = delegateData.m_Network->AddConstantLayer(constantInput);
617*89c4ff92SAndroid Build Coastguard Worker             constantLayer->SetBackendId(setBackend);
618*89c4ff92SAndroid Build Coastguard Worker             armnn::IOutputSlot& outputSlot = constantLayer->GetOutputSlot(0);
619*89c4ff92SAndroid Build Coastguard Worker             outputSlot.SetTensorInfo(inputTensorInfo);
620*89c4ff92SAndroid Build Coastguard Worker 
621*89c4ff92SAndroid Build Coastguard Worker             delegateData.m_OutputSlotForNode[inputIndexArray[inputIndex]] = &outputSlot;
622*89c4ff92SAndroid Build Coastguard Worker         }
623*89c4ff92SAndroid Build Coastguard Worker     }
624*89c4ff92SAndroid Build Coastguard Worker     return kTfLiteOk;
625*89c4ff92SAndroid Build Coastguard Worker }
626*89c4ff92SAndroid Build Coastguard Worker 
627*89c4ff92SAndroid Build Coastguard Worker } // namespace anonymous
628