1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "ArmnnDriver.hpp" 9 #include "ArmnnDriverImpl.hpp" 10 #include "RequestThread.hpp" 11 12 #include <NeuralNetworks.h> 13 #include <armnn/ArmNN.hpp> 14 #include <armnn/Threadpool.hpp> 15 16 #include <string> 17 #include <vector> 18 19 namespace armnn_driver 20 { 21 using armnnExecuteCallback_1_0 = std::function<void(V1_0::ErrorStatus status, std::string callingFunction)>; 22 23 struct ArmnnCallback_1_0 24 { 25 armnnExecuteCallback_1_0 callback; 26 }; 27 28 struct ExecutionContext_1_0 {}; 29 30 using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>; 31 32 template <typename HalVersion> 33 class ArmnnPreparedModel : public V1_0::IPreparedModel 34 { 35 public: 36 using HalModel = typename HalVersion::Model; 37 38 ArmnnPreparedModel(armnn::NetworkId networkId, 39 armnn::IRuntime* runtime, 40 const HalModel& model, 41 const std::string& requestInputsAndOutputsDumpDir, 42 const bool gpuProfilingEnabled, 43 const bool asyncModelExecutionEnabled = false, 44 const unsigned int numberOfThreads = 1, 45 const bool importEnabled = false, 46 const bool exportEnabled = false); 47 48 virtual ~ArmnnPreparedModel(); 49 50 virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request, 51 const ::android::sp<V1_0::IExecutionCallback>& callback) override; 52 53 /// execute the graph prepared from the request 54 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, 55 armnn::InputTensors& inputTensors, 56 armnn::OutputTensors& outputTensors, 57 CallbackContext_1_0 callback); 58 59 /// Executes this model with dummy inputs (e.g. all zeroes). 60 /// \return false on failure, otherwise true 61 bool ExecuteWithDummyInputs(); 62 63 private: 64 65 template<typename CallbackContext> 66 class ArmnnThreadPoolCallback : public armnn::IAsyncExecutionCallback 67 { 68 public: ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion> * model,std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext)69 ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion>* model, 70 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, 71 std::shared_ptr<armnn::InputTensors>& inputTensors, 72 std::shared_ptr<armnn::OutputTensors>& outputTensors, 73 CallbackContext callbackContext) : 74 m_Model(model), 75 m_MemPools(pMemPools), 76 m_InputTensors(inputTensors), 77 m_OutputTensors(outputTensors), 78 m_CallbackContext(callbackContext) 79 {} 80 81 void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override; 82 83 ArmnnPreparedModel<HalVersion>* m_Model; 84 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools; 85 std::shared_ptr<armnn::InputTensors> m_InputTensors; 86 std::shared_ptr<armnn::OutputTensors> m_OutputTensors; 87 CallbackContext m_CallbackContext; 88 }; 89 90 template <typename TensorBindingCollection> 91 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings); 92 93 /// schedule the graph prepared from the request for execution 94 template<typename CallbackContext> 95 void ScheduleGraphForExecution( 96 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, 97 std::shared_ptr<armnn::InputTensors>& inputTensors, 98 std::shared_ptr<armnn::OutputTensors>& outputTensors, 99 CallbackContext m_CallbackContext); 100 101 armnn::NetworkId m_NetworkId; 102 armnn::IRuntime* m_Runtime; 103 HalModel m_Model; 104 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads 105 // It is specific to this class, so it is declared as static here 106 static RequestThread<ArmnnPreparedModel, 107 HalVersion, 108 CallbackContext_1_0> m_RequestThread; 109 uint32_t m_RequestCount; 110 const std::string& m_RequestInputsAndOutputsDumpDir; 111 const bool m_GpuProfilingEnabled; 112 // Static to allow sharing of threadpool between ArmnnPreparedModel instances 113 static std::unique_ptr<armnn::Threadpool> m_Threadpool; 114 std::shared_ptr<armnn::IWorkingMemHandle> m_WorkingMemHandle; 115 const bool m_AsyncModelExecutionEnabled; 116 const bool m_EnableImport; 117 const bool m_EnableExport; 118 }; 119 120 } 121