1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnDriver.hpp" 9*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnDriverImpl.hpp" 10*89c4ff92SAndroid Build Coastguard Worker #include "ModelToINetworkTransformer.hpp" 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp> 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker #include <BufferTracker.h> 15*89c4ff92SAndroid Build Coastguard Worker #include <CpuExecutor.h> 16*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/IExecution.h> 17*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/IPreparedModel.h> 18*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Result.h> 19*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Types.h> 20*89c4ff92SAndroid Build Coastguard Worker 21*89c4ff92SAndroid Build Coastguard Worker #include <memory> 22*89c4ff92SAndroid Build Coastguard Worker #include <tuple> 23*89c4ff92SAndroid Build Coastguard Worker #include <utility> 24*89c4ff92SAndroid Build Coastguard Worker #include <vector> 25*89c4ff92SAndroid Build Coastguard Worker #include <string> 26*89c4ff92SAndroid Build Coastguard Worker 27*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver 28*89c4ff92SAndroid Build Coastguard Worker { 29*89c4ff92SAndroid Build Coastguard Worker struct CanonicalExecutionContext 30*89c4ff92SAndroid Build Coastguard Worker { 31*89c4ff92SAndroid Build Coastguard Worker ::android::nn::MeasureTiming measureTimings = 32*89c4ff92SAndroid Build Coastguard Worker ::android::nn::MeasureTiming::NO; 33*89c4ff92SAndroid Build Coastguard Worker android::nn::TimePoint driverStart; 34*89c4ff92SAndroid Build Coastguard Worker android::nn::TimePoint driverEnd; 35*89c4ff92SAndroid Build Coastguard Worker android::nn::TimePoint deviceStart; 36*89c4ff92SAndroid Build Coastguard Worker android::nn::TimePoint deviceEnd; 37*89c4ff92SAndroid Build Coastguard Worker }; 38*89c4ff92SAndroid Build Coastguard Worker class ArmnnPreparedModel final : public IPreparedModel, 39*89c4ff92SAndroid Build Coastguard Worker public std::enable_shared_from_this<ArmnnPreparedModel> 40*89c4ff92SAndroid Build Coastguard Worker { 41*89c4ff92SAndroid Build Coastguard Worker public: 42*89c4ff92SAndroid Build Coastguard Worker ArmnnPreparedModel(armnn::NetworkId networkId, 43*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* runtime, 44*89c4ff92SAndroid Build Coastguard Worker const Model& model, 45*89c4ff92SAndroid Build Coastguard Worker const std::string& requestInputsAndOutputsDumpDir, 46*89c4ff92SAndroid Build Coastguard Worker const bool gpuProfilingEnabled, 47*89c4ff92SAndroid Build Coastguard Worker Priority priority = Priority::MEDIUM); 48*89c4ff92SAndroid Build Coastguard Worker 49*89c4ff92SAndroid Build Coastguard Worker ArmnnPreparedModel(armnn::NetworkId networkId, 50*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* runtime, 51*89c4ff92SAndroid Build Coastguard Worker const std::string& requestInputsAndOutputsDumpDir, 52*89c4ff92SAndroid Build Coastguard Worker const bool gpuProfilingEnabled, 53*89c4ff92SAndroid Build Coastguard Worker Priority priority = Priority::MEDIUM, 54*89c4ff92SAndroid Build Coastguard Worker const bool prepareModelFromCache = false); 55*89c4ff92SAndroid Build Coastguard Worker 56*89c4ff92SAndroid Build Coastguard Worker virtual ~ArmnnPreparedModel(); 57*89c4ff92SAndroid Build Coastguard Worker 58*89c4ff92SAndroid Build Coastguard Worker ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> execute( 59*89c4ff92SAndroid Build Coastguard Worker const Request& request, 60*89c4ff92SAndroid Build Coastguard Worker MeasureTiming measureTiming, 61*89c4ff92SAndroid Build Coastguard Worker const OptionalTimePoint& deadline, 62*89c4ff92SAndroid Build Coastguard Worker const OptionalDuration& loopTimeoutDuration, 63*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::TokenValuePair>& hints, 64*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override; 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>> executeFenced( 67*89c4ff92SAndroid Build Coastguard Worker const Request& request, 68*89c4ff92SAndroid Build Coastguard Worker const std::vector<SyncFence>& waitFor, 69*89c4ff92SAndroid Build Coastguard Worker MeasureTiming measureTiming, 70*89c4ff92SAndroid Build Coastguard Worker const OptionalTimePoint& deadline, 71*89c4ff92SAndroid Build Coastguard Worker const OptionalDuration& loopTimeoutDuration, 72*89c4ff92SAndroid Build Coastguard Worker const OptionalDuration& timeoutDurationAfterFence, 73*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::TokenValuePair>& hints, 74*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override; 75*89c4ff92SAndroid Build Coastguard Worker 76*89c4ff92SAndroid Build Coastguard Worker GeneralResult<android::nn::SharedExecution> createReusableExecution( 77*89c4ff92SAndroid Build Coastguard Worker const Request& request, 78*89c4ff92SAndroid Build Coastguard Worker MeasureTiming measureTiming, 79*89c4ff92SAndroid Build Coastguard Worker const OptionalDuration& loopTimeoutDuration, 80*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::TokenValuePair>& hints, 81*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override; 82*89c4ff92SAndroid Build Coastguard Worker 83*89c4ff92SAndroid Build Coastguard Worker GeneralResult<SharedBurst> configureExecutionBurst() const override; 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker std::any getUnderlyingResource() const override; 86*89c4ff92SAndroid Build Coastguard Worker 87*89c4ff92SAndroid Build Coastguard Worker /// execute the graph prepared from the request 88*89c4ff92SAndroid Build Coastguard Worker ErrorStatus ExecuteGraph( 89*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<std::vector<android::nn::RunTimePoolInfo>>& pMemPools, 90*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors& inputTensors, 91*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors& outputTensors, 92*89c4ff92SAndroid Build Coastguard Worker CanonicalExecutionContext callback, 93*89c4ff92SAndroid Build Coastguard Worker const bool pointerMemory = false) const; 94*89c4ff92SAndroid Build Coastguard Worker 95*89c4ff92SAndroid Build Coastguard Worker Priority GetModelPriority() const; 96*89c4ff92SAndroid Build Coastguard Worker 97*89c4ff92SAndroid Build Coastguard Worker /// Executes this model with dummy inputs (e.g. all zeroes). 98*89c4ff92SAndroid Build Coastguard Worker /// \return false on failure, otherwise true 99*89c4ff92SAndroid Build Coastguard Worker bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs) const; 100*89c4ff92SAndroid Build Coastguard Worker 101*89c4ff92SAndroid Build Coastguard Worker private: 102*89c4ff92SAndroid Build Coastguard Worker void Init(); 103*89c4ff92SAndroid Build Coastguard Worker ErrorStatus PrepareMemoryForInputs( 104*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors& inputs, 105*89c4ff92SAndroid Build Coastguard Worker const Request& request, 106*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::RunTimePoolInfo>& memPools) const; 107*89c4ff92SAndroid Build Coastguard Worker 108*89c4ff92SAndroid Build Coastguard Worker ErrorStatus PrepareMemoryForOutputs( 109*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors& outputs, 110*89c4ff92SAndroid Build Coastguard Worker std::vector<OutputShape> &outputShapes, 111*89c4ff92SAndroid Build Coastguard Worker const Request& request, 112*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::RunTimePoolInfo>& memPools) const; 113*89c4ff92SAndroid Build Coastguard Worker 114*89c4ff92SAndroid Build Coastguard Worker ErrorStatus PrepareMemoryForIO(armnn::InputTensors& inputs, 115*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors& outputs, 116*89c4ff92SAndroid Build Coastguard Worker std::vector<android::nn::RunTimePoolInfo>& memPools, 117*89c4ff92SAndroid Build Coastguard Worker const Request& request, 118*89c4ff92SAndroid Build Coastguard Worker const bool pointerMemory = false) const; 119*89c4ff92SAndroid Build Coastguard Worker 120*89c4ff92SAndroid Build Coastguard Worker template <typename TensorBindingCollection> 121*89c4ff92SAndroid Build Coastguard Worker void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings) const; 122*89c4ff92SAndroid Build Coastguard Worker 123*89c4ff92SAndroid Build Coastguard Worker /// schedule the graph prepared from the request for execution 124*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId m_NetworkId; 125*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* m_Runtime; 126*89c4ff92SAndroid Build Coastguard Worker 127*89c4ff92SAndroid Build Coastguard Worker const Model m_Model; 128*89c4ff92SAndroid Build Coastguard Worker const std::string& m_RequestInputsAndOutputsDumpDir; 129*89c4ff92SAndroid Build Coastguard Worker const bool m_GpuProfilingEnabled; 130*89c4ff92SAndroid Build Coastguard Worker Priority m_ModelPriority; 131*89c4ff92SAndroid Build Coastguard Worker const bool m_PrepareFromCache; 132*89c4ff92SAndroid Build Coastguard Worker }; 133*89c4ff92SAndroid Build Coastguard Worker 134*89c4ff92SAndroid Build Coastguard Worker } 135