xref: /aosp_15_r20/external/armnn/delegate/classic/include/armnn_delegate.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <DelegateOptions.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/version.h>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3)
17*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_POST_TFLITE_2_3
18*89c4ff92SAndroid Build Coastguard Worker #endif
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 4)
21*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_POST_TFLITE_2_4
22*89c4ff92SAndroid Build Coastguard Worker #endif
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 5)
25*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_POST_TFLITE_2_5
26*89c4ff92SAndroid Build Coastguard Worker #endif
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker struct DelegateData
32*89c4ff92SAndroid Build Coastguard Worker {
DelegateDataarmnnDelegate::DelegateData33*89c4ff92SAndroid Build Coastguard Worker     DelegateData(const std::vector<armnn::BackendId>& backends)
34*89c4ff92SAndroid Build Coastguard Worker             : m_Backends(backends)
35*89c4ff92SAndroid Build Coastguard Worker             , m_Network(nullptr, nullptr)
36*89c4ff92SAndroid Build Coastguard Worker     {}
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::BackendId>       m_Backends;
39*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr                        m_Network;
40*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::IOutputSlot*>          m_OutputSlotForNode;
41*89c4ff92SAndroid Build Coastguard Worker };
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker // Forward decleration for functions initializing the ArmNN Delegate
44*89c4ff92SAndroid Build Coastguard Worker DelegateOptions TfLiteArmnnDelegateOptionsDefault();
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options);
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate);
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate);
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker /// ArmNN Delegate
53*89c4ff92SAndroid Build Coastguard Worker class Delegate
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker     friend class ArmnnSubgraph;
56*89c4ff92SAndroid Build Coastguard Worker public:
57*89c4ff92SAndroid Build Coastguard Worker     explicit Delegate(armnnDelegate::DelegateOptions options);
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context);
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     TfLiteDelegate* GetDelegate();
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     /// Retrieve version in X.Y.Z form
64*89c4ff92SAndroid Build Coastguard Worker     static const std::string GetVersion();
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker private:
67*89c4ff92SAndroid Build Coastguard Worker     /**
68*89c4ff92SAndroid Build Coastguard Worker      * Returns a pointer to the armnn::IRuntime* this will be shared by all armnn_delegates.
69*89c4ff92SAndroid Build Coastguard Worker      */
GetRuntime(const armnn::IRuntime::CreationOptions & options)70*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options)
71*89c4ff92SAndroid Build Coastguard Worker     {
72*89c4ff92SAndroid Build Coastguard Worker         static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options);
73*89c4ff92SAndroid Build Coastguard Worker         // Instantiated on first use.
74*89c4ff92SAndroid Build Coastguard Worker         return instance.get();
75*89c4ff92SAndroid Build Coastguard Worker     }
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     TfLiteDelegate m_Delegate = {
78*89c4ff92SAndroid Build Coastguard Worker             reinterpret_cast<void*>(this),  // .data_
79*89c4ff92SAndroid Build Coastguard Worker             DoPrepare,                      // .Prepare
80*89c4ff92SAndroid Build Coastguard Worker             nullptr,                        // .CopyFromBufferHandle
81*89c4ff92SAndroid Build Coastguard Worker             nullptr,                        // .CopyToBufferHandle
82*89c4ff92SAndroid Build Coastguard Worker             nullptr,                        // .FreeBufferHandle
83*89c4ff92SAndroid Build Coastguard Worker             kTfLiteDelegateFlagsNone,       // .flags
84*89c4ff92SAndroid Build Coastguard Worker             nullptr,                        // .opaque_delegate_builder
85*89c4ff92SAndroid Build Coastguard Worker     };
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     /// ArmNN Runtime pointer
88*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntime* m_Runtime;
89*89c4ff92SAndroid Build Coastguard Worker     /// ArmNN Delegate Options
90*89c4ff92SAndroid Build Coastguard Worker     armnnDelegate::DelegateOptions m_Options;
91*89c4ff92SAndroid Build Coastguard Worker };
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker /// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
94*89c4ff92SAndroid Build Coastguard Worker class ArmnnSubgraph
95*89c4ff92SAndroid Build Coastguard Worker {
96*89c4ff92SAndroid Build Coastguard Worker public:
97*89c4ff92SAndroid Build Coastguard Worker     static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
98*89c4ff92SAndroid Build Coastguard Worker                                  const TfLiteDelegateParams* parameters,
99*89c4ff92SAndroid Build Coastguard Worker                                  const Delegate* delegate);
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus Prepare(TfLiteContext* tfLiteContext);
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode);
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker     static TfLiteStatus VisitNode(DelegateData& delegateData,
106*89c4ff92SAndroid Build Coastguard Worker                                   TfLiteContext* tfLiteContext,
107*89c4ff92SAndroid Build Coastguard Worker                                   TfLiteRegistration* tfLiteRegistration,
108*89c4ff92SAndroid Build Coastguard Worker                                   TfLiteNode* tfLiteNode,
109*89c4ff92SAndroid Build Coastguard Worker                                   int nodeIndex);
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker private:
ArmnnSubgraph(armnn::NetworkId networkId,armnn::IRuntime * runtime,std::vector<armnn::BindingPointInfo> & inputBindings,std::vector<armnn::BindingPointInfo> & outputBindings)112*89c4ff92SAndroid Build Coastguard Worker     ArmnnSubgraph(armnn::NetworkId networkId,
113*89c4ff92SAndroid Build Coastguard Worker                   armnn::IRuntime* runtime,
114*89c4ff92SAndroid Build Coastguard Worker                   std::vector<armnn::BindingPointInfo>& inputBindings,
115*89c4ff92SAndroid Build Coastguard Worker                   std::vector<armnn::BindingPointInfo>& outputBindings)
116*89c4ff92SAndroid Build Coastguard Worker         : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings)
117*89c4ff92SAndroid Build Coastguard Worker     {}
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker     static TfLiteStatus AddInputLayer(DelegateData& delegateData,
120*89c4ff92SAndroid Build Coastguard Worker                                       TfLiteContext* tfLiteContext,
121*89c4ff92SAndroid Build Coastguard Worker                                       const TfLiteIntArray* inputs,
122*89c4ff92SAndroid Build Coastguard Worker                                       std::vector<armnn::BindingPointInfo>& inputBindings);
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker     static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
125*89c4ff92SAndroid Build Coastguard Worker                                        TfLiteContext* tfLiteContext,
126*89c4ff92SAndroid Build Coastguard Worker                                        const TfLiteIntArray* outputs,
127*89c4ff92SAndroid Build Coastguard Worker                                        std::vector<armnn::BindingPointInfo>& outputBindings);
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker     /// The Network Id
131*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId m_NetworkId;
132*89c4ff92SAndroid Build Coastguard Worker     /// ArmNN Runtime
133*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntime* m_Runtime;
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker     // Binding information for inputs and outputs
136*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BindingPointInfo> m_InputBindings;
137*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BindingPointInfo> m_OutputBindings;
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker };
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker } // armnnDelegate namespace