1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <DelegateOptions.hpp> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/builtin_ops.h> 11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/builtin_op_data.h> 12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/c/common.h> 13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/minimal_logging.h> 14*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/version.h> 15*89c4ff92SAndroid Build Coastguard Worker 16*89c4ff92SAndroid Build Coastguard Worker #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 3) 17*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_POST_TFLITE_2_3 18*89c4ff92SAndroid Build Coastguard Worker #endif 19*89c4ff92SAndroid Build Coastguard Worker 20*89c4ff92SAndroid Build Coastguard Worker #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 4) 21*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_POST_TFLITE_2_4 22*89c4ff92SAndroid Build Coastguard Worker #endif 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker #if TF_MAJOR_VERSION > 2 || (TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION > 5) 25*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_POST_TFLITE_2_5 26*89c4ff92SAndroid Build Coastguard Worker #endif 27*89c4ff92SAndroid Build Coastguard Worker 28*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate 29*89c4ff92SAndroid Build Coastguard Worker { 30*89c4ff92SAndroid Build Coastguard Worker 31*89c4ff92SAndroid Build Coastguard Worker struct DelegateData 32*89c4ff92SAndroid Build Coastguard Worker { DelegateDataarmnnDelegate::DelegateData33*89c4ff92SAndroid Build Coastguard Worker DelegateData(const std::vector<armnn::BackendId>& backends) 34*89c4ff92SAndroid Build Coastguard Worker : m_Backends(backends) 35*89c4ff92SAndroid Build Coastguard Worker , m_Network(nullptr, nullptr) 36*89c4ff92SAndroid Build Coastguard Worker {} 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::BackendId> m_Backends; 39*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr m_Network; 40*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::IOutputSlot*> m_OutputSlotForNode; 41*89c4ff92SAndroid Build Coastguard Worker }; 42*89c4ff92SAndroid Build Coastguard Worker 43*89c4ff92SAndroid Build Coastguard Worker // Forward decleration for functions initializing the ArmNN Delegate 44*89c4ff92SAndroid Build Coastguard Worker DelegateOptions TfLiteArmnnDelegateOptionsDefault(); 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options); 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate); 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate); 51*89c4ff92SAndroid Build Coastguard Worker 52*89c4ff92SAndroid Build Coastguard Worker /// ArmNN Delegate 53*89c4ff92SAndroid Build Coastguard Worker class Delegate 54*89c4ff92SAndroid Build Coastguard Worker { 55*89c4ff92SAndroid Build Coastguard Worker friend class ArmnnSubgraph; 56*89c4ff92SAndroid Build Coastguard Worker public: 57*89c4ff92SAndroid Build Coastguard Worker explicit Delegate(armnnDelegate::DelegateOptions options); 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context); 60*89c4ff92SAndroid Build Coastguard Worker 61*89c4ff92SAndroid Build Coastguard Worker TfLiteDelegate* GetDelegate(); 62*89c4ff92SAndroid Build Coastguard Worker 63*89c4ff92SAndroid Build Coastguard Worker /// Retrieve version in X.Y.Z form 64*89c4ff92SAndroid Build Coastguard Worker static const std::string GetVersion(); 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker private: 67*89c4ff92SAndroid Build Coastguard Worker /** 68*89c4ff92SAndroid Build Coastguard Worker * Returns a pointer to the armnn::IRuntime* this will be shared by all armnn_delegates. 69*89c4ff92SAndroid Build Coastguard Worker */ GetRuntime(const armnn::IRuntime::CreationOptions & options)70*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options) 71*89c4ff92SAndroid Build Coastguard Worker { 72*89c4ff92SAndroid Build Coastguard Worker static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options); 73*89c4ff92SAndroid Build Coastguard Worker // Instantiated on first use. 74*89c4ff92SAndroid Build Coastguard Worker return instance.get(); 75*89c4ff92SAndroid Build Coastguard Worker } 76*89c4ff92SAndroid Build Coastguard Worker 77*89c4ff92SAndroid Build Coastguard Worker TfLiteDelegate m_Delegate = { 78*89c4ff92SAndroid Build Coastguard Worker reinterpret_cast<void*>(this), // .data_ 79*89c4ff92SAndroid Build Coastguard Worker DoPrepare, // .Prepare 80*89c4ff92SAndroid Build Coastguard Worker nullptr, // .CopyFromBufferHandle 81*89c4ff92SAndroid Build Coastguard Worker nullptr, // .CopyToBufferHandle 82*89c4ff92SAndroid Build Coastguard Worker nullptr, // .FreeBufferHandle 83*89c4ff92SAndroid Build Coastguard Worker kTfLiteDelegateFlagsNone, // .flags 84*89c4ff92SAndroid Build Coastguard Worker nullptr, // .opaque_delegate_builder 85*89c4ff92SAndroid Build Coastguard Worker }; 86*89c4ff92SAndroid Build Coastguard Worker 87*89c4ff92SAndroid Build Coastguard Worker /// ArmNN Runtime pointer 88*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* m_Runtime; 89*89c4ff92SAndroid Build Coastguard Worker /// ArmNN Delegate Options 90*89c4ff92SAndroid Build Coastguard Worker armnnDelegate::DelegateOptions m_Options; 91*89c4ff92SAndroid Build Coastguard Worker }; 92*89c4ff92SAndroid Build Coastguard Worker 93*89c4ff92SAndroid Build Coastguard Worker /// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph 94*89c4ff92SAndroid Build Coastguard Worker class ArmnnSubgraph 95*89c4ff92SAndroid Build Coastguard Worker { 96*89c4ff92SAndroid Build Coastguard Worker public: 97*89c4ff92SAndroid Build Coastguard Worker static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext, 98*89c4ff92SAndroid Build Coastguard Worker const TfLiteDelegateParams* parameters, 99*89c4ff92SAndroid Build Coastguard Worker const Delegate* delegate); 100*89c4ff92SAndroid Build Coastguard Worker 101*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus Prepare(TfLiteContext* tfLiteContext); 102*89c4ff92SAndroid Build Coastguard Worker 103*89c4ff92SAndroid Build Coastguard Worker TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode); 104*89c4ff92SAndroid Build Coastguard Worker 105*89c4ff92SAndroid Build Coastguard Worker static TfLiteStatus VisitNode(DelegateData& delegateData, 106*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext, 107*89c4ff92SAndroid Build Coastguard Worker TfLiteRegistration* tfLiteRegistration, 108*89c4ff92SAndroid Build Coastguard Worker TfLiteNode* tfLiteNode, 109*89c4ff92SAndroid Build Coastguard Worker int nodeIndex); 110*89c4ff92SAndroid Build Coastguard Worker 111*89c4ff92SAndroid Build Coastguard Worker private: ArmnnSubgraph(armnn::NetworkId networkId,armnn::IRuntime * runtime,std::vector<armnn::BindingPointInfo> & inputBindings,std::vector<armnn::BindingPointInfo> & outputBindings)112*89c4ff92SAndroid Build Coastguard Worker ArmnnSubgraph(armnn::NetworkId networkId, 113*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* runtime, 114*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& inputBindings, 115*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& outputBindings) 116*89c4ff92SAndroid Build Coastguard Worker : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings) 117*89c4ff92SAndroid Build Coastguard Worker {} 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker static TfLiteStatus AddInputLayer(DelegateData& delegateData, 120*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext, 121*89c4ff92SAndroid Build Coastguard Worker const TfLiteIntArray* inputs, 122*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& inputBindings); 123*89c4ff92SAndroid Build Coastguard Worker 124*89c4ff92SAndroid Build Coastguard Worker static TfLiteStatus AddOutputLayer(DelegateData& delegateData, 125*89c4ff92SAndroid Build Coastguard Worker TfLiteContext* tfLiteContext, 126*89c4ff92SAndroid Build Coastguard Worker const TfLiteIntArray* outputs, 127*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo>& outputBindings); 128*89c4ff92SAndroid Build Coastguard Worker 129*89c4ff92SAndroid Build Coastguard Worker 130*89c4ff92SAndroid Build Coastguard Worker /// The Network Id 131*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId m_NetworkId; 132*89c4ff92SAndroid Build Coastguard Worker /// ArmNN Runtime 133*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* m_Runtime; 134*89c4ff92SAndroid Build Coastguard Worker 135*89c4ff92SAndroid Build Coastguard Worker // Binding information for inputs and outputs 136*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> m_InputBindings; 137*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BindingPointInfo> m_OutputBindings; 138*89c4ff92SAndroid Build Coastguard Worker 139*89c4ff92SAndroid Build Coastguard Worker }; 140*89c4ff92SAndroid Build Coastguard Worker 141*89c4ff92SAndroid Build Coastguard Worker } // armnnDelegate namespace