xref: /aosp_15_r20/hardware/interfaces/neuralnetworks/1.2/utils/include/nnapi/hal/1.2/Burst.h (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_BURST_H
18 #define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_BURST_H
19 
20 #include "nnapi/hal/1.2/BurstUtils.h"
21 
22 #include <android-base/thread_annotations.h>
23 #include <android/hardware/neuralnetworks/1.0/types.h>
24 #include <android/hardware/neuralnetworks/1.2/IBurstCallback.h>
25 #include <android/hardware/neuralnetworks/1.2/IBurstContext.h>
26 #include <android/hardware/neuralnetworks/1.2/IPreparedModel.h>
27 #include <android/hardware/neuralnetworks/1.2/types.h>
28 #include <fmq/MessageQueue.h>
29 #include <hidl/MQDescriptor.h>
30 #include <nnapi/IBurst.h>
31 #include <nnapi/IExecution.h>
32 #include <nnapi/IPreparedModel.h>
33 #include <nnapi/Result.h>
34 #include <nnapi/Types.h>
35 #include <nnapi/hal/1.0/ProtectCallback.h>
36 #include <nnapi/hal/CommonUtils.h>
37 
38 #include <atomic>
39 #include <chrono>
40 #include <functional>
41 #include <map>
42 #include <memory>
43 #include <mutex>
44 #include <stack>
45 #include <tuple>
46 #include <utility>
47 #include <vector>
48 
49 namespace android::hardware::neuralnetworks::V1_2::utils {
50 
51 /**
52  * The Burst class manages both the serialization and deserialization of data across FMQ, making it
53  * appear to the runtime as a regular synchronous inference. Additionally, this class manages the
54  * burst's memory cache.
55  */
56 class Burst final : public nn::IBurst, public std::enable_shared_from_this<Burst> {
57     struct PrivateConstructorTag {};
58 
59   public:
60     using FallbackFunction = std::function<
61             nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>>()>;
62 
63     /**
64      * NN runtime memory cache.
65      *
66      * MemoryCache associates a Memory object with a slot number to be passed across FMQ. The
67      * ExecutionBurstServer can use this callback to retrieve a hidl_memory corresponding to the
68      * slot via HIDL.
69      *
70      * Whenever a hidl_memory object is copied, it will duplicate the underlying file descriptor.
71      * Because the NN runtime currently copies the hidl_memory on each execution, it is difficult to
72      * associate hidl_memory objects with previously cached hidl_memory objects. For this reason,
73      * callers of this class must pair each hidl_memory object with an associated key. For
74      * efficiency, if two hidl_memory objects represent the same underlying buffer, they must use
75      * the same key.
76      *
77      * This class is thread-safe.
78      */
79     class MemoryCache : public std::enable_shared_from_this<MemoryCache> {
80         struct PrivateConstructorTag {};
81 
82       public:
83         using Task = std::function<void()>;
84         using Cleanup = base::ScopeGuard<Task>;
85         using SharedCleanup = std::shared_ptr<const Cleanup>;
86         using WeakCleanup = std::weak_ptr<const Cleanup>;
87 
88         // Custom constructor to pre-allocate cache sizes.
89         MemoryCache();
90 
91         /**
92          * Add a burst context to the MemoryCache object.
93          *
94          * If this method is called, it must be called before the MemoryCache::cacheMemory or
95          * MemoryCache::getMemory is used.
96          *
97          * @param burstContext Burst context to be added to the MemoryCache object.
98          */
99         void setBurstContext(sp<IBurstContext> burstContext);
100 
101         /**
102          * Cache a memory object in the MemoryCache object.
103          *
104          * @param memory Memory object to be cached while the returned `SharedCleanup` is alive.
105          * @return A pair of (1) a unique identifier for the cache entry and (2) a ref-counted
106          *     "hold" object which preserves the cache as long as the hold object is alive.
107          */
108         std::pair<int32_t, SharedCleanup> cacheMemory(const nn::SharedMemory& memory);
109 
110         /**
111          * Get the memory object corresponding to a slot identifier.
112          *
113          * @param slot Slot which identifies the memory object to retrieve.
114          * @return The memory object corresponding to slot, otherwise GeneralError.
115          */
116         nn::GeneralResult<nn::SharedMemory> getMemory(int32_t slot);
117 
118       private:
119         void freeMemory(const nn::SharedMemory& memory);
120         int32_t allocateSlotLocked() REQUIRES(mMutex);
121 
122         std::mutex mMutex;
123         std::condition_variable mCond;
124         sp<IBurstContext> mBurstContext GUARDED_BY(mMutex);
125         std::stack<int32_t, std::vector<int32_t>> mFreeSlots GUARDED_BY(mMutex);
126         std::map<nn::SharedMemory, int32_t> mMemoryIdToSlot GUARDED_BY(mMutex);
127         std::vector<nn::SharedMemory> mMemoryCache GUARDED_BY(mMutex);
128         std::vector<WeakCleanup> mCacheCleaner GUARDED_BY(mMutex);
129     };
130 
131     /**
132      * HIDL Callback class to pass memory objects to the Burst server when given corresponding
133      * slots.
134      */
135     class ExecutionBurstCallback : public IBurstCallback {
136       public:
137         // Precondition: memoryCache must be non-null.
138         explicit ExecutionBurstCallback(const std::shared_ptr<MemoryCache>& memoryCache);
139 
140         // See IBurstCallback::getMemories for information on this method.
141         Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override;
142 
143       private:
144         const std::weak_ptr<MemoryCache> kMemoryCache;
145     };
146 
147     /**
148      * Creates a burst controller on a prepared model.
149      *
150      * @param preparedModel Model prepared for execution to execute on.
151      * @param pollingTimeWindow How much time (in microseconds) the Burst is allowed to poll the FMQ
152      *     before waiting on the blocking futex. Polling may result in lower latencies at the
153      *     potential cost of more power usage.
154      * @return Burst Execution burst controller object.
155      */
156     static nn::GeneralResult<std::shared_ptr<const Burst>> create(
157             nn::SharedPreparedModel preparedModel, const sp<IPreparedModel>& hidlPreparedModel,
158             std::chrono::microseconds pollingTimeWindow);
159 
160     Burst(PrivateConstructorTag tag, nn::SharedPreparedModel preparedModel,
161           std::unique_ptr<RequestChannelSender> requestChannelSender,
162           std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
163           sp<ExecutionBurstCallback> callback, sp<IBurstContext> burstContext,
164           std::shared_ptr<MemoryCache> memoryCache,
165           neuralnetworks::utils::DeathHandler deathHandler);
166 
167     // See IBurst::cacheMemory for information on this method.
168     OptionalCacheHold cacheMemory(const nn::SharedMemory& memory) const override;
169 
170     // See IBurst::execute for information on this method.
171     nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
172             const nn::Request& request, nn::MeasureTiming measure,
173             const nn::OptionalTimePoint& deadline, const nn::OptionalDuration& loopTimeoutDuration,
174             const std::vector<nn::TokenValuePair>& hints,
175             const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
176 
177     // See IBurst::createReusableExecution for information on this method.
178     nn::GeneralResult<nn::SharedExecution> createReusableExecution(
179             const nn::Request& request, nn::MeasureTiming measure,
180             const nn::OptionalDuration& loopTimeoutDuration,
181             const std::vector<nn::TokenValuePair>& hints,
182             const std::vector<nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override;
183 
184     // If fallback is not nullptr, this method will invoke the fallback function to try another
185     // execution path if the packet could not be sent. Otherwise, failing to send the packet will
186     // result in an error.
187     nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> executeInternal(
188             const std::vector<FmqRequestDatum>& requestPacket,
189             const hal::utils::RequestRelocation& relocation, FallbackFunction fallback) const;
190 
191   private:
192     mutable std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
193     const nn::SharedPreparedModel kPreparedModel;
194     const std::unique_ptr<RequestChannelSender> mRequestChannelSender;
195     const std::unique_ptr<ResultChannelReceiver> mResultChannelReceiver;
196     const sp<ExecutionBurstCallback> mBurstCallback;
197     const sp<IBurstContext> mBurstContext;
198     const std::shared_ptr<MemoryCache> mMemoryCache;
199     // `kDeathHandler` must come after `mRequestChannelSender` and `mResultChannelReceiver` because
200     // it holds references to both objects.
201     const neuralnetworks::utils::DeathHandler kDeathHandler;
202 };
203 
204 }  // namespace android::hardware::neuralnetworks::V1_2::utils
205 
206 #endif  // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_1_2_UTILS_BURST_H
207