xref: /aosp_15_r20/external/armnn/src/armnn/Runtime.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "LoadedNetwork.hpp"
8 #include "DeviceSpec.hpp"
9 
10 #include <armnn/INetwork.hpp>
11 #include <armnn/IRuntime.hpp>
12 #include <armnn/Tensor.hpp>
13 #include <armnn/BackendId.hpp>
14 
15 #include <armnn/backends/DynamicBackend.hpp>
16 
17 #include <client/include/IInitialiseProfilingService.hpp>
18 #include <client/include/IProfilingService.hpp>
19 #include <client/include/IReportStructure.hpp>
20 
21 #include <mutex>
22 #include <unordered_map>
23 
24 namespace armnn
25 {
26 using LoadedNetworks = std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>>;
27 using IReportStructure = arm::pipe::IReportStructure;
28     using IInitialiseProfilingService = arm::pipe::IInitialiseProfilingService;
29 
30 struct RuntimeImpl final :  public IReportStructure, public IInitialiseProfilingService
31 {
32 public:
33     /// Loads a complete network into the Runtime.
34     /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference.
35     /// @param [in] network - Complete network to load into the Runtime.
36     /// The runtime takes ownership of the network once passed in.
37     /// @return armnn::Status
38     Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network);
39 
40     /// Load a complete network into the IRuntime.
41     /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
42     /// @param [in] network Complete network to load into the IRuntime.
43     /// @param [out] errorMessage Error message if there were any errors.
44     /// The runtime takes ownership of the network once passed in.
45     /// @return armnn::Status
46     Status LoadNetwork(NetworkId& networkIdOut,
47                        IOptimizedNetworkPtr network,
48                        std::string& errorMessage);
49 
50     Status LoadNetwork(NetworkId& networkIdOut,
51                        IOptimizedNetworkPtr network,
52                        std::string& errorMessage,
53                        const INetworkProperties& networkProperties);
54 
55     armnn::TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const;
56     armnn::TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const;
57 
58     std::vector<ImportedInputId> ImportInputs(NetworkId networkId, const InputTensors& inputTensors,
59                                               MemorySource forceImportMemorySource);
60     std::vector<ImportedOutputId> ImportOutputs(NetworkId networkId, const OutputTensors& outputTensors,
61                                                 MemorySource forceImportMemorySource);
62 
63     void ClearImportedInputs(NetworkId networkId, const std::vector<ImportedInputId> inputIds);
64     void ClearImportedOutputs(NetworkId networkId, const std::vector<ImportedOutputId> outputIds);
65 
66     // Evaluates network using input in inputTensors, outputs filled into outputTensors.
67     Status EnqueueWorkload(NetworkId networkId,
68                            const InputTensors& inputTensors,
69                            const OutputTensors& outputTensors,
70                            std::vector<ImportedInputId> preImportedInputIds = {},
71                            std::vector<ImportedOutputId> preImportedOutputIds = {});
72 
73     /// This is an experimental function.
74     /// Evaluates a network using input in inputTensors and outputs filled into outputTensors.
75     /// This function performs a thread safe execution of the network. Returns once execution is complete.
76     /// Will block until this and any other thread using the same workingMem object completes.
77     Status Execute(IWorkingMemHandle& workingMemHandle,
78                    const InputTensors& inputTensors,
79                    const OutputTensors& outputTensors,
80                    std::vector<ImportedInputId> preImportedInputs,
81                    std::vector<ImportedOutputId> preImportedOutputs);
82 
83     /// Unloads a network from the Runtime.
84     /// At the moment this only removes the network from the m_Impl->m_Network.
85     /// This might need more work in the future to be AndroidNN compliant.
86     /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
87     /// @return armnn::Status
88     Status UnloadNetwork(NetworkId networkId);
89 
GetDeviceSpecarmnn::RuntimeImpl90     const IDeviceSpec& GetDeviceSpec() const { return m_DeviceSpec; }
91 
92     /// Gets the profiler corresponding to the given network id.
93     /// @param networkId The id of the network for which to get the profile.
94     /// @return A pointer to the requested profiler, or nullptr if not found.
95     const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const;
96 
97     /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
98     /// overlapped Execution by calling this function from different threads.
99     std::unique_ptr<IWorkingMemHandle> CreateWorkingMemHandle(NetworkId networkId);
100 
101     /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
102     /// @param networkId The id of the network to register the callback.
103     /// @param func callback function to pass to the debug layer.
104     void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func);
105 
106     /// Creates a runtime for workload execution.
107     RuntimeImpl(const IRuntime::CreationOptions& options);
108 
109     ~RuntimeImpl();
110 
111     //NOTE: we won't need the profiling service reference but it is good to pass the service
112     // in this way to facilitate other implementations down the road
113     void ReportStructure(arm::pipe::IProfilingService& profilingService) override;
114 
115     void InitialiseProfilingService(arm::pipe::IProfilingService& profilingService) override;
116 
117 private:
118     friend void RuntimeLoadedNetworksReserve(RuntimeImpl* runtime); // See RuntimeTests.cpp
119 
120     friend arm::pipe::IProfilingService& GetProfilingService(RuntimeImpl* runtime); // See RuntimeTests.cpp
121 
122     int GenerateNetworkId();
123 
124     LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;
125 
126     template<typename Func>
LoadedNetworkFuncSafearmnn::RuntimeImpl127     void LoadedNetworkFuncSafe(NetworkId networkId, Func f)
128     {
129 #if !defined(ARMNN_DISABLE_THREADS)
130         std::lock_guard<std::mutex> lockGuard(m_Mutex);
131 #endif
132         auto iter = m_LoadedNetworks.find(networkId);
133         if (iter != m_LoadedNetworks.end())
134         {
135             f(iter->second.get());
136         }
137     }
138 
139     /// Loads any available/compatible dynamic backend in the runtime.
140     void LoadDynamicBackends(const std::string& overrideBackendPath);
141 
142 #if !defined(ARMNN_DISABLE_THREADS)
143     mutable std::mutex m_Mutex;
144 #endif
145 
146     /// Map of Loaded Networks with associated GUID as key
147     LoadedNetworks m_LoadedNetworks;
148 
149     std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts;
150 
151     int m_NetworkIdCounter;
152 
153     DeviceSpec m_DeviceSpec;
154 
155     /// List of dynamic backends loaded in the runtime
156     std::vector<DynamicBackendPtr> m_DynamicBackends;
157 
158     /// Profiling Service Instance
159     std::unique_ptr<arm::pipe::IProfilingService> m_ProfilingService;
160 
161     /// Keep track of backend ids of the custom allocators that this instance of the runtime added. The
162     /// destructor can then clean up for this runtime.
163     std::set<BackendId> m_AllocatorsAddedByThisRuntime;
164 };
165 
166 } // namespace armnn
167