xref: /aosp_15_r20/external/android-nn-driver/ArmnnPreparedModel_1_2.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ArmnnDriver.hpp"
9 #include "ArmnnDriverImpl.hpp"
10 #include "RequestThread.hpp"
11 #include "ModelToINetworkConverter.hpp"
12 
13 #include <NeuralNetworks.h>
14 #include <armnn/ArmNN.hpp>
15 #include <armnn/Threadpool.hpp>
16 
17 #include <string>
18 #include <vector>
19 
20 namespace armnn_driver
21 {
22 
23 using CallbackAsync_1_2 = std::function<
24                                 void(V1_0::ErrorStatus errorStatus,
25                                      std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
26                                      const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
27                                      std::string callingFunction)>;
28 
29 struct ExecutionContext_1_2
30 {
31     ::android::hardware::neuralnetworks::V1_2::MeasureTiming    measureTimings =
32         ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
33     TimePoint driverStart;
34 };
35 
36 using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>;
37 
38 template <typename HalVersion>
39 class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
40 {
41 public:
42     using HalModel = typename V1_2::Model;
43 
44     ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
45                            armnn::IRuntime* runtime,
46                            const HalModel& model,
47                            const std::string& requestInputsAndOutputsDumpDir,
48                            const bool gpuProfilingEnabled,
49                            const bool asyncModelExecutionEnabled = false,
50                            const unsigned int numberOfThreads = 1,
51                            const bool importEnabled = false,
52                            const bool exportEnabled = false);
53 
54     ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
55                            armnn::IRuntime* runtime,
56                            const std::string& requestInputsAndOutputsDumpDir,
57                            const bool gpuProfilingEnabled,
58                            const bool asyncModelExecutionEnabled = false,
59                            const unsigned int numberOfThreads = 1,
60                            const bool importEnabled = false,
61                            const bool exportEnabled = false,
62                            const bool preparedFromCache = false);
63 
64     virtual ~ArmnnPreparedModel_1_2();
65 
66     virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
67                                               const ::android::sp<V1_0::IExecutionCallback>& callback) override;
68 
69     virtual Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, V1_2::MeasureTiming measure,
70                                                   const ::android::sp<V1_2::IExecutionCallback>& callback) override;
71 
72     virtual Return<void> executeSynchronously(const V1_0::Request &request,
73                                               V1_2::MeasureTiming measure,
74                                               V1_2::IPreparedModel::executeSynchronously_cb cb) override;
75 
76     virtual Return<void> configureExecutionBurst(
77             const ::android::sp<V1_2::IBurstCallback>& callback,
78             const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
79             const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
80             configureExecutionBurst_cb cb) override;
81 
82     /// execute the graph prepared from the request
83     template<typename CallbackContext>
84     bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
85                       armnn::InputTensors& inputTensors,
86                       armnn::OutputTensors& outputTensors,
87                       CallbackContext callback);
88 
89     /// Executes this model with dummy inputs (e.g. all zeroes).
90     /// \return false on failure, otherwise true
91     bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs);
92 
93 private:
94 
95     template<typename CallbackContext>
96     class ArmnnThreadPoolCallback_1_2 : public armnn::IAsyncExecutionCallback
97     {
98     public:
ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion> * model,std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,std::vector<V1_2::OutputShape> outputShapes,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext)99         ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion>* model,
100                                     std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
101                                     std::vector<V1_2::OutputShape> outputShapes,
102                                     std::shared_ptr<armnn::InputTensors>& inputTensors,
103                                     std::shared_ptr<armnn::OutputTensors>& outputTensors,
104                                     CallbackContext callbackContext) :
105                 m_Model(model),
106                 m_MemPools(pMemPools),
107                 m_OutputShapes(outputShapes),
108                 m_InputTensors(inputTensors),
109                 m_OutputTensors(outputTensors),
110                 m_CallbackContext(callbackContext)
111         {}
112 
113         void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
114 
115         ArmnnPreparedModel_1_2<HalVersion>* m_Model;
116         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
117         std::vector<V1_2::OutputShape> m_OutputShapes;
118         std::shared_ptr<armnn::InputTensors> m_InputTensors;
119         std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
120         CallbackContext m_CallbackContext;
121     };
122 
123     Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
124                                       V1_2::MeasureTiming measureTiming,
125                                       CallbackAsync_1_2 callback);
126 
127     Return<V1_0::ErrorStatus> PrepareMemoryForInputs(
128             armnn::InputTensors& inputs,
129             const V1_0::Request& request,
130             const std::vector<android::nn::RunTimePoolInfo>& memPools);
131 
132     Return<V1_0::ErrorStatus> PrepareMemoryForOutputs(
133             armnn::OutputTensors& outputs,
134             std::vector<V1_2::OutputShape> &outputShapes,
135             const V1_0::Request& request,
136             const std::vector<android::nn::RunTimePoolInfo>& memPools);
137 
138     Return <V1_0::ErrorStatus> PrepareMemoryForIO(
139             armnn::InputTensors& inputs,
140             armnn::OutputTensors& outputs,
141             std::vector<android::nn::RunTimePoolInfo>& memPools,
142             const V1_0::Request& request,
143             CallbackAsync_1_2 callback);
144 
145     template <typename TensorBindingCollection>
146     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
147 
148     /// schedule the graph prepared from the request for execution
149     template<typename CallbackContext>
150     void ScheduleGraphForExecution(
151             std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
152             std::shared_ptr<armnn::InputTensors>& inputTensors,
153             std::shared_ptr<armnn::OutputTensors>& outputTensors,
154             CallbackContext m_CallbackContext);
155 
156     armnn::NetworkId                          m_NetworkId;
157     armnn::IRuntime*                          m_Runtime;
158     V1_2::Model                               m_Model;
159     // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
160     // It is specific to this class, so it is declared as static here
161     static RequestThread<ArmnnPreparedModel_1_2,
162                          HalVersion,
163                          CallbackContext_1_2> m_RequestThread;
164     uint32_t                                  m_RequestCount;
165     const std::string&                        m_RequestInputsAndOutputsDumpDir;
166     const bool                                m_GpuProfilingEnabled;
167     // Static to allow sharing of threadpool between ArmnnPreparedModel instances
168     static std::unique_ptr<armnn::Threadpool> m_Threadpool;
169     std::shared_ptr<IWorkingMemHandle>        m_WorkingMemHandle;
170     const bool                                m_AsyncModelExecutionEnabled;
171     const bool                                m_EnableImport;
172     const bool                                m_EnableExport;
173     const bool                                m_PreparedFromCache;
174 };
175 
176 }
177