xref: /aosp_15_r20/external/armnn/shim/sl/canonical/ArmnnPreparedModel.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnDriver.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnDriverImpl.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "ModelToINetworkTransformer.hpp"
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <BufferTracker.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <CpuExecutor.h>
16*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/IExecution.h>
17*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/IPreparedModel.h>
18*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Result.h>
19*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Types.h>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker #include <memory>
22*89c4ff92SAndroid Build Coastguard Worker #include <tuple>
23*89c4ff92SAndroid Build Coastguard Worker #include <utility>
24*89c4ff92SAndroid Build Coastguard Worker #include <vector>
25*89c4ff92SAndroid Build Coastguard Worker #include <string>
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker     struct CanonicalExecutionContext
30*89c4ff92SAndroid Build Coastguard Worker     {
31*89c4ff92SAndroid Build Coastguard Worker         ::android::nn::MeasureTiming    measureTimings =
32*89c4ff92SAndroid Build Coastguard Worker                 ::android::nn::MeasureTiming::NO;
33*89c4ff92SAndroid Build Coastguard Worker         android::nn::TimePoint driverStart;
34*89c4ff92SAndroid Build Coastguard Worker         android::nn::TimePoint driverEnd;
35*89c4ff92SAndroid Build Coastguard Worker         android::nn::TimePoint deviceStart;
36*89c4ff92SAndroid Build Coastguard Worker         android::nn::TimePoint deviceEnd;
37*89c4ff92SAndroid Build Coastguard Worker     };
38*89c4ff92SAndroid Build Coastguard Worker class ArmnnPreparedModel final : public IPreparedModel,
39*89c4ff92SAndroid Build Coastguard Worker                                  public std::enable_shared_from_this<ArmnnPreparedModel>
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker public:
42*89c4ff92SAndroid Build Coastguard Worker     ArmnnPreparedModel(armnn::NetworkId networkId,
43*89c4ff92SAndroid Build Coastguard Worker                        armnn::IRuntime* runtime,
44*89c4ff92SAndroid Build Coastguard Worker                        const Model& model,
45*89c4ff92SAndroid Build Coastguard Worker                        const std::string& requestInputsAndOutputsDumpDir,
46*89c4ff92SAndroid Build Coastguard Worker                        const bool gpuProfilingEnabled,
47*89c4ff92SAndroid Build Coastguard Worker                        Priority priority = Priority::MEDIUM);
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     ArmnnPreparedModel(armnn::NetworkId networkId,
50*89c4ff92SAndroid Build Coastguard Worker                        armnn::IRuntime* runtime,
51*89c4ff92SAndroid Build Coastguard Worker                        const std::string& requestInputsAndOutputsDumpDir,
52*89c4ff92SAndroid Build Coastguard Worker                        const bool gpuProfilingEnabled,
53*89c4ff92SAndroid Build Coastguard Worker                        Priority priority = Priority::MEDIUM,
54*89c4ff92SAndroid Build Coastguard Worker                        const bool prepareModelFromCache = false);
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     virtual ~ArmnnPreparedModel();
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> execute(
59*89c4ff92SAndroid Build Coastguard Worker         const Request& request,
60*89c4ff92SAndroid Build Coastguard Worker         MeasureTiming measureTiming,
61*89c4ff92SAndroid Build Coastguard Worker         const OptionalTimePoint& deadline,
62*89c4ff92SAndroid Build Coastguard Worker         const OptionalDuration& loopTimeoutDuration,
63*89c4ff92SAndroid Build Coastguard Worker         const std::vector<android::nn::TokenValuePair>& hints,
64*89c4ff92SAndroid Build Coastguard Worker         const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>> executeFenced(
67*89c4ff92SAndroid Build Coastguard Worker         const Request& request,
68*89c4ff92SAndroid Build Coastguard Worker         const std::vector<SyncFence>& waitFor,
69*89c4ff92SAndroid Build Coastguard Worker         MeasureTiming measureTiming,
70*89c4ff92SAndroid Build Coastguard Worker         const OptionalTimePoint& deadline,
71*89c4ff92SAndroid Build Coastguard Worker         const OptionalDuration& loopTimeoutDuration,
72*89c4ff92SAndroid Build Coastguard Worker         const OptionalDuration& timeoutDurationAfterFence,
73*89c4ff92SAndroid Build Coastguard Worker         const std::vector<android::nn::TokenValuePair>& hints,
74*89c4ff92SAndroid Build Coastguard Worker         const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     GeneralResult<android::nn::SharedExecution> createReusableExecution(
77*89c4ff92SAndroid Build Coastguard Worker         const Request& request,
78*89c4ff92SAndroid Build Coastguard Worker         MeasureTiming measureTiming,
79*89c4ff92SAndroid Build Coastguard Worker         const OptionalDuration& loopTimeoutDuration,
80*89c4ff92SAndroid Build Coastguard Worker         const std::vector<android::nn::TokenValuePair>& hints,
81*89c4ff92SAndroid Build Coastguard Worker         const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker     GeneralResult<SharedBurst> configureExecutionBurst() const override;
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker     std::any getUnderlyingResource() const override;
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     /// execute the graph prepared from the request
88*89c4ff92SAndroid Build Coastguard Worker     ErrorStatus ExecuteGraph(
89*89c4ff92SAndroid Build Coastguard Worker         std::shared_ptr<std::vector<android::nn::RunTimePoolInfo>>& pMemPools,
90*89c4ff92SAndroid Build Coastguard Worker         armnn::InputTensors& inputTensors,
91*89c4ff92SAndroid Build Coastguard Worker         armnn::OutputTensors& outputTensors,
92*89c4ff92SAndroid Build Coastguard Worker         CanonicalExecutionContext  callback,
93*89c4ff92SAndroid Build Coastguard Worker         const bool pointerMemory = false) const;
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker     Priority GetModelPriority() const;
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker     /// Executes this model with dummy inputs (e.g. all zeroes).
98*89c4ff92SAndroid Build Coastguard Worker     /// \return false on failure, otherwise true
99*89c4ff92SAndroid Build Coastguard Worker     bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs) const;
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker private:
102*89c4ff92SAndroid Build Coastguard Worker     void Init();
103*89c4ff92SAndroid Build Coastguard Worker     ErrorStatus PrepareMemoryForInputs(
104*89c4ff92SAndroid Build Coastguard Worker         armnn::InputTensors& inputs,
105*89c4ff92SAndroid Build Coastguard Worker         const Request& request,
106*89c4ff92SAndroid Build Coastguard Worker         const std::vector<android::nn::RunTimePoolInfo>& memPools) const;
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker     ErrorStatus PrepareMemoryForOutputs(
109*89c4ff92SAndroid Build Coastguard Worker         armnn::OutputTensors& outputs,
110*89c4ff92SAndroid Build Coastguard Worker         std::vector<OutputShape> &outputShapes,
111*89c4ff92SAndroid Build Coastguard Worker         const Request& request,
112*89c4ff92SAndroid Build Coastguard Worker         const std::vector<android::nn::RunTimePoolInfo>& memPools) const;
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker     ErrorStatus PrepareMemoryForIO(armnn::InputTensors& inputs,
115*89c4ff92SAndroid Build Coastguard Worker                                    armnn::OutputTensors& outputs,
116*89c4ff92SAndroid Build Coastguard Worker                                    std::vector<android::nn::RunTimePoolInfo>& memPools,
117*89c4ff92SAndroid Build Coastguard Worker                                    const Request& request,
118*89c4ff92SAndroid Build Coastguard Worker                                    const bool pointerMemory = false) const;
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker     template <typename TensorBindingCollection>
121*89c4ff92SAndroid Build Coastguard Worker     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings) const;
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     /// schedule the graph prepared from the request for execution
124*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId                        m_NetworkId;
125*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntime*                        m_Runtime;
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker     const Model                             m_Model;
128*89c4ff92SAndroid Build Coastguard Worker     const std::string&                      m_RequestInputsAndOutputsDumpDir;
129*89c4ff92SAndroid Build Coastguard Worker     const bool                              m_GpuProfilingEnabled;
130*89c4ff92SAndroid Build Coastguard Worker     Priority                                m_ModelPriority;
131*89c4ff92SAndroid Build Coastguard Worker     const bool                              m_PrepareFromCache;
132*89c4ff92SAndroid Build Coastguard Worker };
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker }
135