1*3e777be0SXin Li // 2*3e777be0SXin Li // Copyright © 2017 Arm Ltd. All rights reserved. 3*3e777be0SXin Li // SPDX-License-Identifier: MIT 4*3e777be0SXin Li // 5*3e777be0SXin Li 6*3e777be0SXin Li #pragma once 7*3e777be0SXin Li 8*3e777be0SXin Li #include "ArmnnDriver.hpp" 9*3e777be0SXin Li #include "ArmnnDriverImpl.hpp" 10*3e777be0SXin Li #include "RequestThread.hpp" 11*3e777be0SXin Li #include "ModelToINetworkConverter.hpp" 12*3e777be0SXin Li 13*3e777be0SXin Li #include <NeuralNetworks.h> 14*3e777be0SXin Li #include <armnn/ArmNN.hpp> 15*3e777be0SXin Li #include <armnn/Threadpool.hpp> 16*3e777be0SXin Li 17*3e777be0SXin Li #include <string> 18*3e777be0SXin Li #include <vector> 19*3e777be0SXin Li 20*3e777be0SXin Li namespace armnn_driver 21*3e777be0SXin Li { 22*3e777be0SXin Li 23*3e777be0SXin Li using CallbackAsync_1_2 = std::function< 24*3e777be0SXin Li void(V1_0::ErrorStatus errorStatus, 25*3e777be0SXin Li std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes, 26*3e777be0SXin Li const ::android::hardware::neuralnetworks::V1_2::Timing& timing, 27*3e777be0SXin Li std::string callingFunction)>; 28*3e777be0SXin Li 29*3e777be0SXin Li struct ExecutionContext_1_2 30*3e777be0SXin Li { 31*3e777be0SXin Li ::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings = 32*3e777be0SXin Li ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO; 33*3e777be0SXin Li TimePoint driverStart; 34*3e777be0SXin Li }; 35*3e777be0SXin Li 36*3e777be0SXin Li using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>; 37*3e777be0SXin Li 38*3e777be0SXin Li template <typename HalVersion> 39*3e777be0SXin Li class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel 40*3e777be0SXin Li { 41*3e777be0SXin Li public: 42*3e777be0SXin Li using HalModel = typename V1_2::Model; 43*3e777be0SXin Li 44*3e777be0SXin Li ArmnnPreparedModel_1_2(armnn::NetworkId networkId, 45*3e777be0SXin Li armnn::IRuntime* runtime, 46*3e777be0SXin Li const HalModel& model, 47*3e777be0SXin Li const std::string& requestInputsAndOutputsDumpDir, 48*3e777be0SXin Li const bool gpuProfilingEnabled, 49*3e777be0SXin Li const bool asyncModelExecutionEnabled = false, 50*3e777be0SXin Li const unsigned int numberOfThreads = 1, 51*3e777be0SXin Li const bool importEnabled = false, 52*3e777be0SXin Li const bool exportEnabled = false); 53*3e777be0SXin Li 54*3e777be0SXin Li ArmnnPreparedModel_1_2(armnn::NetworkId networkId, 55*3e777be0SXin Li armnn::IRuntime* runtime, 56*3e777be0SXin Li const std::string& requestInputsAndOutputsDumpDir, 57*3e777be0SXin Li const bool gpuProfilingEnabled, 58*3e777be0SXin Li const bool asyncModelExecutionEnabled = false, 59*3e777be0SXin Li const unsigned int numberOfThreads = 1, 60*3e777be0SXin Li const bool importEnabled = false, 61*3e777be0SXin Li const bool exportEnabled = false, 62*3e777be0SXin Li const bool preparedFromCache = false); 63*3e777be0SXin Li 64*3e777be0SXin Li virtual ~ArmnnPreparedModel_1_2(); 65*3e777be0SXin Li 66*3e777be0SXin Li virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request, 67*3e777be0SXin Li const ::android::sp<V1_0::IExecutionCallback>& callback) override; 68*3e777be0SXin Li 69*3e777be0SXin Li virtual Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, V1_2::MeasureTiming measure, 70*3e777be0SXin Li const ::android::sp<V1_2::IExecutionCallback>& callback) override; 71*3e777be0SXin Li 72*3e777be0SXin Li virtual Return<void> executeSynchronously(const V1_0::Request &request, 73*3e777be0SXin Li V1_2::MeasureTiming measure, 74*3e777be0SXin Li V1_2::IPreparedModel::executeSynchronously_cb cb) override; 75*3e777be0SXin Li 76*3e777be0SXin Li virtual Return<void> configureExecutionBurst( 77*3e777be0SXin Li const ::android::sp<V1_2::IBurstCallback>& callback, 78*3e777be0SXin Li const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel, 79*3e777be0SXin Li const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel, 80*3e777be0SXin Li configureExecutionBurst_cb cb) override; 81*3e777be0SXin Li 82*3e777be0SXin Li /// execute the graph prepared from the request 83*3e777be0SXin Li template<typename CallbackContext> 84*3e777be0SXin Li bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, 85*3e777be0SXin Li armnn::InputTensors& inputTensors, 86*3e777be0SXin Li armnn::OutputTensors& outputTensors, 87*3e777be0SXin Li CallbackContext callback); 88*3e777be0SXin Li 89*3e777be0SXin Li /// Executes this model with dummy inputs (e.g. all zeroes). 90*3e777be0SXin Li /// \return false on failure, otherwise true 91*3e777be0SXin Li bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs); 92*3e777be0SXin Li 93*3e777be0SXin Li private: 94*3e777be0SXin Li 95*3e777be0SXin Li template<typename CallbackContext> 96*3e777be0SXin Li class ArmnnThreadPoolCallback_1_2 : public armnn::IAsyncExecutionCallback 97*3e777be0SXin Li { 98*3e777be0SXin Li public: ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion> * model,std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,std::vector<V1_2::OutputShape> outputShapes,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext)99*3e777be0SXin Li ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion>* model, 100*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, 101*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes, 102*3e777be0SXin Li std::shared_ptr<armnn::InputTensors>& inputTensors, 103*3e777be0SXin Li std::shared_ptr<armnn::OutputTensors>& outputTensors, 104*3e777be0SXin Li CallbackContext callbackContext) : 105*3e777be0SXin Li m_Model(model), 106*3e777be0SXin Li m_MemPools(pMemPools), 107*3e777be0SXin Li m_OutputShapes(outputShapes), 108*3e777be0SXin Li m_InputTensors(inputTensors), 109*3e777be0SXin Li m_OutputTensors(outputTensors), 110*3e777be0SXin Li m_CallbackContext(callbackContext) 111*3e777be0SXin Li {} 112*3e777be0SXin Li 113*3e777be0SXin Li void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override; 114*3e777be0SXin Li 115*3e777be0SXin Li ArmnnPreparedModel_1_2<HalVersion>* m_Model; 116*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools; 117*3e777be0SXin Li std::vector<V1_2::OutputShape> m_OutputShapes; 118*3e777be0SXin Li std::shared_ptr<armnn::InputTensors> m_InputTensors; 119*3e777be0SXin Li std::shared_ptr<armnn::OutputTensors> m_OutputTensors; 120*3e777be0SXin Li CallbackContext m_CallbackContext; 121*3e777be0SXin Li }; 122*3e777be0SXin Li 123*3e777be0SXin Li Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request, 124*3e777be0SXin Li V1_2::MeasureTiming measureTiming, 125*3e777be0SXin Li CallbackAsync_1_2 callback); 126*3e777be0SXin Li 127*3e777be0SXin Li Return<V1_0::ErrorStatus> PrepareMemoryForInputs( 128*3e777be0SXin Li armnn::InputTensors& inputs, 129*3e777be0SXin Li const V1_0::Request& request, 130*3e777be0SXin Li const std::vector<android::nn::RunTimePoolInfo>& memPools); 131*3e777be0SXin Li 132*3e777be0SXin Li Return<V1_0::ErrorStatus> PrepareMemoryForOutputs( 133*3e777be0SXin Li armnn::OutputTensors& outputs, 134*3e777be0SXin Li std::vector<V1_2::OutputShape> &outputShapes, 135*3e777be0SXin Li const V1_0::Request& request, 136*3e777be0SXin Li const std::vector<android::nn::RunTimePoolInfo>& memPools); 137*3e777be0SXin Li 138*3e777be0SXin Li Return <V1_0::ErrorStatus> PrepareMemoryForIO( 139*3e777be0SXin Li armnn::InputTensors& inputs, 140*3e777be0SXin Li armnn::OutputTensors& outputs, 141*3e777be0SXin Li std::vector<android::nn::RunTimePoolInfo>& memPools, 142*3e777be0SXin Li const V1_0::Request& request, 143*3e777be0SXin Li CallbackAsync_1_2 callback); 144*3e777be0SXin Li 145*3e777be0SXin Li template <typename TensorBindingCollection> 146*3e777be0SXin Li void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings); 147*3e777be0SXin Li 148*3e777be0SXin Li /// schedule the graph prepared from the request for execution 149*3e777be0SXin Li template<typename CallbackContext> 150*3e777be0SXin Li void ScheduleGraphForExecution( 151*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, 152*3e777be0SXin Li std::shared_ptr<armnn::InputTensors>& inputTensors, 153*3e777be0SXin Li std::shared_ptr<armnn::OutputTensors>& outputTensors, 154*3e777be0SXin Li CallbackContext m_CallbackContext); 155*3e777be0SXin Li 156*3e777be0SXin Li armnn::NetworkId m_NetworkId; 157*3e777be0SXin Li armnn::IRuntime* m_Runtime; 158*3e777be0SXin Li V1_2::Model m_Model; 159*3e777be0SXin Li // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads 160*3e777be0SXin Li // It is specific to this class, so it is declared as static here 161*3e777be0SXin Li static RequestThread<ArmnnPreparedModel_1_2, 162*3e777be0SXin Li HalVersion, 163*3e777be0SXin Li CallbackContext_1_2> m_RequestThread; 164*3e777be0SXin Li uint32_t m_RequestCount; 165*3e777be0SXin Li const std::string& m_RequestInputsAndOutputsDumpDir; 166*3e777be0SXin Li const bool m_GpuProfilingEnabled; 167*3e777be0SXin Li // Static to allow sharing of threadpool between ArmnnPreparedModel instances 168*3e777be0SXin Li static std::unique_ptr<armnn::Threadpool> m_Threadpool; 169*3e777be0SXin Li std::shared_ptr<IWorkingMemHandle> m_WorkingMemHandle; 170*3e777be0SXin Li const bool m_AsyncModelExecutionEnabled; 171*3e777be0SXin Li const bool m_EnableImport; 172*3e777be0SXin Li const bool m_EnableExport; 173*3e777be0SXin Li const bool m_PreparedFromCache; 174*3e777be0SXin Li }; 175*3e777be0SXin Li 176*3e777be0SXin Li } 177