1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "ArmnnDriver.hpp" 9 #include "ArmnnDriverImpl.hpp" 10 #include "ModelToINetworkTransformer.hpp" 11 12 #include <armnn/ArmNN.hpp> 13 14 #include <BufferTracker.h> 15 #include <CpuExecutor.h> 16 #include <nnapi/IExecution.h> 17 #include <nnapi/IPreparedModel.h> 18 #include <nnapi/Result.h> 19 #include <nnapi/Types.h> 20 21 #include <memory> 22 #include <tuple> 23 #include <utility> 24 #include <vector> 25 #include <string> 26 27 namespace armnn_driver 28 { 29 struct CanonicalExecutionContext 30 { 31 ::android::nn::MeasureTiming measureTimings = 32 ::android::nn::MeasureTiming::NO; 33 android::nn::TimePoint driverStart; 34 android::nn::TimePoint driverEnd; 35 android::nn::TimePoint deviceStart; 36 android::nn::TimePoint deviceEnd; 37 }; 38 class ArmnnPreparedModel final : public IPreparedModel, 39 public std::enable_shared_from_this<ArmnnPreparedModel> 40 { 41 public: 42 ArmnnPreparedModel(armnn::NetworkId networkId, 43 armnn::IRuntime* runtime, 44 const Model& model, 45 const std::string& requestInputsAndOutputsDumpDir, 46 const bool gpuProfilingEnabled, 47 Priority priority = Priority::MEDIUM); 48 49 ArmnnPreparedModel(armnn::NetworkId networkId, 50 armnn::IRuntime* runtime, 51 const std::string& requestInputsAndOutputsDumpDir, 52 const bool gpuProfilingEnabled, 53 Priority priority = Priority::MEDIUM, 54 const bool prepareModelFromCache = false); 55 56 virtual ~ArmnnPreparedModel(); 57 58 ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> execute( 59 const Request& request, 60 MeasureTiming measureTiming, 61 const OptionalTimePoint& deadline, 62 const OptionalDuration& loopTimeoutDuration, 63 const std::vector<android::nn::TokenValuePair>& hints, 64 const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override; 65 66 GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>> executeFenced( 67 const Request& request, 68 const std::vector<SyncFence>& waitFor, 69 MeasureTiming measureTiming, 70 const OptionalTimePoint& deadline, 71 const OptionalDuration& loopTimeoutDuration, 72 const OptionalDuration& timeoutDurationAfterFence, 73 const std::vector<android::nn::TokenValuePair>& hints, 74 const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override; 75 76 GeneralResult<android::nn::SharedExecution> createReusableExecution( 77 const Request& request, 78 MeasureTiming measureTiming, 79 const OptionalDuration& loopTimeoutDuration, 80 const std::vector<android::nn::TokenValuePair>& hints, 81 const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override; 82 83 GeneralResult<SharedBurst> configureExecutionBurst() const override; 84 85 std::any getUnderlyingResource() const override; 86 87 /// execute the graph prepared from the request 88 ErrorStatus ExecuteGraph( 89 std::shared_ptr<std::vector<android::nn::RunTimePoolInfo>>& pMemPools, 90 armnn::InputTensors& inputTensors, 91 armnn::OutputTensors& outputTensors, 92 CanonicalExecutionContext callback, 93 const bool pointerMemory = false) const; 94 95 Priority GetModelPriority() const; 96 97 /// Executes this model with dummy inputs (e.g. all zeroes). 98 /// \return false on failure, otherwise true 99 bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs) const; 100 101 private: 102 void Init(); 103 ErrorStatus PrepareMemoryForInputs( 104 armnn::InputTensors& inputs, 105 const Request& request, 106 const std::vector<android::nn::RunTimePoolInfo>& memPools) const; 107 108 ErrorStatus PrepareMemoryForOutputs( 109 armnn::OutputTensors& outputs, 110 std::vector<OutputShape> &outputShapes, 111 const Request& request, 112 const std::vector<android::nn::RunTimePoolInfo>& memPools) const; 113 114 ErrorStatus PrepareMemoryForIO(armnn::InputTensors& inputs, 115 armnn::OutputTensors& outputs, 116 std::vector<android::nn::RunTimePoolInfo>& memPools, 117 const Request& request, 118 const bool pointerMemory = false) const; 119 120 template <typename TensorBindingCollection> 121 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings) const; 122 123 /// schedule the graph prepared from the request for execution 124 armnn::NetworkId m_NetworkId; 125 armnn::IRuntime* m_Runtime; 126 127 const Model m_Model; 128 const std::string& m_RequestInputsAndOutputsDumpDir; 129 const bool m_GpuProfilingEnabled; 130 Priority m_ModelPriority; 131 const bool m_PrepareFromCache; 132 }; 133 134 } 135