1*3e777be0SXin Li // 2*3e777be0SXin Li // Copyright © 2017 Arm Ltd. All rights reserved. 3*3e777be0SXin Li // SPDX-License-Identifier: MIT 4*3e777be0SXin Li // 5*3e777be0SXin Li 6*3e777be0SXin Li #pragma once 7*3e777be0SXin Li 8*3e777be0SXin Li #include "ArmnnDriver.hpp" 9*3e777be0SXin Li #include "ArmnnDriverImpl.hpp" 10*3e777be0SXin Li #include "RequestThread.hpp" 11*3e777be0SXin Li 12*3e777be0SXin Li #include <NeuralNetworks.h> 13*3e777be0SXin Li #include <armnn/ArmNN.hpp> 14*3e777be0SXin Li #include <armnn/Threadpool.hpp> 15*3e777be0SXin Li 16*3e777be0SXin Li #include <string> 17*3e777be0SXin Li #include <vector> 18*3e777be0SXin Li 19*3e777be0SXin Li namespace armnn_driver 20*3e777be0SXin Li { 21*3e777be0SXin Li using armnnExecuteCallback_1_0 = std::function<void(V1_0::ErrorStatus status, std::string callingFunction)>; 22*3e777be0SXin Li 23*3e777be0SXin Li struct ArmnnCallback_1_0 24*3e777be0SXin Li { 25*3e777be0SXin Li armnnExecuteCallback_1_0 callback; 26*3e777be0SXin Li }; 27*3e777be0SXin Li 28*3e777be0SXin Li struct ExecutionContext_1_0 {}; 29*3e777be0SXin Li 30*3e777be0SXin Li using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>; 31*3e777be0SXin Li 32*3e777be0SXin Li template <typename HalVersion> 33*3e777be0SXin Li class ArmnnPreparedModel : public V1_0::IPreparedModel 34*3e777be0SXin Li { 35*3e777be0SXin Li public: 36*3e777be0SXin Li using HalModel = typename HalVersion::Model; 37*3e777be0SXin Li 38*3e777be0SXin Li ArmnnPreparedModel(armnn::NetworkId networkId, 39*3e777be0SXin Li armnn::IRuntime* runtime, 40*3e777be0SXin Li const HalModel& model, 41*3e777be0SXin Li const std::string& requestInputsAndOutputsDumpDir, 42*3e777be0SXin Li const bool gpuProfilingEnabled, 43*3e777be0SXin Li const bool asyncModelExecutionEnabled = false, 44*3e777be0SXin Li const unsigned int numberOfThreads = 1, 45*3e777be0SXin Li const bool importEnabled = false, 46*3e777be0SXin Li const bool exportEnabled = false); 47*3e777be0SXin Li 48*3e777be0SXin Li virtual ~ArmnnPreparedModel(); 49*3e777be0SXin Li 50*3e777be0SXin Li virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request, 51*3e777be0SXin Li const ::android::sp<V1_0::IExecutionCallback>& callback) override; 52*3e777be0SXin Li 53*3e777be0SXin Li /// execute the graph prepared from the request 54*3e777be0SXin Li void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, 55*3e777be0SXin Li armnn::InputTensors& inputTensors, 56*3e777be0SXin Li armnn::OutputTensors& outputTensors, 57*3e777be0SXin Li CallbackContext_1_0 callback); 58*3e777be0SXin Li 59*3e777be0SXin Li /// Executes this model with dummy inputs (e.g. all zeroes). 60*3e777be0SXin Li /// \return false on failure, otherwise true 61*3e777be0SXin Li bool ExecuteWithDummyInputs(); 62*3e777be0SXin Li 63*3e777be0SXin Li private: 64*3e777be0SXin Li 65*3e777be0SXin Li template<typename CallbackContext> 66*3e777be0SXin Li class ArmnnThreadPoolCallback : public armnn::IAsyncExecutionCallback 67*3e777be0SXin Li { 68*3e777be0SXin Li public: ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion> * model,std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext)69*3e777be0SXin Li ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion>* model, 70*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, 71*3e777be0SXin Li std::shared_ptr<armnn::InputTensors>& inputTensors, 72*3e777be0SXin Li std::shared_ptr<armnn::OutputTensors>& outputTensors, 73*3e777be0SXin Li CallbackContext callbackContext) : 74*3e777be0SXin Li m_Model(model), 75*3e777be0SXin Li m_MemPools(pMemPools), 76*3e777be0SXin Li m_InputTensors(inputTensors), 77*3e777be0SXin Li m_OutputTensors(outputTensors), 78*3e777be0SXin Li m_CallbackContext(callbackContext) 79*3e777be0SXin Li {} 80*3e777be0SXin Li 81*3e777be0SXin Li void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override; 82*3e777be0SXin Li 83*3e777be0SXin Li ArmnnPreparedModel<HalVersion>* m_Model; 84*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools; 85*3e777be0SXin Li std::shared_ptr<armnn::InputTensors> m_InputTensors; 86*3e777be0SXin Li std::shared_ptr<armnn::OutputTensors> m_OutputTensors; 87*3e777be0SXin Li CallbackContext m_CallbackContext; 88*3e777be0SXin Li }; 89*3e777be0SXin Li 90*3e777be0SXin Li template <typename TensorBindingCollection> 91*3e777be0SXin Li void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings); 92*3e777be0SXin Li 93*3e777be0SXin Li /// schedule the graph prepared from the request for execution 94*3e777be0SXin Li template<typename CallbackContext> 95*3e777be0SXin Li void ScheduleGraphForExecution( 96*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, 97*3e777be0SXin Li std::shared_ptr<armnn::InputTensors>& inputTensors, 98*3e777be0SXin Li std::shared_ptr<armnn::OutputTensors>& outputTensors, 99*3e777be0SXin Li CallbackContext m_CallbackContext); 100*3e777be0SXin Li 101*3e777be0SXin Li armnn::NetworkId m_NetworkId; 102*3e777be0SXin Li armnn::IRuntime* m_Runtime; 103*3e777be0SXin Li HalModel m_Model; 104*3e777be0SXin Li // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads 105*3e777be0SXin Li // It is specific to this class, so it is declared as static here 106*3e777be0SXin Li static RequestThread<ArmnnPreparedModel, 107*3e777be0SXin Li HalVersion, 108*3e777be0SXin Li CallbackContext_1_0> m_RequestThread; 109*3e777be0SXin Li uint32_t m_RequestCount; 110*3e777be0SXin Li const std::string& m_RequestInputsAndOutputsDumpDir; 111*3e777be0SXin Li const bool m_GpuProfilingEnabled; 112*3e777be0SXin Li // Static to allow sharing of threadpool between ArmnnPreparedModel instances 113*3e777be0SXin Li static std::unique_ptr<armnn::Threadpool> m_Threadpool; 114*3e777be0SXin Li std::shared_ptr<armnn::IWorkingMemHandle> m_WorkingMemHandle; 115*3e777be0SXin Li const bool m_AsyncModelExecutionEnabled; 116*3e777be0SXin Li const bool m_EnableImport; 117*3e777be0SXin Li const bool m_EnableExport; 118*3e777be0SXin Li }; 119*3e777be0SXin Li 120*3e777be0SXin Li } 121