xref: /aosp_15_r20/external/armnn/delegate/opaque/src/armnn_delegate.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <OpaqueDelegateUtils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <Version.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include "Activation.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "ArgMinMax.hpp"
13*89c4ff92SAndroid Build Coastguard Worker #include "BatchMatMul.hpp"
14*89c4ff92SAndroid Build Coastguard Worker #include "BatchSpace.hpp"
15*89c4ff92SAndroid Build Coastguard Worker #include "Comparison.hpp"
16*89c4ff92SAndroid Build Coastguard Worker #include "Convolution.hpp"
17*89c4ff92SAndroid Build Coastguard Worker #include "Control.hpp"
18*89c4ff92SAndroid Build Coastguard Worker #include "ElementwiseBinary.hpp"
19*89c4ff92SAndroid Build Coastguard Worker #include "ElementwiseUnary.hpp"
20*89c4ff92SAndroid Build Coastguard Worker #include "Fill.hpp"
21*89c4ff92SAndroid Build Coastguard Worker #include "FullyConnected.hpp"
22*89c4ff92SAndroid Build Coastguard Worker #include "Gather.hpp"
23*89c4ff92SAndroid Build Coastguard Worker #include "GatherNd.hpp"
24*89c4ff92SAndroid Build Coastguard Worker #include "LogicalBinary.hpp"
25*89c4ff92SAndroid Build Coastguard Worker #include "Lstm.hpp"
26*89c4ff92SAndroid Build Coastguard Worker #include "Normalization.hpp"
27*89c4ff92SAndroid Build Coastguard Worker #include "Pack.hpp"
28*89c4ff92SAndroid Build Coastguard Worker #include "Pad.hpp"
29*89c4ff92SAndroid Build Coastguard Worker #include "Pooling.hpp"
30*89c4ff92SAndroid Build Coastguard Worker #include "Prelu.hpp"
31*89c4ff92SAndroid Build Coastguard Worker #include "Quantization.hpp"
32*89c4ff92SAndroid Build Coastguard Worker #include "Redefine.hpp"
33*89c4ff92SAndroid Build Coastguard Worker #include "Reduce.hpp"
34*89c4ff92SAndroid Build Coastguard Worker #include "Resize.hpp"
35*89c4ff92SAndroid Build Coastguard Worker #include "Round.hpp"
36*89c4ff92SAndroid Build Coastguard Worker #include "Shape.hpp"
37*89c4ff92SAndroid Build Coastguard Worker #include "Slice.hpp"
38*89c4ff92SAndroid Build Coastguard Worker #include "StridedSlice.hpp"
39*89c4ff92SAndroid Build Coastguard Worker #include "Softmax.hpp"
40*89c4ff92SAndroid Build Coastguard Worker #include "SpaceDepth.hpp"
41*89c4ff92SAndroid Build Coastguard Worker #include "Split.hpp"
42*89c4ff92SAndroid Build Coastguard Worker #include "Transpose.hpp"
43*89c4ff92SAndroid Build Coastguard Worker #include "UnidirectionalSequenceLstm.hpp"
44*89c4ff92SAndroid Build Coastguard Worker #include "Unpack.hpp"
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
47*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp>
48*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Timer.hpp>
49*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
50*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/context_util.h>
51*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/schema/schema_generated.h>
52*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
53*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/logger.h>
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
56*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
57*89c4ff92SAndroid Build Coastguard Worker #include <sstream>
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker namespace armnnOpaqueDelegate
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker const TfLiteStableDelegate TFL_TheStableDelegate =
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker     /*delegate_abi_version=*/ TFL_STABLE_DELEGATE_ABI_VERSION,
65*89c4ff92SAndroid Build Coastguard Worker     /*delegate_name=*/        "ArmnnDelegatePlugin",
66*89c4ff92SAndroid Build Coastguard Worker     /*delegate_version=*/     "1.0.0",
67*89c4ff92SAndroid Build Coastguard Worker     /*delegate_plugin=*/      GetArmnnDelegatePluginApi()
68*89c4ff92SAndroid Build Coastguard Worker };
69*89c4ff92SAndroid Build Coastguard Worker 
ArmnnOpaqueDelegate(armnnDelegate::DelegateOptions options)70*89c4ff92SAndroid Build Coastguard Worker ArmnnOpaqueDelegate::ArmnnOpaqueDelegate(armnnDelegate::DelegateOptions options)
71*89c4ff92SAndroid Build Coastguard Worker     : m_Options(std::move(options))
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker     // Configures logging for ARMNN
74*89c4ff92SAndroid Build Coastguard Worker     if (m_Options.IsLoggingEnabled())
75*89c4ff92SAndroid Build Coastguard Worker     {
76*89c4ff92SAndroid Build Coastguard Worker         armnn::ConfigureLogging(true, true, m_Options.GetLoggingSeverity());
77*89c4ff92SAndroid Build Coastguard Worker     }
78*89c4ff92SAndroid Build Coastguard Worker     // Create/Get the static ArmNN Runtime. Note that the m_Runtime will be shared by all armnn_delegate
79*89c4ff92SAndroid Build Coastguard Worker     // instances so the RuntimeOptions cannot be altered for different armnn_delegate instances.
80*89c4ff92SAndroid Build Coastguard Worker     m_Runtime = GetRuntime(m_Options.GetRuntimeOptions());
81*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends;
82*89c4ff92SAndroid Build Coastguard Worker     if (m_Runtime)
83*89c4ff92SAndroid Build Coastguard Worker     {
84*89c4ff92SAndroid Build Coastguard Worker         const armnn::BackendIdSet supportedDevices = m_Runtime->GetDeviceSpec().GetSupportedBackends();
85*89c4ff92SAndroid Build Coastguard Worker         for (auto& backend : m_Options.GetBackends())
86*89c4ff92SAndroid Build Coastguard Worker         {
87*89c4ff92SAndroid Build Coastguard Worker             if (std::find(supportedDevices.cbegin(), supportedDevices.cend(), backend) == supportedDevices.cend())
88*89c4ff92SAndroid Build Coastguard Worker             {
89*89c4ff92SAndroid Build Coastguard Worker                 TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
90*89c4ff92SAndroid Build Coastguard Worker                                 "TfLiteArmnnDelegate: Requested unknown backend %s", backend.Get().c_str());
91*89c4ff92SAndroid Build Coastguard Worker             }
92*89c4ff92SAndroid Build Coastguard Worker             else
93*89c4ff92SAndroid Build Coastguard Worker             {
94*89c4ff92SAndroid Build Coastguard Worker                 backends.push_back(backend);
95*89c4ff92SAndroid Build Coastguard Worker             }
96*89c4ff92SAndroid Build Coastguard Worker         }
97*89c4ff92SAndroid Build Coastguard Worker     }
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker     if (backends.empty())
100*89c4ff92SAndroid Build Coastguard Worker     {
101*89c4ff92SAndroid Build Coastguard Worker         // No known backend specified
102*89c4ff92SAndroid Build Coastguard Worker         throw armnn::InvalidArgumentException("TfLiteArmnnOpaqueDelegate: No known backend specified.");
103*89c4ff92SAndroid Build Coastguard Worker     }
104*89c4ff92SAndroid Build Coastguard Worker     m_Options.SetBackends(backends);
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker     TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "TfLiteArmnnOpaqueDelegate: Created TfLite ArmNN delegate.");
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker 
DoPrepare(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueDelegate * tfLiteDelegate,void * data)109*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus DoPrepare(TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueDelegate* tfLiteDelegate, void* data)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker     // We are required to have the void* data parameter in the function signature, but we don't actually use it.
112*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(data);
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker     TfLiteIntArray* supportedOperators =
115*89c4ff92SAndroid Build Coastguard Worker             static_cast<::armnnOpaqueDelegate::ArmnnOpaqueDelegate*>
116*89c4ff92SAndroid Build Coastguard Worker                     (TfLiteOpaqueDelegateGetData(tfLiteDelegate))->IdentifyOperatorsToDelegate(tfLiteContext);
117*89c4ff92SAndroid Build Coastguard Worker     if(supportedOperators == nullptr)
118*89c4ff92SAndroid Build Coastguard Worker     {
119*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
120*89c4ff92SAndroid Build Coastguard Worker     }
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker     // ArmNN Opaque Delegate Registration
123*89c4ff92SAndroid Build Coastguard Worker     TfLiteRegistrationExternal* kernelRegistration =
124*89c4ff92SAndroid Build Coastguard Worker             TfLiteRegistrationExternalCreate(kTfLiteBuiltinDelegate, "TfLiteArmNNOpaqueDelegate", /*version=*/1);
125*89c4ff92SAndroid Build Coastguard Worker     if(kernelRegistration == nullptr)
126*89c4ff92SAndroid Build Coastguard Worker     {
127*89c4ff92SAndroid Build Coastguard Worker         return kTfLiteError;
128*89c4ff92SAndroid Build Coastguard Worker     }
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker     TfLiteRegistrationExternalSetInit(
131*89c4ff92SAndroid Build Coastguard Worker             kernelRegistration,
132*89c4ff92SAndroid Build Coastguard Worker             [](TfLiteOpaqueContext* tfLiteContext, const char* buffer, size_t length) -> void*
133*89c4ff92SAndroid Build Coastguard Worker             {
134*89c4ff92SAndroid Build Coastguard Worker                 armnn::IgnoreUnused(length);
135*89c4ff92SAndroid Build Coastguard Worker                 const TfLiteOpaqueDelegateParams* parameters =
136*89c4ff92SAndroid Build Coastguard Worker                         reinterpret_cast<const TfLiteOpaqueDelegateParams*>(buffer);
137*89c4ff92SAndroid Build Coastguard Worker                 if(parameters == nullptr)
138*89c4ff92SAndroid Build Coastguard Worker                 {
139*89c4ff92SAndroid Build Coastguard Worker                     TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
140*89c4ff92SAndroid Build Coastguard Worker                                               "TfLiteArmnnOpaqueDelegate: Unable to get parameters.");
141*89c4ff92SAndroid Build Coastguard Worker                     return nullptr;
142*89c4ff92SAndroid Build Coastguard Worker                 }
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker                 return static_cast<void*>(
145*89c4ff92SAndroid Build Coastguard Worker                         ArmnnSubgraph::Create(tfLiteContext,
146*89c4ff92SAndroid Build Coastguard Worker                                               parameters,
147*89c4ff92SAndroid Build Coastguard Worker                                               static_cast<::armnnOpaqueDelegate::ArmnnOpaqueDelegate*>(
148*89c4ff92SAndroid Build Coastguard Worker                                                       parameters->delegate->opaque_delegate_builder->data)));
149*89c4ff92SAndroid Build Coastguard Worker             }
150*89c4ff92SAndroid Build Coastguard Worker     );
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker     TfLiteRegistrationExternalSetFree(
153*89c4ff92SAndroid Build Coastguard Worker             kernelRegistration,
154*89c4ff92SAndroid Build Coastguard Worker             [](TfLiteOpaqueContext* tfLiteContext, void* buffer) -> void
155*89c4ff92SAndroid Build Coastguard Worker             {
156*89c4ff92SAndroid Build Coastguard Worker                 armnn::IgnoreUnused(tfLiteContext);
157*89c4ff92SAndroid Build Coastguard Worker                 if (buffer != nullptr)
158*89c4ff92SAndroid Build Coastguard Worker                 {
159*89c4ff92SAndroid Build Coastguard Worker                     delete static_cast<ArmnnSubgraph*>(buffer);
160*89c4ff92SAndroid Build Coastguard Worker                 }
161*89c4ff92SAndroid Build Coastguard Worker             }
162*89c4ff92SAndroid Build Coastguard Worker     );
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker     TfLiteRegistrationExternalSetPrepare(
165*89c4ff92SAndroid Build Coastguard Worker             kernelRegistration,
166*89c4ff92SAndroid Build Coastguard Worker             [](TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueNode* tfLiteNode) -> TfLiteStatus
167*89c4ff92SAndroid Build Coastguard Worker             {
168*89c4ff92SAndroid Build Coastguard Worker                 void* userData = TfLiteOpaqueNodeGetUserData(tfLiteNode);
169*89c4ff92SAndroid Build Coastguard Worker                 if (userData == nullptr)
170*89c4ff92SAndroid Build Coastguard Worker                 {
171*89c4ff92SAndroid Build Coastguard Worker                     return kTfLiteError;
172*89c4ff92SAndroid Build Coastguard Worker                 }
173*89c4ff92SAndroid Build Coastguard Worker                 return static_cast<ArmnnSubgraph*>(userData)->Prepare(tfLiteContext);
174*89c4ff92SAndroid Build Coastguard Worker             }
175*89c4ff92SAndroid Build Coastguard Worker     );
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker     TfLiteRegistrationExternalSetInvoke(
178*89c4ff92SAndroid Build Coastguard Worker             kernelRegistration,
179*89c4ff92SAndroid Build Coastguard Worker             [](TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueNode* tfLiteNode) -> TfLiteStatus
180*89c4ff92SAndroid Build Coastguard Worker             {
181*89c4ff92SAndroid Build Coastguard Worker                 void* userData = TfLiteOpaqueNodeGetUserData(tfLiteNode);
182*89c4ff92SAndroid Build Coastguard Worker                 if (userData == nullptr)
183*89c4ff92SAndroid Build Coastguard Worker                 {
184*89c4ff92SAndroid Build Coastguard Worker                     return kTfLiteError;
185*89c4ff92SAndroid Build Coastguard Worker                 }
186*89c4ff92SAndroid Build Coastguard Worker 
187*89c4ff92SAndroid Build Coastguard Worker                 return static_cast<ArmnnSubgraph*>(userData)->Invoke(tfLiteContext, tfLiteNode);
188*89c4ff92SAndroid Build Coastguard Worker             }
189*89c4ff92SAndroid Build Coastguard Worker     );
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker     const TfLiteStatus status =
192*89c4ff92SAndroid Build Coastguard Worker             TfLiteOpaqueContextReplaceNodeSubsetsWithDelegateKernels(
193*89c4ff92SAndroid Build Coastguard Worker                     tfLiteContext, kernelRegistration, supportedOperators, tfLiteDelegate);
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker     TfLiteIntArrayFree(supportedOperators);
196*89c4ff92SAndroid Build Coastguard Worker     return status;
197*89c4ff92SAndroid Build Coastguard Worker }
198*89c4ff92SAndroid Build Coastguard Worker 
TfLiteArmnnOpaqueDelegateCreate(const void * settings)199*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueDelegate* TfLiteArmnnOpaqueDelegateCreate(const void* settings)
200*89c4ff92SAndroid Build Coastguard Worker {
201*89c4ff92SAndroid Build Coastguard Worker     // This method will always create Opaque Delegate with default settings until
202*89c4ff92SAndroid Build Coastguard Worker     // we have a DelegateOptions Constructor which can parse the void* settings
203*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(settings);
204*89c4ff92SAndroid Build Coastguard Worker     auto options = TfLiteArmnnDelegateOptionsDefault();
205*89c4ff92SAndroid Build Coastguard Worker     auto* armnnDelegate = new ::armnnOpaqueDelegate::ArmnnOpaqueDelegate(options);
206*89c4ff92SAndroid Build Coastguard Worker     return TfLiteOpaqueDelegateCreate(armnnDelegate->GetDelegateBuilder());
207*89c4ff92SAndroid Build Coastguard Worker }
208*89c4ff92SAndroid Build Coastguard Worker 
TfLiteArmnnDelegateOptionsDefault()209*89c4ff92SAndroid Build Coastguard Worker ::armnnDelegate::DelegateOptions TfLiteArmnnDelegateOptionsDefault()
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker     ::armnnDelegate::DelegateOptions options(armnn::Compute::CpuRef);
212*89c4ff92SAndroid Build Coastguard Worker     return options;
213*89c4ff92SAndroid Build Coastguard Worker }
214*89c4ff92SAndroid Build Coastguard Worker 
TfLiteArmnnOpaqueDelegateDelete(TfLiteOpaqueDelegate * tfLiteDelegate)215*89c4ff92SAndroid Build Coastguard Worker void TfLiteArmnnOpaqueDelegateDelete(TfLiteOpaqueDelegate* tfLiteDelegate)
216*89c4ff92SAndroid Build Coastguard Worker {
217*89c4ff92SAndroid Build Coastguard Worker     if (tfLiteDelegate != nullptr)
218*89c4ff92SAndroid Build Coastguard Worker     {
219*89c4ff92SAndroid Build Coastguard Worker         delete static_cast<::armnnOpaqueDelegate::ArmnnOpaqueDelegate*>(TfLiteOpaqueDelegateGetData(tfLiteDelegate));
220*89c4ff92SAndroid Build Coastguard Worker         TfLiteOpaqueDelegateDelete(tfLiteDelegate);
221*89c4ff92SAndroid Build Coastguard Worker     }
222*89c4ff92SAndroid Build Coastguard Worker }
223*89c4ff92SAndroid Build Coastguard Worker 
GetArmnnDelegatePluginApi()224*89c4ff92SAndroid Build Coastguard Worker const TfLiteOpaqueDelegatePlugin* GetArmnnDelegatePluginApi()
225*89c4ff92SAndroid Build Coastguard Worker {
226*89c4ff92SAndroid Build Coastguard Worker     static constexpr TfLiteOpaqueDelegatePlugin armnnPlugin{
227*89c4ff92SAndroid Build Coastguard Worker             TfLiteArmnnOpaqueDelegateCreate, TfLiteArmnnOpaqueDelegateDelete, TfLiteArmnnOpaqueDelegateErrno};
228*89c4ff92SAndroid Build Coastguard Worker     return &armnnPlugin;
229*89c4ff92SAndroid Build Coastguard Worker }
230*89c4ff92SAndroid Build Coastguard Worker 
GetVersion()231*89c4ff92SAndroid Build Coastguard Worker const std::string ArmnnOpaqueDelegate::GetVersion() {
232*89c4ff92SAndroid Build Coastguard Worker     return OPAQUE_DELEGATE_VERSION;
233*89c4ff92SAndroid Build Coastguard Worker }
234*89c4ff92SAndroid Build Coastguard Worker 
IdentifyOperatorsToDelegate(TfLiteOpaqueContext * tfLiteContext)235*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* ArmnnOpaqueDelegate::IdentifyOperatorsToDelegate(TfLiteOpaqueContext* tfLiteContext)
236*89c4ff92SAndroid Build Coastguard Worker {
237*89c4ff92SAndroid Build Coastguard Worker     TfLiteIntArray* executionPlan = nullptr;
238*89c4ff92SAndroid Build Coastguard Worker     if (TfLiteOpaqueContextGetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
239*89c4ff92SAndroid Build Coastguard Worker     {
240*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext, "TfLiteArmnnOpaqueDelegate: Unable to get graph execution plan.");
241*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
242*89c4ff92SAndroid Build Coastguard Worker     }
243*89c4ff92SAndroid Build Coastguard Worker 
244*89c4ff92SAndroid Build Coastguard Worker     // Delegate data with null network
245*89c4ff92SAndroid Build Coastguard Worker     DelegateData delegateData(m_Options.GetBackends());
246*89c4ff92SAndroid Build Coastguard Worker 
247*89c4ff92SAndroid Build Coastguard Worker     TfLiteIntArray* nodesToDelegate = TfLiteIntArrayCreate(executionPlan->size);
248*89c4ff92SAndroid Build Coastguard Worker     if (nodesToDelegate == nullptr)
249*89c4ff92SAndroid Build Coastguard Worker     {
250*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
251*89c4ff92SAndroid Build Coastguard Worker                                   "TfLiteArmnnOpaqueDelegate: Unable to create int array from execution plan.");
252*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
253*89c4ff92SAndroid Build Coastguard Worker     }
254*89c4ff92SAndroid Build Coastguard Worker     nodesToDelegate->size = 0;
255*89c4ff92SAndroid Build Coastguard Worker 
256*89c4ff92SAndroid Build Coastguard Worker     std::set<int32_t> unsupportedOperators;
257*89c4ff92SAndroid Build Coastguard Worker 
258*89c4ff92SAndroid Build Coastguard Worker     for (int i = 0; i < executionPlan->size; ++i)
259*89c4ff92SAndroid Build Coastguard Worker     {
260*89c4ff92SAndroid Build Coastguard Worker         const int nodeIndex = executionPlan->data[i];
261*89c4ff92SAndroid Build Coastguard Worker 
262*89c4ff92SAndroid Build Coastguard Worker         // If TfLiteOpaqueNodes can be delegated to ArmNN
263*89c4ff92SAndroid Build Coastguard Worker         TfLiteOpaqueNode* tfLiteNode = nullptr;
264*89c4ff92SAndroid Build Coastguard Worker         TfLiteRegistrationExternal* tfLiteRegistration = nullptr;
265*89c4ff92SAndroid Build Coastguard Worker 
266*89c4ff92SAndroid Build Coastguard Worker         if (TfLiteOpaqueContextGetNodeAndRegistration(
267*89c4ff92SAndroid Build Coastguard Worker                 tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
268*89c4ff92SAndroid Build Coastguard Worker         {
269*89c4ff92SAndroid Build Coastguard Worker             TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
270*89c4ff92SAndroid Build Coastguard Worker                                       "TfLiteArmnnOpaqueDelegate: Unable to get node and registration for node %d.",
271*89c4ff92SAndroid Build Coastguard Worker                                       nodeIndex);
272*89c4ff92SAndroid Build Coastguard Worker             continue;
273*89c4ff92SAndroid Build Coastguard Worker         }
274*89c4ff92SAndroid Build Coastguard Worker 
275*89c4ff92SAndroid Build Coastguard Worker         TfLiteStatus visitStatus;
276*89c4ff92SAndroid Build Coastguard Worker         try
277*89c4ff92SAndroid Build Coastguard Worker         {
278*89c4ff92SAndroid Build Coastguard Worker             visitStatus = ArmnnSubgraph::VisitNode(
279*89c4ff92SAndroid Build Coastguard Worker                     delegateData, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex);
280*89c4ff92SAndroid Build Coastguard Worker         }
281*89c4ff92SAndroid Build Coastguard Worker         catch(std::exception& ex)
282*89c4ff92SAndroid Build Coastguard Worker         {
283*89c4ff92SAndroid Build Coastguard Worker             ARMNN_LOG(error) << "ArmNN Failed to visit node with error: " << ex.what();
284*89c4ff92SAndroid Build Coastguard Worker             visitStatus = kTfLiteError;
285*89c4ff92SAndroid Build Coastguard Worker         }
286*89c4ff92SAndroid Build Coastguard Worker 
287*89c4ff92SAndroid Build Coastguard Worker         if (visitStatus != kTfLiteOk)
288*89c4ff92SAndroid Build Coastguard Worker         {
289*89c4ff92SAndroid Build Coastguard Worker             // node is not supported by ArmNN
290*89c4ff92SAndroid Build Coastguard Worker             unsupportedOperators.insert(TfLiteRegistrationExternalGetBuiltInCode(tfLiteRegistration));
291*89c4ff92SAndroid Build Coastguard Worker             continue;
292*89c4ff92SAndroid Build Coastguard Worker         }
293*89c4ff92SAndroid Build Coastguard Worker 
294*89c4ff92SAndroid Build Coastguard Worker         nodesToDelegate->data[nodesToDelegate->size++] = nodeIndex;
295*89c4ff92SAndroid Build Coastguard Worker     }
296*89c4ff92SAndroid Build Coastguard Worker 
297*89c4ff92SAndroid Build Coastguard Worker     for (std::set<int32_t>::iterator it=unsupportedOperators.begin(); it!=unsupportedOperators.end(); ++it)
298*89c4ff92SAndroid Build Coastguard Worker     {
299*89c4ff92SAndroid Build Coastguard Worker         TF_LITE_OPAQUE_KERNEL_LOG(tfLiteContext,
300*89c4ff92SAndroid Build Coastguard Worker                                   "Operator %s [%d] is not supported by armnn_opaque_delegate.",
301*89c4ff92SAndroid Build Coastguard Worker                                   tflite::EnumNameBuiltinOperator(tflite::BuiltinOperator(*it)),
302*89c4ff92SAndroid Build Coastguard Worker                                   *it);
303*89c4ff92SAndroid Build Coastguard Worker     }
304*89c4ff92SAndroid Build Coastguard Worker 
305*89c4ff92SAndroid Build Coastguard Worker     if (!unsupportedOperators.empty() && m_Options.TfLiteRuntimeFallbackDisabled())
306*89c4ff92SAndroid Build Coastguard Worker     {
307*89c4ff92SAndroid Build Coastguard Worker         std::stringstream exMessage;
308*89c4ff92SAndroid Build Coastguard Worker         exMessage << "TfLiteArmnnOpaqueDelegate: There are unsupported operators in the model. ";
309*89c4ff92SAndroid Build Coastguard Worker         exMessage << "Not falling back to TfLite Runtime as fallback is disabled. ";
310*89c4ff92SAndroid Build Coastguard Worker         exMessage << "This should only be disabled under test conditions.";
311*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception(exMessage.str());
312*89c4ff92SAndroid Build Coastguard Worker     }
313*89c4ff92SAndroid Build Coastguard Worker     if (nodesToDelegate->size == 0)
314*89c4ff92SAndroid Build Coastguard Worker     {
315*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(info) << "No operators in this model are supported by the Arm NN TfLite delegate." <<
316*89c4ff92SAndroid Build Coastguard Worker                         " The model will be executed entirely by TfLite runtime.";
317*89c4ff92SAndroid Build Coastguard Worker     }
318*89c4ff92SAndroid Build Coastguard Worker 
319*89c4ff92SAndroid Build Coastguard Worker     std::sort(&nodesToDelegate->data[0], &nodesToDelegate->data[nodesToDelegate->size]);
320*89c4ff92SAndroid Build Coastguard Worker     return nodesToDelegate;
321*89c4ff92SAndroid Build Coastguard Worker }
322*89c4ff92SAndroid Build Coastguard Worker 
AddInputLayer(DelegateData & delegateData,TfLiteOpaqueContext * tfLiteContext,const TfLiteIntArray * inputs,std::vector<armnn::BindingPointInfo> & inputBindings)323*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::AddInputLayer(DelegateData& delegateData,
324*89c4ff92SAndroid Build Coastguard Worker                                           TfLiteOpaqueContext* tfLiteContext,
325*89c4ff92SAndroid Build Coastguard Worker                                           const TfLiteIntArray* inputs,
326*89c4ff92SAndroid Build Coastguard Worker                                           std::vector<armnn::BindingPointInfo>& inputBindings)
327*89c4ff92SAndroid Build Coastguard Worker {
328*89c4ff92SAndroid Build Coastguard Worker     const size_t numInputs = static_cast<size_t>(inputs->size);
329*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < numInputs; ++i)
330*89c4ff92SAndroid Build Coastguard Worker     {
331*89c4ff92SAndroid Build Coastguard Worker         const int32_t tensorId = inputs->data[i];
332*89c4ff92SAndroid Build Coastguard Worker         const TfLiteOpaqueTensor* tensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, tensorId);
333*89c4ff92SAndroid Build Coastguard Worker 
334*89c4ff92SAndroid Build Coastguard Worker         if(!tensor)
335*89c4ff92SAndroid Build Coastguard Worker         {
336*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
337*89c4ff92SAndroid Build Coastguard Worker         }
338*89c4ff92SAndroid Build Coastguard Worker 
339*89c4ff92SAndroid Build Coastguard Worker         // Do not create bindings for constant inputs
340*89c4ff92SAndroid Build Coastguard Worker         if (TfLiteOpaqueTensorGetAllocationType(tensor) == kTfLiteMmapRo)
341*89c4ff92SAndroid Build Coastguard Worker         {
342*89c4ff92SAndroid Build Coastguard Worker             continue;
343*89c4ff92SAndroid Build Coastguard Worker         }
344*89c4ff92SAndroid Build Coastguard Worker 
345*89c4ff92SAndroid Build Coastguard Worker         auto bindingId = static_cast<armnn::LayerBindingId>((tensorId));
346*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* layer = delegateData.m_Network->AddInputLayer(bindingId);
347*89c4ff92SAndroid Build Coastguard Worker 
348*89c4ff92SAndroid Build Coastguard Worker         auto tensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tensor);
349*89c4ff92SAndroid Build Coastguard Worker         armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
350*89c4ff92SAndroid Build Coastguard Worker         outputSlot.SetTensorInfo(tensorInfo);
351*89c4ff92SAndroid Build Coastguard Worker 
352*89c4ff92SAndroid Build Coastguard Worker         // Store for creating connections
353*89c4ff92SAndroid Build Coastguard Worker         delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tensorId)] = &outputSlot;
354*89c4ff92SAndroid Build Coastguard Worker 
355*89c4ff92SAndroid Build Coastguard Worker         inputBindings.push_back(std::make_pair(bindingId, tensorInfo));
356*89c4ff92SAndroid Build Coastguard Worker     }
357*89c4ff92SAndroid Build Coastguard Worker 
358*89c4ff92SAndroid Build Coastguard Worker     return kTfLiteOk;
359*89c4ff92SAndroid Build Coastguard Worker }
360*89c4ff92SAndroid Build Coastguard Worker 
AddOutputLayer(DelegateData & delegateData,TfLiteOpaqueContext * tfLiteContext,const TfLiteIntArray * outputs,std::vector<armnn::BindingPointInfo> & outputBindings)361*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::AddOutputLayer(DelegateData& delegateData,
362*89c4ff92SAndroid Build Coastguard Worker                                            TfLiteOpaqueContext* tfLiteContext,
363*89c4ff92SAndroid Build Coastguard Worker                                            const TfLiteIntArray* outputs,
364*89c4ff92SAndroid Build Coastguard Worker                                            std::vector<armnn::BindingPointInfo>& outputBindings)
365*89c4ff92SAndroid Build Coastguard Worker {
366*89c4ff92SAndroid Build Coastguard Worker     const size_t numOutputs = static_cast<size_t>(outputs->size);
367*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < numOutputs; ++i)
368*89c4ff92SAndroid Build Coastguard Worker     {
369*89c4ff92SAndroid Build Coastguard Worker         const int32_t tensorId = outputs->data[i];
370*89c4ff92SAndroid Build Coastguard Worker         const TfLiteOpaqueTensor* tensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, tensorId);
371*89c4ff92SAndroid Build Coastguard Worker 
372*89c4ff92SAndroid Build Coastguard Worker         if(!IsValid(tensor))
373*89c4ff92SAndroid Build Coastguard Worker         {
374*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
375*89c4ff92SAndroid Build Coastguard Worker         }
376*89c4ff92SAndroid Build Coastguard Worker 
377*89c4ff92SAndroid Build Coastguard Worker         auto bindingId = static_cast<armnn::LayerBindingId>((tensorId));
378*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* layer = delegateData.m_Network->AddOutputLayer(bindingId);
379*89c4ff92SAndroid Build Coastguard Worker 
380*89c4ff92SAndroid Build Coastguard Worker         auto tensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tensor);
381*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tensorId)] != nullptr);
382*89c4ff92SAndroid Build Coastguard Worker         delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tensorId)]->Connect(layer->GetInputSlot(0));
383*89c4ff92SAndroid Build Coastguard Worker         outputBindings.push_back(std::make_pair(bindingId, tensorInfo));
384*89c4ff92SAndroid Build Coastguard Worker     }
385*89c4ff92SAndroid Build Coastguard Worker 
386*89c4ff92SAndroid Build Coastguard Worker     return kTfLiteOk;
387*89c4ff92SAndroid Build Coastguard Worker }
388*89c4ff92SAndroid Build Coastguard Worker 
Create(TfLiteOpaqueContext * tfLiteContext,const TfLiteOpaqueDelegateParams * parameters,const ArmnnOpaqueDelegate * delegate)389*89c4ff92SAndroid Build Coastguard Worker ArmnnSubgraph* ArmnnSubgraph::Create(TfLiteOpaqueContext* tfLiteContext,
390*89c4ff92SAndroid Build Coastguard Worker                                      const TfLiteOpaqueDelegateParams* parameters,
391*89c4ff92SAndroid Build Coastguard Worker                                      const ArmnnOpaqueDelegate* delegate)
392*89c4ff92SAndroid Build Coastguard Worker {
393*89c4ff92SAndroid Build Coastguard Worker     const auto startTime = armnn::GetTimeNow();
394*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(info) << "ArmnnSubgraph creation";
395*89c4ff92SAndroid Build Coastguard Worker 
396*89c4ff92SAndroid Build Coastguard Worker     TfLiteIntArray* executionPlan;
397*89c4ff92SAndroid Build Coastguard Worker     if (TfLiteOpaqueContextGetExecutionPlan(tfLiteContext, &executionPlan) != kTfLiteOk)
398*89c4ff92SAndroid Build Coastguard Worker     {
399*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
400*89c4ff92SAndroid Build Coastguard Worker     }
401*89c4ff92SAndroid Build Coastguard Worker 
402*89c4ff92SAndroid Build Coastguard Worker     // Initialize DelegateData holds network and output slots information
403*89c4ff92SAndroid Build Coastguard Worker     DelegateData delegateData(delegate->m_Options.GetBackends());
404*89c4ff92SAndroid Build Coastguard Worker 
405*89c4ff92SAndroid Build Coastguard Worker     // Build ArmNN Network
406*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkOptions networkOptions = delegate->m_Options.GetOptimizerOptions().GetModelOptions();
407*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId networkId;
408*89c4ff92SAndroid Build Coastguard Worker     delegateData.m_Network = armnn::INetwork::Create(networkOptions);
409*89c4ff92SAndroid Build Coastguard Worker 
410*89c4ff92SAndroid Build Coastguard Worker     delegateData.m_OutputSlotForNode = std::vector<armnn::IOutputSlot*>(
411*89c4ff92SAndroid Build Coastguard Worker                                                             TfLiteOpaqueContextGetNumTensors(tfLiteContext), nullptr);
412*89c4ff92SAndroid Build Coastguard Worker 
413*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BindingPointInfo> inputBindings;
414*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BindingPointInfo> outputBindings;
415*89c4ff92SAndroid Build Coastguard Worker 
416*89c4ff92SAndroid Build Coastguard Worker     // Add input layer
417*89c4ff92SAndroid Build Coastguard Worker     if (AddInputLayer(delegateData, tfLiteContext, parameters->input_tensors, inputBindings) != kTfLiteOk)
418*89c4ff92SAndroid Build Coastguard Worker     {
419*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to add Inputs to the network!");
420*89c4ff92SAndroid Build Coastguard Worker     }
421*89c4ff92SAndroid Build Coastguard Worker 
422*89c4ff92SAndroid Build Coastguard Worker     // Parse TfLite delegate nodes to ArmNN
423*89c4ff92SAndroid Build Coastguard Worker     const auto parseStartTime = armnn::GetTimeNow();
424*89c4ff92SAndroid Build Coastguard Worker     for (int i = 0; i < parameters->nodes_to_replace->size; ++i)
425*89c4ff92SAndroid Build Coastguard Worker     {
426*89c4ff92SAndroid Build Coastguard Worker         const int nodeIndex = parameters->nodes_to_replace->data[i];
427*89c4ff92SAndroid Build Coastguard Worker 
428*89c4ff92SAndroid Build Coastguard Worker         TfLiteOpaqueNode* tfLiteNode = nullptr;
429*89c4ff92SAndroid Build Coastguard Worker         TfLiteRegistrationExternal* tfLiteRegistration = nullptr;
430*89c4ff92SAndroid Build Coastguard Worker         if (TfLiteOpaqueContextGetNodeAndRegistration(
431*89c4ff92SAndroid Build Coastguard Worker             tfLiteContext, nodeIndex, &tfLiteNode, &tfLiteRegistration) != kTfLiteOk)
432*89c4ff92SAndroid Build Coastguard Worker         {
433*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception(&"TfLiteArmnnOpaqueDelegate: Unable to get node registration: " [ nodeIndex]);
434*89c4ff92SAndroid Build Coastguard Worker         }
435*89c4ff92SAndroid Build Coastguard Worker 
436*89c4ff92SAndroid Build Coastguard Worker         if (VisitNode(delegateData, tfLiteContext, tfLiteRegistration, tfLiteNode, nodeIndex) != kTfLiteOk)
437*89c4ff92SAndroid Build Coastguard Worker         {
438*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception(&"TfLiteArmnnOpaqueDelegate: Unable to parse node: " [ nodeIndex]);
439*89c4ff92SAndroid Build Coastguard Worker         }
440*89c4ff92SAndroid Build Coastguard Worker     }
441*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(info) << "Parse nodes to ArmNN time: " << std::setprecision(2)
442*89c4ff92SAndroid Build Coastguard Worker                     << std::fixed << armnn::GetTimeDuration(parseStartTime).count() << " ms";
443*89c4ff92SAndroid Build Coastguard Worker 
444*89c4ff92SAndroid Build Coastguard Worker     // Add Output layer
445*89c4ff92SAndroid Build Coastguard Worker     if (AddOutputLayer(delegateData, tfLiteContext, parameters->output_tensors, outputBindings) != kTfLiteOk)
446*89c4ff92SAndroid Build Coastguard Worker     {
447*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to add Outputs to the network!");
448*89c4ff92SAndroid Build Coastguard Worker     }
449*89c4ff92SAndroid Build Coastguard Worker 
450*89c4ff92SAndroid Build Coastguard Worker     // Optimize ArmNN network
451*89c4ff92SAndroid Build Coastguard Worker     armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
452*89c4ff92SAndroid Build Coastguard Worker     try
453*89c4ff92SAndroid Build Coastguard Worker     {
454*89c4ff92SAndroid Build Coastguard Worker         const auto optimizeStartTime = armnn::GetTimeNow();
455*89c4ff92SAndroid Build Coastguard Worker         optNet = armnn::Optimize(*(delegateData.m_Network.get()),
456*89c4ff92SAndroid Build Coastguard Worker                                  delegate->m_Options.GetBackends(),
457*89c4ff92SAndroid Build Coastguard Worker                                  delegate->m_Runtime->GetDeviceSpec(),
458*89c4ff92SAndroid Build Coastguard Worker                                  delegate->m_Options.GetOptimizerOptions());
459*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(info) << "Optimize ArmnnSubgraph time: " << std::setprecision(2)
460*89c4ff92SAndroid Build Coastguard Worker                         << std::fixed << armnn::GetTimeDuration(optimizeStartTime).count() << " ms";
461*89c4ff92SAndroid Build Coastguard Worker     }
462*89c4ff92SAndroid Build Coastguard Worker     catch (std::exception& ex)
463*89c4ff92SAndroid Build Coastguard Worker     {
464*89c4ff92SAndroid Build Coastguard Worker         std::stringstream exMessage;
465*89c4ff92SAndroid Build Coastguard Worker         exMessage << "TfLiteArmnnOpaqueDelegate: Exception (" << ex.what() << ") caught from optimize.";
466*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception(exMessage.str());
467*89c4ff92SAndroid Build Coastguard Worker     }
468*89c4ff92SAndroid Build Coastguard Worker     if (!optNet)
469*89c4ff92SAndroid Build Coastguard Worker     {
470*89c4ff92SAndroid Build Coastguard Worker         // Optimize failed
471*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to optimize the network!");
472*89c4ff92SAndroid Build Coastguard Worker     }
473*89c4ff92SAndroid Build Coastguard Worker 
474*89c4ff92SAndroid Build Coastguard Worker     // If set, we will serialize the optimized model into a dot file.
475*89c4ff92SAndroid Build Coastguard Worker     const std::string serializeToDotFile = delegate->m_Options.GetSerializeToDot();
476*89c4ff92SAndroid Build Coastguard Worker     if (!serializeToDotFile.empty())
477*89c4ff92SAndroid Build Coastguard Worker     {
478*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(info) << "Writing graph to dot file: " << serializeToDotFile;
479*89c4ff92SAndroid Build Coastguard Worker         fs::path filename = serializeToDotFile;
480*89c4ff92SAndroid Build Coastguard Worker         std::fstream file(filename.c_str(), std::ios_base::out);
481*89c4ff92SAndroid Build Coastguard Worker         optNet->SerializeToDot(file);
482*89c4ff92SAndroid Build Coastguard Worker     }
483*89c4ff92SAndroid Build Coastguard Worker 
484*89c4ff92SAndroid Build Coastguard Worker     try
485*89c4ff92SAndroid Build Coastguard Worker     {
486*89c4ff92SAndroid Build Coastguard Worker         const auto loadStartTime = armnn::GetTimeNow();
487*89c4ff92SAndroid Build Coastguard Worker 
488*89c4ff92SAndroid Build Coastguard Worker         // Load graph into runtime
489*89c4ff92SAndroid Build Coastguard Worker         std::string errorMessage;
490*89c4ff92SAndroid Build Coastguard Worker         armnn::Status loadingStatus;
491*89c4ff92SAndroid Build Coastguard Worker         armnn::MemorySource inputSource = armnn::MemorySource::Undefined;
492*89c4ff92SAndroid Build Coastguard Worker         armnn::MemorySource outputSource = armnn::MemorySource::Undefined;
493*89c4ff92SAndroid Build Coastguard Worker         // There's a bit of an assumption here that the delegate will only support Malloc memory source.
494*89c4ff92SAndroid Build Coastguard Worker         if (delegate->m_Options.GetOptimizerOptions().GetImportEnabled())
495*89c4ff92SAndroid Build Coastguard Worker         {
496*89c4ff92SAndroid Build Coastguard Worker             inputSource = armnn::MemorySource::Malloc;
497*89c4ff92SAndroid Build Coastguard Worker         }
498*89c4ff92SAndroid Build Coastguard Worker         if (delegate->m_Options.GetOptimizerOptions().GetExportEnabled())
499*89c4ff92SAndroid Build Coastguard Worker         {
500*89c4ff92SAndroid Build Coastguard Worker             outputSource = armnn::MemorySource::Malloc;
501*89c4ff92SAndroid Build Coastguard Worker         }
502*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkProperties networkProperties(false,
503*89c4ff92SAndroid Build Coastguard Worker                                                     inputSource,
504*89c4ff92SAndroid Build Coastguard Worker                                                     outputSource,
505*89c4ff92SAndroid Build Coastguard Worker                                                     delegate->m_Options.GetInternalProfilingState(),
506*89c4ff92SAndroid Build Coastguard Worker                                                     delegate->m_Options.GetInternalProfilingDetail());
507*89c4ff92SAndroid Build Coastguard Worker         loadingStatus = delegate->m_Runtime->LoadNetwork(networkId,
508*89c4ff92SAndroid Build Coastguard Worker                                                          std::move(optNet),
509*89c4ff92SAndroid Build Coastguard Worker                                                          errorMessage,
510*89c4ff92SAndroid Build Coastguard Worker                                                          networkProperties);
511*89c4ff92SAndroid Build Coastguard Worker         if (loadingStatus != armnn::Status::Success)
512*89c4ff92SAndroid Build Coastguard Worker         {
513*89c4ff92SAndroid Build Coastguard Worker             // Network load failed.
514*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Network could not be loaded: " + errorMessage);
515*89c4ff92SAndroid Build Coastguard Worker         }
516*89c4ff92SAndroid Build Coastguard Worker 
517*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(info) << "Load ArmnnSubgraph time: " << std::setprecision(2)
518*89c4ff92SAndroid Build Coastguard Worker                         << std::fixed << armnn::GetTimeDuration(loadStartTime).count() << " ms";
519*89c4ff92SAndroid Build Coastguard Worker     }
520*89c4ff92SAndroid Build Coastguard Worker     catch (std::exception& ex)
521*89c4ff92SAndroid Build Coastguard Worker     {
522*89c4ff92SAndroid Build Coastguard Worker         std::stringstream exMessage;
523*89c4ff92SAndroid Build Coastguard Worker         exMessage << "TfLiteArmnnOpaqueDelegate: Exception (" << ex.what() << ") caught from LoadNetwork.";
524*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception(exMessage.str());
525*89c4ff92SAndroid Build Coastguard Worker     }
526*89c4ff92SAndroid Build Coastguard Worker 
527*89c4ff92SAndroid Build Coastguard Worker     // Register debug callback function
528*89c4ff92SAndroid Build Coastguard Worker     if (delegate->m_Options.GetDebugCallbackFunction().has_value())
529*89c4ff92SAndroid Build Coastguard Worker     {
530*89c4ff92SAndroid Build Coastguard Worker         delegate->m_Runtime->RegisterDebugCallback(networkId, delegate->m_Options.GetDebugCallbackFunction().value());
531*89c4ff92SAndroid Build Coastguard Worker     }
532*89c4ff92SAndroid Build Coastguard Worker 
533*89c4ff92SAndroid Build Coastguard Worker     ARMNN_LOG(info) << "Overall ArmnnSubgraph creation time: " << std::setprecision(2)
534*89c4ff92SAndroid Build Coastguard Worker                     << std::fixed << armnn::GetTimeDuration(startTime).count() << " ms\n";
535*89c4ff92SAndroid Build Coastguard Worker 
536*89c4ff92SAndroid Build Coastguard Worker     // Create a new SubGraph with networkId and runtime
537*89c4ff92SAndroid Build Coastguard Worker     return new ArmnnSubgraph(networkId, delegate->m_Runtime, inputBindings, outputBindings);
538*89c4ff92SAndroid Build Coastguard Worker }
539*89c4ff92SAndroid Build Coastguard Worker 
Prepare(TfLiteOpaqueContext * tfLiteContext)540*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::Prepare(TfLiteOpaqueContext* tfLiteContext)
541*89c4ff92SAndroid Build Coastguard Worker {
542*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(tfLiteContext);
543*89c4ff92SAndroid Build Coastguard Worker     return kTfLiteOk;
544*89c4ff92SAndroid Build Coastguard Worker }
545*89c4ff92SAndroid Build Coastguard Worker 
Invoke(TfLiteOpaqueContext * tfLiteContext,TfLiteOpaqueNode * tfLiteNode)546*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::Invoke(TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueNode* tfLiteNode)
547*89c4ff92SAndroid Build Coastguard Worker {
548*89c4ff92SAndroid Build Coastguard Worker     // Get array of input indices, inputIndexArray is set from the TfLiteOpaqueNodeInputs function
549*89c4ff92SAndroid Build Coastguard Worker     // This function turns inputIndexArray into an int array of indices. These indices point to the tensors for
550*89c4ff92SAndroid Build Coastguard Worker     // each input slot in the node.
551*89c4ff92SAndroid Build Coastguard Worker     const int* inputIndexArray;
552*89c4ff92SAndroid Build Coastguard Worker     int numInputs;
553*89c4ff92SAndroid Build Coastguard Worker     if(TfLiteOpaqueNodeInputs(tfLiteNode, &inputIndexArray, &numInputs) != kTfLiteOk)
554*89c4ff92SAndroid Build Coastguard Worker     {
555*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to load subgraph inputs!");
556*89c4ff92SAndroid Build Coastguard Worker     }
557*89c4ff92SAndroid Build Coastguard Worker     // Prepare inputs
558*89c4ff92SAndroid Build Coastguard Worker     armnn::InputTensors inputTensors;
559*89c4ff92SAndroid Build Coastguard Worker     size_t inputIndex = 0;
560*89c4ff92SAndroid Build Coastguard Worker     for (int inputIdx = 0; inputIdx < numInputs; inputIdx++)
561*89c4ff92SAndroid Build Coastguard Worker     {
562*89c4ff92SAndroid Build Coastguard Worker         TfLiteOpaqueTensor* tensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, inputIndexArray[inputIdx]);
563*89c4ff92SAndroid Build Coastguard Worker 
564*89c4ff92SAndroid Build Coastguard Worker         if(!IsValid(tensor))
565*89c4ff92SAndroid Build Coastguard Worker         {
566*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
567*89c4ff92SAndroid Build Coastguard Worker         }
568*89c4ff92SAndroid Build Coastguard Worker         // If tensor is not read only
569*89c4ff92SAndroid Build Coastguard Worker         if (TfLiteOpaqueTensorGetAllocationType(tensor) != kTfLiteMmapRo)
570*89c4ff92SAndroid Build Coastguard Worker         {
571*89c4ff92SAndroid Build Coastguard Worker             const armnn::BindingPointInfo& inputBinding = m_InputBindings[inputIndex];
572*89c4ff92SAndroid Build Coastguard Worker             armnn::TensorInfo inputTensorInfo = inputBinding.second;
573*89c4ff92SAndroid Build Coastguard Worker             inputTensorInfo.SetConstant(true);
574*89c4ff92SAndroid Build Coastguard Worker             const armnn::ConstTensor inputTensor(inputTensorInfo, TfLiteOpaqueTensorData(tensor));
575*89c4ff92SAndroid Build Coastguard Worker             inputTensors.emplace_back(inputIdx, inputTensor);
576*89c4ff92SAndroid Build Coastguard Worker 
577*89c4ff92SAndroid Build Coastguard Worker             ++inputIndex;
578*89c4ff92SAndroid Build Coastguard Worker         }
579*89c4ff92SAndroid Build Coastguard Worker     }
580*89c4ff92SAndroid Build Coastguard Worker 
581*89c4ff92SAndroid Build Coastguard Worker     // Get array of output indices, outputIndexArray is set from the TfLiteOpaqueNodeOutputs function
582*89c4ff92SAndroid Build Coastguard Worker     // This function turns outputIndexArray into an int array of indices. These indices point to the tensors for
583*89c4ff92SAndroid Build Coastguard Worker     // each output slot in the node.
584*89c4ff92SAndroid Build Coastguard Worker     const int* outputIndexArray;
585*89c4ff92SAndroid Build Coastguard Worker     int numOutputs;
586*89c4ff92SAndroid Build Coastguard Worker     if(TfLiteOpaqueNodeOutputs(tfLiteNode, &outputIndexArray, &numOutputs) != kTfLiteOk)
587*89c4ff92SAndroid Build Coastguard Worker     {
588*89c4ff92SAndroid Build Coastguard Worker         throw armnn::Exception("TfLiteArmnnOpaqueDelegate: Unable to load subgraph outputs!");
589*89c4ff92SAndroid Build Coastguard Worker     }
590*89c4ff92SAndroid Build Coastguard Worker     // Assign the tensors from the outputIndexArray to the armnn BindingPointInfo
591*89c4ff92SAndroid Build Coastguard Worker     armnn::OutputTensors outputTensors;
592*89c4ff92SAndroid Build Coastguard Worker     for (int outputIdx = 0; outputIdx < numOutputs; outputIdx++)
593*89c4ff92SAndroid Build Coastguard Worker     {
594*89c4ff92SAndroid Build Coastguard Worker         const armnn::BindingPointInfo& outputBinding = m_OutputBindings[outputIdx];
595*89c4ff92SAndroid Build Coastguard Worker         TfLiteOpaqueTensor* tensor = TfLiteOpaqueContextGetOpaqueTensor(tfLiteContext, outputIndexArray[outputIdx]);
596*89c4ff92SAndroid Build Coastguard Worker         if(!IsValid(tensor))
597*89c4ff92SAndroid Build Coastguard Worker         {
598*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
599*89c4ff92SAndroid Build Coastguard Worker         }
600*89c4ff92SAndroid Build Coastguard Worker 
601*89c4ff92SAndroid Build Coastguard Worker         const armnn::Tensor outputTensor(outputBinding.second, reinterpret_cast<TfLiteTensor*>(tensor)->data
602*89c4ff92SAndroid Build Coastguard Worker         .data);
603*89c4ff92SAndroid Build Coastguard Worker         outputTensors.emplace_back(outputIndexArray[outputIdx], outputTensor);
604*89c4ff92SAndroid Build Coastguard Worker     }
605*89c4ff92SAndroid Build Coastguard Worker 
606*89c4ff92SAndroid Build Coastguard Worker     // Run graph
607*89c4ff92SAndroid Build Coastguard Worker     auto status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
608*89c4ff92SAndroid Build Coastguard Worker     // The delegate holds its own Arm NN runtime so this is our last chance to print internal profiling data.
609*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkId);
610*89c4ff92SAndroid Build Coastguard Worker     if (profiler && profiler->IsProfilingEnabled())
611*89c4ff92SAndroid Build Coastguard Worker     {
612*89c4ff92SAndroid Build Coastguard Worker         profiler->Print(std::cout);
613*89c4ff92SAndroid Build Coastguard Worker     }
614*89c4ff92SAndroid Build Coastguard Worker     return (status == armnn::Status::Success) ? kTfLiteOk : kTfLiteError;
615*89c4ff92SAndroid Build Coastguard Worker }
616*89c4ff92SAndroid Build Coastguard Worker 
VisitNode(DelegateData & delegateData,TfLiteOpaqueContext * tfLiteContext,TfLiteRegistrationExternal * tfLiteRegistration,TfLiteOpaqueNode * tfLiteNode,int nodeIndex)617*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData,
618*89c4ff92SAndroid Build Coastguard Worker                                       TfLiteOpaqueContext* tfLiteContext,
619*89c4ff92SAndroid Build Coastguard Worker                                       TfLiteRegistrationExternal* tfLiteRegistration,
620*89c4ff92SAndroid Build Coastguard Worker                                       TfLiteOpaqueNode* tfLiteNode,
621*89c4ff92SAndroid Build Coastguard Worker                                       int nodeIndex)
622*89c4ff92SAndroid Build Coastguard Worker {
623*89c4ff92SAndroid Build Coastguard Worker     switch (TfLiteRegistrationExternalGetBuiltInCode(tfLiteRegistration))
624*89c4ff92SAndroid Build Coastguard Worker     {
625*89c4ff92SAndroid Build Coastguard Worker         case kTfLiteBuiltinCast:
626*89c4ff92SAndroid Build Coastguard Worker             return VisitCastOperator(delegateData,
627*89c4ff92SAndroid Build Coastguard Worker                                      tfLiteContext,
628*89c4ff92SAndroid Build Coastguard Worker                                      tfLiteNode,
629*89c4ff92SAndroid Build Coastguard Worker                                      nodeIndex,
630*89c4ff92SAndroid Build Coastguard Worker                                      kTfLiteBuiltinCast);
631*89c4ff92SAndroid Build Coastguard Worker         default:
632*89c4ff92SAndroid Build Coastguard Worker             return kTfLiteError;
633*89c4ff92SAndroid Build Coastguard Worker     }
634*89c4ff92SAndroid Build Coastguard Worker }
635*89c4ff92SAndroid Build Coastguard Worker } // armnnOpaqueDelegate namespace