1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include "Version.hpp"
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include "Activation.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "ArgMinMax.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "BatchMatMul.hpp"
13*89c4ff92SAndroid Build Coastguard Worker #include "BatchSpace.hpp"
14*89c4ff92SAndroid Build Coastguard Worker #include "Comparison.hpp"
15*89c4ff92SAndroid Build Coastguard Worker #include "Convolution.hpp"
16*89c4ff92SAndroid Build Coastguard Worker #include "Control.hpp"
17*89c4ff92SAndroid Build Coastguard Worker #include "ElementwiseBinary.hpp"
18*89c4ff92SAndroid Build Coastguard Worker #include "ElementwiseUnary.hpp"
19*89c4ff92SAndroid Build Coastguard Worker #include "Fill.hpp"
20*89c4ff92SAndroid Build Coastguard Worker #include "FullyConnected.hpp"
21*89c4ff92SAndroid Build Coastguard Worker #include "Gather.hpp"
22*89c4ff92SAndroid Build Coastguard Worker #include "GatherNd.hpp"
23*89c4ff92SAndroid Build Coastguard Worker #include "LogicalBinary.hpp"
24*89c4ff92SAndroid Build Coastguard Worker #include "Lstm.hpp"
25*89c4ff92SAndroid Build Coastguard Worker #include "Normalization.hpp"
26*89c4ff92SAndroid Build Coastguard Worker #include "Pack.hpp"
27*89c4ff92SAndroid Build Coastguard Worker #include "Pad.hpp"
28*89c4ff92SAndroid Build Coastguard Worker #include "Pooling.hpp"
29*89c4ff92SAndroid Build Coastguard Worker #include "Prelu.hpp"
30*89c4ff92SAndroid Build Coastguard Worker #include "Quantization.hpp"
31*89c4ff92SAndroid Build Coastguard Worker #include "Redefine.hpp"
32*89c4ff92SAndroid Build Coastguard Worker #include "Reduce.hpp"
33*89c4ff92SAndroid Build Coastguard Worker #include "Resize.hpp"
34*89c4ff92SAndroid Build Coastguard Worker #include "Round.hpp"
35*89c4ff92SAndroid Build Coastguard Worker #include "Shape.hpp"
36*89c4ff92SAndroid Build Coastguard Worker #include "Slice.hpp"
37*89c4ff92SAndroid Build Coastguard Worker #include "StridedSlice.hpp"
38*89c4ff92SAndroid Build Coastguard Worker #include "Softmax.hpp"
39*89c4ff92SAndroid Build Coastguard Worker #include "SpaceDepth.hpp"
40*89c4ff92SAndroid Build Coastguard Worker #include "Split.hpp"
41*89c4ff92SAndroid Build Coastguard Worker #include "Transpose.hpp"
42*89c4ff92SAndroid Build Coastguard Worker #include "UnidirectionalSequenceLstm.hpp"
43*89c4ff92SAndroid Build Coastguard Worker #include "Unpack.hpp"
44*89c4ff92SAndroid Build Coastguard Worker
45*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp>
46*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Timer.hpp>
47*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
48*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/context_util.h>
49*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/schema/schema_generated.h>
50*89c4ff92SAndroid Build Coastguard Worker
51*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
52*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
53*89c4ff92SAndroid Build Coastguard Worker #include <sstream>
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker
TfLiteArmnnDelegateOptionsDefault()58*89c4ff92SAndroid Build Coastguard Worker DelegateOptions TfLiteArmnnDelegateOptionsDefault()
59*89c4ff92SAndroid Build Coastguard Worker {
60*89c4ff92SAndroid Build Coastguard Worker DelegateOptions options(armnn::Compute::CpuRef);
61*89c4ff92SAndroid Build Coastguard Worker return options;
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker
TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options)64*89c4ff92SAndroid Build Coastguard Worker TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker auto* armnnDelegate = new ::armnnDelegate::Delegate(options);
67*89c4ff92SAndroid Build Coastguard Worker return armnnDelegate->GetDelegate();
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker
TfLiteArmnnDelegateDelete(TfLiteDelegate * tfLiteDelegate)70*89c4ff92SAndroid Build Coastguard Worker void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate)
71*89c4ff92SAndroid Build Coastguard Worker {
72*89c4ff92SAndroid Build Coastguard Worker if (tfLiteDelegate != nullptr)
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker delete static_cast<::armnnDelegate::Delegate*>(tfLiteDelegate->data_);
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker
DoPrepare(TfLiteContext * tfLiteContext,TfLiteDelegate * tfLiteDelegate)78*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus DoPrepare(TfLiteContext* tfLiteContext, TfLiteDelegate* tfLiteDelegate)
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* supportedOperators =
81*89c4ff92SAndroid Build Coastguard Worker static_cast<::armnnDelegate::Delegate*>(tfLiteDelegate->data_)->IdentifyOperatorsToDelegate(tfLiteContext);
82*89c4ff92SAndroid Build Coastguard Worker
83*89c4ff92SAndroid Build Coastguard Worker // ArmNN Delegate Registration
84*89c4ff92SAndroid Build Coastguard Worker static const TfLiteRegistration kArmnnSubgraphRegistration = {
85*89c4ff92SAndroid Build Coastguard Worker // ArmnnSubgraph Init
86*89c4ff92SAndroid Build Coastguard Worker .init = [](TfLiteContext* tfLiteContext, const char* buffer, size_t length) -> void* {
87*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(length);
88*89c4ff92SAndroid Build Coastguard Worker const TfLiteDelegateParams* parameters = reinterpret_cast<const TfLiteDelegateParams*>(buffer);
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker return static_cast<void*>(ArmnnSubgraph::Create(
91*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, parameters, static_cast<::armnnDelegate::Delegate*>(parameters->delegate->data_)));
92*89c4ff92SAndroid Build Coastguard Worker },
93*89c4ff92SAndroid Build Coastguard Worker // ArmnnSubgraph Free
94*89c4ff92SAndroid Build Coastguard Worker .free = [](TfLiteContext* tfLiteContext, void* buffer) -> void {
95*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(tfLiteContext);
96*89c4ff92SAndroid Build Coastguard Worker if (buffer != nullptr)
97*89c4ff92SAndroid Build Coastguard Worker {
98*89c4ff92SAndroid Build Coastguard Worker delete static_cast<ArmnnSubgraph*>(buffer);
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker },
101*89c4ff92SAndroid Build Coastguard Worker // ArmnnSubgraph Prepare
102*89c4ff92SAndroid Build Coastguard Worker .prepare = [](TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode) -> TfLiteStatus {
103*89c4ff92SAndroid Build Coastguard Worker if (tfLiteNode->user_data == nullptr)
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker return static_cast<ArmnnSubgraph*>(tfLiteNode->user_data)->Prepare(tfLiteContext);
108*89c4ff92SAndroid Build Coastguard Worker },
109*89c4ff92SAndroid Build Coastguard Worker // ArmnnSubgraph Invoke
110*89c4ff92SAndroid Build Coastguard Worker .invoke = [](TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode) -> TfLiteStatus {
111*89c4ff92SAndroid Build Coastguard Worker if (tfLiteNode->user_data == nullptr)
112*89c4ff92SAndroid Build Coastguard Worker {
113*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker return static_cast<ArmnnSubgraph*>(tfLiteNode->user_data)->Invoke(tfLiteContext, tfLiteNode);
117*89c4ff92SAndroid Build Coastguard Worker },
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker .profiling_string = nullptr,
120*89c4ff92SAndroid Build Coastguard Worker .builtin_code = kTfLiteBuiltinDelegate,
121*89c4ff92SAndroid Build Coastguard Worker .custom_name = "TfLiteArmNnDelegate",
122*89c4ff92SAndroid Build Coastguard Worker .version = 1,
123*89c4ff92SAndroid Build Coastguard Worker .registration_external = nullptr,
124*89c4ff92SAndroid Build Coastguard Worker };
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker const TfLiteStatus status =
127*89c4ff92SAndroid Build Coastguard Worker tfLiteContext->ReplaceNodeSubsetsWithDelegateKernels(
128*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, kArmnnSubgraphRegistration, supportedOperators, tfLiteDelegate);
129*89c4ff92SAndroid Build Coastguard Worker
130*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArrayFree(supportedOperators);
131*89c4ff92SAndroid Build Coastguard Worker return status;
132*89c4ff92SAndroid Build Coastguard Worker
133*89c4ff92SAndroid Build Coastguard Worker }
134*89c4ff92SAndroid Build Coastguard Worker
Delegate(armnnDelegate::DelegateOptions options)135*89c4ff92SAndroid Build Coastguard Worker Delegate::Delegate(armnnDelegate::DelegateOptions options)
136*89c4ff92SAndroid Build Coastguard Worker : m_Options(std::move(options))
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker // Configures logging for ARMNN
139*89c4ff92SAndroid Build Coastguard Worker if (m_Options.IsLoggingEnabled())
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker armnn::ConfigureLogging(true, true, m_Options.GetLoggingSeverity());
142*89c4ff92SAndroid Build Coastguard Worker }
143*89c4ff92SAndroid Build Coastguard Worker // Create/Get the static ArmNN Runtime. Note that the m_Runtime will be shared by all armnn_delegate
144*89c4ff92SAndroid Build Coastguard Worker // instances so the RuntimeOptions cannot be altered for different armnn_delegate instances.
145*89c4ff92SAndroid Build Coastguard Worker m_Runtime = GetRuntime(m_Options.GetRuntimeOptions());
146*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends;
147*89c4ff92SAndroid Build Coastguard Worker if (m_Runtime)
148*89c4ff92SAndroid Build Coastguard Worker {
149*89c4ff92SAndroid Build Coastguard Worker const armnn::BackendIdSet supportedDevices = m_Runtime->GetDeviceSpec().GetSupportedBackends();
150*89c4ff92SAndroid Build Coastguard Worker for (auto& backend : m_Options.GetBackends())
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker if (std::find(supportedDevices.cbegin(), supportedDevices.cend(), backend) == supportedDevices.cend())
153*89c4ff92SAndroid Build Coastguard Worker {
154*89c4ff92SAndroid Build Coastguard Worker TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
155*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Requested unknown backend %s", backend.Get().c_str());
156*89c4ff92SAndroid Build Coastguard Worker }
157*89c4ff92SAndroid Build Coastguard Worker else
158*89c4ff92SAndroid Build Coastguard Worker {
159*89c4ff92SAndroid Build Coastguard Worker backends.push_back(backend);
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker }
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker if (backends.empty())
165*89c4ff92SAndroid Build Coastguard Worker {
166*89c4ff92SAndroid Build Coastguard Worker // No known backend specified
167*89c4ff92SAndroid Build Coastguard Worker throw armnn::InvalidArgumentException("TfLiteArmnnDelegate: No known backend specified.");
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker m_Options.SetBackends(backends);
170*89c4ff92SAndroid Build Coastguard Worker
171*89c4ff92SAndroid Build Coastguard Worker TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "TfLiteArmnnDelegate: Created TfLite ArmNN delegate.");
172*89c4ff92SAndroid Build Coastguard Worker }
173*89c4ff92SAndroid Build Coastguard Worker
IdentifyOperatorsToDelegate(TfLiteContext * tfLiteContext)174*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* Delegate::IdentifyOperatorsToDelegate(TfLiteContext* tfLiteContext)
175*89c4ff92SAndroid Build Coastguard Worker {
176*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* executionPlan = nullptr;
177*89c4ff92SAndroid Build Coastguard Worker if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
178*89c4ff92SAndroid Build Coastguard Worker {
179*89c4ff92SAndroid Build Coastguard Worker TF_LITE_KERNEL_LOG(tfLiteContext, "TfLiteArmnnDelegate: Unable to get graph execution plan.");
180*89c4ff92SAndroid Build Coastguard Worker return nullptr;
181*89c4ff92SAndroid Build Coastguard Worker }
182*89c4ff92SAndroid Build Coastguard Worker
183*89c4ff92SAndroid Build Coastguard Worker // Delegate data with null network
184*89c4ff92SAndroid Build Coastguard Worker DelegateData delegateData(m_Options.GetBackends());
185*89c4ff92SAndroid Build Coastguard Worker
186*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* nodesToDelegate = TfLiteIntArrayCreate(executionPlan->size);
187*89c4ff92SAndroid Build Coastguard Worker nodesToDelegate->size = 0;
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Worker std::set<int32_t> unsupportedOperators;
190*89c4ff92SAndroid Build Coastguard Worker
191*89c4ff92SAndroid Build Coastguard Worker for (int i = 0; i < executionPlan->size; ++i)
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker const int nodeIndex = executionPlan->data[i];
194*89c4ff92SAndroid Build Coastguard Worker
195*89c4ff92SAndroid Build Coastguard Worker // If TfLite nodes can be delegated to ArmNN
196*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode = nullptr;
197*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistration* tfLiteRegistration = nullptr;
198*89c4ff92SAndroid Build Coastguard Worker if (tfLiteContext->GetNodeAndRegistration(
199*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
200*89c4ff92SAndroid Build Coastguard Worker {
201*89c4ff92SAndroid Build Coastguard Worker TF_LITE_KERNEL_LOG(tfLiteContext,
202*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Unable to get node and registration for node %d.",
203*89c4ff92SAndroid Build Coastguard Worker nodeIndex);
204*89c4ff92SAndroid Build Coastguard Worker continue;
205*89c4ff92SAndroid Build Coastguard Worker }
206*89c4ff92SAndroid Build Coastguard Worker
207*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus visitStatus;
208*89c4ff92SAndroid Build Coastguard Worker
209*89c4ff92SAndroid Build Coastguard Worker try
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker visitStatus = ArmnnSubgraph::VisitNode(
212*89c4ff92SAndroid Build Coastguard Worker delegateData, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex);
213*89c4ff92SAndroid Build Coastguard Worker }
214*89c4ff92SAndroid Build Coastguard Worker catch(std::exception& ex)
215*89c4ff92SAndroid Build Coastguard Worker {
216*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "ArmNN Failed to visit node with error: " << ex.what();
217*89c4ff92SAndroid Build Coastguard Worker visitStatus = kTfLiteError;
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker
220*89c4ff92SAndroid Build Coastguard Worker if ( visitStatus != kTfLiteOk)
221*89c4ff92SAndroid Build Coastguard Worker {
222*89c4ff92SAndroid Build Coastguard Worker // node is not supported by ArmNN
223*89c4ff92SAndroid Build Coastguard Worker unsupportedOperators.insert(tfLiteRegistration->builtin_code);
224*89c4ff92SAndroid Build Coastguard Worker continue;
225*89c4ff92SAndroid Build Coastguard Worker }
226*89c4ff92SAndroid Build Coastguard Worker
227*89c4ff92SAndroid Build Coastguard Worker nodesToDelegate->data[nodesToDelegate->size++] = nodeIndex;
228*89c4ff92SAndroid Build Coastguard Worker }
229*89c4ff92SAndroid Build Coastguard Worker
230*89c4ff92SAndroid Build Coastguard Worker for (std::set<int32_t>::iterator it=unsupportedOperators.begin(); it!=unsupportedOperators.end(); ++it)
231*89c4ff92SAndroid Build Coastguard Worker {
232*89c4ff92SAndroid Build Coastguard Worker TF_LITE_KERNEL_LOG(tfLiteContext,
233*89c4ff92SAndroid Build Coastguard Worker "Operator %s [%d] is not supported by armnn_delegate.",
234*89c4ff92SAndroid Build Coastguard Worker tflite::EnumNameBuiltinOperator(tflite::BuiltinOperator(*it)),
235*89c4ff92SAndroid Build Coastguard Worker *it);
236*89c4ff92SAndroid Build Coastguard Worker }
237*89c4ff92SAndroid Build Coastguard Worker
238*89c4ff92SAndroid Build Coastguard Worker if (!unsupportedOperators.empty() && m_Options.TfLiteRuntimeFallbackDisabled())
239*89c4ff92SAndroid Build Coastguard Worker {
240*89c4ff92SAndroid Build Coastguard Worker std::stringstream exMessage;
241*89c4ff92SAndroid Build Coastguard Worker exMessage << "TfLiteArmnnDelegate: There are unsupported operators in the model. ";
242*89c4ff92SAndroid Build Coastguard Worker exMessage << "Not falling back to TfLite Runtime as fallback is disabled. ";
243*89c4ff92SAndroid Build Coastguard Worker exMessage << "This should only be disabled under test conditions.";
244*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(exMessage.str());
245*89c4ff92SAndroid Build Coastguard Worker }
246*89c4ff92SAndroid Build Coastguard Worker if (nodesToDelegate->size == 0)
247*89c4ff92SAndroid Build Coastguard Worker {
248*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "No operators in this model are supported by the Arm NN TfLite delegate." <<
249*89c4ff92SAndroid Build Coastguard Worker " The model will be executed entirely by TfLite runtime.";
250*89c4ff92SAndroid Build Coastguard Worker }
251*89c4ff92SAndroid Build Coastguard Worker
252*89c4ff92SAndroid Build Coastguard Worker std::sort(&nodesToDelegate->data[0], &nodesToDelegate->data[nodesToDelegate->size]);
253*89c4ff92SAndroid Build Coastguard Worker return nodesToDelegate;
254*89c4ff92SAndroid Build Coastguard Worker }
255*89c4ff92SAndroid Build Coastguard Worker
GetDelegate()256*89c4ff92SAndroid Build Coastguard Worker TfLiteDelegate* Delegate::GetDelegate()
257*89c4ff92SAndroid Build Coastguard Worker {
258*89c4ff92SAndroid Build Coastguard Worker return &m_Delegate;
259*89c4ff92SAndroid Build Coastguard Worker }
260*89c4ff92SAndroid Build Coastguard Worker
GetVersion()261*89c4ff92SAndroid Build Coastguard Worker const std::string Delegate::GetVersion()
262*89c4ff92SAndroid Build Coastguard Worker {
263*89c4ff92SAndroid Build Coastguard Worker return DELEGATE_VERSION;
264*89c4ff92SAndroid Build Coastguard Worker }
265*89c4ff92SAndroid Build Coastguard Worker
AddInputLayer(DelegateData & delegateData,TfLiteContext * tfLiteContext,const TfLiteIntArray * inputs,std::vector<armnn::BindingPointInfo> & inputBindings)266*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::AddInputLayer(DelegateData& delegateData,
267*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
268*89c4ff92SAndroid Build Coastguard Worker const TfLiteIntArray* inputs,
269*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& inputBindings)
270*89c4ff92SAndroid Build Coastguard Worker {
271*89c4ff92SAndroid Build Coastguard Worker const size_t numInputs = static_cast<size_t>(inputs->size);
272*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numInputs; ++i)
273*89c4ff92SAndroid Build Coastguard Worker {
274*89c4ff92SAndroid Build Coastguard Worker const int32_t tensorId = inputs->data[i];
275*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor tensor = tfLiteContext->tensors[tensorId];
276*89c4ff92SAndroid Build Coastguard Worker // Do not create bindings for constant inputs
277*89c4ff92SAndroid Build Coastguard Worker if (tensor.allocation_type == kTfLiteMmapRo)
278*89c4ff92SAndroid Build Coastguard Worker {
279*89c4ff92SAndroid Build Coastguard Worker continue;
280*89c4ff92SAndroid Build Coastguard Worker }
281*89c4ff92SAndroid Build Coastguard Worker
282*89c4ff92SAndroid Build Coastguard Worker auto bindingId = static_cast<armnn::LayerBindingId>((tensorId));
283*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = delegateData.m_Network->AddInputLayer(bindingId);
284*89c4ff92SAndroid Build Coastguard Worker
285*89c4ff92SAndroid Build Coastguard Worker auto tensorInfo = GetTensorInfoForTfLiteTensor(tensor);
286*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
287*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(tensorInfo);
288*89c4ff92SAndroid Build Coastguard Worker
289*89c4ff92SAndroid Build Coastguard Worker // Store for creating connections
290*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tensorId)] = &outputSlot;
291*89c4ff92SAndroid Build Coastguard Worker
292*89c4ff92SAndroid Build Coastguard Worker inputBindings.push_back(std::make_pair(bindingId, tensorInfo));
293*89c4ff92SAndroid Build Coastguard Worker }
294*89c4ff92SAndroid Build Coastguard Worker
295*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
296*89c4ff92SAndroid Build Coastguard Worker }
297*89c4ff92SAndroid Build Coastguard Worker
AddOutputLayer(DelegateData & delegateData,TfLiteContext * tfLiteContext,const TfLiteIntArray * outputs,std::vector<armnn::BindingPointInfo> & outputBindings)298*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::AddOutputLayer(DelegateData& delegateData,
299*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
300*89c4ff92SAndroid Build Coastguard Worker const TfLiteIntArray* outputs,
301*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& outputBindings)
302*89c4ff92SAndroid Build Coastguard Worker {
303*89c4ff92SAndroid Build Coastguard Worker const size_t numOutputs = static_cast<size_t>(outputs->size);
304*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numOutputs; ++i)
305*89c4ff92SAndroid Build Coastguard Worker {
306*89c4ff92SAndroid Build Coastguard Worker const int32_t tensorId = outputs->data[i];
307*89c4ff92SAndroid Build Coastguard Worker const TfLiteTensor tensor = tfLiteContext->tensors[tensorId];
308*89c4ff92SAndroid Build Coastguard Worker
309*89c4ff92SAndroid Build Coastguard Worker auto bindingId = static_cast<armnn::LayerBindingId>((tensorId));
310*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = delegateData.m_Network->AddOutputLayer(bindingId);
311*89c4ff92SAndroid Build Coastguard Worker
312*89c4ff92SAndroid Build Coastguard Worker auto tensorInfo = GetTensorInfoForTfLiteTensor(tensor);
313*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tensorId)] != nullptr);
314*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tensorId)]->Connect(layer->GetInputSlot(0));
315*89c4ff92SAndroid Build Coastguard Worker outputBindings.push_back(std::make_pair(bindingId, tensorInfo));
316*89c4ff92SAndroid Build Coastguard Worker }
317*89c4ff92SAndroid Build Coastguard Worker
318*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
319*89c4ff92SAndroid Build Coastguard Worker }
320*89c4ff92SAndroid Build Coastguard Worker
Create(TfLiteContext * tfLiteContext,const TfLiteDelegateParams * parameters,const Delegate * delegate)321*89c4ff92SAndroid Build Coastguard Worker ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteContext* tfLiteContext,
322*89c4ff92SAndroid Build Coastguard Worker const TfLiteDelegateParams* parameters,
323*89c4ff92SAndroid Build Coastguard Worker const Delegate* delegate)
324*89c4ff92SAndroid Build Coastguard Worker {
325*89c4ff92SAndroid Build Coastguard Worker const auto startTime = armnn::GetTimeNow();
326*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "ArmnnSubgraph creation";
327*89c4ff92SAndroid Build Coastguard Worker
328*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* executionPlan;
329*89c4ff92SAndroid Build Coastguard Worker if (tfLiteContext->GetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
330*89c4ff92SAndroid Build Coastguard Worker {
331*89c4ff92SAndroid Build Coastguard Worker return nullptr;
332*89c4ff92SAndroid Build Coastguard Worker }
333*89c4ff92SAndroid Build Coastguard Worker
334*89c4ff92SAndroid Build Coastguard Worker // Initialize DelegateData holds network and output slots information
335*89c4ff92SAndroid Build Coastguard Worker DelegateData delegateData(delegate->m_Options.GetBackends());
336*89c4ff92SAndroid Build Coastguard Worker
337*89c4ff92SAndroid Build Coastguard Worker // Build ArmNN Network
338*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkOptions networkOptions = delegate->m_Options.GetOptimizerOptions().GetModelOptions();
339*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId networkId;
340*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Network = armnn::INetwork::Create(networkOptions);
341*89c4ff92SAndroid Build Coastguard Worker
342*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode = std::vector<armnn::IOutputSlot*>(tfLiteContext->tensors_size, nullptr);
343*89c4ff92SAndroid Build Coastguard Worker
344*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> inputBindings;
345*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> outputBindings;
346*89c4ff92SAndroid Build Coastguard Worker
347*89c4ff92SAndroid Build Coastguard Worker // Add input layer
348*89c4ff92SAndroid Build Coastguard Worker auto status = AddInputLayer(delegateData, tfLiteContext, parameters->input_tensors, inputBindings);
349*89c4ff92SAndroid Build Coastguard Worker if (status != kTfLiteOk)
350*89c4ff92SAndroid Build Coastguard Worker {
351*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnDelegate: Unable to add Inputs to the network!");
352*89c4ff92SAndroid Build Coastguard Worker }
353*89c4ff92SAndroid Build Coastguard Worker
354*89c4ff92SAndroid Build Coastguard Worker // Parse TfLite delegate nodes to ArmNN
355*89c4ff92SAndroid Build Coastguard Worker const auto parseStartTime = armnn::GetTimeNow();
356*89c4ff92SAndroid Build Coastguard Worker for (int i = 0; i < parameters->nodes_to_replace->size; ++i)
357*89c4ff92SAndroid Build Coastguard Worker {
358*89c4ff92SAndroid Build Coastguard Worker const int nodeIndex = parameters->nodes_to_replace->data[i];
359*89c4ff92SAndroid Build Coastguard Worker
360*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode = nullptr;
361*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistration* tfLiteRegistration = nullptr;
362*89c4ff92SAndroid Build Coastguard Worker if (tfLiteContext->GetNodeAndRegistration(
363*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
364*89c4ff92SAndroid Build Coastguard Worker {
365*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(&"TfLiteArmnnDelegate: Unable to get node registration: " [ nodeIndex]);
366*89c4ff92SAndroid Build Coastguard Worker }
367*89c4ff92SAndroid Build Coastguard Worker
368*89c4ff92SAndroid Build Coastguard Worker if (VisitNode(delegateData, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk)
369*89c4ff92SAndroid Build Coastguard Worker {
370*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(&"TfLiteArmnnDelegate: Unable to parse node: " [ nodeIndex]);
371*89c4ff92SAndroid Build Coastguard Worker }
372*89c4ff92SAndroid Build Coastguard Worker }
373*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Parse nodes to ArmNN time: " << std::setprecision(2)
374*89c4ff92SAndroid Build Coastguard Worker << std::fixed << armnn::GetTimeDuration(parseStartTime).count() << " ms";
375*89c4ff92SAndroid Build Coastguard Worker
376*89c4ff92SAndroid Build Coastguard Worker // Add Output layer
377*89c4ff92SAndroid Build Coastguard Worker status = AddOutputLayer(delegateData, tfLiteContext, parameters->output_tensors, outputBindings);
378*89c4ff92SAndroid Build Coastguard Worker if (status != kTfLiteOk)
379*89c4ff92SAndroid Build Coastguard Worker {
380*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnDelegate: Unable to add Outputs to the network!");
381*89c4ff92SAndroid Build Coastguard Worker }
382*89c4ff92SAndroid Build Coastguard Worker
383*89c4ff92SAndroid Build Coastguard Worker // Optimize ArmNN network
384*89c4ff92SAndroid Build Coastguard Worker armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
385*89c4ff92SAndroid Build Coastguard Worker try
386*89c4ff92SAndroid Build Coastguard Worker {
387*89c4ff92SAndroid Build Coastguard Worker const auto optimizeStartTime = armnn::GetTimeNow();
388*89c4ff92SAndroid Build Coastguard Worker optNet = armnn::Optimize(*(delegateData.m_Network.get()),
389*89c4ff92SAndroid Build Coastguard Worker delegate->m_Options.GetBackends(),
390*89c4ff92SAndroid Build Coastguard Worker delegate->m_Runtime->GetDeviceSpec(),
391*89c4ff92SAndroid Build Coastguard Worker delegate->m_Options.GetOptimizerOptions());
392*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Optimize ArmnnSubgraph time: " << std::setprecision(2)
393*89c4ff92SAndroid Build Coastguard Worker << std::fixed << armnn::GetTimeDuration(optimizeStartTime).count() << " ms";
394*89c4ff92SAndroid Build Coastguard Worker }
395*89c4ff92SAndroid Build Coastguard Worker catch (std::exception& ex)
396*89c4ff92SAndroid Build Coastguard Worker {
397*89c4ff92SAndroid Build Coastguard Worker std::stringstream exMessage;
398*89c4ff92SAndroid Build Coastguard Worker exMessage << "TfLiteArmnnDelegate: Exception (" << ex.what() << ") caught from optimize.";
399*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(exMessage.str());
400*89c4ff92SAndroid Build Coastguard Worker }
401*89c4ff92SAndroid Build Coastguard Worker if (!optNet)
402*89c4ff92SAndroid Build Coastguard Worker {
403*89c4ff92SAndroid Build Coastguard Worker // Optimize failed
404*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnDelegate: Unable to optimize the network!");
405*89c4ff92SAndroid Build Coastguard Worker }
406*89c4ff92SAndroid Build Coastguard Worker
407*89c4ff92SAndroid Build Coastguard Worker // If set, we will serialize the optimized model into a dot file.
408*89c4ff92SAndroid Build Coastguard Worker const std::string serializeToDotFile = delegate->m_Options.GetSerializeToDot();
409*89c4ff92SAndroid Build Coastguard Worker if (!serializeToDotFile.empty())
410*89c4ff92SAndroid Build Coastguard Worker {
411*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Writing graph to dot file: " << serializeToDotFile;
412*89c4ff92SAndroid Build Coastguard Worker fs::path filename = serializeToDotFile;
413*89c4ff92SAndroid Build Coastguard Worker std::fstream file(filename.c_str(), std::ios_base::out);
414*89c4ff92SAndroid Build Coastguard Worker optNet->SerializeToDot(file);
415*89c4ff92SAndroid Build Coastguard Worker }
416*89c4ff92SAndroid Build Coastguard Worker
417*89c4ff92SAndroid Build Coastguard Worker try
418*89c4ff92SAndroid Build Coastguard Worker {
419*89c4ff92SAndroid Build Coastguard Worker const auto loadStartTime = armnn::GetTimeNow();
420*89c4ff92SAndroid Build Coastguard Worker
421*89c4ff92SAndroid Build Coastguard Worker // Load graph into runtime
422*89c4ff92SAndroid Build Coastguard Worker std::string errorMessage;
423*89c4ff92SAndroid Build Coastguard Worker armnn::Status loadingStatus;
424*89c4ff92SAndroid Build Coastguard Worker armnn::MemorySource inputSource = armnn::MemorySource::Undefined;
425*89c4ff92SAndroid Build Coastguard Worker armnn::MemorySource outputSource = armnn::MemorySource::Undefined;
426*89c4ff92SAndroid Build Coastguard Worker // There's a bit of an assumption here that the delegate will only support Malloc memory source.
427*89c4ff92SAndroid Build Coastguard Worker if (delegate->m_Options.GetOptimizerOptions().GetImportEnabled())
428*89c4ff92SAndroid Build Coastguard Worker {
429*89c4ff92SAndroid Build Coastguard Worker inputSource = armnn::MemorySource::Malloc;
430*89c4ff92SAndroid Build Coastguard Worker }
431*89c4ff92SAndroid Build Coastguard Worker if (delegate->m_Options.GetOptimizerOptions().GetExportEnabled())
432*89c4ff92SAndroid Build Coastguard Worker {
433*89c4ff92SAndroid Build Coastguard Worker outputSource = armnn::MemorySource::Malloc;
434*89c4ff92SAndroid Build Coastguard Worker }
435*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkProperties networkProperties(false,
436*89c4ff92SAndroid Build Coastguard Worker inputSource,
437*89c4ff92SAndroid Build Coastguard Worker outputSource,
438*89c4ff92SAndroid Build Coastguard Worker delegate->m_Options.GetInternalProfilingState(),
439*89c4ff92SAndroid Build Coastguard Worker delegate->m_Options.GetInternalProfilingDetail());
440*89c4ff92SAndroid Build Coastguard Worker loadingStatus = delegate->m_Runtime->LoadNetwork(networkId,
441*89c4ff92SAndroid Build Coastguard Worker std::move(optNet),
442*89c4ff92SAndroid Build Coastguard Worker errorMessage,
443*89c4ff92SAndroid Build Coastguard Worker networkProperties);
444*89c4ff92SAndroid Build Coastguard Worker if (loadingStatus != armnn::Status::Success)
445*89c4ff92SAndroid Build Coastguard Worker {
446*89c4ff92SAndroid Build Coastguard Worker // Network load failed.
447*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnDelegate: Network could not be loaded: " + errorMessage);
448*89c4ff92SAndroid Build Coastguard Worker }
449*89c4ff92SAndroid Build Coastguard Worker
450*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Load ArmnnSubgraph time: " << std::setprecision(2)
451*89c4ff92SAndroid Build Coastguard Worker << std::fixed << armnn::GetTimeDuration(loadStartTime).count() << " ms";
452*89c4ff92SAndroid Build Coastguard Worker }
453*89c4ff92SAndroid Build Coastguard Worker catch (std::exception& ex)
454*89c4ff92SAndroid Build Coastguard Worker {
455*89c4ff92SAndroid Build Coastguard Worker std::stringstream exMessage;
456*89c4ff92SAndroid Build Coastguard Worker exMessage << "TfLiteArmnnDelegate: Exception (" << ex.what() << ") caught from LoadNetwork.";
457*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(exMessage.str());
458*89c4ff92SAndroid Build Coastguard Worker }
459*89c4ff92SAndroid Build Coastguard Worker
460*89c4ff92SAndroid Build Coastguard Worker // Register debug callback function
461*89c4ff92SAndroid Build Coastguard Worker if (delegate->m_Options.GetDebugCallbackFunction().has_value())
462*89c4ff92SAndroid Build Coastguard Worker {
463*89c4ff92SAndroid Build Coastguard Worker delegate->m_Runtime->RegisterDebugCallback(networkId, delegate->m_Options.GetDebugCallbackFunction().value());
464*89c4ff92SAndroid Build Coastguard Worker }
465*89c4ff92SAndroid Build Coastguard Worker
466*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Overall ArmnnSubgraph creation time: " << std::setprecision(2)
467*89c4ff92SAndroid Build Coastguard Worker << std::fixed << armnn::GetTimeDuration(startTime).count() << " ms\n";
468*89c4ff92SAndroid Build Coastguard Worker
469*89c4ff92SAndroid Build Coastguard Worker // Create a new SubGraph with networkId and runtime
470*89c4ff92SAndroid Build Coastguard Worker return new ArmnnSubgraph(networkId, delegate->m_Runtime, inputBindings, outputBindings);
471*89c4ff92SAndroid Build Coastguard Worker }
472*89c4ff92SAndroid Build Coastguard Worker
Prepare(TfLiteContext * tfLiteContext)473*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::Prepare(TfLiteContext* tfLiteContext)
474*89c4ff92SAndroid Build Coastguard Worker {
475*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(tfLiteContext);
476*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
477*89c4ff92SAndroid Build Coastguard Worker }
478*89c4ff92SAndroid Build Coastguard Worker
Invoke(TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode)479*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode)
480*89c4ff92SAndroid Build Coastguard Worker {
481*89c4ff92SAndroid Build Coastguard Worker // Prepare inputs
482*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors inputTensors;
483*89c4ff92SAndroid Build Coastguard Worker size_t inputIndex = 0;
484*89c4ff92SAndroid Build Coastguard Worker for (auto inputIdx : tflite::TfLiteIntArrayView(tfLiteNode->inputs))
485*89c4ff92SAndroid Build Coastguard Worker {
486*89c4ff92SAndroid Build Coastguard Worker TfLiteTensor* tensor = &tfLiteContext->tensors[inputIdx];
487*89c4ff92SAndroid Build Coastguard Worker if (tensor->allocation_type != kTfLiteMmapRo)
488*89c4ff92SAndroid Build Coastguard Worker {
489*89c4ff92SAndroid Build Coastguard Worker const armnn::BindingPointInfo& inputBinding = m_InputBindings[inputIndex];
490*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = inputBinding.second;
491*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true);
492*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor inputTensor(inputTensorInfo, tensor->data.data);
493*89c4ff92SAndroid Build Coastguard Worker inputTensors.emplace_back(inputIdx, inputTensor);
494*89c4ff92SAndroid Build Coastguard Worker
495*89c4ff92SAndroid Build Coastguard Worker ++inputIndex;
496*89c4ff92SAndroid Build Coastguard Worker }
497*89c4ff92SAndroid Build Coastguard Worker }
498*89c4ff92SAndroid Build Coastguard Worker
499*89c4ff92SAndroid Build Coastguard Worker // Prepare outputs
500*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors outputTensors;
501*89c4ff92SAndroid Build Coastguard Worker size_t outputIndex = 0;
502*89c4ff92SAndroid Build Coastguard Worker for (auto outputIdx : tflite::TfLiteIntArrayView(tfLiteNode->outputs))
503*89c4ff92SAndroid Build Coastguard Worker {
504*89c4ff92SAndroid Build Coastguard Worker const armnn::BindingPointInfo& outputBinding = m_OutputBindings[outputIndex];
505*89c4ff92SAndroid Build Coastguard Worker TfLiteTensor* tensor = &tfLiteContext->tensors[outputIdx];
506*89c4ff92SAndroid Build Coastguard Worker const armnn::Tensor outputTensor(outputBinding.second, tensor->data.data);
507*89c4ff92SAndroid Build Coastguard Worker outputTensors.emplace_back(outputIdx, outputTensor);
508*89c4ff92SAndroid Build Coastguard Worker
509*89c4ff92SAndroid Build Coastguard Worker ++outputIndex;
510*89c4ff92SAndroid Build Coastguard Worker }
511*89c4ff92SAndroid Build Coastguard Worker
512*89c4ff92SAndroid Build Coastguard Worker // Run graph
513*89c4ff92SAndroid Build Coastguard Worker auto status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
514*89c4ff92SAndroid Build Coastguard Worker // The delegate holds its own Arm NN runtime so this is our last chance to print internal profiling data.
515*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkId);
516*89c4ff92SAndroid Build Coastguard Worker if (profiler && profiler->IsProfilingEnabled())
517*89c4ff92SAndroid Build Coastguard Worker {
518*89c4ff92SAndroid Build Coastguard Worker profiler->Print(std::cout);
519*89c4ff92SAndroid Build Coastguard Worker }
520*89c4ff92SAndroid Build Coastguard Worker return (status == armnn::Status::Success) ? kTfLiteOk : kTfLiteError;
521*89c4ff92SAndroid Build Coastguard Worker }
522*89c4ff92SAndroid Build Coastguard Worker
VisitNode(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteRegistration * tfLiteRegistration,TfLiteNode * tfLiteNode,int nodeIndex)523*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData,
524*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext,
525*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistration* tfLiteRegistration,
526*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode,
527*89c4ff92SAndroid Build Coastguard Worker int nodeIndex)
528*89c4ff92SAndroid Build Coastguard Worker {
529*89c4ff92SAndroid Build Coastguard Worker switch (tfLiteRegistration->builtin_code)
530*89c4ff92SAndroid Build Coastguard Worker {
531*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinCustom:
532*89c4ff92SAndroid Build Coastguard Worker {
533*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_POST_TFLITE_2_5)
534*89c4ff92SAndroid Build Coastguard Worker // Custom operators are defined by the name rather than the builtin code.
535*89c4ff92SAndroid Build Coastguard Worker // Parse the custom_name param in the registration to point to the correct visitor function.
536*89c4ff92SAndroid Build Coastguard Worker std::string customOperatorName = tfLiteRegistration->custom_name;
537*89c4ff92SAndroid Build Coastguard Worker if ( customOperatorName == "AveragePool3D" )
538*89c4ff92SAndroid Build Coastguard Worker {
539*89c4ff92SAndroid Build Coastguard Worker return VisitPooling3dOperator(delegateData,
540*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
541*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
542*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
543*89c4ff92SAndroid Build Coastguard Worker customOperatorName);
544*89c4ff92SAndroid Build Coastguard Worker }
545*89c4ff92SAndroid Build Coastguard Worker else if (customOperatorName == "MaxPool3D")
546*89c4ff92SAndroid Build Coastguard Worker {
547*89c4ff92SAndroid Build Coastguard Worker return VisitPooling3dOperator(delegateData,
548*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
549*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
550*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
551*89c4ff92SAndroid Build Coastguard Worker customOperatorName);
552*89c4ff92SAndroid Build Coastguard Worker }
553*89c4ff92SAndroid Build Coastguard Worker #endif
554*89c4ff92SAndroid Build Coastguard Worker // Invalid or unsupported custom operator
555*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
556*89c4ff92SAndroid Build Coastguard Worker }
557*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinAbs:
558*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseUnaryOperator(delegateData,
559*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
560*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
561*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
562*89c4ff92SAndroid Build Coastguard Worker armnn::UnaryOperation::Abs);
563*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinAdd:
564*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseBinaryOperator(delegateData,
565*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
566*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
567*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
568*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinAdd);
569*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinArgMax:
570*89c4ff92SAndroid Build Coastguard Worker return VisitArgMinMaxOperator(delegateData,
571*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
572*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
573*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
574*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinArgMax);
575*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinArgMin:
576*89c4ff92SAndroid Build Coastguard Worker return VisitArgMinMaxOperator(delegateData,
577*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
578*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
579*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
580*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinArgMin);
581*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinAveragePool2d:
582*89c4ff92SAndroid Build Coastguard Worker return VisitPooling2dOperator(delegateData,
583*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
584*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
585*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
586*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinAveragePool2d);
587*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinBatchMatmul:
588*89c4ff92SAndroid Build Coastguard Worker return VisitBatchMatMulOperator(delegateData,
589*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
590*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
591*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
592*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinBatchMatmul);
593*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinBatchToSpaceNd:
594*89c4ff92SAndroid Build Coastguard Worker return VisitBatchToSpaceNdOperator(delegateData,
595*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
596*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
597*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
598*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinBatchToSpaceNd);
599*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinCast:
600*89c4ff92SAndroid Build Coastguard Worker return VisitCastOperator(delegateData,
601*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
602*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
603*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
604*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinCast);
605*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinCeil:
606*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseUnaryOperator(delegateData,
607*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
608*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
609*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
610*89c4ff92SAndroid Build Coastguard Worker armnn::UnaryOperation::Ceil);
611*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinConcatenation:
612*89c4ff92SAndroid Build Coastguard Worker return VisitControlOperator(delegateData,
613*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
614*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
615*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
616*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinConcatenation);
617*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinConv2d:
618*89c4ff92SAndroid Build Coastguard Worker return VisitConvolutionOperator(delegateData,
619*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
620*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
621*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
622*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinConv2d);
623*89c4ff92SAndroid Build Coastguard Worker // Conv3d is only correctly supported for external delegates from TF Lite v2.6, as there was a breaking bug in v2.5.
624*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_POST_TFLITE_2_5)
625*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinConv3d:
626*89c4ff92SAndroid Build Coastguard Worker return VisitConvolutionOperator(delegateData,
627*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
628*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
629*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
630*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinConv3d);
631*89c4ff92SAndroid Build Coastguard Worker #endif
632*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinDepthToSpace:
633*89c4ff92SAndroid Build Coastguard Worker return VisitDepthToSpaceOperator(delegateData,
634*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
635*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
636*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
637*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinDepthToSpace);
638*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinDepthwiseConv2d:
639*89c4ff92SAndroid Build Coastguard Worker return VisitConvolutionOperator(delegateData,
640*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
641*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
642*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
643*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinDepthwiseConv2d);
644*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinDequantize:
645*89c4ff92SAndroid Build Coastguard Worker return VisitDequantizeOperator(delegateData,
646*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
647*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
648*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
649*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinDequantize);
650*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinDiv:
651*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseBinaryOperator(delegateData,
652*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
653*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
654*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
655*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinDiv);
656*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinElu:
657*89c4ff92SAndroid Build Coastguard Worker return VisitActivationOperator(delegateData,
658*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
659*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
660*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
661*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinElu);
662*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinEqual:
663*89c4ff92SAndroid Build Coastguard Worker return VisitComparisonOperator(delegateData,
664*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
665*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
666*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
667*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinEqual);
668*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinExp:
669*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseUnaryOperator(delegateData,
670*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
671*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
672*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
673*89c4ff92SAndroid Build Coastguard Worker armnn::UnaryOperation::Exp);
674*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinExpandDims:
675*89c4ff92SAndroid Build Coastguard Worker return VisitExpandDimsOperator(delegateData,
676*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
677*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
678*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
679*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinExpandDims);
680*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinFill:
681*89c4ff92SAndroid Build Coastguard Worker return VisitFillOperator(delegateData,
682*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
683*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
684*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
685*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinFill);
686*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinFloor:
687*89c4ff92SAndroid Build Coastguard Worker return VisitFloorOperator(delegateData,
688*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
689*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
690*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
691*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinFloor);
692*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinFloorDiv:
693*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseBinaryOperator(delegateData,
694*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
695*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
696*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
697*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinFloorDiv);
698*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinFullyConnected:
699*89c4ff92SAndroid Build Coastguard Worker return VisitFullyConnectedOperator(delegateData,
700*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
701*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
702*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
703*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinFullyConnected);
704*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinGather:
705*89c4ff92SAndroid Build Coastguard Worker return VisitGatherOperator(delegateData,
706*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
707*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
708*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
709*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinGather);
710*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinGatherNd:
711*89c4ff92SAndroid Build Coastguard Worker return VisitGatherNdOperator(delegateData,
712*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
713*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
714*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
715*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinGatherNd);
716*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinGreater:
717*89c4ff92SAndroid Build Coastguard Worker return VisitComparisonOperator(delegateData,
718*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
719*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
720*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
721*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinGreater);
722*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinGreaterEqual:
723*89c4ff92SAndroid Build Coastguard Worker return VisitComparisonOperator(delegateData,
724*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
725*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
726*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
727*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinGreaterEqual);
728*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinHardSwish:
729*89c4ff92SAndroid Build Coastguard Worker return VisitActivationOperator(delegateData,
730*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
731*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
732*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
733*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinHardSwish);
734*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinL2Normalization:
735*89c4ff92SAndroid Build Coastguard Worker return VisitL2NormalizationOperator(delegateData,
736*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
737*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
738*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
739*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinL2Normalization);
740*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinL2Pool2d:
741*89c4ff92SAndroid Build Coastguard Worker return VisitPooling2dOperator(delegateData,
742*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
743*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
744*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
745*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinL2Pool2d);
746*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinLess:
747*89c4ff92SAndroid Build Coastguard Worker return VisitComparisonOperator(delegateData,
748*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
749*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
750*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
751*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinLess);
752*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinLessEqual:
753*89c4ff92SAndroid Build Coastguard Worker return VisitComparisonOperator(delegateData,
754*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
755*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
756*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
757*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinLessEqual);
758*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinLocalResponseNormalization:
759*89c4ff92SAndroid Build Coastguard Worker return VisitLocalResponseNormalizationOperator(delegateData,
760*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
761*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
762*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
763*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinLocalResponseNormalization);
764*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinLog:
765*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseUnaryOperator(delegateData,
766*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
767*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
768*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
769*89c4ff92SAndroid Build Coastguard Worker armnn::UnaryOperation::Log);
770*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinLogicalAnd:
771*89c4ff92SAndroid Build Coastguard Worker return VisitLogicalBinaryOperator(delegateData,
772*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
773*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
774*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
775*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinLogicalAnd,
776*89c4ff92SAndroid Build Coastguard Worker armnn::LogicalBinaryOperation::LogicalAnd);
777*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinLogicalNot:
778*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseUnaryOperator(delegateData,
779*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
780*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
781*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
782*89c4ff92SAndroid Build Coastguard Worker armnn::UnaryOperation::LogicalNot);
783*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinLogicalOr:
784*89c4ff92SAndroid Build Coastguard Worker return VisitLogicalBinaryOperator(delegateData,
785*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
786*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
787*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
788*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinLogicalOr,
789*89c4ff92SAndroid Build Coastguard Worker armnn::LogicalBinaryOperation::LogicalOr);
790*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinLogistic:
791*89c4ff92SAndroid Build Coastguard Worker return VisitActivationOperator(delegateData,
792*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
793*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
794*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
795*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinLogistic);
796*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinLogSoftmax:
797*89c4ff92SAndroid Build Coastguard Worker return VisitSoftmaxOperator(delegateData,
798*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
799*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
800*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
801*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinLogSoftmax);
802*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinLstm:
803*89c4ff92SAndroid Build Coastguard Worker return VisitLstmOperator(delegateData,
804*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
805*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
806*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
807*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinLstm);
808*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMaxPool2d:
809*89c4ff92SAndroid Build Coastguard Worker return VisitPooling2dOperator(delegateData,
810*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
811*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
812*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
813*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinMaxPool2d);
814*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMaximum:
815*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseBinaryOperator(delegateData,
816*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
817*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
818*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
819*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinMaximum);
820*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMean:
821*89c4ff92SAndroid Build Coastguard Worker return VisitControlOperator(delegateData,
822*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
823*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
824*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
825*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinMean);
826*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMinimum:
827*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseBinaryOperator(delegateData,
828*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
829*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
830*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
831*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinMinimum);
832*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMirrorPad:
833*89c4ff92SAndroid Build Coastguard Worker return VisitPadOperator(delegateData,
834*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
835*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
836*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
837*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinMirrorPad);
838*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinMul:
839*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseBinaryOperator(delegateData,
840*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
841*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
842*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
843*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinMul);
844*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinNeg:
845*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseUnaryOperator(delegateData,
846*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
847*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
848*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
849*89c4ff92SAndroid Build Coastguard Worker armnn::UnaryOperation::Neg);
850*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinNotEqual:
851*89c4ff92SAndroid Build Coastguard Worker return VisitComparisonOperator(delegateData,
852*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
853*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
854*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
855*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinNotEqual);
856*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinPack:
857*89c4ff92SAndroid Build Coastguard Worker return VisitPackOperator(delegateData,
858*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
859*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
860*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
861*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinPack);
862*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinPad:
863*89c4ff92SAndroid Build Coastguard Worker return VisitPadOperator(delegateData,
864*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
865*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
866*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
867*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinPad);
868*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinPadv2:
869*89c4ff92SAndroid Build Coastguard Worker return VisitPadOperator(delegateData,
870*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
871*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
872*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
873*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinPadv2);
874*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinPrelu:
875*89c4ff92SAndroid Build Coastguard Worker return VisitPreluOperator(delegateData,
876*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
877*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
878*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
879*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinPrelu);
880*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinQuantize:
881*89c4ff92SAndroid Build Coastguard Worker return VisitQuantizeOperator(delegateData,
882*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
883*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
884*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
885*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinQuantize);
886*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinRank:
887*89c4ff92SAndroid Build Coastguard Worker return VisitControlOperator(delegateData,
888*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
889*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
890*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
891*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinRank);
892*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinReduceMax:
893*89c4ff92SAndroid Build Coastguard Worker return VisitReduceOperator(delegateData,
894*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
895*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
896*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
897*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinReduceMax);
898*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinReduceMin:
899*89c4ff92SAndroid Build Coastguard Worker return VisitReduceOperator(delegateData,
900*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
901*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
902*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
903*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinReduceMin);
904*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinReduceProd:
905*89c4ff92SAndroid Build Coastguard Worker return VisitReduceOperator(delegateData,
906*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
907*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
908*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
909*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinReduceProd);
910*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinRelu:
911*89c4ff92SAndroid Build Coastguard Worker return VisitActivationOperator(delegateData,
912*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
913*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
914*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
915*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinRelu);
916*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinReluN1To1:
917*89c4ff92SAndroid Build Coastguard Worker return VisitActivationOperator(delegateData,
918*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
919*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
920*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
921*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinReluN1To1);
922*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinRelu6:
923*89c4ff92SAndroid Build Coastguard Worker return VisitActivationOperator(delegateData,
924*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
925*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
926*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
927*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinRelu6);
928*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinReshape:
929*89c4ff92SAndroid Build Coastguard Worker return VisitReshapeOperator(delegateData,
930*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
931*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
932*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
933*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinReshape);
934*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinResizeBilinear:
935*89c4ff92SAndroid Build Coastguard Worker return VisitResizeOperator(delegateData,
936*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
937*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
938*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
939*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinResizeBilinear);
940*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinResizeNearestNeighbor:
941*89c4ff92SAndroid Build Coastguard Worker return VisitResizeOperator(delegateData,
942*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
943*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
944*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
945*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinResizeNearestNeighbor);
946*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinRsqrt:
947*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseUnaryOperator(delegateData,
948*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
949*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
950*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
951*89c4ff92SAndroid Build Coastguard Worker armnn::UnaryOperation::Rsqrt);
952*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinShape:
953*89c4ff92SAndroid Build Coastguard Worker return VisitShapeOperator(delegateData,
954*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
955*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
956*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
957*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinShape);
958*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSin:
959*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseUnaryOperator(delegateData,
960*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
961*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
962*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
963*89c4ff92SAndroid Build Coastguard Worker armnn::UnaryOperation::Sin);
964*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSplit:
965*89c4ff92SAndroid Build Coastguard Worker return VisitSplitOperator(delegateData,
966*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
967*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
968*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
969*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinSplit);
970*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSplitV:
971*89c4ff92SAndroid Build Coastguard Worker return VisitSplitVOperator(delegateData,
972*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
973*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
974*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
975*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinSplitV);
976*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSqrt:
977*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseUnaryOperator(delegateData,
978*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
979*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
980*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
981*89c4ff92SAndroid Build Coastguard Worker armnn::UnaryOperation::Sqrt);
982*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSqueeze:
983*89c4ff92SAndroid Build Coastguard Worker return VisitSqueezeOperator(delegateData,
984*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
985*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
986*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
987*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinSqueeze);
988*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSlice:
989*89c4ff92SAndroid Build Coastguard Worker return VisitSliceOperator(delegateData,
990*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
991*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
992*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
993*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinSlice);
994*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinStridedSlice:
995*89c4ff92SAndroid Build Coastguard Worker return VisitStridedSliceOperator(delegateData,
996*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
997*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
998*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
999*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinStridedSlice);
1000*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSum:
1001*89c4ff92SAndroid Build Coastguard Worker return VisitReduceOperator(delegateData,
1002*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
1003*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
1004*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
1005*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinSum);
1006*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinTranspose:
1007*89c4ff92SAndroid Build Coastguard Worker return VisitTransposeOperator(delegateData,
1008*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
1009*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
1010*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
1011*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinTranspose);
1012*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinTransposeConv:
1013*89c4ff92SAndroid Build Coastguard Worker return VisitConvolutionOperator(delegateData,
1014*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
1015*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
1016*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
1017*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinTransposeConv);
1018*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSoftmax:
1019*89c4ff92SAndroid Build Coastguard Worker return VisitSoftmaxOperator(delegateData,
1020*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
1021*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
1022*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
1023*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinSoftmax);
1024*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSpaceToBatchNd:
1025*89c4ff92SAndroid Build Coastguard Worker return VisitSpaceToBatchNdOperator(delegateData,
1026*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
1027*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
1028*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
1029*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinSpaceToBatchNd);
1030*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSpaceToDepth:
1031*89c4ff92SAndroid Build Coastguard Worker return VisitSpaceToDepthOperator(delegateData,
1032*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
1033*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
1034*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
1035*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinSpaceToDepth);
1036*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinSub:
1037*89c4ff92SAndroid Build Coastguard Worker return VisitElementwiseBinaryOperator(delegateData,
1038*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
1039*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
1040*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
1041*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinSub);
1042*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinTanh:
1043*89c4ff92SAndroid Build Coastguard Worker return VisitActivationOperator(delegateData,
1044*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
1045*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
1046*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
1047*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinTanh);
1048*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinUnidirectionalSequenceLstm:
1049*89c4ff92SAndroid Build Coastguard Worker return VisitUnidirectionalSequenceLstmOperator(delegateData,
1050*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
1051*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
1052*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
1053*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinUnidirectionalSequenceLstm);
1054*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinUnpack:
1055*89c4ff92SAndroid Build Coastguard Worker return VisitUnpackOperator(delegateData,
1056*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
1057*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
1058*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
1059*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinUnpack);
1060*89c4ff92SAndroid Build Coastguard Worker default:
1061*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
1062*89c4ff92SAndroid Build Coastguard Worker }
1063*89c4ff92SAndroid Build Coastguard Worker }
1064*89c4ff92SAndroid Build Coastguard Worker
1065*89c4ff92SAndroid Build Coastguard Worker } // armnnDelegate namespace