1 // 2 // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "LoadedNetwork.hpp" 8 #include "DeviceSpec.hpp" 9 10 #include <armnn/INetwork.hpp> 11 #include <armnn/IRuntime.hpp> 12 #include <armnn/Tensor.hpp> 13 #include <armnn/BackendId.hpp> 14 15 #include <armnn/backends/DynamicBackend.hpp> 16 17 #include <client/include/IInitialiseProfilingService.hpp> 18 #include <client/include/IProfilingService.hpp> 19 #include <client/include/IReportStructure.hpp> 20 21 #include <mutex> 22 #include <unordered_map> 23 24 namespace armnn 25 { 26 using LoadedNetworks = std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>>; 27 using IReportStructure = arm::pipe::IReportStructure; 28 using IInitialiseProfilingService = arm::pipe::IInitialiseProfilingService; 29 30 struct RuntimeImpl final : public IReportStructure, public IInitialiseProfilingService 31 { 32 public: 33 /// Loads a complete network into the Runtime. 34 /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference. 35 /// @param [in] network - Complete network to load into the Runtime. 36 /// The runtime takes ownership of the network once passed in. 37 /// @return armnn::Status 38 Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network); 39 40 /// Load a complete network into the IRuntime. 41 /// @param [out] networkIdOut Unique identifier for the network is returned in this reference. 42 /// @param [in] network Complete network to load into the IRuntime. 43 /// @param [out] errorMessage Error message if there were any errors. 44 /// The runtime takes ownership of the network once passed in. 45 /// @return armnn::Status 46 Status LoadNetwork(NetworkId& networkIdOut, 47 IOptimizedNetworkPtr network, 48 std::string& errorMessage); 49 50 Status LoadNetwork(NetworkId& networkIdOut, 51 IOptimizedNetworkPtr network, 52 std::string& errorMessage, 53 const INetworkProperties& networkProperties); 54 55 armnn::TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const; 56 armnn::TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const; 57 58 std::vector<ImportedInputId> ImportInputs(NetworkId networkId, const InputTensors& inputTensors, 59 MemorySource forceImportMemorySource); 60 std::vector<ImportedOutputId> ImportOutputs(NetworkId networkId, const OutputTensors& outputTensors, 61 MemorySource forceImportMemorySource); 62 63 void ClearImportedInputs(NetworkId networkId, const std::vector<ImportedInputId> inputIds); 64 void ClearImportedOutputs(NetworkId networkId, const std::vector<ImportedOutputId> outputIds); 65 66 // Evaluates network using input in inputTensors, outputs filled into outputTensors. 67 Status EnqueueWorkload(NetworkId networkId, 68 const InputTensors& inputTensors, 69 const OutputTensors& outputTensors, 70 std::vector<ImportedInputId> preImportedInputIds = {}, 71 std::vector<ImportedOutputId> preImportedOutputIds = {}); 72 73 /// This is an experimental function. 74 /// Evaluates a network using input in inputTensors and outputs filled into outputTensors. 75 /// This function performs a thread safe execution of the network. Returns once execution is complete. 76 /// Will block until this and any other thread using the same workingMem object completes. 77 Status Execute(IWorkingMemHandle& workingMemHandle, 78 const InputTensors& inputTensors, 79 const OutputTensors& outputTensors, 80 std::vector<ImportedInputId> preImportedInputs, 81 std::vector<ImportedOutputId> preImportedOutputs); 82 83 /// Unloads a network from the Runtime. 84 /// At the moment this only removes the network from the m_Impl->m_Network. 85 /// This might need more work in the future to be AndroidNN compliant. 86 /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork(). 87 /// @return armnn::Status 88 Status UnloadNetwork(NetworkId networkId); 89 GetDeviceSpecarmnn::RuntimeImpl90 const IDeviceSpec& GetDeviceSpec() const { return m_DeviceSpec; } 91 92 /// Gets the profiler corresponding to the given network id. 93 /// @param networkId The id of the network for which to get the profile. 94 /// @return A pointer to the requested profiler, or nullptr if not found. 95 const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const; 96 97 /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have 98 /// overlapped Execution by calling this function from different threads. 99 std::unique_ptr<IWorkingMemHandle> CreateWorkingMemHandle(NetworkId networkId); 100 101 /// Registers a callback function to debug layers performing custom computations on intermediate tensors. 102 /// @param networkId The id of the network to register the callback. 103 /// @param func callback function to pass to the debug layer. 104 void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func); 105 106 /// Creates a runtime for workload execution. 107 RuntimeImpl(const IRuntime::CreationOptions& options); 108 109 ~RuntimeImpl(); 110 111 //NOTE: we won't need the profiling service reference but it is good to pass the service 112 // in this way to facilitate other implementations down the road 113 void ReportStructure(arm::pipe::IProfilingService& profilingService) override; 114 115 void InitialiseProfilingService(arm::pipe::IProfilingService& profilingService) override; 116 117 private: 118 friend void RuntimeLoadedNetworksReserve(RuntimeImpl* runtime); // See RuntimeTests.cpp 119 120 friend arm::pipe::IProfilingService& GetProfilingService(RuntimeImpl* runtime); // See RuntimeTests.cpp 121 122 int GenerateNetworkId(); 123 124 LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const; 125 126 template<typename Func> LoadedNetworkFuncSafearmnn::RuntimeImpl127 void LoadedNetworkFuncSafe(NetworkId networkId, Func f) 128 { 129 #if !defined(ARMNN_DISABLE_THREADS) 130 std::lock_guard<std::mutex> lockGuard(m_Mutex); 131 #endif 132 auto iter = m_LoadedNetworks.find(networkId); 133 if (iter != m_LoadedNetworks.end()) 134 { 135 f(iter->second.get()); 136 } 137 } 138 139 /// Loads any available/compatible dynamic backend in the runtime. 140 void LoadDynamicBackends(const std::string& overrideBackendPath); 141 142 #if !defined(ARMNN_DISABLE_THREADS) 143 mutable std::mutex m_Mutex; 144 #endif 145 146 /// Map of Loaded Networks with associated GUID as key 147 LoadedNetworks m_LoadedNetworks; 148 149 std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts; 150 151 int m_NetworkIdCounter; 152 153 DeviceSpec m_DeviceSpec; 154 155 /// List of dynamic backends loaded in the runtime 156 std::vector<DynamicBackendPtr> m_DynamicBackends; 157 158 /// Profiling Service Instance 159 std::unique_ptr<arm::pipe::IProfilingService> m_ProfilingService; 160 161 /// Keep track of backend ids of the custom allocators that this instance of the runtime added. The 162 /// destructor can then clean up for this runtime. 163 std::set<BackendId> m_AllocatorsAddedByThisRuntime; 164 }; 165 166 } // namespace armnn 167