1 #include <c10/core/Allocator.h> 2 #include <c10/util/flat_hash_map.h> 3 #include <c10/util/llvmMathExtras.h> 4 #include <optional> 5 6 #include <deque> 7 #include <mutex> 8 #include <set> 9 10 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") 11 namespace at { 12 13 /** 14 * HostBlock is typically a fundamental memory block used in pinned memory. It 15 * is likely related to Event and Stream of device runtime. It is probably a 16 * base struct or interface that can be inherited and extended by each backend. 17 */ 18 template <typename S> 19 struct HostBlock { 20 // constructor for search key HostBlockHostBlock21 HostBlock(size_t size) : size_(size) {} 22 HostBlockHostBlock23 HostBlock(size_t size, void* ptr) : size_(size), ptr_(ptr) {} 24 25 std::mutex mutex_; 26 size_t size_{0}; // block size in bytes 27 void* ptr_{nullptr}; // memory address 28 bool allocated_{false}; // in-use flag 29 size_t event_count_{0}; // number of related events 30 ska::flat_hash_set<S> streams_; // streams on which the block was used 31 }; 32 33 template <typename B> 34 struct alignas(64) FreeBlockList { 35 std::mutex mutex_; 36 std::deque<B*> list_; 37 }; 38 39 namespace { 40 // Max cached block sizes: (1 << MAX_SIZE_INDEX) bytes 41 constexpr size_t MAX_SIZE_INDEX = 64; 42 } 43 44 /** 45 * Note [HostAllocator design] 46 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 47 * We have three key data structures - the free list which stores blocks that 48 * are not currently used, the block list which stores all blocks that have been 49 * allocated, and the event queue which stores runtime events and their 50 * corresponding blocks. 51 * 52 * Each of these are protected by a separate mutex. The key design principles 53 * are to 1) only hold each mutex for the minimal amount of time possible, 2) 54 * never do any possible expensive operations (such as CUDA runtime API calls) 55 * while holding the lock. 56 * 57 * There are four public methods: allocate, free, record_event and empty_cache. 58 * 1) In the allocate path, we first check to see if we can service our 59 * request from this free list, and otherwise we create a new block with 60 * allocate_host_memory. 61 * 2) In the free path, we insert events (if required) into the event queue, 62 * and if possible insert our block back into the free list. In allocate, we 63 * first eagerly query events until we find one that is not ready, and insert 64 * the corresponding block onto the free list if all the events recorded for a 65 * block are ready. 66 * 3) In the record_event path, we simply insert the given stream into the set 67 * of streams tracked by the specified block. This set of streams is then 68 * consumed in the free path. 69 * 4) In the empty_cache path, we flush any available blocks into the free 70 * list. Remove all element of free list, then remove them from block list and 71 * release the associated pinned memory allocation via free_block. 72 * 73 * We generalize the caching host allocator into two parts: interface and 74 * implementation. For any new backend looking to integrate with host allocator 75 * and reuse caching mechanism, these two parts are necessary to be specialized. 76 * 77 * For the implementation, we provide a CachingHostAllocatorImpl struct 78 * to abstract the caching mechanism. Any backend needs to provide a customized 79 * implementation by specializing its own public functions and the related 80 * runtime functions. Its template parameter S represents runtime Stream, E 81 * denotes runtime Event, B indicates the fundamental memory block. 82 * 83 * For the interface, we provide a CachingHostAllocatorInterface struct as an 84 * interface. Any backend needs to derive its own host allocator from this 85 * interface. Its template parameter T refers to an implementation that 86 * inherited from CachingHostAllocatorImpl. 87 * 88 * So this design can share the caching mechanism across each backend, and 89 * provide flexibility to each backend. A backend can choose to follow this 90 * implementation or reuse them by extending and overriding them as necessary. 91 * Taking CUDA as an example, it specializes runtime related functions to reuse 92 * the caching mechanism. Additionally, it extends the allocator's functionality 93 * by adding the allocWithCudaHostRegister function to support page-locking the 94 * memory range used by CUDA. Of course, you can also refer to 95 * XPUCachingHostAllocator, which is a host caching allocator supported on XPU 96 * backend, to implement a basic host caching allocator. 97 * 98 * Some of the invariants here are less strict than they could be - for example, 99 * we do not enforce that free(Block* block) => block->event_count == 0. This is 100 * for compatibility reasons, and we can explore enforcing these in subsequent 101 * versions. 102 * 103 * Note that this caching host allocator does not split larger allocations into 104 * smaller blocks, unlike the caching device allocator. 105 */ 106 107 template < 108 typename S, 109 typename E, 110 typename B = HostBlock<S>> 111 struct CachingHostAllocatorImpl { 112 virtual ~CachingHostAllocatorImpl() = default; 113 114 public: 115 // return data_ptr and block pair. allocateCachingHostAllocatorImpl116 virtual std::pair<void*, void*> allocate(size_t size) { 117 if (size == 0) { 118 return {nullptr, nullptr}; 119 } 120 121 process_events(); 122 123 // First, try to allocate from the free list 124 auto* block = get_free_block(size); 125 if (block) { 126 return {block->ptr_, reinterpret_cast<void*>(block)}; 127 } 128 129 // Round up the allocation to the nearest power of two to improve reuse. 130 // These power of two sizes are also used to index into the free list. 131 size_t roundSize = c10::llvm::PowerOf2Ceil(size); 132 void* ptr = nullptr; 133 allocate_host_memory(roundSize, &ptr); 134 135 // Then, create a new block. 136 block = new B(roundSize, ptr); 137 block->allocated_ = true; 138 139 add_allocated_block(block); 140 return {block->ptr_, reinterpret_cast<void*>(block)}; 141 } 142 freeCachingHostAllocatorImpl143 virtual void free(void* ctx) { 144 if (!ctx) { 145 return; 146 } 147 148 // Note: we can assume that free is correctly paired with alloc, and thus we 149 // do not need to look up the ctx in blocks_. 150 auto* block = reinterpret_cast<B*>(ctx); 151 152 std::optional<std::vector<E>> events; 153 { 154 std::lock_guard<std::mutex> g(block->mutex_); 155 block->allocated_ = false; 156 if (block->streams_.empty()) { 157 TORCH_INTERNAL_ASSERT(block->event_count_ == 0); 158 } else { 159 events = std::vector<E>(); 160 events->reserve(block->streams_.size()); 161 for (auto stream : block->streams_) { 162 record_stream(events, stream); 163 } 164 block->event_count_ += events->size(); 165 block->streams_.clear(); 166 } 167 } 168 169 if (!events) { 170 auto index = size_index(block->size_); 171 std::lock_guard<std::mutex> g(free_list_[index].mutex_); 172 free_list_[index].list_.push_back(block); 173 } else { 174 // restore these events that record by used streams. 175 std::lock_guard<std::mutex> g(events_mutex_); 176 for (auto&& event : *events) { 177 events_.emplace_front(std::move(event), block); 178 } 179 } 180 } 181 record_eventCachingHostAllocatorImpl182 virtual bool record_event(void* ptr, void* ctx, S stream) { 183 auto* block = reinterpret_cast<B*>(ctx); 184 185 // Note: we need to check if the passed-in `ctx` is valid. This is because 186 // `record_event` (via `CachingHostAllocator_recordEvent`) can be invoked on 187 // an arbitrary tensor, and is not guaranteed to correspond to a pinned 188 // memory allocation. Therefore, we need to check that `ctx` is valid before 189 // proceeding. 190 { 191 std::lock_guard<std::mutex> g(blocks_mutex_); 192 if (blocks_.find(block) != blocks_.end()) { 193 // Now we know this object is safe to access. 194 std::lock_guard<std::mutex> gb(block->mutex_); 195 TORCH_INTERNAL_ASSERT(block->allocated_); 196 block->streams_.insert(stream); 197 return true; 198 } 199 auto it = ptr_to_block_.find(ptr); 200 if (it != ptr_to_block_.end()) { 201 block = it->second; 202 std::lock_guard<std::mutex> g(block->mutex_); 203 TORCH_INTERNAL_ASSERT(block->allocated_); 204 block->streams_.insert(stream); 205 return true; 206 } 207 } 208 209 return false; 210 } 211 empty_cacheCachingHostAllocatorImpl212 virtual void empty_cache() { 213 // Flush any available blocks into the free_list. 214 process_events(); 215 216 // Remove all elements from the free list, remove them from the blocks 217 // list, and free the associated pinned memory allocation. This requires 218 // concurrently holding both the free list mutexes and the blocks mutex, and 219 // is the only function that concurrently holds multiple mutexes. 220 for (size_t i = 0; i < free_list_.size(); ++i) { 221 std::lock(free_list_[i].mutex_, blocks_mutex_); 222 std::lock_guard<std::mutex> gf(free_list_[i].mutex_, std::adopt_lock); 223 std::lock_guard<std::mutex> gb(blocks_mutex_, std::adopt_lock); 224 225 std::vector<B*> blocks_to_remove(free_list_[i].list_.begin(), free_list_[i].list_.end()); 226 free_list_[i].list_.clear(); 227 for (auto* block : blocks_to_remove) { 228 blocks_.erase(block); 229 ptr_to_block_.erase(block->ptr_); 230 free_block(block); 231 delete block; 232 } 233 } 234 } 235 size_indexCachingHostAllocatorImpl236 inline size_t size_index(size_t size) { 237 return c10::llvm::Log2_64_Ceil(size); 238 } 239 copy_dataCachingHostAllocatorImpl240 virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const { 241 TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data"); 242 } 243 244 private: add_allocated_blockCachingHostAllocatorImpl245 virtual void add_allocated_block(B* block) { 246 std::lock_guard<std::mutex> g(blocks_mutex_); 247 blocks_.insert(block); 248 ptr_to_block_.insert({block->ptr_, block}); 249 } 250 get_free_blockCachingHostAllocatorImpl251 virtual B* get_free_block(size_t size) { 252 auto index = size_index(size); 253 std::lock_guard<std::mutex> g(free_list_[index].mutex_); 254 if (free_list_[index].list_.size() > 0) { 255 B* block = free_list_[index].list_.back(); 256 free_list_[index].list_.pop_back(); 257 block->allocated_ = true; 258 return block; 259 } 260 return nullptr; 261 } 262 process_eventsCachingHostAllocatorImpl263 virtual void process_events() { 264 while (true) { 265 // Avoid calling cudaEventDestroy while holding a mutex, so move 266 // intermediate events out of the lock into this object. 267 // process the last event 268 std::optional<std::pair<E, B*>> processed; 269 { 270 std::lock_guard<std::mutex> g(events_mutex_); 271 if (!events_.empty()) { 272 processed = std::move(events_.back()); 273 events_.pop_back(); 274 } 275 } 276 277 if (!processed) { 278 return; 279 } 280 281 // otherwise, query the event 282 { 283 // now, see if we can handle this element 284 auto& event = processed->first; 285 if (!query_event(event)) { 286 // push the event onto the back if it's not ready. 287 { 288 std::lock_guard<std::mutex> g(events_mutex_); 289 events_.push_back(std::move(*processed)); 290 } 291 return; 292 } 293 } 294 295 // Process the events. 296 TORCH_INTERNAL_ASSERT(processed); 297 auto* block = processed->second; 298 bool available = false; 299 { 300 std::lock_guard<std::mutex> g(block->mutex_); 301 TORCH_INTERNAL_ASSERT(!block->allocated_) 302 block->event_count_--; 303 if (block->event_count_ == 0) { 304 available = true; 305 } 306 } 307 308 if (available) { 309 auto index = size_index(block->size_); 310 std::lock_guard<std::mutex> g(free_list_[index].mutex_); 311 free_list_[index].list_.push_back(block); 312 } 313 } 314 } 315 316 /* These following functions are runtime-related. */ 317 318 // Allocate page-locked memory on the host. allocate_host_memoryCachingHostAllocatorImpl319 virtual void allocate_host_memory(size_t size, void** ptr) { 320 TORCH_CHECK_NOT_IMPLEMENTED( 321 false, "Not implemented for allocate_host_memory"); 322 } 323 324 // Free block and release the pointer contained in block. free_blockCachingHostAllocatorImpl325 virtual void free_block(B* block) { 326 TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block"); 327 } 328 329 // Record an event on stream and store event into events. record_streamCachingHostAllocatorImpl330 virtual void record_stream(std::optional<std::vector<E>>& events, S stream) { 331 TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream"); 332 } 333 334 // Query event if it is completed. query_eventCachingHostAllocatorImpl335 virtual bool query_event(E& event) { 336 TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); 337 } 338 339 alignas(64) std::mutex blocks_mutex_; 340 ska::flat_hash_set<B*> blocks_; // block list 341 ska::flat_hash_map<void*, B*> ptr_to_block_; 342 343 // We keep free list as a vector of free lists, one for each power of two 344 // size. This allows us to quickly find a free block of the right size. 345 // We use deque to store per size free list and guard the list with its own 346 // mutex. 347 alignas(64) std::vector<FreeBlockList<B>> free_list_ = std::vector<FreeBlockList<B>>(MAX_SIZE_INDEX); 348 349 alignas(64) std::mutex events_mutex_; 350 std::deque<std::pair<E, B*>> events_; // event queue paired with block 351 }; 352 353 template <typename T> 354 struct CachingHostAllocatorInterface : public at::Allocator { CachingHostAllocatorInterfaceCachingHostAllocatorInterface355 CachingHostAllocatorInterface() : impl_(std::make_unique<T>()) {} 356 allocateCachingHostAllocatorInterface357 at::DataPtr allocate(size_t size) override { 358 TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for allocate"); 359 } 360 freeCachingHostAllocatorInterface361 void free(void* ctx) { 362 impl_->free(ctx); 363 } 364 365 template <typename S> record_eventCachingHostAllocatorInterface366 bool record_event(void* ptr, void* ctx, S stream) { 367 return impl_->record_event(ptr, ctx, stream); 368 } 369 empty_cacheCachingHostAllocatorInterface370 void empty_cache() { 371 impl_->empty_cache(); 372 } 373 copy_dataCachingHostAllocatorInterface374 void copy_data(void* dest, const void* src, std::size_t count) 375 const override { 376 impl_->copy_data(dest, src, count); 377 } 378 379 std::unique_ptr<T> impl_; 380 }; 381 382 } // namespace at 383 C10_DIAGNOSTIC_POP() 384