1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <DelegateUtils.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendHelper.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TensorUtils.hpp>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
20*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
21*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
22*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/c_api_opaque.h>
23*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
24*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/kernel_util.h>
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker namespace
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker // Macro to call an Is<layer_name>Supported function and log caller name together with reason for lack of support
30*89c4ff92SAndroid Build Coastguard Worker #define FORWARD_LAYER_OPAQUE_SUPPORT_FUNC(opName, tfLiteContext, func, backends, supported, setBackend, ...) \
31*89c4ff92SAndroid Build Coastguard Worker try \
32*89c4ff92SAndroid Build Coastguard Worker { \
33*89c4ff92SAndroid Build Coastguard Worker for (auto&& backendId : backends) \
34*89c4ff92SAndroid Build Coastguard Worker { \
35*89c4ff92SAndroid Build Coastguard Worker auto layerSupportObject = armnn::GetILayerSupportByBackendId(backendId); \
36*89c4ff92SAndroid Build Coastguard Worker if (layerSupportObject.IsBackendRegistered()) \
37*89c4ff92SAndroid Build Coastguard Worker { \
38*89c4ff92SAndroid Build Coastguard Worker std::string reasonIfUnsupported; \
39*89c4ff92SAndroid Build Coastguard Worker supported = \
40*89c4ff92SAndroid Build Coastguard Worker layerSupportObject.func(__VA_ARGS__, armnn::Optional<std::string&>(reasonIfUnsupported)); \
41*89c4ff92SAndroid Build Coastguard Worker if (supported) \
42*89c4ff92SAndroid Build Coastguard Worker { \
43*89c4ff92SAndroid Build Coastguard Worker setBackend = backendId; \
44*89c4ff92SAndroid Build Coastguard Worker break; \
45*89c4ff92SAndroid Build Coastguard Worker } \
46*89c4ff92SAndroid Build Coastguard Worker else \
47*89c4ff92SAndroid Build Coastguard Worker { \
48*89c4ff92SAndroid Build Coastguard Worker if (reasonIfUnsupported.size() > 0) \
49*89c4ff92SAndroid Build Coastguard Worker { \
50*89c4ff92SAndroid Build Coastguard Worker TFLITE_LOG_PROD(tflite::TFLITE_LOG_WARNING, \
51*89c4ff92SAndroid Build Coastguard Worker "%s: not supported by armnn: %s", opName, reasonIfUnsupported.c_str()); \
52*89c4ff92SAndroid Build Coastguard Worker } \
53*89c4ff92SAndroid Build Coastguard Worker else \
54*89c4ff92SAndroid Build Coastguard Worker { \
55*89c4ff92SAndroid Build Coastguard Worker TFLITE_LOG_PROD(tflite::TFLITE_LOG_WARNING, \
56*89c4ff92SAndroid Build Coastguard Worker "%s: not supported by armnn", opName); \
57*89c4ff92SAndroid Build Coastguard Worker } \
58*89c4ff92SAndroid Build Coastguard Worker } \
59*89c4ff92SAndroid Build Coastguard Worker } \
60*89c4ff92SAndroid Build Coastguard Worker else \
61*89c4ff92SAndroid Build Coastguard Worker { \
62*89c4ff92SAndroid Build Coastguard Worker TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext, "%s: backend not registered: %s", \
63*89c4ff92SAndroid Build Coastguard Worker opName, backendId.Get().c_str()); \
64*89c4ff92SAndroid Build Coastguard Worker } \
65*89c4ff92SAndroid Build Coastguard Worker } \
66*89c4ff92SAndroid Build Coastguard Worker if (!supported) \
67*89c4ff92SAndroid Build Coastguard Worker { \
68*89c4ff92SAndroid Build Coastguard Worker TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext, "%s: not supported by any specified backend", opName); \
69*89c4ff92SAndroid Build Coastguard Worker } \
70*89c4ff92SAndroid Build Coastguard Worker } \
71*89c4ff92SAndroid Build Coastguard Worker catch (const armnn::InvalidArgumentException &e) \
72*89c4ff92SAndroid Build Coastguard Worker { \
73*89c4ff92SAndroid Build Coastguard Worker throw armnn::InvalidArgumentException(e, "Failed to check layer support", CHECK_LOCATION()); \
74*89c4ff92SAndroid Build Coastguard Worker }
75*89c4ff92SAndroid Build Coastguard Worker
ValidateNumInputs(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,const unsigned int expectedSize,int nodeIndex)76*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateNumInputs(TfLiteOpaqueContext* tfLiteContext,
77*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueNode* tfLiteNode,
78*89c4ff92SAndroid Build Coastguard Worker const unsigned int expectedSize,
79*89c4ff92SAndroid Build Coastguard Worker int nodeIndex)
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker int numInputs = TfLiteOpaqueNodeNumberOfInputs(tfLiteNode);
82*89c4ff92SAndroid Build Coastguard Worker if (static_cast<unsigned int>(numInputs) != expectedSize)
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
85*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, "TfLiteArmnnOpaqueDelegate: Unexpected number of inputs (%d != %d) in node #%d",
86*89c4ff92SAndroid Build Coastguard Worker numInputs, expectedSize, nodeIndex);
87*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
88*89c4ff92SAndroid Build Coastguard Worker }
89*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker
ValidateNumOutputs(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,const unsigned int expectedSize,int nodeIndex)92*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ValidateNumOutputs(TfLiteOpaqueContext* tfLiteContext,
93*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueNode* tfLiteNode,
94*89c4ff92SAndroid Build Coastguard Worker const unsigned int expectedSize,
95*89c4ff92SAndroid Build Coastguard Worker int nodeIndex)
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker auto numOutputs = TfLiteOpaqueNodeNumberOfOutputs(tfLiteNode);
98*89c4ff92SAndroid Build Coastguard Worker if (static_cast<unsigned int>(numOutputs) != expectedSize)
99*89c4ff92SAndroid Build Coastguard Worker {
100*89c4ff92SAndroid Build Coastguard Worker TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
101*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, "TfLiteArmnnOpaqueDelegate: Unexpected number of outputs (%d != %d) in node #%d",
102*89c4ff92SAndroid Build Coastguard Worker numOutputs, expectedSize, nodeIndex);
103*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker
IsConstantTensor(const TfLiteOpaqueTensor * tfLiteTensor)108*89c4ff92SAndroid Build Coastguard Worker bool IsConstantTensor(const TfLiteOpaqueTensor* tfLiteTensor)
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker auto tensorAllocationType = TfLiteOpaqueTensorGetAllocationType(tfLiteTensor);
111*89c4ff92SAndroid Build Coastguard Worker if (tensorAllocationType == kTfLiteMmapRo)
112*89c4ff92SAndroid Build Coastguard Worker {
113*89c4ff92SAndroid Build Coastguard Worker return true;
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker return false;
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker
IsDynamicTensor(const TfLiteOpaqueTensor * tfLiteTensor)118*89c4ff92SAndroid Build Coastguard Worker bool IsDynamicTensor(const TfLiteOpaqueTensor* tfLiteTensor)
119*89c4ff92SAndroid Build Coastguard Worker {
120*89c4ff92SAndroid Build Coastguard Worker auto tensorAllocationType = TfLiteOpaqueTensorGetAllocationType(tfLiteTensor);
121*89c4ff92SAndroid Build Coastguard Worker if (tensorAllocationType == kTfLiteDynamic)
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker return true;
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker return false;
126*89c4ff92SAndroid Build Coastguard Worker }
127*89c4ff92SAndroid Build Coastguard Worker
IsValid(const TfLiteOpaqueTensor * tfLiteTensor)128*89c4ff92SAndroid Build Coastguard Worker bool IsValid(const TfLiteOpaqueTensor* tfLiteTensor)
129*89c4ff92SAndroid Build Coastguard Worker {
130*89c4ff92SAndroid Build Coastguard Worker return tfLiteTensor == nullptr ? false : true;
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker
IsValid(TfLiteOpaqueContext * tfLiteContext,const TfLiteOpaqueTensor * tfLiteTensor,int32_t operatorCode,int32_t nodeIndex)133*89c4ff92SAndroid Build Coastguard Worker bool IsValid(TfLiteOpaqueContext* tfLiteContext,
134*89c4ff92SAndroid Build Coastguard Worker const TfLiteOpaqueTensor* tfLiteTensor,
135*89c4ff92SAndroid Build Coastguard Worker int32_t operatorCode,
136*89c4ff92SAndroid Build Coastguard Worker int32_t nodeIndex)
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(tfLiteTensor))
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
141*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
142*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Invalid TfLite tensor in operator #%d node #%d: ",
143*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
144*89c4ff92SAndroid Build Coastguard Worker return false;
145*89c4ff92SAndroid Build Coastguard Worker }
146*89c4ff92SAndroid Build Coastguard Worker if (IsDynamicTensor(tfLiteTensor))
147*89c4ff92SAndroid Build Coastguard Worker {
148*89c4ff92SAndroid Build Coastguard Worker TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
149*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
150*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Dynamic tensors are not supported in operator #%d node #%d: ",
151*89c4ff92SAndroid Build Coastguard Worker operatorCode, nodeIndex);
152*89c4ff92SAndroid Build Coastguard Worker return false;
153*89c4ff92SAndroid Build Coastguard Worker }
154*89c4ff92SAndroid Build Coastguard Worker return true;
155*89c4ff92SAndroid Build Coastguard Worker }
156*89c4ff92SAndroid Build Coastguard Worker
IsAffineQuantization(const TfLiteOpaqueTensor & tfLiteTensor)157*89c4ff92SAndroid Build Coastguard Worker bool IsAffineQuantization(const TfLiteOpaqueTensor& tfLiteTensor)
158*89c4ff92SAndroid Build Coastguard Worker {
159*89c4ff92SAndroid Build Coastguard Worker auto quantizationInfo = TfLiteOpaqueTensorGetQuantization(&tfLiteTensor);
160*89c4ff92SAndroid Build Coastguard Worker if (quantizationInfo.type == kTfLiteAffineQuantization)
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker return true;
163*89c4ff92SAndroid Build Coastguard Worker }
164*89c4ff92SAndroid Build Coastguard Worker return false;
165*89c4ff92SAndroid Build Coastguard Worker }
166*89c4ff92SAndroid Build Coastguard Worker
167*89c4ff92SAndroid Build Coastguard Worker // Connects the layer to the graph
Connect(armnn::IConnectableLayer * layer,TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,armnnOpaqueDelegate::DelegateData & data)168*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus Connect(armnn::IConnectableLayer* layer,
169*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueContext* tfLiteContext,
170*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueNode* tfLiteNode,
171*89c4ff92SAndroid Build Coastguard Worker armnnOpaqueDelegate::DelegateData& data)
172*89c4ff92SAndroid Build Coastguard Worker {
173*89c4ff92SAndroid Build Coastguard Worker // Get array of input indices, inputIndexArray is set from the TfLiteOpaqueNodeInputs function
174*89c4ff92SAndroid Build Coastguard Worker // This function turns inputIndexArray into an int array of indices. These indices point to the index of the
175*89c4ff92SAndroid Build Coastguard Worker // tensors for each input slot in the node.
176*89c4ff92SAndroid Build Coastguard Worker const int* inputIndexArray;
177*89c4ff92SAndroid Build Coastguard Worker int numInputs;
178*89c4ff92SAndroid Build Coastguard Worker if(TfLiteOpaqueNodeInputs(tfLiteNode, &inputIndexArray, &numInputs) != kTfLiteOk)
179*89c4ff92SAndroid Build Coastguard Worker {
180*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
181*89c4ff92SAndroid Build Coastguard Worker }
182*89c4ff92SAndroid Build Coastguard Worker // numInputs is set from TfLiteOpaqueNodeInputs.
183*89c4ff92SAndroid Build Coastguard Worker if(numInputs != static_cast<int>(layer->GetNumInputSlots()))
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Layer: " << layer->GetName() << ": Expected number of input slots does not match actual "
186*89c4ff92SAndroid Build Coastguard Worker "number of input slots.";
187*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker // Connect the input slots.
190*89c4ff92SAndroid Build Coastguard Worker // For each input slot, get the index of the opaque tensor that was allocated for it.
191*89c4ff92SAndroid Build Coastguard Worker for (unsigned int inputIndex = 0; inputIndex < layer->GetNumInputSlots(); ++inputIndex)
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker if (data.m_OutputSlotForNode[inputIndexArray[inputIndex]] != nullptr)
194*89c4ff92SAndroid Build Coastguard Worker {
195*89c4ff92SAndroid Build Coastguard Worker data.m_OutputSlotForNode[inputIndexArray[inputIndex]]->Connect(layer->GetInputSlot(inputIndex));
196*89c4ff92SAndroid Build Coastguard Worker }
197*89c4ff92SAndroid Build Coastguard Worker }
198*89c4ff92SAndroid Build Coastguard Worker
199*89c4ff92SAndroid Build Coastguard Worker // Get array of output indices, outputIndexArray is set from the TfLiteOpaqueNodeOutputs function
200*89c4ff92SAndroid Build Coastguard Worker // This function turns outputIndexArray into an int array of indices. These indices point to the tensors for
201*89c4ff92SAndroid Build Coastguard Worker // each output slot in the node.
202*89c4ff92SAndroid Build Coastguard Worker const int* outputIndexArray;
203*89c4ff92SAndroid Build Coastguard Worker int numOutputs;
204*89c4ff92SAndroid Build Coastguard Worker if(TfLiteOpaqueNodeOutputs(tfLiteNode, &outputIndexArray, &numOutputs) != kTfLiteOk)
205*89c4ff92SAndroid Build Coastguard Worker {
206*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
207*89c4ff92SAndroid Build Coastguard Worker }
208*89c4ff92SAndroid Build Coastguard Worker // numOutputs is set from TfLiteOpaqueNodeOutputs.
209*89c4ff92SAndroid Build Coastguard Worker if(numOutputs != static_cast<int>(layer->GetNumOutputSlots()))
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "Layer: " << layer->GetName() << ": Expected number of output slots does not match actual "
212*89c4ff92SAndroid Build Coastguard Worker "number of output slots.";
213*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
214*89c4ff92SAndroid Build Coastguard Worker }
215*89c4ff92SAndroid Build Coastguard Worker
216*89c4ff92SAndroid Build Coastguard Worker // Prepare output slots
217*89c4ff92SAndroid Build Coastguard Worker for (unsigned int outputIndex = 0; outputIndex < layer->GetNumOutputSlots(); ++outputIndex)
218*89c4ff92SAndroid Build Coastguard Worker {
219*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
220*89c4ff92SAndroid Build Coastguard Worker data.m_OutputSlotForNode[static_cast<unsigned long>(outputIndexArray[outputIndex])] = &outputSlot;
221*89c4ff92SAndroid Build Coastguard Worker }
222*89c4ff92SAndroid Build Coastguard Worker
223*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
224*89c4ff92SAndroid Build Coastguard Worker }
225*89c4ff92SAndroid Build Coastguard Worker
FusedActivation(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,TfLiteFusedActivation activationType,armnn::IConnectableLayer * prevLayer,unsigned int outputSlotIndex,armnnOpaqueDelegate::DelegateData & data)226*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus FusedActivation(TfLiteOpaqueContext* tfLiteContext,
227*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueNode* tfLiteNode,
228*89c4ff92SAndroid Build Coastguard Worker TfLiteFusedActivation activationType,
229*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* prevLayer,
230*89c4ff92SAndroid Build Coastguard Worker unsigned int outputSlotIndex,
231*89c4ff92SAndroid Build Coastguard Worker armnnOpaqueDelegate::DelegateData& data)
232*89c4ff92SAndroid Build Coastguard Worker {
233*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& activationOutputInfo = prevLayer->GetOutputSlot(outputSlotIndex).GetTensorInfo();
234*89c4ff92SAndroid Build Coastguard Worker
235*89c4ff92SAndroid Build Coastguard Worker armnn::ActivationDescriptor activationDesc;
236*89c4ff92SAndroid Build Coastguard Worker
237*89c4ff92SAndroid Build Coastguard Worker switch (activationType)
238*89c4ff92SAndroid Build Coastguard Worker {
239*89c4ff92SAndroid Build Coastguard Worker case kTfLiteActNone:
240*89c4ff92SAndroid Build Coastguard Worker {
241*89c4ff92SAndroid Build Coastguard Worker // No Activation
242*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
243*89c4ff92SAndroid Build Coastguard Worker }
244*89c4ff92SAndroid Build Coastguard Worker case kTfLiteActRelu:
245*89c4ff92SAndroid Build Coastguard Worker {
246*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_Function = armnn::ActivationFunction::ReLu;
247*89c4ff92SAndroid Build Coastguard Worker break;
248*89c4ff92SAndroid Build Coastguard Worker }
249*89c4ff92SAndroid Build Coastguard Worker case kTfLiteActReluN1To1:
250*89c4ff92SAndroid Build Coastguard Worker {
251*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
252*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_A = 1.0f;
253*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_B = -1.0f;
254*89c4ff92SAndroid Build Coastguard Worker break;
255*89c4ff92SAndroid Build Coastguard Worker }
256*89c4ff92SAndroid Build Coastguard Worker case kTfLiteActRelu6:
257*89c4ff92SAndroid Build Coastguard Worker {
258*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
259*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_A = 6.0f;
260*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_B = 0.0f;
261*89c4ff92SAndroid Build Coastguard Worker break;
262*89c4ff92SAndroid Build Coastguard Worker }
263*89c4ff92SAndroid Build Coastguard Worker case kTfLiteActSigmoid:
264*89c4ff92SAndroid Build Coastguard Worker {
265*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_Function = armnn::ActivationFunction::Sigmoid;
266*89c4ff92SAndroid Build Coastguard Worker break;
267*89c4ff92SAndroid Build Coastguard Worker }
268*89c4ff92SAndroid Build Coastguard Worker case kTfLiteActTanh:
269*89c4ff92SAndroid Build Coastguard Worker {
270*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_Function = armnn::ActivationFunction::TanH;
271*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_A = 1.0f;
272*89c4ff92SAndroid Build Coastguard Worker activationDesc.m_B = 1.0f;
273*89c4ff92SAndroid Build Coastguard Worker break;
274*89c4ff92SAndroid Build Coastguard Worker }
275*89c4ff92SAndroid Build Coastguard Worker default:
276*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
277*89c4ff92SAndroid Build Coastguard Worker }
278*89c4ff92SAndroid Build Coastguard Worker
279*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
280*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
281*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("ACTIVATION",
282*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
283*89c4ff92SAndroid Build Coastguard Worker IsActivationSupported,
284*89c4ff92SAndroid Build Coastguard Worker data.m_Backends,
285*89c4ff92SAndroid Build Coastguard Worker isSupported,
286*89c4ff92SAndroid Build Coastguard Worker setBackend,
287*89c4ff92SAndroid Build Coastguard Worker activationOutputInfo,
288*89c4ff92SAndroid Build Coastguard Worker activationOutputInfo,
289*89c4ff92SAndroid Build Coastguard Worker activationDesc);
290*89c4ff92SAndroid Build Coastguard Worker if (!isSupported)
291*89c4ff92SAndroid Build Coastguard Worker {
292*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
293*89c4ff92SAndroid Build Coastguard Worker }
294*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* activationLayer = data.m_Network->AddActivationLayer(activationDesc);
295*89c4ff92SAndroid Build Coastguard Worker activationLayer->SetBackendId(setBackend);
296*89c4ff92SAndroid Build Coastguard Worker
297*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(activationLayer != nullptr);
298*89c4ff92SAndroid Build Coastguard Worker activationLayer->GetOutputSlot(0).SetTensorInfo(activationOutputInfo);
299*89c4ff92SAndroid Build Coastguard Worker
300*89c4ff92SAndroid Build Coastguard Worker // Get array of output indices, outputIndexArray is set from the TfLiteOpaqueNodeOutputs function
301*89c4ff92SAndroid Build Coastguard Worker // This function turns outputIndexArray into an int array of indices. These indices point to the tensors for
302*89c4ff92SAndroid Build Coastguard Worker // each output slot in the node.
303*89c4ff92SAndroid Build Coastguard Worker const int* outputIndexArray;
304*89c4ff92SAndroid Build Coastguard Worker int numOutputs;
305*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus outputStatus = TfLiteOpaqueNodeOutputs(tfLiteNode, &outputIndexArray, &numOutputs);
306*89c4ff92SAndroid Build Coastguard Worker if(outputStatus != kTfLiteOk)
307*89c4ff92SAndroid Build Coastguard Worker {
308*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
309*89c4ff92SAndroid Build Coastguard Worker }
310*89c4ff92SAndroid Build Coastguard Worker
311*89c4ff92SAndroid Build Coastguard Worker // Connect and prepare output slots
312*89c4ff92SAndroid Build Coastguard Worker for (unsigned int outputIndex = 0; outputIndex < activationLayer->GetNumOutputSlots(); ++outputIndex)
313*89c4ff92SAndroid Build Coastguard Worker {
314*89c4ff92SAndroid Build Coastguard Worker data.m_OutputSlotForNode[static_cast<unsigned long>(
315*89c4ff92SAndroid Build Coastguard Worker outputIndexArray[outputIndex])]->Connect(activationLayer->GetInputSlot(0));
316*89c4ff92SAndroid Build Coastguard Worker
317*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = activationLayer->GetOutputSlot(outputIndex);
318*89c4ff92SAndroid Build Coastguard Worker data.m_OutputSlotForNode[static_cast<unsigned long>(outputIndexArray[outputIndex])] = &outputSlot;
319*89c4ff92SAndroid Build Coastguard Worker }
320*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
321*89c4ff92SAndroid Build Coastguard Worker }
322*89c4ff92SAndroid Build Coastguard Worker
AddReshapeLayer(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,armnn::IConnectableLayer * prevLayer,armnn::TensorInfo reshapedOutputTensorInfo,armnn::TensorInfo outputTensorInfo,armnnOpaqueDelegate::DelegateData & data)323*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* AddReshapeLayer(TfLiteOpaqueContext* tfLiteContext,
324*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueNode* tfLiteNode,
325*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* prevLayer,
326*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo reshapedOutputTensorInfo,
327*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo,
328*89c4ff92SAndroid Build Coastguard Worker armnnOpaqueDelegate::DelegateData& data)
329*89c4ff92SAndroid Build Coastguard Worker {
330*89c4ff92SAndroid Build Coastguard Worker armnn::ReshapeDescriptor desc;
331*89c4ff92SAndroid Build Coastguard Worker desc.m_TargetShape = outputTensorInfo.GetShape();
332*89c4ff92SAndroid Build Coastguard Worker
333*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
334*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
335*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("RESHAPE",
336*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
337*89c4ff92SAndroid Build Coastguard Worker IsReshapeSupported,
338*89c4ff92SAndroid Build Coastguard Worker data.m_Backends,
339*89c4ff92SAndroid Build Coastguard Worker isSupported,
340*89c4ff92SAndroid Build Coastguard Worker setBackend,
341*89c4ff92SAndroid Build Coastguard Worker reshapedOutputTensorInfo,
342*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo,
343*89c4ff92SAndroid Build Coastguard Worker desc);
344*89c4ff92SAndroid Build Coastguard Worker
345*89c4ff92SAndroid Build Coastguard Worker if (!isSupported)
346*89c4ff92SAndroid Build Coastguard Worker {
347*89c4ff92SAndroid Build Coastguard Worker return nullptr;
348*89c4ff92SAndroid Build Coastguard Worker }
349*89c4ff92SAndroid Build Coastguard Worker
350*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* reshapeLayer = data.m_Network->AddReshapeLayer(desc);
351*89c4ff92SAndroid Build Coastguard Worker reshapeLayer->SetBackendId(setBackend);
352*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(reshapeLayer != nullptr);
353*89c4ff92SAndroid Build Coastguard Worker
354*89c4ff92SAndroid Build Coastguard Worker prevLayer->GetOutputSlot(0).SetTensorInfo(reshapedOutputTensorInfo);
355*89c4ff92SAndroid Build Coastguard Worker reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
356*89c4ff92SAndroid Build Coastguard Worker
357*89c4ff92SAndroid Build Coastguard Worker // Gather array of indices and it's length, replaces node->outputs->data[i]
358*89c4ff92SAndroid Build Coastguard Worker const int* outputIndices = nullptr;
359*89c4ff92SAndroid Build Coastguard Worker int numOutputs = 0;
360*89c4ff92SAndroid Build Coastguard Worker
361*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus status = TfLiteOpaqueNodeOutputs(tfLiteNode, &outputIndices, &numOutputs);
362*89c4ff92SAndroid Build Coastguard Worker if(status != kTfLiteOk)
363*89c4ff92SAndroid Build Coastguard Worker {
364*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to gather output information from node.");
365*89c4ff92SAndroid Build Coastguard Worker }
366*89c4ff92SAndroid Build Coastguard Worker
367*89c4ff92SAndroid Build Coastguard Worker if (static_cast<unsigned int>(numOutputs) != reshapeLayer->GetNumOutputSlots())
368*89c4ff92SAndroid Build Coastguard Worker {
369*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unexpected number of outputs (" +
370*89c4ff92SAndroid Build Coastguard Worker std::to_string(numOutputs) +
371*89c4ff92SAndroid Build Coastguard Worker "!= " +
372*89c4ff92SAndroid Build Coastguard Worker std::to_string(reshapeLayer->GetNumOutputSlots()) +
373*89c4ff92SAndroid Build Coastguard Worker ") in node.");
374*89c4ff92SAndroid Build Coastguard Worker }
375*89c4ff92SAndroid Build Coastguard Worker
376*89c4ff92SAndroid Build Coastguard Worker // Connect and prepare output slots
377*89c4ff92SAndroid Build Coastguard Worker for (unsigned int outputIndex = 0; outputIndex < reshapeLayer->GetNumOutputSlots(); ++outputIndex)
378*89c4ff92SAndroid Build Coastguard Worker {
379*89c4ff92SAndroid Build Coastguard Worker data.m_OutputSlotForNode[static_cast<unsigned long>(
380*89c4ff92SAndroid Build Coastguard Worker outputIndices[outputIndex])]->Connect(reshapeLayer->GetInputSlot(0));
381*89c4ff92SAndroid Build Coastguard Worker
382*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = reshapeLayer->GetOutputSlot(outputIndex);
383*89c4ff92SAndroid Build Coastguard Worker data.m_OutputSlotForNode[static_cast<unsigned long>(outputIndices[outputIndex])] = &outputSlot;
384*89c4ff92SAndroid Build Coastguard Worker }
385*89c4ff92SAndroid Build Coastguard Worker return reshapeLayer;
386*89c4ff92SAndroid Build Coastguard Worker }
387*89c4ff92SAndroid Build Coastguard Worker
GetDataType(const TfLiteOpaqueTensor * tfLiteTensor)388*89c4ff92SAndroid Build Coastguard Worker armnn::DataType GetDataType(const TfLiteOpaqueTensor* tfLiteTensor)
389*89c4ff92SAndroid Build Coastguard Worker {
390*89c4ff92SAndroid Build Coastguard Worker switch (TfLiteOpaqueTensorType(tfLiteTensor))
391*89c4ff92SAndroid Build Coastguard Worker {
392*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBool:
393*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Boolean;
394*89c4ff92SAndroid Build Coastguard Worker case kTfLiteFloat32:
395*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Float32;
396*89c4ff92SAndroid Build Coastguard Worker case kTfLiteFloat16:
397*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Float16;
398*89c4ff92SAndroid Build Coastguard Worker case kTfLiteUInt8:
399*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::QAsymmU8;
400*89c4ff92SAndroid Build Coastguard Worker case kTfLiteInt8:
401*89c4ff92SAndroid Build Coastguard Worker {
402*89c4ff92SAndroid Build Coastguard Worker auto quantizationInfo = TfLiteOpaqueTensorGetQuantization(tfLiteTensor);
403*89c4ff92SAndroid Build Coastguard Worker if (quantizationInfo.type == kTfLiteAffineQuantization)
404*89c4ff92SAndroid Build Coastguard Worker {
405*89c4ff92SAndroid Build Coastguard Worker auto* quantization =
406*89c4ff92SAndroid Build Coastguard Worker reinterpret_cast<TfLiteAffineQuantization*>(quantizationInfo.params);
407*89c4ff92SAndroid Build Coastguard Worker
408*89c4ff92SAndroid Build Coastguard Worker if (quantization->zero_point != nullptr && quantization->zero_point->size == 1)
409*89c4ff92SAndroid Build Coastguard Worker {
410*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::QAsymmS8;
411*89c4ff92SAndroid Build Coastguard Worker }
412*89c4ff92SAndroid Build Coastguard Worker else
413*89c4ff92SAndroid Build Coastguard Worker {
414*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::QSymmS8;
415*89c4ff92SAndroid Build Coastguard Worker }
416*89c4ff92SAndroid Build Coastguard Worker }
417*89c4ff92SAndroid Build Coastguard Worker else
418*89c4ff92SAndroid Build Coastguard Worker {
419*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::QAsymmS8;
420*89c4ff92SAndroid Build Coastguard Worker }
421*89c4ff92SAndroid Build Coastguard Worker }
422*89c4ff92SAndroid Build Coastguard Worker case kTfLiteInt16:
423*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::QSymmS16;
424*89c4ff92SAndroid Build Coastguard Worker case kTfLiteInt32:
425*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Signed32;
426*89c4ff92SAndroid Build Coastguard Worker case kTfLiteInt64:
427*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Signed64;
428*89c4ff92SAndroid Build Coastguard Worker default:
429*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(
430*89c4ff92SAndroid Build Coastguard Worker &"TfLiteArmnnDelegate: Unsupported data type: " [ TfLiteOpaqueTensorType(tfLiteTensor) ]);
431*89c4ff92SAndroid Build Coastguard Worker }
432*89c4ff92SAndroid Build Coastguard Worker }
433*89c4ff92SAndroid Build Coastguard Worker
GetTensorInfoForTfLiteOpaqueTensor(const TfLiteOpaqueTensor * tfLiteTensor,bool isOutput=false)434*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo GetTensorInfoForTfLiteOpaqueTensor(const TfLiteOpaqueTensor* tfLiteTensor, bool isOutput = false)
435*89c4ff92SAndroid Build Coastguard Worker {
436*89c4ff92SAndroid Build Coastguard Worker armnn::DataType type = GetDataType(tfLiteTensor);
437*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo ret;
438*89c4ff92SAndroid Build Coastguard Worker
439*89c4ff92SAndroid Build Coastguard Worker auto tensorDimensionSize = TfLiteOpaqueTensorNumDims(tfLiteTensor);
440*89c4ff92SAndroid Build Coastguard Worker if (tensorDimensionSize == 0)
441*89c4ff92SAndroid Build Coastguard Worker {
442*89c4ff92SAndroid Build Coastguard Worker // If input tensor does not have a shape
443*89c4ff92SAndroid Build Coastguard Worker // assuming that it has 1D tensor
444*89c4ff92SAndroid Build Coastguard Worker if (!isOutput)
445*89c4ff92SAndroid Build Coastguard Worker {
446*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> safeShape = { 1 };
447*89c4ff92SAndroid Build Coastguard Worker bool dimensionsSpecificity[1] = { true };
448*89c4ff92SAndroid Build Coastguard Worker
449*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape tensorShape(armnn::numeric_cast<unsigned int>(safeShape.size()),
450*89c4ff92SAndroid Build Coastguard Worker safeShape.data(),
451*89c4ff92SAndroid Build Coastguard Worker dimensionsSpecificity);
452*89c4ff92SAndroid Build Coastguard Worker ret = armnn::TensorInfo(tensorShape, type);
453*89c4ff92SAndroid Build Coastguard Worker
454*89c4ff92SAndroid Build Coastguard Worker if(IsConstantTensor(tfLiteTensor))
455*89c4ff92SAndroid Build Coastguard Worker {
456*89c4ff92SAndroid Build Coastguard Worker ret.SetConstant(true);
457*89c4ff92SAndroid Build Coastguard Worker }
458*89c4ff92SAndroid Build Coastguard Worker }
459*89c4ff92SAndroid Build Coastguard Worker else
460*89c4ff92SAndroid Build Coastguard Worker {
461*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape tensorShape(armnn::Dimensionality::NotSpecified);
462*89c4ff92SAndroid Build Coastguard Worker ret = armnn::TensorInfo(tensorShape, type);
463*89c4ff92SAndroid Build Coastguard Worker }
464*89c4ff92SAndroid Build Coastguard Worker }
465*89c4ff92SAndroid Build Coastguard Worker else
466*89c4ff92SAndroid Build Coastguard Worker {
467*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> tensorDims(static_cast<unsigned int>(tensorDimensionSize));
468*89c4ff92SAndroid Build Coastguard Worker bool dimensionsSpecificity[5] = { true, true, true, true, true };
469*89c4ff92SAndroid Build Coastguard Worker
470*89c4ff92SAndroid Build Coastguard Worker for (int32_t i = 0; i < tensorDimensionSize; ++i)
471*89c4ff92SAndroid Build Coastguard Worker {
472*89c4ff92SAndroid Build Coastguard Worker int32_t dim = TfLiteOpaqueTensorDim(tfLiteTensor, i);
473*89c4ff92SAndroid Build Coastguard Worker
474*89c4ff92SAndroid Build Coastguard Worker if (dim == 0)
475*89c4ff92SAndroid Build Coastguard Worker {
476*89c4ff92SAndroid Build Coastguard Worker dimensionsSpecificity[i] = false;
477*89c4ff92SAndroid Build Coastguard Worker }
478*89c4ff92SAndroid Build Coastguard Worker tensorDims[i] = static_cast<unsigned int>(dim);
479*89c4ff92SAndroid Build Coastguard Worker }
480*89c4ff92SAndroid Build Coastguard Worker
481*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape tensorShape(static_cast<unsigned int>(tensorDimensionSize),
482*89c4ff92SAndroid Build Coastguard Worker tensorDims.data(),
483*89c4ff92SAndroid Build Coastguard Worker dimensionsSpecificity);
484*89c4ff92SAndroid Build Coastguard Worker
485*89c4ff92SAndroid Build Coastguard Worker if(IsConstantTensor(tfLiteTensor))
486*89c4ff92SAndroid Build Coastguard Worker {
487*89c4ff92SAndroid Build Coastguard Worker ret = armnn::TensorInfo(tensorShape, type);
488*89c4ff92SAndroid Build Coastguard Worker ret.SetConstant(true);
489*89c4ff92SAndroid Build Coastguard Worker }
490*89c4ff92SAndroid Build Coastguard Worker else
491*89c4ff92SAndroid Build Coastguard Worker {
492*89c4ff92SAndroid Build Coastguard Worker ret = armnn::TensorInfo(tensorShape, type);
493*89c4ff92SAndroid Build Coastguard Worker }
494*89c4ff92SAndroid Build Coastguard Worker }
495*89c4ff92SAndroid Build Coastguard Worker
496*89c4ff92SAndroid Build Coastguard Worker auto quantizationInfo = TfLiteOpaqueTensorGetQuantization(tfLiteTensor);
497*89c4ff92SAndroid Build Coastguard Worker if (quantizationInfo.type == kTfLiteAffineQuantization)
498*89c4ff92SAndroid Build Coastguard Worker {
499*89c4ff92SAndroid Build Coastguard Worker // get per-channel quantization parameters
500*89c4ff92SAndroid Build Coastguard Worker const auto* affineQuantization =
501*89c4ff92SAndroid Build Coastguard Worker reinterpret_cast<TfLiteAffineQuantization*>(quantizationInfo.params);
502*89c4ff92SAndroid Build Coastguard Worker if (affineQuantization->scale->size > 1)
503*89c4ff92SAndroid Build Coastguard Worker {
504*89c4ff92SAndroid Build Coastguard Worker std::vector<float> quantizationScales;
505*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < static_cast<unsigned int>(affineQuantization->scale->size); ++i)
506*89c4ff92SAndroid Build Coastguard Worker {
507*89c4ff92SAndroid Build Coastguard Worker quantizationScales.push_back(affineQuantization->scale->data[i]);
508*89c4ff92SAndroid Build Coastguard Worker }
509*89c4ff92SAndroid Build Coastguard Worker ret.SetQuantizationScales(quantizationScales);
510*89c4ff92SAndroid Build Coastguard Worker ret.SetQuantizationDim(armnn::numeric_cast<unsigned int>(affineQuantization->quantized_dimension));
511*89c4ff92SAndroid Build Coastguard Worker }
512*89c4ff92SAndroid Build Coastguard Worker else
513*89c4ff92SAndroid Build Coastguard Worker {
514*89c4ff92SAndroid Build Coastguard Worker ret.SetQuantizationScale(affineQuantization->scale->data[0]);
515*89c4ff92SAndroid Build Coastguard Worker ret.SetQuantizationOffset(affineQuantization->zero_point->data[0]);
516*89c4ff92SAndroid Build Coastguard Worker }
517*89c4ff92SAndroid Build Coastguard Worker }
518*89c4ff92SAndroid Build Coastguard Worker else
519*89c4ff92SAndroid Build Coastguard Worker {
520*89c4ff92SAndroid Build Coastguard Worker auto quantizationParameters = TfLiteOpaqueTensorGetQuantizationParams(tfLiteTensor);
521*89c4ff92SAndroid Build Coastguard Worker ret.SetQuantizationScale(quantizationParameters.scale);
522*89c4ff92SAndroid Build Coastguard Worker ret.SetQuantizationOffset(quantizationParameters.zero_point);
523*89c4ff92SAndroid Build Coastguard Worker }
524*89c4ff92SAndroid Build Coastguard Worker
525*89c4ff92SAndroid Build Coastguard Worker return ret;
526*89c4ff92SAndroid Build Coastguard Worker }
527*89c4ff92SAndroid Build Coastguard Worker
CreateConstTensor(const TfLiteOpaqueTensor * tfLiteTensor,const armnn::TensorInfo & tensorInfo)528*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor CreateConstTensor(const TfLiteOpaqueTensor* tfLiteTensor,
529*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& tensorInfo)
530*89c4ff92SAndroid Build Coastguard Worker {
531*89c4ff92SAndroid Build Coastguard Worker auto allocType = TfLiteOpaqueTensorGetAllocationType(tfLiteTensor);
532*89c4ff92SAndroid Build Coastguard Worker if (allocType != kTfLiteMmapRo)
533*89c4ff92SAndroid Build Coastguard Worker {
534*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnDelegate: Not constant allocation type: " + std::to_string(allocType));
535*89c4ff92SAndroid Build Coastguard Worker }
536*89c4ff92SAndroid Build Coastguard Worker
537*89c4ff92SAndroid Build Coastguard Worker return armnn::ConstTensor(tensorInfo, TfLiteOpaqueTensorData(tfLiteTensor));
538*89c4ff92SAndroid Build Coastguard Worker }
539*89c4ff92SAndroid Build Coastguard Worker
GetConstTensorForTfLiteTensor(const TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode,int index)540*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor* GetConstTensorForTfLiteTensor(const TfLiteOpaqueContext* tfLiteContext,
541*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueNode* tfLiteNode,
542*89c4ff92SAndroid Build Coastguard Worker int index)
543*89c4ff92SAndroid Build Coastguard Worker {
544*89c4ff92SAndroid Build Coastguard Worker const TfLiteOpaqueTensor* tfLiteTensor = TfLiteOpaqueNodeGetInput(tfLiteContext, tfLiteNode, index);
545*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo tensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteTensor);
546*89c4ff92SAndroid Build Coastguard Worker
547*89c4ff92SAndroid Build Coastguard Worker return new armnn::ConstTensor(tensorInfo, TfLiteOpaqueTensorData(tfLiteTensor));
548*89c4ff92SAndroid Build Coastguard Worker }
549*89c4ff92SAndroid Build Coastguard Worker
IsOptionalOperandPresent(TfLiteOpaqueNode * tfLiteNode,const int operandIndex)550*89c4ff92SAndroid Build Coastguard Worker bool IsOptionalOperandPresent(TfLiteOpaqueNode* tfLiteNode, const int operandIndex)
551*89c4ff92SAndroid Build Coastguard Worker {
552*89c4ff92SAndroid Build Coastguard Worker // Get array of input indices, inputIndexArray is set from the TfLiteOpaqueNodeInputs function
553*89c4ff92SAndroid Build Coastguard Worker // This function turns inputIndexArray into an int array of indices. These indices point to the index of the
554*89c4ff92SAndroid Build Coastguard Worker // tensors for each input slot in the node.
555*89c4ff92SAndroid Build Coastguard Worker const int* inputIndexArray;
556*89c4ff92SAndroid Build Coastguard Worker int numInputs = 0;
557*89c4ff92SAndroid Build Coastguard Worker
558*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus status = TfLiteOpaqueNodeInputs(tfLiteNode, &inputIndexArray, &numInputs);
559*89c4ff92SAndroid Build Coastguard Worker if(status != kTfLiteOk)
560*89c4ff92SAndroid Build Coastguard Worker {
561*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to gather input information from node.");
562*89c4ff92SAndroid Build Coastguard Worker }
563*89c4ff92SAndroid Build Coastguard Worker
564*89c4ff92SAndroid Build Coastguard Worker // If the inputs array has fewer than operandIndex entries or if the entry at operandIndex has a value of -1 or
565*89c4ff92SAndroid Build Coastguard Worker // less then the input is not present.
566*89c4ff92SAndroid Build Coastguard Worker if (numInputs > operandIndex && inputIndexArray[operandIndex] >= 0)
567*89c4ff92SAndroid Build Coastguard Worker {
568*89c4ff92SAndroid Build Coastguard Worker return true;
569*89c4ff92SAndroid Build Coastguard Worker }
570*89c4ff92SAndroid Build Coastguard Worker return false;
571*89c4ff92SAndroid Build Coastguard Worker }
572*89c4ff92SAndroid Build Coastguard Worker
ProcessInputs(armnn::IConnectableLayer * layer,armnnOpaqueDelegate::DelegateData & delegateData,TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode)573*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ProcessInputs(armnn::IConnectableLayer* layer,
574*89c4ff92SAndroid Build Coastguard Worker armnnOpaqueDelegate::DelegateData& delegateData,
575*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueContext* tfLiteContext,
576*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueNode* tfLiteNode)
577*89c4ff92SAndroid Build Coastguard Worker {
578*89c4ff92SAndroid Build Coastguard Worker // Get array of input indices, inputIndexArray is set from the TfLiteOpaqueNodeInputs function
579*89c4ff92SAndroid Build Coastguard Worker // This function turns inputIndexArray into an int array of indices. These indices point to the index of the
580*89c4ff92SAndroid Build Coastguard Worker // tensors for each input slot in the node.
581*89c4ff92SAndroid Build Coastguard Worker const int* inputIndexArray;
582*89c4ff92SAndroid Build Coastguard Worker int numInputs = 0;
583*89c4ff92SAndroid Build Coastguard Worker
584*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus status = TfLiteOpaqueNodeInputs(tfLiteNode, &inputIndexArray, &numInputs);
585*89c4ff92SAndroid Build Coastguard Worker if(status != kTfLiteOk)
586*89c4ff92SAndroid Build Coastguard Worker {
587*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to gather input information from node.");
588*89c4ff92SAndroid Build Coastguard Worker }
589*89c4ff92SAndroid Build Coastguard Worker
590*89c4ff92SAndroid Build Coastguard Worker // Process input tensors
591*89c4ff92SAndroid Build Coastguard Worker // If input tensor is a Constant tensor create a constant layer and connect it to the network
592*89c4ff92SAndroid Build Coastguard Worker for (int32_t inputIndex = 0; inputIndex < static_cast<int32_t>(layer->GetNumInputSlots()); ++inputIndex)
593*89c4ff92SAndroid Build Coastguard Worker {
594*89c4ff92SAndroid Build Coastguard Worker const TfLiteOpaqueTensor* tfLiteInputTensor = TfLiteOpaqueNodeGetInput(tfLiteContext, tfLiteNode, inputIndex);
595*89c4ff92SAndroid Build Coastguard Worker
596*89c4ff92SAndroid Build Coastguard Worker if (IsConstantTensor(tfLiteInputTensor))
597*89c4ff92SAndroid Build Coastguard Worker {
598*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
599*89c4ff92SAndroid Build Coastguard Worker
600*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
601*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
602*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_OPAQUE_SUPPORT_FUNC("CONSTANT",
603*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
604*89c4ff92SAndroid Build Coastguard Worker IsConstantSupported,
605*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Backends,
606*89c4ff92SAndroid Build Coastguard Worker isSupported,
607*89c4ff92SAndroid Build Coastguard Worker setBackend,
608*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo);
609*89c4ff92SAndroid Build Coastguard Worker if (!isSupported)
610*89c4ff92SAndroid Build Coastguard Worker {
611*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
612*89c4ff92SAndroid Build Coastguard Worker }
613*89c4ff92SAndroid Build Coastguard Worker
614*89c4ff92SAndroid Build Coastguard Worker auto constantInput = CreateConstTensor(tfLiteInputTensor, inputTensorInfo);
615*89c4ff92SAndroid Build Coastguard Worker
616*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* constantLayer = delegateData.m_Network->AddConstantLayer(constantInput);
617*89c4ff92SAndroid Build Coastguard Worker constantLayer->SetBackendId(setBackend);
618*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = constantLayer->GetOutputSlot(0);
619*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(inputTensorInfo);
620*89c4ff92SAndroid Build Coastguard Worker
621*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode[inputIndexArray[inputIndex]] = &outputSlot;
622*89c4ff92SAndroid Build Coastguard Worker }
623*89c4ff92SAndroid Build Coastguard Worker }
624*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
625*89c4ff92SAndroid Build Coastguard Worker }
626*89c4ff92SAndroid Build Coastguard Worker
627*89c4ff92SAndroid Build Coastguard Worker } // namespace anonymous
628