1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_BURST_H 18 #define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_BURST_H 19 20 #include "nnapi/hal/1.2/BurstUtils.h" 21 22 #include <android-base/thread_annotations.h> 23 #include <android/hardware/neuralnetworks/1.0/types.h> 24 #include <android/hardware/neuralnetworks/1.2/IBurstCallback.h> 25 #include <android/hardware/neuralnetworks/1.2/IBurstContext.h> 26 #include <android/hardware/neuralnetworks/1.2/IPreparedModel.h> 27 #include <android/hardware/neuralnetworks/1.2/types.h> 28 #include <fmq/MessageQueue.h> 29 #include <hidl/MQDescriptor.h> 30 #include <nnapi/IBurst.h> 31 #include <nnapi/IExecution.h> 32 #include <nnapi/IPreparedModel.h> 33 #include <nnapi/Result.h> 34 #include <nnapi/Types.h> 35 #include <nnapi/hal/1.0/ProtectCallback.h> 36 #include <nnapi/hal/CommonUtils.h> 37 38 #include <atomic> 39 #include <chrono> 40 #include <functional> 41 #include <map> 42 #include <memory> 43 #include <mutex> 44 #include <stack> 45 #include <tuple> 46 #include <utility> 47 #include <vector> 48 49 namespace android::hardware::neuralnetworks::V1_2::utils { 50 51 /** 52 * The Burst class manages both the serialization and deserialization of data across FMQ, making it 53 * appear to the runtime as a regular synchronous inference. Additionally, this class manages the 54 * burst's memory cache. 55 */ 56 class Burst final : public nn::IBurst, public std::enable_shared_from_this<Burst> { 57 struct PrivateConstructorTag {}; 58 59 public: 60 using FallbackFunction = std::function< 61 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>()>; 62 63 /** 64 * NN runtime memory cache. 65 * 66 * MemoryCache associates a Memory object with a slot number to be passed across FMQ. The 67 * ExecutionBurstServer can use this callback to retrieve a hidl_memory corresponding to the 68 * slot via HIDL. 69 * 70 * Whenever a hidl_memory object is copied, it will duplicate the underlying file descriptor. 71 * Because the NN runtime currently copies the hidl_memory on each execution, it is difficult to 72 * associate hidl_memory objects with previously cached hidl_memory objects. For this reason, 73 * callers of this class must pair each hidl_memory object with an associated key. For 74 * efficiency, if two hidl_memory objects represent the same underlying buffer, they must use 75 * the same key. 76 * 77 * This class is thread-safe. 78 */ 79 class MemoryCache : public std::enable_shared_from_this<MemoryCache> { 80 struct PrivateConstructorTag {}; 81 82 public: 83 using Task = std::function<void()>; 84 using Cleanup = base::ScopeGuard<Task>; 85 using SharedCleanup = std::shared_ptr<const Cleanup>; 86 using WeakCleanup = std::weak_ptr<const Cleanup>; 87 88 // Custom constructor to pre-allocate cache sizes. 89 MemoryCache(); 90 91 /** 92 * Add a burst context to the MemoryCache object. 93 * 94 * If this method is called, it must be called before the MemoryCache::cacheMemory or 95 * MemoryCache::getMemory is used. 96 * 97 * @param burstContext Burst context to be added to the MemoryCache object. 98 */ 99 void setBurstContext(sp<IBurstContext> burstContext); 100 101 /** 102 * Cache a memory object in the MemoryCache object. 103 * 104 * @param memory Memory object to be cached while the returned `SharedCleanup` is alive. 105 * @return A pair of (1) a unique identifier for the cache entry and (2) a ref-counted 106 * "hold" object which preserves the cache as long as the hold object is alive. 107 */ 108 std::pair<int32_t, SharedCleanup> cacheMemory(const nn::SharedMemory& memory); 109 110 /** 111 * Get the memory object corresponding to a slot identifier. 112 * 113 * @param slot Slot which identifies the memory object to retrieve. 114 * @return The memory object corresponding to slot, otherwise GeneralError. 115 */ 116 nn::GeneralResult<nn::SharedMemory> getMemory(int32_t slot); 117 118 private: 119 void freeMemory(const nn::SharedMemory& memory); 120 int32_t allocateSlotLocked() REQUIRES(mMutex); 121 122 std::mutex mMutex; 123 std::condition_variable mCond; 124 sp<IBurstContext> mBurstContext GUARDED_BY(mMutex); 125 std::stack<int32_t, std::vector<int32_t>> mFreeSlots GUARDED_BY(mMutex); 126 std::map<nn::SharedMemory, int32_t> mMemoryIdToSlot GUARDED_BY(mMutex); 127 std::vector<nn::SharedMemory> mMemoryCache GUARDED_BY(mMutex); 128 std::vector<WeakCleanup> mCacheCleaner GUARDED_BY(mMutex); 129 }; 130 131 /** 132 * HIDL Callback class to pass memory objects to the Burst server when given corresponding 133 * slots. 134 */ 135 class ExecutionBurstCallback : public IBurstCallback { 136 public: 137 // Precondition: memoryCache must be non-null. 138 explicit ExecutionBurstCallback(const std::shared_ptr<MemoryCache>& memoryCache); 139 140 // See IBurstCallback::getMemories for information on this method. 141 Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override; 142 143 private: 144 const std::weak_ptr<MemoryCache> kMemoryCache; 145 }; 146 147 /** 148 * Creates a burst controller on a prepared model. 149 * 150 * @param preparedModel Model prepared for execution to execute on. 151 * @param pollingTimeWindow How much time (in microseconds) the Burst is allowed to poll the FMQ 152 * before waiting on the blocking futex. Polling may result in lower latencies at the 153 * potential cost of more power usage. 154 * @return Burst Execution burst controller object. 155 */ 156 static nn::GeneralResult<std::shared_ptr<const Burst>> create( 157 nn::SharedPreparedModel preparedModel, const sp<IPreparedModel>& hidlPreparedModel, 158 std::chrono::microseconds pollingTimeWindow); 159 160 Burst(PrivateConstructorTag tag, nn::SharedPreparedModel preparedModel, 161 std::unique_ptr<RequestChannelSender> requestChannelSender, 162 std::unique_ptr<ResultChannelReceiver> resultChannelReceiver, 163 sp<ExecutionBurstCallback> callback, sp<IBurstContext> burstContext, 164 std::shared_ptr<MemoryCache> memoryCache, 165 neuralnetworks::utils::DeathHandler deathHandler); 166 167 // See IBurst::cacheMemory for information on this method. 168 OptionalCacheHold cacheMemory(const nn::SharedMemory& memory) const override; 169 170 // See IBurst::execute for information on this method. 171 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute( 172 const nn::Request& request, nn::MeasureTiming measure, 173 const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration, 174 const std::vector<nn::TokenValuePair>& hints, 175 const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override; 176 177 // See IBurst::createReusableExecution for information on this method. 178 nn::GeneralResult<nn::SharedExecution> createReusableExecution( 179 const nn::Request& request, nn::MeasureTiming measure, 180 const nn::OptionalDuration& loopTimeoutDuration, 181 const std::vector<nn::TokenValuePair>& hints, 182 const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override; 183 184 // If fallback is not nullptr, this method will invoke the fallback function to try another 185 // execution path if the packet could not be sent. Otherwise, failing to send the packet will 186 // result in an error. 187 nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> executeInternal( 188 const std::vector<FmqRequestDatum>& requestPacket, 189 const hal::utils::RequestRelocation& relocation, FallbackFunction fallback) const; 190 191 private: 192 mutable std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT; 193 const nn::SharedPreparedModel kPreparedModel; 194 const std::unique_ptr<RequestChannelSender> mRequestChannelSender; 195 const std::unique_ptr<ResultChannelReceiver> mResultChannelReceiver; 196 const sp<ExecutionBurstCallback> mBurstCallback; 197 const sp<IBurstContext> mBurstContext; 198 const std::shared_ptr<MemoryCache> mMemoryCache; 199 // `kDeathHandler` must come after `mRequestChannelSender` and `mResultChannelReceiver` because 200 // it holds references to both objects. 201 const neuralnetworks::utils::DeathHandler kDeathHandler; 202 }; 203 204 } // namespace android::hardware::neuralnetworks::V1_2::utils 205 206 #endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_BURST_H 207