1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <OpaqueDelegateUtils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <Version.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include "Activation.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "ArgMinMax.hpp"
13*89c4ff92SAndroid Build Coastguard Worker #include "BatchMatMul.hpp"
14*89c4ff92SAndroid Build Coastguard Worker #include "BatchSpace.hpp"
15*89c4ff92SAndroid Build Coastguard Worker #include "Comparison.hpp"
16*89c4ff92SAndroid Build Coastguard Worker #include "Convolution.hpp"
17*89c4ff92SAndroid Build Coastguard Worker #include "Control.hpp"
18*89c4ff92SAndroid Build Coastguard Worker #include "ElementwiseBinary.hpp"
19*89c4ff92SAndroid Build Coastguard Worker #include "ElementwiseUnary.hpp"
20*89c4ff92SAndroid Build Coastguard Worker #include "Fill.hpp"
21*89c4ff92SAndroid Build Coastguard Worker #include "FullyConnected.hpp"
22*89c4ff92SAndroid Build Coastguard Worker #include "Gather.hpp"
23*89c4ff92SAndroid Build Coastguard Worker #include "GatherNd.hpp"
24*89c4ff92SAndroid Build Coastguard Worker #include "LogicalBinary.hpp"
25*89c4ff92SAndroid Build Coastguard Worker #include "Lstm.hpp"
26*89c4ff92SAndroid Build Coastguard Worker #include "Normalization.hpp"
27*89c4ff92SAndroid Build Coastguard Worker #include "Pack.hpp"
28*89c4ff92SAndroid Build Coastguard Worker #include "Pad.hpp"
29*89c4ff92SAndroid Build Coastguard Worker #include "Pooling.hpp"
30*89c4ff92SAndroid Build Coastguard Worker #include "Prelu.hpp"
31*89c4ff92SAndroid Build Coastguard Worker #include "Quantization.hpp"
32*89c4ff92SAndroid Build Coastguard Worker #include "Redefine.hpp"
33*89c4ff92SAndroid Build Coastguard Worker #include "Reduce.hpp"
34*89c4ff92SAndroid Build Coastguard Worker #include "Resize.hpp"
35*89c4ff92SAndroid Build Coastguard Worker #include "Round.hpp"
36*89c4ff92SAndroid Build Coastguard Worker #include "Shape.hpp"
37*89c4ff92SAndroid Build Coastguard Worker #include "Slice.hpp"
38*89c4ff92SAndroid Build Coastguard Worker #include "StridedSlice.hpp"
39*89c4ff92SAndroid Build Coastguard Worker #include "Softmax.hpp"
40*89c4ff92SAndroid Build Coastguard Worker #include "SpaceDepth.hpp"
41*89c4ff92SAndroid Build Coastguard Worker #include "Split.hpp"
42*89c4ff92SAndroid Build Coastguard Worker #include "Transpose.hpp"
43*89c4ff92SAndroid Build Coastguard Worker #include "UnidirectionalSequenceLstm.hpp"
44*89c4ff92SAndroid Build Coastguard Worker #include "Unpack.hpp"
45*89c4ff92SAndroid Build Coastguard Worker
46*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
47*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp>
48*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Timer.hpp>
49*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
50*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/context_util.h>
51*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/schema/schema_generated.h>
52*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
53*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/logger.h>
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
56*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
57*89c4ff92SAndroid Build Coastguard Worker #include <sstream>
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker namespace armnnOpaqueDelegate
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker
62*89c4ff92SAndroid Build Coastguard Worker const TfLiteStableDelegate TFL_TheStableDelegate =
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker /*delegate_abi_version=*/ TFL_STABLE_DELEGATE_ABI_VERSION,
65*89c4ff92SAndroid Build Coastguard Worker /*delegate_name=*/ "ArmnnDelegatePlugin",
66*89c4ff92SAndroid Build Coastguard Worker /*delegate_version=*/ "1.0.0",
67*89c4ff92SAndroid Build Coastguard Worker /*delegate_plugin=*/ GetArmnnDelegatePluginApi()
68*89c4ff92SAndroid Build Coastguard Worker };
69*89c4ff92SAndroid Build Coastguard Worker
ArmnnOpaqueDelegate(armnnDelegate::DelegateOptions options)70*89c4ff92SAndroid Build Coastguard Worker ArmnnOpaqueDelegate::ArmnnOpaqueDelegate(armnnDelegate::DelegateOptions options)
71*89c4ff92SAndroid Build Coastguard Worker : m_Options(std::move(options))
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker // Configures logging for ARMNN
74*89c4ff92SAndroid Build Coastguard Worker if (m_Options.IsLoggingEnabled())
75*89c4ff92SAndroid Build Coastguard Worker {
76*89c4ff92SAndroid Build Coastguard Worker armnn::ConfigureLogging(true, true, m_Options.GetLoggingSeverity());
77*89c4ff92SAndroid Build Coastguard Worker }
78*89c4ff92SAndroid Build Coastguard Worker // Create/Get the static ArmNN Runtime. Note that the m_Runtime will be shared by all armnn_delegate
79*89c4ff92SAndroid Build Coastguard Worker // instances so the RuntimeOptions cannot be altered for different armnn_delegate instances.
80*89c4ff92SAndroid Build Coastguard Worker m_Runtime = GetRuntime(m_Options.GetRuntimeOptions());
81*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends;
82*89c4ff92SAndroid Build Coastguard Worker if (m_Runtime)
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker const armnn::BackendIdSet supportedDevices = m_Runtime->GetDeviceSpec().GetSupportedBackends();
85*89c4ff92SAndroid Build Coastguard Worker for (auto& backend : m_Options.GetBackends())
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker if (std::find(supportedDevices.cbegin(), supportedDevices.cend(), backend) == supportedDevices.cend())
88*89c4ff92SAndroid Build Coastguard Worker {
89*89c4ff92SAndroid Build Coastguard Worker TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
90*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnDelegate: Requested unknown backend %s", backend.Get().c_str());
91*89c4ff92SAndroid Build Coastguard Worker }
92*89c4ff92SAndroid Build Coastguard Worker else
93*89c4ff92SAndroid Build Coastguard Worker {
94*89c4ff92SAndroid Build Coastguard Worker backends.push_back(backend);
95*89c4ff92SAndroid Build Coastguard Worker }
96*89c4ff92SAndroid Build Coastguard Worker }
97*89c4ff92SAndroid Build Coastguard Worker }
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker if (backends.empty())
100*89c4ff92SAndroid Build Coastguard Worker {
101*89c4ff92SAndroid Build Coastguard Worker // No known backend specified
102*89c4ff92SAndroid Build Coastguard Worker throw armnn::InvalidArgumentException("TfLiteArmnnOpaqueDelegate: No known backend specified.");
103*89c4ff92SAndroid Build Coastguard Worker }
104*89c4ff92SAndroid Build Coastguard Worker m_Options.SetBackends(backends);
105*89c4ff92SAndroid Build Coastguard Worker
106*89c4ff92SAndroid Build Coastguard Worker TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "TfLiteArmnnOpaqueDelegate: Created TfLite ArmNN delegate.");
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker
DoPrepare(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueDelegate * tfLiteDelegate,void * data)109*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus DoPrepare(TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueDelegate* tfLiteDelegate, void* data)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker // We are required to have the void* data parameter in the function signature, but we don't actually use it.
112*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(data);
113*89c4ff92SAndroid Build Coastguard Worker
114*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* supportedOperators =
115*89c4ff92SAndroid Build Coastguard Worker static_cast<::armnnOpaqueDelegate::ArmnnOpaqueDelegate*>
116*89c4ff92SAndroid Build Coastguard Worker (TfLiteOpaqueDelegateGetData(tfLiteDelegate))->IdentifyOperatorsToDelegate(tfLiteContext);
117*89c4ff92SAndroid Build Coastguard Worker if(supportedOperators == nullptr)
118*89c4ff92SAndroid Build Coastguard Worker {
119*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker
122*89c4ff92SAndroid Build Coastguard Worker // ArmNN Opaque Delegate Registration
123*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistrationExternal* kernelRegistration =
124*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistrationExternalCreate(kTfLiteBuiltinDelegate, "TfLiteArmNNOpaqueDelegate", /*version=*/1);
125*89c4ff92SAndroid Build Coastguard Worker if(kernelRegistration == nullptr)
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
128*89c4ff92SAndroid Build Coastguard Worker }
129*89c4ff92SAndroid Build Coastguard Worker
130*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistrationExternalSetInit(
131*89c4ff92SAndroid Build Coastguard Worker kernelRegistration,
132*89c4ff92SAndroid Build Coastguard Worker [](TfLiteOpaqueContext* tfLiteContext, const char* buffer, size_t length) -> void*
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(length);
135*89c4ff92SAndroid Build Coastguard Worker const TfLiteOpaqueDelegateParams* parameters =
136*89c4ff92SAndroid Build Coastguard Worker reinterpret_cast<const TfLiteOpaqueDelegateParams*>(buffer);
137*89c4ff92SAndroid Build Coastguard Worker if(parameters == nullptr)
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
140*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnOpaqueDelegate: Unable to get parameters.");
141*89c4ff92SAndroid Build Coastguard Worker return nullptr;
142*89c4ff92SAndroid Build Coastguard Worker }
143*89c4ff92SAndroid Build Coastguard Worker
144*89c4ff92SAndroid Build Coastguard Worker return static_cast<void*>(
145*89c4ff92SAndroid Build Coastguard Worker ArmnnSubgraph::Create(tfLiteContext,
146*89c4ff92SAndroid Build Coastguard Worker parameters,
147*89c4ff92SAndroid Build Coastguard Worker static_cast<::armnnOpaqueDelegate::ArmnnOpaqueDelegate*>(
148*89c4ff92SAndroid Build Coastguard Worker parameters->delegate->opaque_delegate_builder->data)));
149*89c4ff92SAndroid Build Coastguard Worker }
150*89c4ff92SAndroid Build Coastguard Worker );
151*89c4ff92SAndroid Build Coastguard Worker
152*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistrationExternalSetFree(
153*89c4ff92SAndroid Build Coastguard Worker kernelRegistration,
154*89c4ff92SAndroid Build Coastguard Worker [](TfLiteOpaqueContext* tfLiteContext, void* buffer) -> void
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(tfLiteContext);
157*89c4ff92SAndroid Build Coastguard Worker if (buffer != nullptr)
158*89c4ff92SAndroid Build Coastguard Worker {
159*89c4ff92SAndroid Build Coastguard Worker delete static_cast<ArmnnSubgraph*>(buffer);
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker }
162*89c4ff92SAndroid Build Coastguard Worker );
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistrationExternalSetPrepare(
165*89c4ff92SAndroid Build Coastguard Worker kernelRegistration,
166*89c4ff92SAndroid Build Coastguard Worker [](TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueNode* tfLiteNode) -> TfLiteStatus
167*89c4ff92SAndroid Build Coastguard Worker {
168*89c4ff92SAndroid Build Coastguard Worker void* userData = TfLiteOpaqueNodeGetUserData(tfLiteNode);
169*89c4ff92SAndroid Build Coastguard Worker if (userData == nullptr)
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
172*89c4ff92SAndroid Build Coastguard Worker }
173*89c4ff92SAndroid Build Coastguard Worker return static_cast<ArmnnSubgraph*>(userData)->Prepare(tfLiteContext);
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker );
176*89c4ff92SAndroid Build Coastguard Worker
177*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistrationExternalSetInvoke(
178*89c4ff92SAndroid Build Coastguard Worker kernelRegistration,
179*89c4ff92SAndroid Build Coastguard Worker [](TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueNode* tfLiteNode) -> TfLiteStatus
180*89c4ff92SAndroid Build Coastguard Worker {
181*89c4ff92SAndroid Build Coastguard Worker void* userData = TfLiteOpaqueNodeGetUserData(tfLiteNode);
182*89c4ff92SAndroid Build Coastguard Worker if (userData == nullptr)
183*89c4ff92SAndroid Build Coastguard Worker {
184*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
185*89c4ff92SAndroid Build Coastguard Worker }
186*89c4ff92SAndroid Build Coastguard Worker
187*89c4ff92SAndroid Build Coastguard Worker return static_cast<ArmnnSubgraph*>(userData)->Invoke(tfLiteContext, tfLiteNode);
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker );
190*89c4ff92SAndroid Build Coastguard Worker
191*89c4ff92SAndroid Build Coastguard Worker const TfLiteStatus status =
192*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueContextReplaceNodeSubsetsWithDelegateKernels(
193*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, kernelRegistration, supportedOperators, tfLiteDelegate);
194*89c4ff92SAndroid Build Coastguard Worker
195*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArrayFree(supportedOperators);
196*89c4ff92SAndroid Build Coastguard Worker return status;
197*89c4ff92SAndroid Build Coastguard Worker }
198*89c4ff92SAndroid Build Coastguard Worker
TfLiteArmnnOpaqueDelegateCreate(const void * settings)199*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueDelegate* TfLiteArmnnOpaqueDelegateCreate(const void* settings)
200*89c4ff92SAndroid Build Coastguard Worker {
201*89c4ff92SAndroid Build Coastguard Worker // This method will always create Opaque Delegate with default settings until
202*89c4ff92SAndroid Build Coastguard Worker // we have a DelegateOptions Constructor which can parse the void* settings
203*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(settings);
204*89c4ff92SAndroid Build Coastguard Worker auto options = TfLiteArmnnDelegateOptionsDefault();
205*89c4ff92SAndroid Build Coastguard Worker auto* armnnDelegate = new ::armnnOpaqueDelegate::ArmnnOpaqueDelegate(options);
206*89c4ff92SAndroid Build Coastguard Worker return TfLiteOpaqueDelegateCreate(armnnDelegate->GetDelegateBuilder());
207*89c4ff92SAndroid Build Coastguard Worker }
208*89c4ff92SAndroid Build Coastguard Worker
TfLiteArmnnDelegateOptionsDefault()209*89c4ff92SAndroid Build Coastguard Worker ::armnnDelegate::DelegateOptions TfLiteArmnnDelegateOptionsDefault()
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker ::armnnDelegate::DelegateOptions options(armnn::Compute::CpuRef);
212*89c4ff92SAndroid Build Coastguard Worker return options;
213*89c4ff92SAndroid Build Coastguard Worker }
214*89c4ff92SAndroid Build Coastguard Worker
TfLiteArmnnOpaqueDelegateDelete(TfLiteOpaqueDelegate * tfLiteDelegate)215*89c4ff92SAndroid Build Coastguard Worker void TfLiteArmnnOpaqueDelegateDelete(TfLiteOpaqueDelegate* tfLiteDelegate)
216*89c4ff92SAndroid Build Coastguard Worker {
217*89c4ff92SAndroid Build Coastguard Worker if (tfLiteDelegate != nullptr)
218*89c4ff92SAndroid Build Coastguard Worker {
219*89c4ff92SAndroid Build Coastguard Worker delete static_cast<::armnnOpaqueDelegate::ArmnnOpaqueDelegate*>(TfLiteOpaqueDelegateGetData(tfLiteDelegate));
220*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueDelegateDelete(tfLiteDelegate);
221*89c4ff92SAndroid Build Coastguard Worker }
222*89c4ff92SAndroid Build Coastguard Worker }
223*89c4ff92SAndroid Build Coastguard Worker
GetArmnnDelegatePluginApi()224*89c4ff92SAndroid Build Coastguard Worker const TfLiteOpaqueDelegatePlugin* GetArmnnDelegatePluginApi()
225*89c4ff92SAndroid Build Coastguard Worker {
226*89c4ff92SAndroid Build Coastguard Worker static constexpr TfLiteOpaqueDelegatePlugin armnnPlugin{
227*89c4ff92SAndroid Build Coastguard Worker TfLiteArmnnOpaqueDelegateCreate, TfLiteArmnnOpaqueDelegateDelete, TfLiteArmnnOpaqueDelegateErrno};
228*89c4ff92SAndroid Build Coastguard Worker return &armnnPlugin;
229*89c4ff92SAndroid Build Coastguard Worker }
230*89c4ff92SAndroid Build Coastguard Worker
GetVersion()231*89c4ff92SAndroid Build Coastguard Worker const std::string ArmnnOpaqueDelegate::GetVersion() {
232*89c4ff92SAndroid Build Coastguard Worker return OPAQUE_DELEGATE_VERSION;
233*89c4ff92SAndroid Build Coastguard Worker }
234*89c4ff92SAndroid Build Coastguard Worker
IdentifyOperatorsToDelegate(TfLiteOpaqueContext * tfLiteContext)235*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* ArmnnOpaqueDelegate::IdentifyOperatorsToDelegate(TfLiteOpaqueContext* tfLiteContext)
236*89c4ff92SAndroid Build Coastguard Worker {
237*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* executionPlan = nullptr;
238*89c4ff92SAndroid Build Coastguard Worker if (TfLiteOpaqueContextGetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
239*89c4ff92SAndroid Build Coastguard Worker {
240*89c4ff92SAndroid Build Coastguard Worker TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext, "TfLiteArmnnOpaqueDelegate: Unable to get graph execution plan.");
241*89c4ff92SAndroid Build Coastguard Worker return nullptr;
242*89c4ff92SAndroid Build Coastguard Worker }
243*89c4ff92SAndroid Build Coastguard Worker
244*89c4ff92SAndroid Build Coastguard Worker // Delegate data with null network
245*89c4ff92SAndroid Build Coastguard Worker DelegateData delegateData(m_Options.GetBackends());
246*89c4ff92SAndroid Build Coastguard Worker
247*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* nodesToDelegate = TfLiteIntArrayCreate(executionPlan->size);
248*89c4ff92SAndroid Build Coastguard Worker if (nodesToDelegate == nullptr)
249*89c4ff92SAndroid Build Coastguard Worker {
250*89c4ff92SAndroid Build Coastguard Worker TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
251*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnOpaqueDelegate: Unable to create int array from execution plan.");
252*89c4ff92SAndroid Build Coastguard Worker return nullptr;
253*89c4ff92SAndroid Build Coastguard Worker }
254*89c4ff92SAndroid Build Coastguard Worker nodesToDelegate->size = 0;
255*89c4ff92SAndroid Build Coastguard Worker
256*89c4ff92SAndroid Build Coastguard Worker std::set<int32_t> unsupportedOperators;
257*89c4ff92SAndroid Build Coastguard Worker
258*89c4ff92SAndroid Build Coastguard Worker for (int i = 0; i < executionPlan->size; ++i)
259*89c4ff92SAndroid Build Coastguard Worker {
260*89c4ff92SAndroid Build Coastguard Worker const int nodeIndex = executionPlan->data[i];
261*89c4ff92SAndroid Build Coastguard Worker
262*89c4ff92SAndroid Build Coastguard Worker // If TfLiteOpaqueNodes can be delegated to ArmNN
263*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueNode* tfLiteNode = nullptr;
264*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistrationExternal* tfLiteRegistration = nullptr;
265*89c4ff92SAndroid Build Coastguard Worker
266*89c4ff92SAndroid Build Coastguard Worker if (TfLiteOpaqueContextGetNodeAndRegistration(
267*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
268*89c4ff92SAndroid Build Coastguard Worker {
269*89c4ff92SAndroid Build Coastguard Worker TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
270*89c4ff92SAndroid Build Coastguard Worker "TfLiteArmnnOpaqueDelegate: Unable to get node and registration for node %d.",
271*89c4ff92SAndroid Build Coastguard Worker nodeIndex);
272*89c4ff92SAndroid Build Coastguard Worker continue;
273*89c4ff92SAndroid Build Coastguard Worker }
274*89c4ff92SAndroid Build Coastguard Worker
275*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus visitStatus;
276*89c4ff92SAndroid Build Coastguard Worker try
277*89c4ff92SAndroid Build Coastguard Worker {
278*89c4ff92SAndroid Build Coastguard Worker visitStatus = ArmnnSubgraph::VisitNode(
279*89c4ff92SAndroid Build Coastguard Worker delegateData, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex);
280*89c4ff92SAndroid Build Coastguard Worker }
281*89c4ff92SAndroid Build Coastguard Worker catch(std::exception& ex)
282*89c4ff92SAndroid Build Coastguard Worker {
283*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "ArmNN Failed to visit node with error: " << ex.what();
284*89c4ff92SAndroid Build Coastguard Worker visitStatus = kTfLiteError;
285*89c4ff92SAndroid Build Coastguard Worker }
286*89c4ff92SAndroid Build Coastguard Worker
287*89c4ff92SAndroid Build Coastguard Worker if (visitStatus != kTfLiteOk)
288*89c4ff92SAndroid Build Coastguard Worker {
289*89c4ff92SAndroid Build Coastguard Worker // node is not supported by ArmNN
290*89c4ff92SAndroid Build Coastguard Worker unsupportedOperators.insert(TfLiteRegistrationExternalGetBuiltInCode(tfLiteRegistration));
291*89c4ff92SAndroid Build Coastguard Worker continue;
292*89c4ff92SAndroid Build Coastguard Worker }
293*89c4ff92SAndroid Build Coastguard Worker
294*89c4ff92SAndroid Build Coastguard Worker nodesToDelegate->data[nodesToDelegate->size++] = nodeIndex;
295*89c4ff92SAndroid Build Coastguard Worker }
296*89c4ff92SAndroid Build Coastguard Worker
297*89c4ff92SAndroid Build Coastguard Worker for (std::set<int32_t>::iterator it=unsupportedOperators.begin(); it!=unsupportedOperators.end(); ++it)
298*89c4ff92SAndroid Build Coastguard Worker {
299*89c4ff92SAndroid Build Coastguard Worker TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
300*89c4ff92SAndroid Build Coastguard Worker "Operator %s [%d] is not supported by armnn_opaque_delegate.",
301*89c4ff92SAndroid Build Coastguard Worker tflite::EnumNameBuiltinOperator(tflite::BuiltinOperator(*it)),
302*89c4ff92SAndroid Build Coastguard Worker *it);
303*89c4ff92SAndroid Build Coastguard Worker }
304*89c4ff92SAndroid Build Coastguard Worker
305*89c4ff92SAndroid Build Coastguard Worker if (!unsupportedOperators.empty() && m_Options.TfLiteRuntimeFallbackDisabled())
306*89c4ff92SAndroid Build Coastguard Worker {
307*89c4ff92SAndroid Build Coastguard Worker std::stringstream exMessage;
308*89c4ff92SAndroid Build Coastguard Worker exMessage << "TfLiteArmnnOpaqueDelegate: There are unsupported operators in the model. ";
309*89c4ff92SAndroid Build Coastguard Worker exMessage << "Not falling back to TfLite Runtime as fallback is disabled. ";
310*89c4ff92SAndroid Build Coastguard Worker exMessage << "This should only be disabled under test conditions.";
311*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(exMessage.str());
312*89c4ff92SAndroid Build Coastguard Worker }
313*89c4ff92SAndroid Build Coastguard Worker if (nodesToDelegate->size == 0)
314*89c4ff92SAndroid Build Coastguard Worker {
315*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "No operators in this model are supported by the Arm NN TfLite delegate." <<
316*89c4ff92SAndroid Build Coastguard Worker " The model will be executed entirely by TfLite runtime.";
317*89c4ff92SAndroid Build Coastguard Worker }
318*89c4ff92SAndroid Build Coastguard Worker
319*89c4ff92SAndroid Build Coastguard Worker std::sort(&nodesToDelegate->data[0], &nodesToDelegate->data[nodesToDelegate->size]);
320*89c4ff92SAndroid Build Coastguard Worker return nodesToDelegate;
321*89c4ff92SAndroid Build Coastguard Worker }
322*89c4ff92SAndroid Build Coastguard Worker
AddInputLayer(DelegateData & delegateData,TfLiteOpaqueContext * tfLiteContext,const TfLiteIntArray * inputs,std::vector<armnn::BindingPointInfo> & inputBindings)323*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::AddInputLayer(DelegateData& delegateData,
324*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueContext* tfLiteContext,
325*89c4ff92SAndroid Build Coastguard Worker const TfLiteIntArray* inputs,
326*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& inputBindings)
327*89c4ff92SAndroid Build Coastguard Worker {
328*89c4ff92SAndroid Build Coastguard Worker const size_t numInputs = static_cast<size_t>(inputs->size);
329*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numInputs; ++i)
330*89c4ff92SAndroid Build Coastguard Worker {
331*89c4ff92SAndroid Build Coastguard Worker const int32_t tensorId = inputs->data[i];
332*89c4ff92SAndroid Build Coastguard Worker const TfLiteOpaqueTensor* tensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, tensorId);
333*89c4ff92SAndroid Build Coastguard Worker
334*89c4ff92SAndroid Build Coastguard Worker if(!tensor)
335*89c4ff92SAndroid Build Coastguard Worker {
336*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
337*89c4ff92SAndroid Build Coastguard Worker }
338*89c4ff92SAndroid Build Coastguard Worker
339*89c4ff92SAndroid Build Coastguard Worker // Do not create bindings for constant inputs
340*89c4ff92SAndroid Build Coastguard Worker if (TfLiteOpaqueTensorGetAllocationType(tensor) == kTfLiteMmapRo)
341*89c4ff92SAndroid Build Coastguard Worker {
342*89c4ff92SAndroid Build Coastguard Worker continue;
343*89c4ff92SAndroid Build Coastguard Worker }
344*89c4ff92SAndroid Build Coastguard Worker
345*89c4ff92SAndroid Build Coastguard Worker auto bindingId = static_cast<armnn::LayerBindingId>((tensorId));
346*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = delegateData.m_Network->AddInputLayer(bindingId);
347*89c4ff92SAndroid Build Coastguard Worker
348*89c4ff92SAndroid Build Coastguard Worker auto tensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tensor);
349*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
350*89c4ff92SAndroid Build Coastguard Worker outputSlot.SetTensorInfo(tensorInfo);
351*89c4ff92SAndroid Build Coastguard Worker
352*89c4ff92SAndroid Build Coastguard Worker // Store for creating connections
353*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tensorId)] = &outputSlot;
354*89c4ff92SAndroid Build Coastguard Worker
355*89c4ff92SAndroid Build Coastguard Worker inputBindings.push_back(std::make_pair(bindingId, tensorInfo));
356*89c4ff92SAndroid Build Coastguard Worker }
357*89c4ff92SAndroid Build Coastguard Worker
358*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
359*89c4ff92SAndroid Build Coastguard Worker }
360*89c4ff92SAndroid Build Coastguard Worker
AddOutputLayer(DelegateData & delegateData,TfLiteOpaqueContext * tfLiteContext,const TfLiteIntArray * outputs,std::vector<armnn::BindingPointInfo> & outputBindings)361*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::AddOutputLayer(DelegateData& delegateData,
362*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueContext* tfLiteContext,
363*89c4ff92SAndroid Build Coastguard Worker const TfLiteIntArray* outputs,
364*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& outputBindings)
365*89c4ff92SAndroid Build Coastguard Worker {
366*89c4ff92SAndroid Build Coastguard Worker const size_t numOutputs = static_cast<size_t>(outputs->size);
367*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numOutputs; ++i)
368*89c4ff92SAndroid Build Coastguard Worker {
369*89c4ff92SAndroid Build Coastguard Worker const int32_t tensorId = outputs->data[i];
370*89c4ff92SAndroid Build Coastguard Worker const TfLiteOpaqueTensor* tensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, tensorId);
371*89c4ff92SAndroid Build Coastguard Worker
372*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(tensor))
373*89c4ff92SAndroid Build Coastguard Worker {
374*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
375*89c4ff92SAndroid Build Coastguard Worker }
376*89c4ff92SAndroid Build Coastguard Worker
377*89c4ff92SAndroid Build Coastguard Worker auto bindingId = static_cast<armnn::LayerBindingId>((tensorId));
378*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* layer = delegateData.m_Network->AddOutputLayer(bindingId);
379*89c4ff92SAndroid Build Coastguard Worker
380*89c4ff92SAndroid Build Coastguard Worker auto tensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tensor);
381*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tensorId)] != nullptr);
382*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tensorId)]->Connect(layer->GetInputSlot(0));
383*89c4ff92SAndroid Build Coastguard Worker outputBindings.push_back(std::make_pair(bindingId, tensorInfo));
384*89c4ff92SAndroid Build Coastguard Worker }
385*89c4ff92SAndroid Build Coastguard Worker
386*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
387*89c4ff92SAndroid Build Coastguard Worker }
388*89c4ff92SAndroid Build Coastguard Worker
Create(TfLiteOpaqueContext * tfLiteContext,const TfLiteOpaqueDelegateParams * parameters,const ArmnnOpaqueDelegate * delegate)389*89c4ff92SAndroid Build Coastguard Worker ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteOpaqueContext* tfLiteContext,
390*89c4ff92SAndroid Build Coastguard Worker const TfLiteOpaqueDelegateParams* parameters,
391*89c4ff92SAndroid Build Coastguard Worker const ArmnnOpaqueDelegate* delegate)
392*89c4ff92SAndroid Build Coastguard Worker {
393*89c4ff92SAndroid Build Coastguard Worker const auto startTime = armnn::GetTimeNow();
394*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "ArmnnSubgraph creation";
395*89c4ff92SAndroid Build Coastguard Worker
396*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* executionPlan;
397*89c4ff92SAndroid Build Coastguard Worker if (TfLiteOpaqueContextGetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
398*89c4ff92SAndroid Build Coastguard Worker {
399*89c4ff92SAndroid Build Coastguard Worker return nullptr;
400*89c4ff92SAndroid Build Coastguard Worker }
401*89c4ff92SAndroid Build Coastguard Worker
402*89c4ff92SAndroid Build Coastguard Worker // Initialize DelegateData holds network and output slots information
403*89c4ff92SAndroid Build Coastguard Worker DelegateData delegateData(delegate->m_Options.GetBackends());
404*89c4ff92SAndroid Build Coastguard Worker
405*89c4ff92SAndroid Build Coastguard Worker // Build ArmNN Network
406*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkOptions networkOptions = delegate->m_Options.GetOptimizerOptions().GetModelOptions();
407*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId networkId;
408*89c4ff92SAndroid Build Coastguard Worker delegateData.m_Network = armnn::INetwork::Create(networkOptions);
409*89c4ff92SAndroid Build Coastguard Worker
410*89c4ff92SAndroid Build Coastguard Worker delegateData.m_OutputSlotForNode = std::vector<armnn::IOutputSlot*>(
411*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueContextGetNumTensors(tfLiteContext), nullptr);
412*89c4ff92SAndroid Build Coastguard Worker
413*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> inputBindings;
414*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> outputBindings;
415*89c4ff92SAndroid Build Coastguard Worker
416*89c4ff92SAndroid Build Coastguard Worker // Add input layer
417*89c4ff92SAndroid Build Coastguard Worker if (AddInputLayer(delegateData, tfLiteContext, parameters->input_tensors, inputBindings) != kTfLiteOk)
418*89c4ff92SAndroid Build Coastguard Worker {
419*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to add Inputs to the network!");
420*89c4ff92SAndroid Build Coastguard Worker }
421*89c4ff92SAndroid Build Coastguard Worker
422*89c4ff92SAndroid Build Coastguard Worker // Parse TfLite delegate nodes to ArmNN
423*89c4ff92SAndroid Build Coastguard Worker const auto parseStartTime = armnn::GetTimeNow();
424*89c4ff92SAndroid Build Coastguard Worker for (int i = 0; i < parameters->nodes_to_replace->size; ++i)
425*89c4ff92SAndroid Build Coastguard Worker {
426*89c4ff92SAndroid Build Coastguard Worker const int nodeIndex = parameters->nodes_to_replace->data[i];
427*89c4ff92SAndroid Build Coastguard Worker
428*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueNode* tfLiteNode = nullptr;
429*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistrationExternal* tfLiteRegistration = nullptr;
430*89c4ff92SAndroid Build Coastguard Worker if (TfLiteOpaqueContextGetNodeAndRegistration(
431*89c4ff92SAndroid Build Coastguard Worker tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
432*89c4ff92SAndroid Build Coastguard Worker {
433*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(&"TfLiteArmnnOpaqueDelegate: Unable to get node registration: " [ nodeIndex]);
434*89c4ff92SAndroid Build Coastguard Worker }
435*89c4ff92SAndroid Build Coastguard Worker
436*89c4ff92SAndroid Build Coastguard Worker if (VisitNode(delegateData, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk)
437*89c4ff92SAndroid Build Coastguard Worker {
438*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(&"TfLiteArmnnOpaqueDelegate: Unable to parse node: " [ nodeIndex]);
439*89c4ff92SAndroid Build Coastguard Worker }
440*89c4ff92SAndroid Build Coastguard Worker }
441*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Parse nodes to ArmNN time: " << std::setprecision(2)
442*89c4ff92SAndroid Build Coastguard Worker << std::fixed << armnn::GetTimeDuration(parseStartTime).count() << " ms";
443*89c4ff92SAndroid Build Coastguard Worker
444*89c4ff92SAndroid Build Coastguard Worker // Add Output layer
445*89c4ff92SAndroid Build Coastguard Worker if (AddOutputLayer(delegateData, tfLiteContext, parameters->output_tensors, outputBindings) != kTfLiteOk)
446*89c4ff92SAndroid Build Coastguard Worker {
447*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to add Outputs to the network!");
448*89c4ff92SAndroid Build Coastguard Worker }
449*89c4ff92SAndroid Build Coastguard Worker
450*89c4ff92SAndroid Build Coastguard Worker // Optimize ArmNN network
451*89c4ff92SAndroid Build Coastguard Worker armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
452*89c4ff92SAndroid Build Coastguard Worker try
453*89c4ff92SAndroid Build Coastguard Worker {
454*89c4ff92SAndroid Build Coastguard Worker const auto optimizeStartTime = armnn::GetTimeNow();
455*89c4ff92SAndroid Build Coastguard Worker optNet = armnn::Optimize(*(delegateData.m_Network.get()),
456*89c4ff92SAndroid Build Coastguard Worker delegate->m_Options.GetBackends(),
457*89c4ff92SAndroid Build Coastguard Worker delegate->m_Runtime->GetDeviceSpec(),
458*89c4ff92SAndroid Build Coastguard Worker delegate->m_Options.GetOptimizerOptions());
459*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Optimize ArmnnSubgraph time: " << std::setprecision(2)
460*89c4ff92SAndroid Build Coastguard Worker << std::fixed << armnn::GetTimeDuration(optimizeStartTime).count() << " ms";
461*89c4ff92SAndroid Build Coastguard Worker }
462*89c4ff92SAndroid Build Coastguard Worker catch (std::exception& ex)
463*89c4ff92SAndroid Build Coastguard Worker {
464*89c4ff92SAndroid Build Coastguard Worker std::stringstream exMessage;
465*89c4ff92SAndroid Build Coastguard Worker exMessage << "TfLiteArmnnOpaqueDelegate: Exception (" << ex.what() << ") caught from optimize.";
466*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(exMessage.str());
467*89c4ff92SAndroid Build Coastguard Worker }
468*89c4ff92SAndroid Build Coastguard Worker if (!optNet)
469*89c4ff92SAndroid Build Coastguard Worker {
470*89c4ff92SAndroid Build Coastguard Worker // Optimize failed
471*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to optimize the network!");
472*89c4ff92SAndroid Build Coastguard Worker }
473*89c4ff92SAndroid Build Coastguard Worker
474*89c4ff92SAndroid Build Coastguard Worker // If set, we will serialize the optimized model into a dot file.
475*89c4ff92SAndroid Build Coastguard Worker const std::string serializeToDotFile = delegate->m_Options.GetSerializeToDot();
476*89c4ff92SAndroid Build Coastguard Worker if (!serializeToDotFile.empty())
477*89c4ff92SAndroid Build Coastguard Worker {
478*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Writing graph to dot file: " << serializeToDotFile;
479*89c4ff92SAndroid Build Coastguard Worker fs::path filename = serializeToDotFile;
480*89c4ff92SAndroid Build Coastguard Worker std::fstream file(filename.c_str(), std::ios_base::out);
481*89c4ff92SAndroid Build Coastguard Worker optNet->SerializeToDot(file);
482*89c4ff92SAndroid Build Coastguard Worker }
483*89c4ff92SAndroid Build Coastguard Worker
484*89c4ff92SAndroid Build Coastguard Worker try
485*89c4ff92SAndroid Build Coastguard Worker {
486*89c4ff92SAndroid Build Coastguard Worker const auto loadStartTime = armnn::GetTimeNow();
487*89c4ff92SAndroid Build Coastguard Worker
488*89c4ff92SAndroid Build Coastguard Worker // Load graph into runtime
489*89c4ff92SAndroid Build Coastguard Worker std::string errorMessage;
490*89c4ff92SAndroid Build Coastguard Worker armnn::Status loadingStatus;
491*89c4ff92SAndroid Build Coastguard Worker armnn::MemorySource inputSource = armnn::MemorySource::Undefined;
492*89c4ff92SAndroid Build Coastguard Worker armnn::MemorySource outputSource = armnn::MemorySource::Undefined;
493*89c4ff92SAndroid Build Coastguard Worker // There's a bit of an assumption here that the delegate will only support Malloc memory source.
494*89c4ff92SAndroid Build Coastguard Worker if (delegate->m_Options.GetOptimizerOptions().GetImportEnabled())
495*89c4ff92SAndroid Build Coastguard Worker {
496*89c4ff92SAndroid Build Coastguard Worker inputSource = armnn::MemorySource::Malloc;
497*89c4ff92SAndroid Build Coastguard Worker }
498*89c4ff92SAndroid Build Coastguard Worker if (delegate->m_Options.GetOptimizerOptions().GetExportEnabled())
499*89c4ff92SAndroid Build Coastguard Worker {
500*89c4ff92SAndroid Build Coastguard Worker outputSource = armnn::MemorySource::Malloc;
501*89c4ff92SAndroid Build Coastguard Worker }
502*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkProperties networkProperties(false,
503*89c4ff92SAndroid Build Coastguard Worker inputSource,
504*89c4ff92SAndroid Build Coastguard Worker outputSource,
505*89c4ff92SAndroid Build Coastguard Worker delegate->m_Options.GetInternalProfilingState(),
506*89c4ff92SAndroid Build Coastguard Worker delegate->m_Options.GetInternalProfilingDetail());
507*89c4ff92SAndroid Build Coastguard Worker loadingStatus = delegate->m_Runtime->LoadNetwork(networkId,
508*89c4ff92SAndroid Build Coastguard Worker std::move(optNet),
509*89c4ff92SAndroid Build Coastguard Worker errorMessage,
510*89c4ff92SAndroid Build Coastguard Worker networkProperties);
511*89c4ff92SAndroid Build Coastguard Worker if (loadingStatus != armnn::Status::Success)
512*89c4ff92SAndroid Build Coastguard Worker {
513*89c4ff92SAndroid Build Coastguard Worker // Network load failed.
514*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Network could not be loaded: " + errorMessage);
515*89c4ff92SAndroid Build Coastguard Worker }
516*89c4ff92SAndroid Build Coastguard Worker
517*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Load ArmnnSubgraph time: " << std::setprecision(2)
518*89c4ff92SAndroid Build Coastguard Worker << std::fixed << armnn::GetTimeDuration(loadStartTime).count() << " ms";
519*89c4ff92SAndroid Build Coastguard Worker }
520*89c4ff92SAndroid Build Coastguard Worker catch (std::exception& ex)
521*89c4ff92SAndroid Build Coastguard Worker {
522*89c4ff92SAndroid Build Coastguard Worker std::stringstream exMessage;
523*89c4ff92SAndroid Build Coastguard Worker exMessage << "TfLiteArmnnOpaqueDelegate: Exception (" << ex.what() << ") caught from LoadNetwork.";
524*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(exMessage.str());
525*89c4ff92SAndroid Build Coastguard Worker }
526*89c4ff92SAndroid Build Coastguard Worker
527*89c4ff92SAndroid Build Coastguard Worker // Register debug callback function
528*89c4ff92SAndroid Build Coastguard Worker if (delegate->m_Options.GetDebugCallbackFunction().has_value())
529*89c4ff92SAndroid Build Coastguard Worker {
530*89c4ff92SAndroid Build Coastguard Worker delegate->m_Runtime->RegisterDebugCallback(networkId, delegate->m_Options.GetDebugCallbackFunction().value());
531*89c4ff92SAndroid Build Coastguard Worker }
532*89c4ff92SAndroid Build Coastguard Worker
533*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(info) << "Overall ArmnnSubgraph creation time: " << std::setprecision(2)
534*89c4ff92SAndroid Build Coastguard Worker << std::fixed << armnn::GetTimeDuration(startTime).count() << " ms\n";
535*89c4ff92SAndroid Build Coastguard Worker
536*89c4ff92SAndroid Build Coastguard Worker // Create a new SubGraph with networkId and runtime
537*89c4ff92SAndroid Build Coastguard Worker return new ArmnnSubgraph(networkId, delegate->m_Runtime, inputBindings, outputBindings);
538*89c4ff92SAndroid Build Coastguard Worker }
539*89c4ff92SAndroid Build Coastguard Worker
Prepare(TfLiteOpaqueContext * tfLiteContext)540*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::Prepare(TfLiteOpaqueContext* tfLiteContext)
541*89c4ff92SAndroid Build Coastguard Worker {
542*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(tfLiteContext);
543*89c4ff92SAndroid Build Coastguard Worker return kTfLiteOk;
544*89c4ff92SAndroid Build Coastguard Worker }
545*89c4ff92SAndroid Build Coastguard Worker
Invoke(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode)546*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::Invoke(TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueNode* tfLiteNode)
547*89c4ff92SAndroid Build Coastguard Worker {
548*89c4ff92SAndroid Build Coastguard Worker // Get array of input indices, inputIndexArray is set from the TfLiteOpaqueNodeInputs function
549*89c4ff92SAndroid Build Coastguard Worker // This function turns inputIndexArray into an int array of indices. These indices point to the tensors for
550*89c4ff92SAndroid Build Coastguard Worker // each input slot in the node.
551*89c4ff92SAndroid Build Coastguard Worker const int* inputIndexArray;
552*89c4ff92SAndroid Build Coastguard Worker int numInputs;
553*89c4ff92SAndroid Build Coastguard Worker if(TfLiteOpaqueNodeInputs(tfLiteNode, &inputIndexArray, &numInputs) != kTfLiteOk)
554*89c4ff92SAndroid Build Coastguard Worker {
555*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to load subgraph inputs!");
556*89c4ff92SAndroid Build Coastguard Worker }
557*89c4ff92SAndroid Build Coastguard Worker // Prepare inputs
558*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors inputTensors;
559*89c4ff92SAndroid Build Coastguard Worker size_t inputIndex = 0;
560*89c4ff92SAndroid Build Coastguard Worker for (int inputIdx = 0; inputIdx < numInputs; inputIdx++)
561*89c4ff92SAndroid Build Coastguard Worker {
562*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueTensor* tensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputIndexArray[inputIdx]);
563*89c4ff92SAndroid Build Coastguard Worker
564*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(tensor))
565*89c4ff92SAndroid Build Coastguard Worker {
566*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
567*89c4ff92SAndroid Build Coastguard Worker }
568*89c4ff92SAndroid Build Coastguard Worker // If tensor is not read only
569*89c4ff92SAndroid Build Coastguard Worker if (TfLiteOpaqueTensorGetAllocationType(tensor) != kTfLiteMmapRo)
570*89c4ff92SAndroid Build Coastguard Worker {
571*89c4ff92SAndroid Build Coastguard Worker const armnn::BindingPointInfo& inputBinding = m_InputBindings[inputIndex];
572*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = inputBinding.second;
573*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true);
574*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor inputTensor(inputTensorInfo, TfLiteOpaqueTensorData(tensor));
575*89c4ff92SAndroid Build Coastguard Worker inputTensors.emplace_back(inputIdx, inputTensor);
576*89c4ff92SAndroid Build Coastguard Worker
577*89c4ff92SAndroid Build Coastguard Worker ++inputIndex;
578*89c4ff92SAndroid Build Coastguard Worker }
579*89c4ff92SAndroid Build Coastguard Worker }
580*89c4ff92SAndroid Build Coastguard Worker
581*89c4ff92SAndroid Build Coastguard Worker // Get array of output indices, outputIndexArray is set from the TfLiteOpaqueNodeOutputs function
582*89c4ff92SAndroid Build Coastguard Worker // This function turns outputIndexArray into an int array of indices. These indices point to the tensors for
583*89c4ff92SAndroid Build Coastguard Worker // each output slot in the node.
584*89c4ff92SAndroid Build Coastguard Worker const int* outputIndexArray;
585*89c4ff92SAndroid Build Coastguard Worker int numOutputs;
586*89c4ff92SAndroid Build Coastguard Worker if(TfLiteOpaqueNodeOutputs(tfLiteNode, &outputIndexArray, &numOutputs) != kTfLiteOk)
587*89c4ff92SAndroid Build Coastguard Worker {
588*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to load subgraph outputs!");
589*89c4ff92SAndroid Build Coastguard Worker }
590*89c4ff92SAndroid Build Coastguard Worker // Assign the tensors from the outputIndexArray to the armnn BindingPointInfo
591*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors outputTensors;
592*89c4ff92SAndroid Build Coastguard Worker for (int outputIdx = 0; outputIdx < numOutputs; outputIdx++)
593*89c4ff92SAndroid Build Coastguard Worker {
594*89c4ff92SAndroid Build Coastguard Worker const armnn::BindingPointInfo& outputBinding = m_OutputBindings[outputIdx];
595*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueTensor* tensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputIndexArray[outputIdx]);
596*89c4ff92SAndroid Build Coastguard Worker if(!IsValid(tensor))
597*89c4ff92SAndroid Build Coastguard Worker {
598*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
599*89c4ff92SAndroid Build Coastguard Worker }
600*89c4ff92SAndroid Build Coastguard Worker
601*89c4ff92SAndroid Build Coastguard Worker const armnn::Tensor outputTensor(outputBinding.second, reinterpret_cast<TfLiteTensor*>(tensor)->data
602*89c4ff92SAndroid Build Coastguard Worker .data);
603*89c4ff92SAndroid Build Coastguard Worker outputTensors.emplace_back(outputIndexArray[outputIdx], outputTensor);
604*89c4ff92SAndroid Build Coastguard Worker }
605*89c4ff92SAndroid Build Coastguard Worker
606*89c4ff92SAndroid Build Coastguard Worker // Run graph
607*89c4ff92SAndroid Build Coastguard Worker auto status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
608*89c4ff92SAndroid Build Coastguard Worker // The delegate holds its own Arm NN runtime so this is our last chance to print internal profiling data.
609*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkId);
610*89c4ff92SAndroid Build Coastguard Worker if (profiler && profiler->IsProfilingEnabled())
611*89c4ff92SAndroid Build Coastguard Worker {
612*89c4ff92SAndroid Build Coastguard Worker profiler->Print(std::cout);
613*89c4ff92SAndroid Build Coastguard Worker }
614*89c4ff92SAndroid Build Coastguard Worker return (status == armnn::Status::Success) ? kTfLiteOk : kTfLiteError;
615*89c4ff92SAndroid Build Coastguard Worker }
616*89c4ff92SAndroid Build Coastguard Worker
VisitNode(DelegateData & delegateData,TfLiteOpaqueContext * tfLiteContext,TfLiteRegistrationExternal * tfLiteRegistration,TfLiteOpaqueNode * tfLiteNode,int nodeIndex)617*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData,
618*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueContext* tfLiteContext,
619*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistrationExternal* tfLiteRegistration,
620*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueNode* tfLiteNode,
621*89c4ff92SAndroid Build Coastguard Worker int nodeIndex)
622*89c4ff92SAndroid Build Coastguard Worker {
623*89c4ff92SAndroid Build Coastguard Worker switch (TfLiteRegistrationExternalGetBuiltInCode(tfLiteRegistration))
624*89c4ff92SAndroid Build Coastguard Worker {
625*89c4ff92SAndroid Build Coastguard Worker case kTfLiteBuiltinCast:
626*89c4ff92SAndroid Build Coastguard Worker return VisitCastOperator(delegateData,
627*89c4ff92SAndroid Build Coastguard Worker tfLiteContext,
628*89c4ff92SAndroid Build Coastguard Worker tfLiteNode,
629*89c4ff92SAndroid Build Coastguard Worker nodeIndex,
630*89c4ff92SAndroid Build Coastguard Worker kTfLiteBuiltinCast);
631*89c4ff92SAndroid Build Coastguard Worker default:
632*89c4ff92SAndroid Build Coastguard Worker return kTfLiteError;
633*89c4ff92SAndroid Build Coastguard Worker }
634*89c4ff92SAndroid Build Coastguard Worker }
635*89c4ff92SAndroid Build Coastguard Worker } // armnnOpaqueDelegate namespace