1// Copyright © 2022 Apple Inc. 2 3#include <ATen/CPUFunctions.h> 4#include <ATen/EmptyTensor.h> 5#include <ATen/mps/MPSAllocator.h> 6#include <c10/core/Allocator.h> 7#include <c10/core/Storage.h> 8 9#include <iostream> 10 11namespace at::mps { 12 13C10_DEFINE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback); 14 15namespace HeapAllocator { 16 17uint64_t BufferBlock::buffer_counter = 0; 18uint64_t HeapBlock::heap_counter = 0; 19 20void MPSHeapAllocatorImpl::init_allocator() { 21 init_buffer_pools(); 22 23 // debug verbosity flags (see DebugVerbosity enum) 24 static const char* verbosity_str = getenv("PYTORCH_DEBUG_MPS_ALLOCATOR"); 25 m_debug_verbosity = verbosity_str ? strtol(verbosity_str, nullptr, 0) : DebugVerbosity::SILENT; 26 27 static const char* high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO"); 28 const double high_watermark_ratio = 29 high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) : default_high_watermark_ratio; 30 setHighWatermarkRatio(high_watermark_ratio); 31 32 const double default_low_watermark_ratio = 33 m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified : default_low_watermark_ratio_discrete; 34 static const char* low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO"); 35 const double low_watermark_ratio = 36 low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio; 37 setLowWatermarkRatio(low_watermark_ratio); 38} 39 40void MPSHeapAllocatorImpl::init_buffer_pools() { 41 // using a container for pools to simplify iterating over them 42 // Pool of large buffers with private storage mode 43 m_pools.emplace(BufferPool::Kind::PRIVATE_LARGE, 44 std::make_unique<BufferPool>(m_device, UsageFlags::PRIVATE | UsageFlags::HAZARD)); 45 // Pool of large buffers with shared storage mode 46 m_pools.emplace(BufferPool::Kind::SHARED_LARGE, 47 std::make_unique<BufferPool>(m_device, UsageFlags::SHARED | UsageFlags::HAZARD)); 48 // Pool of small buffers with private storage mode 49 m_pools.emplace(BufferPool::Kind::PRIVATE_SMALL, 50 std::make_unique<BufferPool>(m_device, UsageFlags::SMALL | UsageFlags::PRIVATE | UsageFlags::HAZARD)); 51 // Pool of small buffers with shared storage mode 52 m_pools.emplace(BufferPool::Kind::SHARED_SMALL, 53 std::make_unique<BufferPool>(m_device, UsageFlags::SMALL | UsageFlags::SHARED | UsageFlags::HAZARD)); 54 // Pool of small buffers with shared storage mode used to allocate and copy Scalars 55 // from CPU to Metal buffers (see allocScalarBufferWithValue()). 56 // no Hazard Tracking required for the Scalar pool (synchronized manually). 57 m_pools.emplace(BufferPool::Kind::SCALAR, 58 std::make_unique<BufferPool>(m_device, UsageFlags::SMALL | UsageFlags::SHARED | UsageFlags::SCALAR)); 59} 60 61BufferPool& MPSHeapAllocatorImpl::get_pool(size_t requested_size, size_t aligned_size, uint32_t usage) { 62 BufferPool::Kind poolKind; 63 64 if (usage & UsageFlags::SCALAR) { 65 poolKind = BufferPool::Kind::SCALAR; 66 } else if (requested_size <= kMaxScalarAlloc && m_device.hasUnifiedMemory) { 67 poolKind = BufferPool::Kind::SHARED_SMALL; 68 } else if (aligned_size <= kMaxSmallAlloc) { 69 poolKind = (usage & UsageFlags::SHARED) ? BufferPool::Kind::SHARED_SMALL : BufferPool::Kind::PRIVATE_SMALL; 70 } else { 71 poolKind = (usage & UsageFlags::SHARED) ? BufferPool::Kind::SHARED_LARGE : BufferPool::Kind::PRIVATE_LARGE; 72 } 73 return *m_pools[poolKind]; 74} 75 76size_t MPSHeapAllocatorImpl::get_allocation_size(size_t size, uint32_t usage) const { 77 MTLSizeAndAlign sizeAlign = [m_device heapBufferSizeAndAlignWithLength:size options:HeapBlock::getOptions(usage)]; 78 return BufferBlock::alignUp(sizeAlign.size, sizeAlign.align); 79} 80 81void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio) { 82 TORCH_CHECK(ratio >= 0.0 && ratio <= default_high_watermark_upper_bound, "invalid high watermark ratio ", ratio); 83 m_max_total_allowed_size = 84 (ratio == 0.0) ? std::numeric_limits<size_t>::max() : static_cast<size_t>(ratio * (double)max_device_size()); 85 if (m_debug_verbosity & DebugVerbosity::PROFILING) { 86 std::cerr << "\nHigh watermark memory allocation limit: " 87 << (ratio == 0.0 ? "unlimited" : format_size(m_max_total_allowed_size)) << "\n"; 88 } 89 m_high_watermark_ratio = ratio; 90} 91 92void MPSHeapAllocatorImpl::setLowWatermarkRatio(double ratio) { 93 // used for comparison with lower_watermark_ratio 94 const double high_watermark_limit = 95 m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio; 96 TORCH_CHECK(ratio >= 0.0 && ratio <= high_watermark_limit, "invalid low watermark ratio ", ratio); 97 // we use this to detect if there's memory pressure 98 m_low_watermark_limit = 99 (ratio == 0.0) ? std::numeric_limits<size_t>::max() : static_cast<size_t>(ratio * (double)max_device_size()); 100 if (m_debug_verbosity & DebugVerbosity::PROFILING) { 101 std::cerr << "Low watermark memory allocation limit: " 102 << (ratio == 0.0 ? "unlimited" : format_size(m_low_watermark_limit)) << "\n"; 103 } 104 m_low_watermark_ratio = ratio; 105} 106 107HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& params) { 108 BufferPool& pool = *params.pool; 109 HeapBlock* heap_block = nullptr; 110 HeapBlock search_key(params.size()); 111 112 auto it = pool.heaps.lower_bound(&search_key); 113 if (it == pool.heaps.end()) { 114 heap_block = HeapBlock::createHeapBlock(params, pool.device, pool.usage); 115 if (heap_block) { 116 m_total_allocated_memory += heap_block->size.total; 117 if (m_debug_verbosity & DebugVerbosity::ALLOCATIONS) { 118 std::cerr << "\nAllocated " << ((pool.usage & UsageFlags::SHARED) ? "shared" : "private") << " heap #" 119 << heap_block->heap_id << " of size " << format_size(heap_block->size.total) 120 << " (#heaps: " << (pool.heaps.size() + 1) 121 << ", current allocated: " << format_size(current_allocated_size()) << ")\n"; 122 } 123 } 124 } else { 125 heap_block = *it; 126 // remove and re-insert heap in the set later after a buffer is created. 127 // this ensures updating the order of heaps based on their new available sizes 128 pool.heaps.erase(it); 129 } 130 return heap_block; 131} 132 133bool MPSHeapAllocatorImpl::alloc_buffer(AllocParams& params) { 134 if (m_max_total_allowed_size != std::numeric_limits<size_t>::max() && 135 current_allocated_size() + params.size() > m_max_total_allowed_size) { 136 return false; 137 } 138 HeapBlock* heap = get_free_heap(params); 139 if (!heap) { 140 return false; // this will cause releasing pool buffers to free up memory 141 } 142 BufferPool& pool = *params.pool; 143 144 id<MTLBuffer> buffer = heap->newMTLBuffer(params.size(), pool.usage); 145 // this should never happen as the backing memory (i.e., heap) was allocated successfully. 146 TORCH_INTERNAL_ASSERT(buffer); 147 // insert heap after a buffer was created on it to update the order of heap's set 148 pool.heaps.insert(heap); 149 params.buffer_block = new BufferBlock(params.size(), params.requested_size, buffer, heap); 150 m_allocated_buffers[params.buffer_block->buffer] = params.buffer_block; 151 pool.allocated_size += params.size(); 152 pool.n_buffers++; 153 154 if ((m_debug_verbosity & DebugVerbosity::ALLOCATIONS) && 155 (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { 156 std::cerr << "Allocated " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private") 157 << ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #" 158 << params.buffer_block->buf_id << " of size " << format_size(params.size()) << " at " 159 << params.buffer_block->buffer << " from heap #" << heap->heap_id 160 << " (requested: " << format_size(params.requested_size) 161 << ", heap: " << format_size(heap->size.available) << ", total: " << format_size(m_total_allocated_memory) 162 << ")\n"; 163 } 164 return true; 165} 166 167bool MPSHeapAllocatorImpl::get_free_buffer(AllocParams& params) { 168 // this helps to monitor "implicit" allocations from MPS backend and to prevent OOM and system failure. 169 if (m_high_watermark_ratio > 0.0 && current_allocated_size() + params.size() > m_max_total_allowed_size) { 170 return false; 171 } 172 BufferPool& pool = *params.pool; 173 // track buffer reuse intervals only on large pool when low watermark limit is enabled. 174 if (m_low_watermark_ratio > 0.0 && !(pool.usage & UsageFlags::SMALL)) { 175 for (auto& b : pool.available_buffers) { 176 ++b->gc_count; 177 } 178 } 179 auto it = pool.available_buffers.lower_bound(¶ms.search_key); 180 if (it != pool.available_buffers.end()) { 181 BufferBlock* buffer_block = *it; 182 183 // the logic in here is simple: keep reusing existing heaps capacity as long as possible (by splitting 184 // or releasing oversize buffers, if required), and avoid 'new' heap allocations as much as possible. 185 if (buffer_block->size <= params.size() + kLargeHeap) { 186 // return the existing buffer if it already fits the requested size (i.e., not oversize) 187 params.buffer_block = buffer_block; 188 } else { 189 HeapBlock search_key(params.size()); 190 // if there's an 'existing' heap with enough capacity, then don't 191 // return the oversize buffer and sub-allocate from that existing heap. 192 if (pool.heaps.lower_bound(&search_key) != pool.heaps.end()) { 193 params.buffer_block = nullptr; 194 } else if (buffer_block->retainCount() <= 1) { 195 // otherwise if buffer is releasable immediately, we make room by releasing the 196 // buffer and reuse the new space within its heap container for the new smaller buffer allocation 197 release_buffer(buffer_block, false); 198 // this will skip unnecessary garbage collection as we'll reuse the newly released space 199 params.has_memory_pressure = false; 200 } else if (params.has_memory_pressure) { 201 // the oversized buffer is busy and not reusable at the moment. So release it (and potentially its heap 202 // container) in allocator, and ARC will later free up its backing memory when the busy command buffer finishes. 203 release_buffer(buffer_block, true); 204 } else { 205 // only if there's no memory pressure, we'll reuse the oversized buffer 206 params.buffer_block = buffer_block; 207 } 208 } 209 } 210 211 if (!params.buffer_block) { 212 return false; // this will make allocator to allocate a new buffer 213 } 214 pool.available_buffers.erase(params.buffer_block); 215 params.buffer_block->requested_size = params.requested_size; 216 params.buffer_block->gc_count = 0; 217 pool.available_size -= params.buffer_block->size; 218 219 if ((m_debug_verbosity & DebugVerbosity::RECYCLES) && 220 (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { 221 std::cerr << "Reusing " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private") 222 << ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #" 223 << params.buffer_block->buf_id << " of size " << format_size(params.buffer_block->size) << " at " 224 << params.buffer_block->buffer << " (requested: " << format_size(params.requested_size) 225 << ", use#: " << params.buffer_block->use_count + 1 << ", retain#: " << params.buffer_block->retainCount() 226 << ")\n"; 227 } 228 return true; 229} 230 231BufferBlock* MPSHeapAllocatorImpl::alloc_buffer_block(size_t size, uint32_t usage) { 232 TORCH_CHECK(size < m_max_buffer_size, "Invalid buffer size: ", format_size(size)); 233 234 size_t alloc_size = get_allocation_size(size, usage); 235 auto& pool = get_pool(size, alloc_size, usage); 236 AllocParams params(alloc_size, size, &pool); 237 // we care about memory pressure if only we're allocating large buffers when the 238 // low watermark limit has been reached 239 params.has_memory_pressure = !(pool.usage & UsageFlags::SMALL) && getLowWatermarkValue() <= 0; 240 params.has_unified_memory = m_device.hasUnifiedMemory; 241 242 // first, try to get a block from the existing pool. 243 bool block_found = get_free_buffer(params); 244 if (!block_found) { 245 // do garbage collection if memory pressure is high and there's enough memory in pool 246 if (params.has_memory_pressure && alloc_size < pool.available_size) { 247 garbage_collect_cached_buffers(params); 248 } 249 250 block_found = 251 // Attempt allocate 252 alloc_buffer(params) || 253 // Callbacks might release more memory (eg. by forcing a GC in the host language) thus 254 // we can retry getting a free buffer in the pool, before trying to alloc again. 255 (trigger_memory_callbacks(nullptr, IMpsAllocatorCallback::EventType::ALLOCATION_FAILED) && 256 get_free_buffer(params)) || 257 // Free enough available cached blocks to satisfy alloc and retry alloc. 258 (release_available_cached_buffers(params) && alloc_buffer(params)) || 259 // Free all cached buffers and retry alloc. 260 (release_cached_buffers() && alloc_buffer(params)); 261 } 262 263 BufferBlock* buffer_block = params.buffer_block; 264 265 // the OOM could be triggered if: 266 // 1- the High Watermark limit has been reached (if enabled) 267 // 2- ran out of device memory, or the memory fragmentation is so high that a contiguous 268 // chunk of requested size couldn't be found. 269 if (!block_found || !buffer_block) { 270 if (m_high_watermark_ratio > 0.0) { 271 TORCH_CHECK( 272 false, 273 "MPS backend out of memory (MPS allocated: ", 274 format_size(m_total_allocated_memory), 275 ", other allocations: ", 276 format_size(current_allocated_size() - m_total_allocated_memory), 277 ", max allowed: ", 278 format_size(m_max_total_allowed_size), 279 "). Tried to allocate ", 280 format_size(alloc_size), 281 " on ", 282 ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"), 283 " pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure)."); 284 } else { 285 TORCH_CHECK(false, 286 "MPS backend out of memory (MPS allocated: ", 287 format_size(m_total_allocated_memory), 288 ", other allocations: ", 289 format_size(current_allocated_size() - m_total_allocated_memory), 290 "). Tried to allocate ", 291 format_size(alloc_size), 292 " on ", 293 ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"), 294 " pool."); 295 } 296 } 297 buffer_block->in_use = true; 298 buffer_block->use_count++; 299 m_current_allocated_memory += buffer_block->size; 300 301 return buffer_block; 302} 303 304void MPSHeapAllocatorImpl::free_buffer(BufferBlock* buffer_block) { 305 TORCH_INTERNAL_ASSERT(buffer_block->in_use); 306 307 BufferPool& pool = *buffer_block->heap->pool; 308 // Makes sure the BufferBlock* isn't already present in the pool we're freeing it back into. 309 TORCH_INTERNAL_ASSERT(pool.available_buffers.insert(buffer_block).second); 310 pool.available_size += buffer_block->size; 311 buffer_block->shape.clear(); // reset shape 312 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_current_allocated_memory >= buffer_block->size); 313 m_current_allocated_memory -= buffer_block->size; 314 if (buffer_block->event) { 315 // returns the MPSEvent back to MPSEventPool 316 buffer_block->event.reset(nullptr); 317 } 318 buffer_block->in_use = false; 319} 320 321BufferBlock* MPSHeapAllocatorImpl::get_allocated_buffer_block(const void* ptr) { 322 auto it = m_allocated_buffers.find(ptr); 323 if (it == m_allocated_buffers.end()) { 324 return nullptr; 325 } 326 return it->second; 327} 328 329bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove_empty_heap) { 330 HeapBlock* heap_block = buffer_block->heap; 331 BufferPool& pool = *heap_block->pool; 332 pool.allocated_size -= buffer_block->size; 333 pool.available_size -= buffer_block->size; 334 m_allocated_buffers.erase(buffer_block->buffer); 335 pool.available_buffers.erase(buffer_block); 336 pool.n_buffers--; 337 // will re-insert later to keep the heaps list sorted based on heap's new available size (if heap not empty) 338 pool.heaps.erase(heap_block); 339 uint32_t retainCount = heap_block->releaseMTLBuffer(buffer_block->buffer); 340 341 if ((m_debug_verbosity & DebugVerbosity::RELEASES) && 342 (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { 343 std::cerr << "Released buffer #" << buffer_block->buf_id << " of size " << format_size(buffer_block->size) 344 << " from heap #" << heap_block->heap_id << " (heap size: " << format_size(heap_block->size.available) 345 << ", use#: " << buffer_block->use_count << ", retain#: " << retainCount 346 << ", gc#: " << buffer_block->gc_count << ")\n"; 347 } 348 delete buffer_block; 349 350 if (remove_empty_heap && heap_block->n_buffers == 0) { 351 pool.heaps_pending_update.erase(heap_block); 352 m_total_allocated_memory -= heap_block->size.total; 353 retainCount = heap_block->releaseMTLHeap(); 354 if (m_debug_verbosity & DebugVerbosity::RELEASES) { 355 std::cerr << "Released heap #" << heap_block->heap_id << " of size " << format_size(heap_block->size.total) 356 << " (current allocated: " << format_size(current_allocated_size()) << ", retain#: " << retainCount 357 << ")\n"; 358 } 359 delete heap_block; 360 return true; 361 } else { 362 pool.heaps.insert(heap_block); 363 // if heap wasn't released and its released buffer is still busy in command buffer, the available 364 // size of the heap cannot be updated and we should defer updating until command buffer finishes. 365 if (retainCount > 1) { 366 pool.heaps_pending_update.insert(heap_block); 367 m_mutex.unlock(); 368 m_stream->addCompletedHandler(^(id<MTLCommandBuffer>) { 369 std::lock_guard<std::recursive_mutex> lock(m_mutex); 370 // check if the heap block still exists 371 if (pool.heaps_pending_update.find(heap_block) != pool.heaps_pending_update.end()) { 372 pool.heaps_pending_update.erase(heap_block); 373 pool.heaps.erase(heap_block); 374 heap_block->updateAvailableSize(); 375 pool.heaps.insert(heap_block); 376 } 377 }); 378 m_mutex.lock(); 379 } 380 } 381 return false; 382} 383 384void MPSHeapAllocatorImpl::release_buffers(BufferPool& pool) { 385 if (pool.available_buffers.empty()) { 386 return; 387 } 388 if ((m_debug_verbosity & DebugVerbosity::RELEASES)) { 389 std::cerr << "Releasing " << pool.available_buffers.size() << " buffers from " 390 << ((pool.usage & UsageFlags::SMALL) ? "small " : "large ") 391 << ((pool.usage & UsageFlags::SHARED) ? "shared" : "private") 392 << ((pool.usage & UsageFlags::SCALAR) ? " scalar" : "") 393 << " pool (total size: " << format_size(pool.allocated_size) << ", #buffers: " << pool.n_buffers << ")\n"; 394 } 395 auto it = pool.available_buffers.begin(); 396 while (it != pool.available_buffers.end()) { 397 BufferBlock* buffer_block = *it; 398 ++it; 399 release_buffer(buffer_block); 400 } 401} 402 403bool MPSHeapAllocatorImpl::release_available_cached_buffers(AllocParams& params) { 404 BufferPool& pool = *params.pool; 405 406 if (pool.available_buffers.empty()) { 407 return false; 408 } 409 auto it = pool.available_buffers.lower_bound(¶ms.search_key); 410 if (it == pool.available_buffers.end()) { 411 size_t totalReleased = 0; 412 --it; 413 while (totalReleased < params.search_key.size) { 414 auto cur = it; 415 totalReleased += (*it)->size; 416 if (it != pool.available_buffers.begin()) { 417 --it; 418 release_buffer(*cur); 419 } else { 420 release_buffer(*cur); 421 break; 422 } 423 } 424 if (totalReleased < params.search_key.size) { 425 return false; 426 } 427 } else { 428 release_buffer(*it); 429 } 430 return true; 431} 432 433bool MPSHeapAllocatorImpl::release_cached_buffers() { 434 if (m_debug_verbosity >= DebugVerbosity::PROFILING) { 435 std::cerr << "Attempting to release cached buffers (MPS allocated: " << format_size(m_total_allocated_memory) 436 << ", other allocations: " << format_size(current_allocated_size() - m_total_allocated_memory) << ")\n"; 437 } 438 // before releasing the buffers make sure the command buffer has finished. 439 // we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers. 440 m_mutex.unlock(); 441 auto stream = getDefaultMPSStream(); 442 dispatch_sync(stream->queue(), ^() { 443 stream->synchronize(SyncType::COMMIT_AND_WAIT); 444 }); 445 m_mutex.lock(); 446 // Free all cached blocks to system allocator 447 for (const auto& poolIt : m_pools) { 448 BufferPool& pool = *poolIt.second; 449 release_buffers(pool); 450 } 451 return true; 452} 453 454void MPSHeapAllocatorImpl::garbage_collect_cached_buffers(AllocParams& params) { 455 // skip garbage collection if memory pressure has already relieved 456 if (current_allocated_size() < m_low_watermark_limit) { 457 return; 458 } 459 // attempt to collect garbage until we reach below low watermark limit 460 const auto target_size = current_allocated_size() - m_low_watermark_limit; 461 const BufferPool& pool = *params.pool; 462 // calculate the total age of the free-able blocks. We'll use it later to get the average age threshold. 463 double total_age = 0.0; 464 unsigned int freeable_block_count = 0, freed_count = 0; 465 size_t gc_reclaimed = 0; 466 467 for (auto& b : pool.available_buffers) { 468 if (b->retainCount() <= 1) { 469 total_age += b->gc_count; 470 ++freeable_block_count; 471 } 472 } 473 if (freeable_block_count == 0) { 474 return; 475 } 476 // repeat GC until we reach reclaim > target size. 477 bool block_freed = true; 478 while (gc_reclaimed < target_size && block_freed && freeable_block_count > 0) { 479 // free blocks exceeding this age threshold first. 480 double age_threshold = total_age / freeable_block_count; 481 // stop iteration if we can no longer free a block. 482 block_freed = false; 483 // free blocks of > avg age. Stop garbage collection if we reach below the 484 // low watermark limit since re-allocation or fragmentation could be costly. 485 auto it = pool.available_buffers.begin(); 486 while (it != pool.available_buffers.end() && gc_reclaimed < target_size) { 487 BufferBlock* buffer_block = *it++; 488 if (buffer_block->gc_count >= age_threshold && buffer_block->retainCount() <= 1) { 489 block_freed = true; 490 gc_reclaimed += buffer_block->size; 491 total_age -= buffer_block->gc_count; 492 freeable_block_count--; 493 freed_count++; 494 release_buffer(buffer_block, !buffer_block->heap->is_split); 495 } 496 } 497 } 498 if (m_debug_verbosity & DebugVerbosity::RELEASES) { 499 std::cerr << "Garbage collected " << freed_count << " buffers from large " 500 << ((pool.usage & UsageFlags::SHARED) ? "shared" : "private") 501 << " pool (total reclaimed: " << format_size(gc_reclaimed) 502 << ", #buffers: " << pool.available_buffers.size() << ")\n"; 503 } 504} 505 506// public interface to MPSAllocator 507id<MTLBuffer> MPSHeapAllocatorImpl::malloc(size_t size, uint32_t usage) { 508 std::lock_guard<std::recursive_mutex> lock(m_mutex); 509 510 BufferBlock* buffer_block = alloc_buffer_block(size, usage); 511 return buffer_block ? buffer_block->buffer : nullptr; 512} 513 514bool MPSHeapAllocatorImpl::isSharedBuffer(const void* ptr) { 515 std::lock_guard<std::recursive_mutex> lock(m_mutex); 516 517 BufferBlock* buffer_block = get_allocated_buffer_block(ptr); 518 // it's OK for the buffer_block to not exist yet 519 return buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED); 520} 521 522id<MTLBuffer> MPSHeapAllocatorImpl::allocScalarBufferWithValue(void* value, size_t size) { 523 BufferBlock* buffer_block = nullptr; 524 { 525 std::lock_guard<std::recursive_mutex> lock(m_mutex); 526 527 buffer_block = alloc_buffer_block(size, UsageFlags::SCALAR); 528 if (!buffer_block) { 529 return nullptr; 530 } 531 if (!buffer_block->cpu_ptr) { 532 buffer_block->cpu_ptr = [buffer_block->buffer contents]; 533 } 534 } 535 // buffer is out of the pool, so no mutex lock is needed 536 memcpy(buffer_block->cpu_ptr, value, size); 537 return buffer_block->buffer; 538} 539 540std::pair<const void*, uint32_t> MPSHeapAllocatorImpl::getSharedBufferPtr(const void* ptr) { 541 std::lock_guard<std::recursive_mutex> lock(m_mutex); 542 543 BufferBlock* buffer_block = get_allocated_buffer_block(ptr); 544 // return if buffer was not allocated on MPSAllocator or isn't a Shared buffer 545 if (!buffer_block || !(buffer_block->heap->pool->usage & UsageFlags::SHARED)) { 546 return {nullptr, 0}; 547 } 548 if (!buffer_block->cpu_ptr) { 549 buffer_block->cpu_ptr = [buffer_block->buffer contents]; 550 } 551 return {buffer_block->cpu_ptr, buffer_block->retainCount()}; 552} 553 554bool MPSHeapAllocatorImpl::recordEvents(c10::ArrayRef<const void*> buffers) { 555 bool recordedEvent = false; 556 std::lock_guard<std::recursive_mutex> lock(m_mutex); 557 558 for (const auto& buffer : buffers) { 559 BufferBlock* buffer_block = get_allocated_buffer_block(buffer); 560 // return if buffer was not allocated on MPSAllocator or isn't a Shared buffer 561 if (buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED)) { 562 if (!buffer_block->event) { 563 buffer_block->event = m_event_pool->acquireEvent(false, nullptr); 564 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer_block->event); 565 } 566 buffer_block->event->record(/*needsLock*/ false); 567 recordedEvent = true; 568 } 569 } 570 return recordedEvent; 571} 572 573bool MPSHeapAllocatorImpl::waitForEvents(c10::ArrayRef<const void*> buffers) { 574 std::vector<BufferBlock*> buffer_blocks; 575 { 576 std::lock_guard<std::recursive_mutex> lock(m_mutex); 577 for (const auto& buffer : buffers) { 578 BufferBlock* buffer_block = get_allocated_buffer_block(buffer); 579 // wait on event if "shared" buffer was allocated on MPSAllocator and 580 // or actually needs waiting (based on retainCount) 581 if (buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED) && buffer_block->retainCount() > 1 && 582 buffer_block->event) { 583 buffer_blocks.push_back(buffer_block); 584 } 585 } 586 } 587 bool waitedForEvent = false; 588 589 for (const auto& buffer_block : buffer_blocks) { 590 // check for retain count again as the previous wait might have released the buffer 591 if (buffer_block->retainCount() > 1) { 592 bool waitedOnCPU = buffer_block->event->synchronize(); 593 if (waitedOnCPU) { 594 // after waiting, it's a good time to free some pending inactive buffers 595 freeInactiveBuffers(); 596 waitedForEvent |= buffer_block->retainCount() <= 1; 597 } else { 598 // even if one of the buffers weren't recorded beforehand, we return 599 // without continuing with other buffers since retainCount > 1 600 waitedForEvent = false; 601 break; 602 } 603 } 604 } 605 return waitedForEvent; 606} 607 608id_t MPSHeapAllocatorImpl::getBufferId(const void* ptr) { 609 std::lock_guard<std::recursive_mutex> lock(m_mutex); 610 611 BufferBlock* buffer_block = get_allocated_buffer_block(ptr); 612 return buffer_block ? buffer_block->buf_id : 0; 613} 614 615ssize_t MPSHeapAllocatorImpl::getUnalignedBufferSize(const void* ptr) { 616 std::lock_guard<std::recursive_mutex> lock(m_mutex); 617 618 BufferBlock* buffer_block = get_allocated_buffer_block(ptr); 619 if (buffer_block) { 620 return (ssize_t)buffer_block->requested_size; 621 } 622 // -1 indicates the passed buffer pointer wasn't found 623 return -1; 624} 625 626void MPSHeapAllocatorImpl::setBufferShape(const void* ptr, const IntArrayRef& shape) { 627 std::lock_guard<std::recursive_mutex> lock(m_mutex); 628 629 BufferBlock* buffer_block = get_allocated_buffer_block(ptr); 630 TORCH_INTERNAL_ASSERT(buffer_block, "failed to find the buffer ", ptr); 631 // note that the IntArrayRef doesn't own the underlying data, and the backing 632 // memory for shape data must persist as long as the buffer is in use. 633 // So we need to copy to vector. 634 buffer_block->shape = shape.vec(); 635} 636 637IntArrayRef MPSHeapAllocatorImpl::getBufferShape(const void* ptr) { 638 std::lock_guard<std::recursive_mutex> lock(m_mutex); 639 640 BufferBlock* buffer_block = get_allocated_buffer_block(ptr); 641 if (buffer_block && buffer_block->shape.size() > 0) { 642 return IntArrayRef{buffer_block->shape}; 643 } 644 return IntArrayRef(); 645} 646 647void MPSHeapAllocatorImpl::free(void* ptr) { 648 BufferBlock* buffer_block = nullptr; 649 { 650 std::lock_guard<std::recursive_mutex> lock(m_mutex); 651 652 buffer_block = get_allocated_buffer_block(ptr); 653 TORCH_INTERNAL_ASSERT(buffer_block); 654 const BufferPool& pool = *buffer_block->heap->pool; 655 if (!(pool.usage & UsageFlags::SCALAR)) { 656 free_buffer(buffer_block); 657 return; 658 } 659 } 660 // we sync the scalar pool manually with completion handler at the time buffer is 661 // freed when the MPSScalar instance goes our of scope 662 m_stream->addCompletedHandler(^(id<MTLCommandBuffer>) { 663 std::lock_guard<std::recursive_mutex> lock(m_mutex); 664 free_buffer(buffer_block); 665 }); 666} 667 668void MPSHeapAllocatorImpl::freeInactiveBuffers() { 669 std::lock_guard<std::recursive_mutex> lock(m_mutex); 670 671 for (const auto& poolIt : m_pools) { 672 BufferPool& pool = *poolIt.second; 673 if (!pool.buffers_pending_free.empty()) { 674 for (auto it = pool.buffers_pending_free.begin(), last = pool.buffers_pending_free.end(); it != last;) { 675 BufferBlock* buffer_block = *it; 676 if (buffer_block->retainCount() <= 1) { 677 it = pool.buffers_pending_free.erase(it); 678 free_buffer(buffer_block); 679 } else { 680 ++it; 681 } 682 } 683 } 684 } 685} 686 687void MPSHeapAllocatorImpl::emptyCache() { 688 std::lock_guard<std::recursive_mutex> lock(m_mutex); 689 release_cached_buffers(); 690} 691 692ssize_t MPSHeapAllocatorImpl::getLowWatermarkValue() { 693 // check if low watermark limit is disabled 694 if (m_low_watermark_ratio == 0.0) { 695 return std::numeric_limits<ssize_t>::max(); 696 } 697 // current_allocated_size could exceed m_low_watermark_limit (e.g., when swapping to disk) 698 return std::max<ssize_t>(0, (ssize_t)(m_low_watermark_limit - current_allocated_size()) / 1048576L); 699} 700 701inline std::string MPSHeapAllocatorImpl::format_size(uint64_t size) const { 702 std::ostringstream os; 703 os.precision(2); 704 os << std::fixed; 705 if (size <= 1024UL) { 706 os << size << " bytes"; 707 } else if (size <= 1048576UL) { 708 os << ((float)size / 1024.0) << " KB"; 709 } else if (size <= 1073741824UL) { 710 os << ((float)size / 1048576.0) << " MB"; 711 } else { 712 os << ((float)size / 1073741824.0) << " GB"; 713 } 714 return os.str(); 715} 716 717} // namespace HeapAllocator 718 719// Use "at::mps::GetMPSAllocator()" to acquire a handle to MPS Allocator 720namespace { 721HeapAllocator::MPSHeapAllocatorImpl& _getAllocImpl() { 722 static HeapAllocator::MPSHeapAllocatorImpl s_allocatorImpl; 723 return s_allocatorImpl; 724} 725} // namespace 726 727// MPS allocator struct to be registered with Pytorch 728struct TORCH_API MPSAllocator final : public IMPSAllocator { 729 public: 730 explicit MPSAllocator(uint32_t Usage) 731 : m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage) { 732 if (_getAllocImpl().getDebugVerbosity()) { 733 if (!(m_usage & HeapAllocator::UsageFlags::SHARED) || m_has_unified_memory) { 734 std::cerr << "Initializing " << ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private") 735 << " heap allocator on " << (m_has_unified_memory ? "unified" : "discrete") 736 << " device memory of size " 737 << _getAllocImpl().format_size(_getAllocImpl().Device().recommendedMaxWorkingSetSize) << "\n"; 738 } 739 } 740 } 741 742 ~MPSAllocator() override { 743 _getAllocImpl().emptyCache(); 744 } 745 DeleterFnPtr raw_deleter() const override { 746 return &Delete; 747 } 748 749 DataPtr allocate(const size_t nbytes) override { 750 __block id<MTLBuffer> buf = nbytes > 0 ? _getAllocImpl().malloc(nbytes, m_usage) : nullptr; 751 return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)}; 752 } 753 754 // implementation of IMPSAllocator interface 755 DataPtr allocScalarBufferWithValue(void* value, size_t size) const override { 756 id<MTLBuffer> buf = _getAllocImpl().allocScalarBufferWithValue(value, size); 757 return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)}; 758 } 759 std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const override { 760 return _getAllocImpl().getSharedBufferPtr(ptr); 761 } 762 bool isSharedBuffer(const void* ptr) const override { 763 return _getAllocImpl().isSharedBuffer(ptr); 764 } 765 bool isSharedStorageSupported() const override { 766 return m_has_unified_memory; 767 } 768 void emptyCache() const override { 769 _getAllocImpl().emptyCache(); 770 } 771 void freeInactiveBuffers() const override { 772 _getAllocImpl().freeInactiveBuffers(); 773 } 774 ssize_t getUnalignedBufferSize(const void* ptr) const override { 775 return _getAllocImpl().getUnalignedBufferSize(ptr); 776 } 777 id_t getBufferId(const void* ptr) const override { 778 return _getAllocImpl().getBufferId(ptr); 779 }; 780 IntArrayRef getBufferShape(const void* ptr) const override { 781 return _getAllocImpl().getBufferShape(ptr); 782 } 783 void setBufferShape(const void* ptr, const IntArrayRef& shape) const override { 784 _getAllocImpl().setBufferShape(ptr, shape); 785 } 786 size_t getTotalAllocatedMemory() const override { 787 return _getAllocImpl().getTotalAllocatedMemory(); 788 } 789 size_t getCurrentAllocatedMemory() const override { 790 return _getAllocImpl().getCurrentAllocatedMemory(); 791 } 792 size_t getDriverAllocatedMemory() const override { 793 return _getAllocImpl().getDriverAllocatedMemory(); 794 } 795 size_t getRecommendedMaxMemory() const override { 796 return _getAllocImpl().getRecommendedMaxMemory(); 797 } 798 ssize_t getLowWatermarkValue() const override { 799 return _getAllocImpl().getLowWatermarkValue(); 800 } 801 size_t getLowWatermarkLimit() const override { 802 return _getAllocImpl().getLowWatermarkLimit(); 803 } 804 size_t getHighWatermarkLimit() const override { 805 return _getAllocImpl().getHighWatermarkLimit(); 806 } 807 void setLowWatermarkRatio(double ratio) const override { 808 _getAllocImpl().setLowWatermarkRatio(ratio); 809 } 810 void setHighWatermarkRatio(double ratio) const override { 811 _getAllocImpl().setHighWatermarkRatio(ratio); 812 } 813 bool recordEvents(c10::ArrayRef<const void*> buffers) const override { 814 return _getAllocImpl().recordEvents(buffers); 815 } 816 bool waitForEvents(c10::ArrayRef<const void*> buffers) const override { 817 return _getAllocImpl().waitForEvents(buffers); 818 } 819 std::string formatSize(size_t size) const override { 820 return _getAllocImpl().format_size(size); 821 } 822 823 void copy_data(void* dest, const void* src, std::size_t count) const final { 824 default_copy_data(dest, src, count); 825 } 826 827 private: 828 bool m_has_unified_memory; 829 uint32_t m_usage; 830 831 static void Delete(void* ptr) { 832 if (ptr) { 833 _getAllocImpl().free(ptr); 834 } 835 } 836}; 837 838namespace { 839MPSAllocator& _getSharedAllocator() { 840 static MPSAllocator s_mps_shared_alloc(HeapAllocator::UsageFlags::SHARED); 841 return s_mps_shared_alloc; 842} 843 844MPSAllocator& _getPrivateAllocator() { 845 static MPSAllocator s_mps_private_alloc(HeapAllocator::UsageFlags::PRIVATE); 846 return s_mps_private_alloc; 847} 848} // anonymous namespace 849 850IMPSAllocator* getIMPSAllocator(bool sharedAllocator) { 851 if (!sharedAllocator) { 852 return &_getPrivateAllocator(); 853 } 854 auto& sa = _getSharedAllocator(); 855 if (sa.isSharedStorageSupported()) { 856 return &sa; 857 } 858 return nullptr; 859} 860 861// torch.is_pinned() implementation 862// Pinned memory will be helpful on Apple Silicon Macs with Unified memory as we 863// will be able to use SharedStorageMode for MTLBuffer allocations. This will 864// avoid extra copies on DataLoading operations. 865bool isMPSPinnedPtr(const void* data) { 866 return at::mps::_getSharedAllocator().isSharedBuffer(data); 867} 868 869} // namespace at::mps 870