xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/CachingHostAllocator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/Allocator.h>
2 #include <c10/util/flat_hash_map.h>
3 #include <c10/util/llvmMathExtras.h>
4 #include <optional>
5 
6 #include <deque>
7 #include <mutex>
8 #include <set>
9 
10 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
11 namespace at {
12 
13 /**
14  * HostBlock is typically a fundamental memory block used in pinned memory. It
15  * is likely related to Event and Stream of device runtime. It is probably a
16  * base struct or interface that can be inherited and extended by each backend.
17  */
18 template <typename S>
19 struct HostBlock {
20   // constructor for search key
HostBlockHostBlock21   HostBlock(size_t size) : size_(size) {}
22 
HostBlockHostBlock23   HostBlock(size_t size, void* ptr) : size_(size), ptr_(ptr) {}
24 
25   std::mutex mutex_;
26   size_t size_{0}; // block size in bytes
27   void* ptr_{nullptr}; // memory address
28   bool allocated_{false}; // in-use flag
29   size_t event_count_{0}; // number of related events
30   ska::flat_hash_set<S> streams_; // streams on which the block was used
31 };
32 
33 template <typename B>
34 struct alignas(64) FreeBlockList {
35   std::mutex mutex_;
36   std::deque<B*> list_;
37 };
38 
39 namespace {
40   // Max cached block sizes: (1 << MAX_SIZE_INDEX) bytes
41   constexpr size_t MAX_SIZE_INDEX = 64;
42 }
43 
44 /**
45  * Note [HostAllocator design]
46  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
47  * We have three key data structures - the free list which stores blocks that
48  * are not currently used, the block list which stores all blocks that have been
49  * allocated, and the event queue which stores runtime events and their
50  * corresponding blocks.
51  *
52  * Each of these are protected by a separate mutex. The key design principles
53  * are to 1) only hold each mutex for the minimal amount of time possible, 2)
54  * never do any possible expensive operations (such as CUDA runtime API calls)
55  * while holding the lock.
56  *
57  * There are four public methods: allocate, free, record_event and empty_cache.
58  *   1) In the allocate path, we first check to see if we can service our
59  * request from this free list, and otherwise we create a new block with
60  * allocate_host_memory.
61  *   2) In the free path, we insert events (if required) into the event queue,
62  * and if possible insert our block back into the free list. In allocate, we
63  * first eagerly query events until we find one that is not ready, and insert
64  * the corresponding block onto the free list if all the events recorded for a
65  * block are ready.
66  *   3) In the record_event path, we simply insert the given stream into the set
67  * of streams tracked by the specified block. This set of streams is then
68  * consumed in the free path.
69  *   4) In the empty_cache path, we flush any available blocks into the free
70  * list. Remove all element of free list, then remove them from block list and
71  * release the associated pinned memory allocation via free_block.
72  *
73  * We generalize the caching host allocator into two parts: interface and
74  * implementation. For any new backend looking to integrate with host allocator
75  * and reuse caching mechanism, these two parts are necessary to be specialized.
76  *
77  * For the implementation, we provide a CachingHostAllocatorImpl struct
78  * to abstract the caching mechanism. Any backend needs to provide a customized
79  * implementation by specializing its own public functions and the related
80  * runtime functions. Its template parameter S represents runtime Stream, E
81  * denotes runtime Event, B indicates the fundamental memory block.
82  *
83  * For the interface, we provide a CachingHostAllocatorInterface struct as an
84  * interface. Any backend needs to derive its own host allocator from this
85  * interface. Its template parameter T refers to an implementation that
86  * inherited from CachingHostAllocatorImpl.
87  *
88  * So this design can share the caching mechanism across each backend, and
89  * provide flexibility to each backend. A backend can choose to follow this
90  * implementation or reuse them by extending and overriding them as necessary.
91  * Taking CUDA as an example, it specializes runtime related functions to reuse
92  * the caching mechanism. Additionally, it extends the allocator's functionality
93  * by adding the allocWithCudaHostRegister function to support page-locking the
94  * memory range used by CUDA. Of course, you can also refer to
95  * XPUCachingHostAllocator, which is a host caching allocator supported on XPU
96  * backend, to implement a basic host caching allocator.
97  *
98  * Some of the invariants here are less strict than they could be - for example,
99  * we do not enforce that free(Block* block) => block->event_count == 0. This is
100  * for compatibility reasons, and we can explore enforcing these in subsequent
101  * versions.
102  *
103  * Note that this caching host allocator does not split larger allocations into
104  * smaller blocks, unlike the caching device allocator.
105  */
106 
107 template <
108     typename S,
109     typename E,
110     typename B = HostBlock<S>>
111 struct CachingHostAllocatorImpl {
112   virtual ~CachingHostAllocatorImpl() = default;
113 
114  public:
115   // return data_ptr and block pair.
allocateCachingHostAllocatorImpl116   virtual std::pair<void*, void*> allocate(size_t size) {
117     if (size == 0) {
118       return {nullptr, nullptr};
119     }
120 
121     process_events();
122 
123     // First, try to allocate from the free list
124     auto* block = get_free_block(size);
125     if (block) {
126       return {block->ptr_, reinterpret_cast<void*>(block)};
127     }
128 
129     // Round up the allocation to the nearest power of two to improve reuse.
130     // These power of two sizes are also used to index into the free list.
131     size_t roundSize = c10::llvm::PowerOf2Ceil(size);
132     void* ptr = nullptr;
133     allocate_host_memory(roundSize, &ptr);
134 
135     // Then, create a new block.
136     block = new B(roundSize, ptr);
137     block->allocated_ = true;
138 
139     add_allocated_block(block);
140     return {block->ptr_, reinterpret_cast<void*>(block)};
141   }
142 
freeCachingHostAllocatorImpl143   virtual void free(void* ctx) {
144     if (!ctx) {
145       return;
146     }
147 
148     // Note: we can assume that free is correctly paired with alloc, and thus we
149     // do not need to look up the ctx in blocks_.
150     auto* block = reinterpret_cast<B*>(ctx);
151 
152     std::optional<std::vector<E>> events;
153     {
154       std::lock_guard<std::mutex> g(block->mutex_);
155       block->allocated_ = false;
156       if (block->streams_.empty()) {
157         TORCH_INTERNAL_ASSERT(block->event_count_ == 0);
158       } else {
159         events = std::vector<E>();
160         events->reserve(block->streams_.size());
161         for (auto stream : block->streams_) {
162           record_stream(events, stream);
163         }
164         block->event_count_ += events->size();
165         block->streams_.clear();
166       }
167     }
168 
169     if (!events) {
170       auto index = size_index(block->size_);
171       std::lock_guard<std::mutex> g(free_list_[index].mutex_);
172       free_list_[index].list_.push_back(block);
173     } else {
174       // restore these events that record by used streams.
175       std::lock_guard<std::mutex> g(events_mutex_);
176       for (auto&& event : *events) {
177         events_.emplace_front(std::move(event), block);
178       }
179     }
180   }
181 
record_eventCachingHostAllocatorImpl182   virtual bool record_event(void* ptr, void* ctx, S stream) {
183     auto* block = reinterpret_cast<B*>(ctx);
184 
185     // Note: we need to check if the passed-in `ctx` is valid. This is because
186     // `record_event` (via `CachingHostAllocator_recordEvent`) can be invoked on
187     // an arbitrary tensor, and is not guaranteed to correspond to a pinned
188     // memory allocation. Therefore, we need to check that `ctx` is valid before
189     // proceeding.
190     {
191       std::lock_guard<std::mutex> g(blocks_mutex_);
192       if (blocks_.find(block) != blocks_.end()) {
193         // Now we know this object is safe to access.
194         std::lock_guard<std::mutex> gb(block->mutex_);
195         TORCH_INTERNAL_ASSERT(block->allocated_);
196         block->streams_.insert(stream);
197         return true;
198       }
199       auto it = ptr_to_block_.find(ptr);
200       if (it != ptr_to_block_.end()) {
201         block = it->second;
202         std::lock_guard<std::mutex> g(block->mutex_);
203         TORCH_INTERNAL_ASSERT(block->allocated_);
204         block->streams_.insert(stream);
205         return true;
206       }
207     }
208 
209     return false;
210   }
211 
empty_cacheCachingHostAllocatorImpl212   virtual void empty_cache() {
213     // Flush any available blocks into the free_list.
214     process_events();
215 
216     // Remove all elements from the free list, remove them from the blocks
217     // list, and free the associated pinned memory allocation. This requires
218     // concurrently holding both the free list mutexes and the blocks mutex, and
219     // is the only function that concurrently holds multiple mutexes.
220     for (size_t i = 0; i < free_list_.size(); ++i) {
221       std::lock(free_list_[i].mutex_, blocks_mutex_);
222       std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock);
223       std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock);
224 
225       std::vector<B*> blocks_to_remove(free_list_[i].list_.begin(), free_list_[i].list_.end());
226       free_list_[i].list_.clear();
227       for (auto* block : blocks_to_remove) {
228         blocks_.erase(block);
229         ptr_to_block_.erase(block->ptr_);
230         free_block(block);
231         delete block;
232       }
233     }
234   }
235 
size_indexCachingHostAllocatorImpl236   inline size_t size_index(size_t size) {
237     return c10::llvm::Log2_64_Ceil(size);
238   }
239 
copy_dataCachingHostAllocatorImpl240   virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const {
241     TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data");
242   }
243 
244  private:
add_allocated_blockCachingHostAllocatorImpl245   virtual void add_allocated_block(B* block) {
246     std::lock_guard<std::mutex> g(blocks_mutex_);
247     blocks_.insert(block);
248     ptr_to_block_.insert({block->ptr_, block});
249   }
250 
get_free_blockCachingHostAllocatorImpl251   virtual B* get_free_block(size_t size) {
252     auto index = size_index(size);
253     std::lock_guard<std::mutex> g(free_list_[index].mutex_);
254     if (free_list_[index].list_.size() > 0) {
255       B* block = free_list_[index].list_.back();
256       free_list_[index].list_.pop_back();
257       block->allocated_ = true;
258       return block;
259     }
260     return nullptr;
261   }
262 
process_eventsCachingHostAllocatorImpl263   virtual void process_events() {
264     while (true) {
265       // Avoid calling cudaEventDestroy while holding a mutex, so move
266       // intermediate events out of the lock into this object.
267       // process the last event
268       std::optional<std::pair<E, B*>> processed;
269       {
270         std::lock_guard<std::mutex> g(events_mutex_);
271         if (!events_.empty()) {
272           processed = std::move(events_.back());
273           events_.pop_back();
274         }
275       }
276 
277       if (!processed) {
278         return;
279       }
280 
281       // otherwise, query the event
282       {
283         // now, see if we can handle this element
284         auto& event = processed->first;
285         if (!query_event(event)) {
286           // push the event onto the back if it's not ready.
287           {
288             std::lock_guard<std::mutex> g(events_mutex_);
289             events_.push_back(std::move(*processed));
290           }
291           return;
292         }
293       }
294 
295       // Process the events.
296       TORCH_INTERNAL_ASSERT(processed);
297       auto* block = processed->second;
298       bool available = false;
299       {
300         std::lock_guard<std::mutex> g(block->mutex_);
301         TORCH_INTERNAL_ASSERT(!block->allocated_)
302         block->event_count_--;
303         if (block->event_count_ == 0) {
304           available = true;
305         }
306       }
307 
308       if (available) {
309         auto index = size_index(block->size_);
310         std::lock_guard<std::mutex> g(free_list_[index].mutex_);
311         free_list_[index].list_.push_back(block);
312       }
313     }
314   }
315 
316   /* These following functions are runtime-related. */
317 
318   // Allocate page-locked memory on the host.
allocate_host_memoryCachingHostAllocatorImpl319   virtual void allocate_host_memory(size_t size, void** ptr) {
320     TORCH_CHECK_NOT_IMPLEMENTED(
321         false, "Not implemented for allocate_host_memory");
322   }
323 
324   // Free block and release the pointer contained in block.
free_blockCachingHostAllocatorImpl325   virtual void free_block(B* block) {
326     TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block");
327   }
328 
329   // Record an event on stream and store event into events.
record_streamCachingHostAllocatorImpl330   virtual void record_stream(std::optional<std::vector<E>>& events, S stream) {
331     TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream");
332   }
333 
334   // Query event if it is completed.
query_eventCachingHostAllocatorImpl335   virtual bool query_event(E& event) {
336     TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event");
337   }
338 
339   alignas(64) std::mutex blocks_mutex_;
340   ska::flat_hash_set<B*> blocks_; // block list
341   ska::flat_hash_map<void*, B*> ptr_to_block_;
342 
343   // We keep free list as a vector of free lists, one for each power of two
344   // size. This allows us to quickly find a free block of the right size.
345   // We use deque to store per size free list and guard the list with its own
346   // mutex.
347   alignas(64) std::vector<FreeBlockList<B>> free_list_ = std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX);
348 
349   alignas(64) std::mutex events_mutex_;
350   std::deque<std::pair<E, B*>> events_; // event queue paired with block
351 };
352 
353 template <typename T>
354 struct CachingHostAllocatorInterface : public at::Allocator {
CachingHostAllocatorInterfaceCachingHostAllocatorInterface355   CachingHostAllocatorInterface() : impl_(std::make_unique<T>()) {}
356 
allocateCachingHostAllocatorInterface357   at::DataPtr allocate(size_t size) override {
358     TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for allocate");
359   }
360 
freeCachingHostAllocatorInterface361   void free(void* ctx) {
362     impl_->free(ctx);
363   }
364 
365   template <typename S>
record_eventCachingHostAllocatorInterface366   bool record_event(void* ptr, void* ctx, S stream) {
367     return impl_->record_event(ptr, ctx, stream);
368   }
369 
empty_cacheCachingHostAllocatorInterface370   void empty_cache() {
371     impl_->empty_cache();
372   }
373 
copy_dataCachingHostAllocatorInterface374   void copy_data(void* dest, const void* src, std::size_t count)
375       const override {
376     impl_->copy_data(dest, src, count);
377   }
378 
379   std::unique_ptr<T> impl_;
380 };
381 
382 } // namespace at
383 C10_DIAGNOSTIC_POP()
384