xref: /aosp_15_r20/external/armnn/shim/sl/canonical/ArmnnPreparedModel.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ArmnnDriver.hpp"
9 #include "ArmnnDriverImpl.hpp"
10 #include "ModelToINetworkTransformer.hpp"
11 
12 #include <armnn/ArmNN.hpp>
13 
14 #include <BufferTracker.h>
15 #include <CpuExecutor.h>
16 #include <nnapi/IExecution.h>
17 #include <nnapi/IPreparedModel.h>
18 #include <nnapi/Result.h>
19 #include <nnapi/Types.h>
20 
21 #include <memory>
22 #include <tuple>
23 #include <utility>
24 #include <vector>
25 #include <string>
26 
27 namespace armnn_driver
28 {
29     struct CanonicalExecutionContext
30     {
31         ::android::nn::MeasureTiming    measureTimings =
32                 ::android::nn::MeasureTiming::NO;
33         android::nn::TimePoint driverStart;
34         android::nn::TimePoint driverEnd;
35         android::nn::TimePoint deviceStart;
36         android::nn::TimePoint deviceEnd;
37     };
38 class ArmnnPreparedModel final : public IPreparedModel,
39                                  public std::enable_shared_from_this<ArmnnPreparedModel>
40 {
41 public:
42     ArmnnPreparedModel(armnn::NetworkId networkId,
43                        armnn::IRuntime* runtime,
44                        const Model& model,
45                        const std::string& requestInputsAndOutputsDumpDir,
46                        const bool gpuProfilingEnabled,
47                        Priority priority = Priority::MEDIUM);
48 
49     ArmnnPreparedModel(armnn::NetworkId networkId,
50                        armnn::IRuntime* runtime,
51                        const std::string& requestInputsAndOutputsDumpDir,
52                        const bool gpuProfilingEnabled,
53                        Priority priority = Priority::MEDIUM,
54                        const bool prepareModelFromCache = false);
55 
56     virtual ~ArmnnPreparedModel();
57 
58     ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> execute(
59         const Request& request,
60         MeasureTiming measureTiming,
61         const OptionalTimePoint& deadline,
62         const OptionalDuration& loopTimeoutDuration,
63         const std::vector<android::nn::TokenValuePair>& hints,
64         const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
65 
66     GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>> executeFenced(
67         const Request& request,
68         const std::vector<SyncFence>& waitFor,
69         MeasureTiming measureTiming,
70         const OptionalTimePoint& deadline,
71         const OptionalDuration& loopTimeoutDuration,
72         const OptionalDuration& timeoutDurationAfterFence,
73         const std::vector<android::nn::TokenValuePair>& hints,
74         const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
75 
76     GeneralResult<android::nn::SharedExecution> createReusableExecution(
77         const Request& request,
78         MeasureTiming measureTiming,
79         const OptionalDuration& loopTimeoutDuration,
80         const std::vector<android::nn::TokenValuePair>& hints,
81         const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
82 
83     GeneralResult<SharedBurst> configureExecutionBurst() const override;
84 
85     std::any getUnderlyingResource() const override;
86 
87     /// execute the graph prepared from the request
88     ErrorStatus ExecuteGraph(
89         std::shared_ptr<std::vector<android::nn::RunTimePoolInfo>>& pMemPools,
90         armnn::InputTensors& inputTensors,
91         armnn::OutputTensors& outputTensors,
92         CanonicalExecutionContext  callback,
93         const bool pointerMemory = false) const;
94 
95     Priority GetModelPriority() const;
96 
97     /// Executes this model with dummy inputs (e.g. all zeroes).
98     /// \return false on failure, otherwise true
99     bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs) const;
100 
101 private:
102     void Init();
103     ErrorStatus PrepareMemoryForInputs(
104         armnn::InputTensors& inputs,
105         const Request& request,
106         const std::vector<android::nn::RunTimePoolInfo>& memPools) const;
107 
108     ErrorStatus PrepareMemoryForOutputs(
109         armnn::OutputTensors& outputs,
110         std::vector<OutputShape> &outputShapes,
111         const Request& request,
112         const std::vector<android::nn::RunTimePoolInfo>& memPools) const;
113 
114     ErrorStatus PrepareMemoryForIO(armnn::InputTensors& inputs,
115                                    armnn::OutputTensors& outputs,
116                                    std::vector<android::nn::RunTimePoolInfo>& memPools,
117                                    const Request& request,
118                                    const bool pointerMemory = false) const;
119 
120     template <typename TensorBindingCollection>
121     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings) const;
122 
123     /// schedule the graph prepared from the request for execution
124     armnn::NetworkId                        m_NetworkId;
125     armnn::IRuntime*                        m_Runtime;
126 
127     const Model                             m_Model;
128     const std::string&                      m_RequestInputsAndOutputsDumpDir;
129     const bool                              m_GpuProfilingEnabled;
130     Priority                                m_ModelPriority;
131     const bool                              m_PrepareFromCache;
132 };
133 
134 }
135