xref: /aosp_15_r20/external/android-nn-driver/ArmnnPreparedModel.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017 Arm Ltd. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #pragma once
7*3e777be0SXin Li 
8*3e777be0SXin Li #include "ArmnnDriver.hpp"
9*3e777be0SXin Li #include "ArmnnDriverImpl.hpp"
10*3e777be0SXin Li #include "RequestThread.hpp"
11*3e777be0SXin Li 
12*3e777be0SXin Li #include <NeuralNetworks.h>
13*3e777be0SXin Li #include <armnn/ArmNN.hpp>
14*3e777be0SXin Li #include <armnn/Threadpool.hpp>
15*3e777be0SXin Li 
16*3e777be0SXin Li #include <string>
17*3e777be0SXin Li #include <vector>
18*3e777be0SXin Li 
19*3e777be0SXin Li namespace armnn_driver
20*3e777be0SXin Li {
21*3e777be0SXin Li using armnnExecuteCallback_1_0 = std::function<void(V1_0::ErrorStatus status, std::string callingFunction)>;
22*3e777be0SXin Li 
23*3e777be0SXin Li struct ArmnnCallback_1_0
24*3e777be0SXin Li {
25*3e777be0SXin Li     armnnExecuteCallback_1_0 callback;
26*3e777be0SXin Li };
27*3e777be0SXin Li 
28*3e777be0SXin Li struct ExecutionContext_1_0 {};
29*3e777be0SXin Li 
30*3e777be0SXin Li using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>;
31*3e777be0SXin Li 
32*3e777be0SXin Li template <typename HalVersion>
33*3e777be0SXin Li class ArmnnPreparedModel : public V1_0::IPreparedModel
34*3e777be0SXin Li {
35*3e777be0SXin Li public:
36*3e777be0SXin Li     using HalModel = typename HalVersion::Model;
37*3e777be0SXin Li 
38*3e777be0SXin Li     ArmnnPreparedModel(armnn::NetworkId networkId,
39*3e777be0SXin Li                        armnn::IRuntime* runtime,
40*3e777be0SXin Li                        const HalModel& model,
41*3e777be0SXin Li                        const std::string& requestInputsAndOutputsDumpDir,
42*3e777be0SXin Li                        const bool gpuProfilingEnabled,
43*3e777be0SXin Li                        const bool asyncModelExecutionEnabled = false,
44*3e777be0SXin Li                        const unsigned int numberOfThreads = 1,
45*3e777be0SXin Li                        const bool importEnabled = false,
46*3e777be0SXin Li                        const bool exportEnabled = false);
47*3e777be0SXin Li 
48*3e777be0SXin Li     virtual ~ArmnnPreparedModel();
49*3e777be0SXin Li 
50*3e777be0SXin Li     virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
51*3e777be0SXin Li                                               const ::android::sp<V1_0::IExecutionCallback>& callback) override;
52*3e777be0SXin Li 
53*3e777be0SXin Li     /// execute the graph prepared from the request
54*3e777be0SXin Li     void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
55*3e777be0SXin Li                       armnn::InputTensors& inputTensors,
56*3e777be0SXin Li                       armnn::OutputTensors& outputTensors,
57*3e777be0SXin Li                       CallbackContext_1_0 callback);
58*3e777be0SXin Li 
59*3e777be0SXin Li     /// Executes this model with dummy inputs (e.g. all zeroes).
60*3e777be0SXin Li     /// \return false on failure, otherwise true
61*3e777be0SXin Li     bool ExecuteWithDummyInputs();
62*3e777be0SXin Li 
63*3e777be0SXin Li private:
64*3e777be0SXin Li 
65*3e777be0SXin Li     template<typename CallbackContext>
66*3e777be0SXin Li     class ArmnnThreadPoolCallback : public armnn::IAsyncExecutionCallback
67*3e777be0SXin Li     {
68*3e777be0SXin Li     public:
ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion> * model,std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext)69*3e777be0SXin Li         ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion>* model,
70*3e777be0SXin Li                                 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
71*3e777be0SXin Li                                 std::shared_ptr<armnn::InputTensors>& inputTensors,
72*3e777be0SXin Li                                 std::shared_ptr<armnn::OutputTensors>& outputTensors,
73*3e777be0SXin Li                                 CallbackContext callbackContext) :
74*3e777be0SXin Li                 m_Model(model),
75*3e777be0SXin Li                 m_MemPools(pMemPools),
76*3e777be0SXin Li                 m_InputTensors(inputTensors),
77*3e777be0SXin Li                 m_OutputTensors(outputTensors),
78*3e777be0SXin Li                 m_CallbackContext(callbackContext)
79*3e777be0SXin Li         {}
80*3e777be0SXin Li 
81*3e777be0SXin Li         void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
82*3e777be0SXin Li 
83*3e777be0SXin Li         ArmnnPreparedModel<HalVersion>* m_Model;
84*3e777be0SXin Li         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
85*3e777be0SXin Li         std::shared_ptr<armnn::InputTensors> m_InputTensors;
86*3e777be0SXin Li         std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
87*3e777be0SXin Li         CallbackContext m_CallbackContext;
88*3e777be0SXin Li     };
89*3e777be0SXin Li 
90*3e777be0SXin Li     template <typename TensorBindingCollection>
91*3e777be0SXin Li     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
92*3e777be0SXin Li 
93*3e777be0SXin Li     /// schedule the graph prepared from the request for execution
94*3e777be0SXin Li     template<typename CallbackContext>
95*3e777be0SXin Li     void ScheduleGraphForExecution(
96*3e777be0SXin Li             std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
97*3e777be0SXin Li             std::shared_ptr<armnn::InputTensors>& inputTensors,
98*3e777be0SXin Li             std::shared_ptr<armnn::OutputTensors>& outputTensors,
99*3e777be0SXin Li             CallbackContext m_CallbackContext);
100*3e777be0SXin Li 
101*3e777be0SXin Li     armnn::NetworkId                          m_NetworkId;
102*3e777be0SXin Li     armnn::IRuntime*                          m_Runtime;
103*3e777be0SXin Li     HalModel                                  m_Model;
104*3e777be0SXin Li     // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
105*3e777be0SXin Li     // It is specific to this class, so it is declared as static here
106*3e777be0SXin Li     static RequestThread<ArmnnPreparedModel,
107*3e777be0SXin Li                          HalVersion,
108*3e777be0SXin Li                          CallbackContext_1_0> m_RequestThread;
109*3e777be0SXin Li     uint32_t                                  m_RequestCount;
110*3e777be0SXin Li     const std::string&                        m_RequestInputsAndOutputsDumpDir;
111*3e777be0SXin Li     const bool                                m_GpuProfilingEnabled;
112*3e777be0SXin Li     // Static to allow sharing of threadpool between ArmnnPreparedModel instances
113*3e777be0SXin Li     static std::unique_ptr<armnn::Threadpool> m_Threadpool;
114*3e777be0SXin Li     std::shared_ptr<armnn::IWorkingMemHandle> m_WorkingMemHandle;
115*3e777be0SXin Li     const bool m_AsyncModelExecutionEnabled;
116*3e777be0SXin Li     const bool m_EnableImport;
117*3e777be0SXin Li     const bool m_EnableExport;
118*3e777be0SXin Li };
119*3e777be0SXin Li 
120*3e777be0SXin Li }
121