xref: /aosp_15_r20/external/android-nn-driver/ArmnnPreparedModel_1_2.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017 Arm Ltd. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #pragma once
7*3e777be0SXin Li 
8*3e777be0SXin Li #include "ArmnnDriver.hpp"
9*3e777be0SXin Li #include "ArmnnDriverImpl.hpp"
10*3e777be0SXin Li #include "RequestThread.hpp"
11*3e777be0SXin Li #include "ModelToINetworkConverter.hpp"
12*3e777be0SXin Li 
13*3e777be0SXin Li #include <NeuralNetworks.h>
14*3e777be0SXin Li #include <armnn/ArmNN.hpp>
15*3e777be0SXin Li #include <armnn/Threadpool.hpp>
16*3e777be0SXin Li 
17*3e777be0SXin Li #include <string>
18*3e777be0SXin Li #include <vector>
19*3e777be0SXin Li 
20*3e777be0SXin Li namespace armnn_driver
21*3e777be0SXin Li {
22*3e777be0SXin Li 
23*3e777be0SXin Li using CallbackAsync_1_2 = std::function<
24*3e777be0SXin Li                                 void(V1_0::ErrorStatus errorStatus,
25*3e777be0SXin Li                                      std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
26*3e777be0SXin Li                                      const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
27*3e777be0SXin Li                                      std::string callingFunction)>;
28*3e777be0SXin Li 
29*3e777be0SXin Li struct ExecutionContext_1_2
30*3e777be0SXin Li {
31*3e777be0SXin Li     ::android::hardware::neuralnetworks::V1_2::MeasureTiming    measureTimings =
32*3e777be0SXin Li         ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
33*3e777be0SXin Li     TimePoint driverStart;
34*3e777be0SXin Li };
35*3e777be0SXin Li 
36*3e777be0SXin Li using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>;
37*3e777be0SXin Li 
38*3e777be0SXin Li template <typename HalVersion>
39*3e777be0SXin Li class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
40*3e777be0SXin Li {
41*3e777be0SXin Li public:
42*3e777be0SXin Li     using HalModel = typename V1_2::Model;
43*3e777be0SXin Li 
44*3e777be0SXin Li     ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
45*3e777be0SXin Li                            armnn::IRuntime* runtime,
46*3e777be0SXin Li                            const HalModel& model,
47*3e777be0SXin Li                            const std::string& requestInputsAndOutputsDumpDir,
48*3e777be0SXin Li                            const bool gpuProfilingEnabled,
49*3e777be0SXin Li                            const bool asyncModelExecutionEnabled = false,
50*3e777be0SXin Li                            const unsigned int numberOfThreads = 1,
51*3e777be0SXin Li                            const bool importEnabled = false,
52*3e777be0SXin Li                            const bool exportEnabled = false);
53*3e777be0SXin Li 
54*3e777be0SXin Li     ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
55*3e777be0SXin Li                            armnn::IRuntime* runtime,
56*3e777be0SXin Li                            const std::string& requestInputsAndOutputsDumpDir,
57*3e777be0SXin Li                            const bool gpuProfilingEnabled,
58*3e777be0SXin Li                            const bool asyncModelExecutionEnabled = false,
59*3e777be0SXin Li                            const unsigned int numberOfThreads = 1,
60*3e777be0SXin Li                            const bool importEnabled = false,
61*3e777be0SXin Li                            const bool exportEnabled = false,
62*3e777be0SXin Li                            const bool preparedFromCache = false);
63*3e777be0SXin Li 
64*3e777be0SXin Li     virtual ~ArmnnPreparedModel_1_2();
65*3e777be0SXin Li 
66*3e777be0SXin Li     virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
67*3e777be0SXin Li                                               const ::android::sp<V1_0::IExecutionCallback>& callback) override;
68*3e777be0SXin Li 
69*3e777be0SXin Li     virtual Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, V1_2::MeasureTiming measure,
70*3e777be0SXin Li                                                   const ::android::sp<V1_2::IExecutionCallback>& callback) override;
71*3e777be0SXin Li 
72*3e777be0SXin Li     virtual Return<void> executeSynchronously(const V1_0::Request &request,
73*3e777be0SXin Li                                               V1_2::MeasureTiming measure,
74*3e777be0SXin Li                                               V1_2::IPreparedModel::executeSynchronously_cb cb) override;
75*3e777be0SXin Li 
76*3e777be0SXin Li     virtual Return<void> configureExecutionBurst(
77*3e777be0SXin Li             const ::android::sp<V1_2::IBurstCallback>& callback,
78*3e777be0SXin Li             const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
79*3e777be0SXin Li             const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
80*3e777be0SXin Li             configureExecutionBurst_cb cb) override;
81*3e777be0SXin Li 
82*3e777be0SXin Li     /// execute the graph prepared from the request
83*3e777be0SXin Li     template<typename CallbackContext>
84*3e777be0SXin Li     bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
85*3e777be0SXin Li                       armnn::InputTensors& inputTensors,
86*3e777be0SXin Li                       armnn::OutputTensors& outputTensors,
87*3e777be0SXin Li                       CallbackContext callback);
88*3e777be0SXin Li 
89*3e777be0SXin Li     /// Executes this model with dummy inputs (e.g. all zeroes).
90*3e777be0SXin Li     /// \return false on failure, otherwise true
91*3e777be0SXin Li     bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs);
92*3e777be0SXin Li 
93*3e777be0SXin Li private:
94*3e777be0SXin Li 
95*3e777be0SXin Li     template<typename CallbackContext>
96*3e777be0SXin Li     class ArmnnThreadPoolCallback_1_2 : public armnn::IAsyncExecutionCallback
97*3e777be0SXin Li     {
98*3e777be0SXin Li     public:
ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion> * model,std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,std::vector<V1_2::OutputShape> outputShapes,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext)99*3e777be0SXin Li         ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion>* model,
100*3e777be0SXin Li                                     std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
101*3e777be0SXin Li                                     std::vector<V1_2::OutputShape> outputShapes,
102*3e777be0SXin Li                                     std::shared_ptr<armnn::InputTensors>& inputTensors,
103*3e777be0SXin Li                                     std::shared_ptr<armnn::OutputTensors>& outputTensors,
104*3e777be0SXin Li                                     CallbackContext callbackContext) :
105*3e777be0SXin Li                 m_Model(model),
106*3e777be0SXin Li                 m_MemPools(pMemPools),
107*3e777be0SXin Li                 m_OutputShapes(outputShapes),
108*3e777be0SXin Li                 m_InputTensors(inputTensors),
109*3e777be0SXin Li                 m_OutputTensors(outputTensors),
110*3e777be0SXin Li                 m_CallbackContext(callbackContext)
111*3e777be0SXin Li         {}
112*3e777be0SXin Li 
113*3e777be0SXin Li         void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
114*3e777be0SXin Li 
115*3e777be0SXin Li         ArmnnPreparedModel_1_2<HalVersion>* m_Model;
116*3e777be0SXin Li         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
117*3e777be0SXin Li         std::vector<V1_2::OutputShape> m_OutputShapes;
118*3e777be0SXin Li         std::shared_ptr<armnn::InputTensors> m_InputTensors;
119*3e777be0SXin Li         std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
120*3e777be0SXin Li         CallbackContext m_CallbackContext;
121*3e777be0SXin Li     };
122*3e777be0SXin Li 
123*3e777be0SXin Li     Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
124*3e777be0SXin Li                                       V1_2::MeasureTiming measureTiming,
125*3e777be0SXin Li                                       CallbackAsync_1_2 callback);
126*3e777be0SXin Li 
127*3e777be0SXin Li     Return<V1_0::ErrorStatus> PrepareMemoryForInputs(
128*3e777be0SXin Li             armnn::InputTensors& inputs,
129*3e777be0SXin Li             const V1_0::Request& request,
130*3e777be0SXin Li             const std::vector<android::nn::RunTimePoolInfo>& memPools);
131*3e777be0SXin Li 
132*3e777be0SXin Li     Return<V1_0::ErrorStatus> PrepareMemoryForOutputs(
133*3e777be0SXin Li             armnn::OutputTensors& outputs,
134*3e777be0SXin Li             std::vector<V1_2::OutputShape> &outputShapes,
135*3e777be0SXin Li             const V1_0::Request& request,
136*3e777be0SXin Li             const std::vector<android::nn::RunTimePoolInfo>& memPools);
137*3e777be0SXin Li 
138*3e777be0SXin Li     Return <V1_0::ErrorStatus> PrepareMemoryForIO(
139*3e777be0SXin Li             armnn::InputTensors& inputs,
140*3e777be0SXin Li             armnn::OutputTensors& outputs,
141*3e777be0SXin Li             std::vector<android::nn::RunTimePoolInfo>& memPools,
142*3e777be0SXin Li             const V1_0::Request& request,
143*3e777be0SXin Li             CallbackAsync_1_2 callback);
144*3e777be0SXin Li 
145*3e777be0SXin Li     template <typename TensorBindingCollection>
146*3e777be0SXin Li     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
147*3e777be0SXin Li 
148*3e777be0SXin Li     /// schedule the graph prepared from the request for execution
149*3e777be0SXin Li     template<typename CallbackContext>
150*3e777be0SXin Li     void ScheduleGraphForExecution(
151*3e777be0SXin Li             std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
152*3e777be0SXin Li             std::shared_ptr<armnn::InputTensors>& inputTensors,
153*3e777be0SXin Li             std::shared_ptr<armnn::OutputTensors>& outputTensors,
154*3e777be0SXin Li             CallbackContext m_CallbackContext);
155*3e777be0SXin Li 
156*3e777be0SXin Li     armnn::NetworkId                          m_NetworkId;
157*3e777be0SXin Li     armnn::IRuntime*                          m_Runtime;
158*3e777be0SXin Li     V1_2::Model                               m_Model;
159*3e777be0SXin Li     // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
160*3e777be0SXin Li     // It is specific to this class, so it is declared as static here
161*3e777be0SXin Li     static RequestThread<ArmnnPreparedModel_1_2,
162*3e777be0SXin Li                          HalVersion,
163*3e777be0SXin Li                          CallbackContext_1_2> m_RequestThread;
164*3e777be0SXin Li     uint32_t                                  m_RequestCount;
165*3e777be0SXin Li     const std::string&                        m_RequestInputsAndOutputsDumpDir;
166*3e777be0SXin Li     const bool                                m_GpuProfilingEnabled;
167*3e777be0SXin Li     // Static to allow sharing of threadpool between ArmnnPreparedModel instances
168*3e777be0SXin Li     static std::unique_ptr<armnn::Threadpool> m_Threadpool;
169*3e777be0SXin Li     std::shared_ptr<IWorkingMemHandle>        m_WorkingMemHandle;
170*3e777be0SXin Li     const bool                                m_AsyncModelExecutionEnabled;
171*3e777be0SXin Li     const bool                                m_EnableImport;
172*3e777be0SXin Li     const bool                                m_EnableExport;
173*3e777be0SXin Li     const bool                                m_PreparedFromCache;
174*3e777be0SXin Li };
175*3e777be0SXin Li 
176*3e777be0SXin Li }
177