1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <DelegateOptions.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/c_api_opaque.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/core/experimental/acceleration/configuration/c/stable_delegate.h>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker namespace armnnOpaqueDelegate
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker struct DelegateData
17*89c4ff92SAndroid Build Coastguard Worker {
DelegateDataarmnnOpaqueDelegate::DelegateData18*89c4ff92SAndroid Build Coastguard Worker DelegateData(const std::vector<armnn::BackendId>& backends)
19*89c4ff92SAndroid Build Coastguard Worker : m_Backends(backends)
20*89c4ff92SAndroid Build Coastguard Worker , m_Network(nullptr, nullptr)
21*89c4ff92SAndroid Build Coastguard Worker {}
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::BackendId> m_Backends;
24*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr m_Network;
25*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::IOutputSlot*> m_OutputSlotForNode;
26*89c4ff92SAndroid Build Coastguard Worker };
27*89c4ff92SAndroid Build Coastguard Worker
28*89c4ff92SAndroid Build Coastguard Worker /// Forward declaration for functions initializing the ArmNN Delegate
29*89c4ff92SAndroid Build Coastguard Worker ::armnnDelegate::DelegateOptions TfLiteArmnnDelegateOptionsDefault();
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueDelegate* TfLiteArmnnOpaqueDelegateCreate(const void* settings);
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker void TfLiteArmnnOpaqueDelegateDelete(TfLiteOpaqueDelegate* tfLiteDelegate);
34*89c4ff92SAndroid Build Coastguard Worker
35*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus DoPrepare(TfLiteOpaqueContext* context, TfLiteOpaqueDelegate* delegate, void* data);
36*89c4ff92SAndroid Build Coastguard Worker
37*89c4ff92SAndroid Build Coastguard Worker /// ArmNN Opaque Delegate
38*89c4ff92SAndroid Build Coastguard Worker class ArmnnOpaqueDelegate
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker friend class ArmnnSubgraph;
41*89c4ff92SAndroid Build Coastguard Worker public:
42*89c4ff92SAndroid Build Coastguard Worker explicit ArmnnOpaqueDelegate(armnnDelegate::DelegateOptions options);
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteOpaqueContext* context);
45*89c4ff92SAndroid Build Coastguard Worker
GetDelegateBuilder()46*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueDelegateBuilder* GetDelegateBuilder() { return &m_Builder; }
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker /// Retrieve version in X.Y.Z form
49*89c4ff92SAndroid Build Coastguard Worker static const std::string GetVersion();
50*89c4ff92SAndroid Build Coastguard Worker
51*89c4ff92SAndroid Build Coastguard Worker private:
52*89c4ff92SAndroid Build Coastguard Worker /**
53*89c4ff92SAndroid Build Coastguard Worker * Returns a pointer to the armnn::IRuntime* this will be shared by all armnn_delegates.
54*89c4ff92SAndroid Build Coastguard Worker */
GetRuntime(const armnn::IRuntime::CreationOptions & options)55*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options);
58*89c4ff92SAndroid Build Coastguard Worker /// Instantiated on first use.
59*89c4ff92SAndroid Build Coastguard Worker return instance.get();
60*89c4ff92SAndroid Build Coastguard Worker }
61*89c4ff92SAndroid Build Coastguard Worker
62*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueDelegateBuilder m_Builder =
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker reinterpret_cast<void*>(this), // .data_
65*89c4ff92SAndroid Build Coastguard Worker DoPrepare, // .Prepare
66*89c4ff92SAndroid Build Coastguard Worker nullptr, // .CopyFromBufferHandle
67*89c4ff92SAndroid Build Coastguard Worker nullptr, // .CopyToBufferHandle
68*89c4ff92SAndroid Build Coastguard Worker nullptr, // .FreeBufferHandle
69*89c4ff92SAndroid Build Coastguard Worker kTfLiteDelegateFlagsNone, // .flags
70*89c4ff92SAndroid Build Coastguard Worker };
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker /// ArmNN Runtime pointer
73*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* m_Runtime;
74*89c4ff92SAndroid Build Coastguard Worker /// ArmNN Delegate Options
75*89c4ff92SAndroid Build Coastguard Worker armnnDelegate::DelegateOptions m_Options;
76*89c4ff92SAndroid Build Coastguard Worker };
77*89c4ff92SAndroid Build Coastguard Worker
TfLiteArmnnOpaqueDelegateErrno(TfLiteOpaqueDelegate * delegate)78*89c4ff92SAndroid Build Coastguard Worker static int TfLiteArmnnOpaqueDelegateErrno(TfLiteOpaqueDelegate* delegate) { return 0; }
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker /// In order for the delegate to be loaded by TfLite
81*89c4ff92SAndroid Build Coastguard Worker const TfLiteOpaqueDelegatePlugin* GetArmnnDelegatePluginApi();
82*89c4ff92SAndroid Build Coastguard Worker
83*89c4ff92SAndroid Build Coastguard Worker extern const TfLiteStableDelegate TFL_TheStableDelegate;
84*89c4ff92SAndroid Build Coastguard Worker
85*89c4ff92SAndroid Build Coastguard Worker /// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
86*89c4ff92SAndroid Build Coastguard Worker class ArmnnSubgraph
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker public:
89*89c4ff92SAndroid Build Coastguard Worker static ArmnnSubgraph* Create(TfLiteOpaqueContext* tfLiteContext,
90*89c4ff92SAndroid Build Coastguard Worker const TfLiteOpaqueDelegateParams* parameters,
91*89c4ff92SAndroid Build Coastguard Worker const ArmnnOpaqueDelegate* delegate);
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus Prepare(TfLiteOpaqueContext* tfLiteContext);
94*89c4ff92SAndroid Build Coastguard Worker
95*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus Invoke(TfLiteOpaqueContext* tfLiteContext, TfLiteOpaqueNode* tfLiteNode);
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker static TfLiteStatus VisitNode(DelegateData& delegateData,
98*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueContext* tfLiteContext,
99*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistrationExternal* tfLiteRegistration,
100*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueNode* tfLiteNode,
101*89c4ff92SAndroid Build Coastguard Worker int nodeIndex);
102*89c4ff92SAndroid Build Coastguard Worker private:
ArmnnSubgraph(armnn::NetworkId networkId,armnn::IRuntime * runtime,std::vector<armnn::BindingPointInfo> & inputBindings,std::vector<armnn::BindingPointInfo> & outputBindings)103*89c4ff92SAndroid Build Coastguard Worker ArmnnSubgraph(armnn::NetworkId networkId,
104*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* runtime,
105*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& inputBindings,
106*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& outputBindings)
107*89c4ff92SAndroid Build Coastguard Worker : m_NetworkId(networkId)
108*89c4ff92SAndroid Build Coastguard Worker , m_Runtime(runtime)
109*89c4ff92SAndroid Build Coastguard Worker , m_InputBindings(inputBindings)
110*89c4ff92SAndroid Build Coastguard Worker , m_OutputBindings(outputBindings)
111*89c4ff92SAndroid Build Coastguard Worker {}
112*89c4ff92SAndroid Build Coastguard Worker static TfLiteStatus AddInputLayer(DelegateData& delegateData,
113*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueContext* tfLiteContext,
114*89c4ff92SAndroid Build Coastguard Worker const TfLiteIntArray* inputs,
115*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& inputBindings);
116*89c4ff92SAndroid Build Coastguard Worker static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
117*89c4ff92SAndroid Build Coastguard Worker TfLiteOpaqueContext* tfLiteContext,
118*89c4ff92SAndroid Build Coastguard Worker const TfLiteIntArray* outputs,
119*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& outputBindings);
120*89c4ff92SAndroid Build Coastguard Worker /// The Network Id
121*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId m_NetworkId;
122*89c4ff92SAndroid Build Coastguard Worker /// ArmNN Runtime
123*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* m_Runtime;
124*89c4ff92SAndroid Build Coastguard Worker /// Binding information for inputs and outputs
125*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> m_InputBindings;
126*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> m_OutputBindings;
127*89c4ff92SAndroid Build Coastguard Worker };
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker } // armnnOpaqueDelegate namespace