xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSAllocator.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1//  Copyright © 2022 Apple Inc.
2
3#include <ATen/CPUFunctions.h>
4#include <ATen/EmptyTensor.h>
5#include <ATen/mps/MPSAllocator.h>
6#include <c10/core/Allocator.h>
7#include <c10/core/Storage.h>
8
9#include <iostream>
10
11namespace at::mps {
12
13C10_DEFINE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
14
15namespace HeapAllocator {
16
17uint64_t BufferBlock::buffer_counter = 0;
18uint64_t HeapBlock::heap_counter = 0;
19
20void MPSHeapAllocatorImpl::init_allocator() {
21  init_buffer_pools();
22
23  // debug verbosity flags (see DebugVerbosity enum)
24  static const char* verbosity_str = getenv("PYTORCH_DEBUG_MPS_ALLOCATOR");
25  m_debug_verbosity = verbosity_str ? strtol(verbosity_str, nullptr, 0) : DebugVerbosity::SILENT;
26
27  static const char* high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO");
28  const double high_watermark_ratio =
29      high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) : default_high_watermark_ratio;
30  setHighWatermarkRatio(high_watermark_ratio);
31
32  const double default_low_watermark_ratio =
33      m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified : default_low_watermark_ratio_discrete;
34  static const char* low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO");
35  const double low_watermark_ratio =
36      low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio;
37  setLowWatermarkRatio(low_watermark_ratio);
38}
39
40void MPSHeapAllocatorImpl::init_buffer_pools() {
41  // using a container for pools to simplify iterating over them
42  // Pool of large buffers with private storage mode
43  m_pools.emplace(BufferPool::Kind::PRIVATE_LARGE,
44                  std::make_unique<BufferPool>(m_device, UsageFlags::PRIVATE | UsageFlags::HAZARD));
45  // Pool of large buffers with shared storage mode
46  m_pools.emplace(BufferPool::Kind::SHARED_LARGE,
47                  std::make_unique<BufferPool>(m_device, UsageFlags::SHARED | UsageFlags::HAZARD));
48  // Pool of small buffers with private storage mode
49  m_pools.emplace(BufferPool::Kind::PRIVATE_SMALL,
50                  std::make_unique<BufferPool>(m_device, UsageFlags::SMALL | UsageFlags::PRIVATE | UsageFlags::HAZARD));
51  // Pool of small buffers with shared storage mode
52  m_pools.emplace(BufferPool::Kind::SHARED_SMALL,
53                  std::make_unique<BufferPool>(m_device, UsageFlags::SMALL | UsageFlags::SHARED | UsageFlags::HAZARD));
54  // Pool of small buffers with shared storage mode used to allocate and copy Scalars
55  // from CPU to Metal buffers (see allocScalarBufferWithValue()).
56  // no Hazard Tracking required for the Scalar pool (synchronized manually).
57  m_pools.emplace(BufferPool::Kind::SCALAR,
58                  std::make_unique<BufferPool>(m_device, UsageFlags::SMALL | UsageFlags::SHARED | UsageFlags::SCALAR));
59}
60
61BufferPool& MPSHeapAllocatorImpl::get_pool(size_t requested_size, size_t aligned_size, uint32_t usage) {
62  BufferPool::Kind poolKind;
63
64  if (usage & UsageFlags::SCALAR) {
65    poolKind = BufferPool::Kind::SCALAR;
66  } else if (requested_size <= kMaxScalarAlloc && m_device.hasUnifiedMemory) {
67    poolKind = BufferPool::Kind::SHARED_SMALL;
68  } else if (aligned_size <= kMaxSmallAlloc) {
69    poolKind = (usage & UsageFlags::SHARED) ? BufferPool::Kind::SHARED_SMALL : BufferPool::Kind::PRIVATE_SMALL;
70  } else {
71    poolKind = (usage & UsageFlags::SHARED) ? BufferPool::Kind::SHARED_LARGE : BufferPool::Kind::PRIVATE_LARGE;
72  }
73  return *m_pools[poolKind];
74}
75
76size_t MPSHeapAllocatorImpl::get_allocation_size(size_t size, uint32_t usage) const {
77  MTLSizeAndAlign sizeAlign = [m_device heapBufferSizeAndAlignWithLength:size options:HeapBlock::getOptions(usage)];
78  return BufferBlock::alignUp(sizeAlign.size, sizeAlign.align);
79}
80
81void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio) {
82  TORCH_CHECK(ratio >= 0.0 && ratio <= default_high_watermark_upper_bound, "invalid high watermark ratio ", ratio);
83  m_max_total_allowed_size =
84      (ratio == 0.0) ? std::numeric_limits<size_t>::max() : static_cast<size_t>(ratio * (double)max_device_size());
85  if (m_debug_verbosity & DebugVerbosity::PROFILING) {
86    std::cerr << "\nHigh watermark memory allocation limit: "
87              << (ratio == 0.0 ? "unlimited" : format_size(m_max_total_allowed_size)) << "\n";
88  }
89  m_high_watermark_ratio = ratio;
90}
91
92void MPSHeapAllocatorImpl::setLowWatermarkRatio(double ratio) {
93  // used for comparison with lower_watermark_ratio
94  const double high_watermark_limit =
95      m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio;
96  TORCH_CHECK(ratio >= 0.0 && ratio <= high_watermark_limit, "invalid low watermark ratio ", ratio);
97  // we use this to detect if there's memory pressure
98  m_low_watermark_limit =
99      (ratio == 0.0) ? std::numeric_limits<size_t>::max() : static_cast<size_t>(ratio * (double)max_device_size());
100  if (m_debug_verbosity & DebugVerbosity::PROFILING) {
101    std::cerr << "Low watermark memory allocation limit: "
102              << (ratio == 0.0 ? "unlimited" : format_size(m_low_watermark_limit)) << "\n";
103  }
104  m_low_watermark_ratio = ratio;
105}
106
107HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& params) {
108  BufferPool& pool = *params.pool;
109  HeapBlock* heap_block = nullptr;
110  HeapBlock search_key(params.size());
111
112  auto it = pool.heaps.lower_bound(&search_key);
113  if (it == pool.heaps.end()) {
114    heap_block = HeapBlock::createHeapBlock(params, pool.device, pool.usage);
115    if (heap_block) {
116      m_total_allocated_memory += heap_block->size.total;
117      if (m_debug_verbosity & DebugVerbosity::ALLOCATIONS) {
118        std::cerr << "\nAllocated " << ((pool.usage & UsageFlags::SHARED) ? "shared" : "private") << " heap #"
119                  << heap_block->heap_id << " of size " << format_size(heap_block->size.total)
120                  << " (#heaps: " << (pool.heaps.size() + 1)
121                  << ", current allocated: " << format_size(current_allocated_size()) << ")\n";
122      }
123    }
124  } else {
125    heap_block = *it;
126    // remove and re-insert heap in the set later after a buffer is created.
127    // this ensures updating the order of heaps based on their new available sizes
128    pool.heaps.erase(it);
129  }
130  return heap_block;
131}
132
133bool MPSHeapAllocatorImpl::alloc_buffer(AllocParams& params) {
134  if (m_max_total_allowed_size != std::numeric_limits<size_t>::max() &&
135      current_allocated_size() + params.size() > m_max_total_allowed_size) {
136    return false;
137  }
138  HeapBlock* heap = get_free_heap(params);
139  if (!heap) {
140    return false; // this will cause releasing pool buffers to free up memory
141  }
142  BufferPool& pool = *params.pool;
143
144  id<MTLBuffer> buffer = heap->newMTLBuffer(params.size(), pool.usage);
145  // this should never happen as the backing memory (i.e., heap) was allocated successfully.
146  TORCH_INTERNAL_ASSERT(buffer);
147  // insert heap after a buffer was created on it to update the order of heap's set
148  pool.heaps.insert(heap);
149  params.buffer_block = new BufferBlock(params.size(), params.requested_size, buffer, heap);
150  m_allocated_buffers[params.buffer_block->buffer] = params.buffer_block;
151  pool.allocated_size += params.size();
152  pool.n_buffers++;
153
154  if ((m_debug_verbosity & DebugVerbosity::ALLOCATIONS) &&
155      (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
156    std::cerr << "Allocated " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
157              << ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #"
158              << params.buffer_block->buf_id << " of size " << format_size(params.size()) << " at "
159              << params.buffer_block->buffer << " from heap #" << heap->heap_id
160              << " (requested: " << format_size(params.requested_size)
161              << ", heap: " << format_size(heap->size.available) << ", total: " << format_size(m_total_allocated_memory)
162              << ")\n";
163  }
164  return true;
165}
166
167bool MPSHeapAllocatorImpl::get_free_buffer(AllocParams& params) {
168  // this helps to monitor "implicit" allocations from MPS backend and to prevent OOM and system failure.
169  if (m_high_watermark_ratio > 0.0 && current_allocated_size() + params.size() > m_max_total_allowed_size) {
170    return false;
171  }
172  BufferPool& pool = *params.pool;
173  // track buffer reuse intervals only on large pool when low watermark limit is enabled.
174  if (m_low_watermark_ratio > 0.0 && !(pool.usage & UsageFlags::SMALL)) {
175    for (auto& b : pool.available_buffers) {
176      ++b->gc_count;
177    }
178  }
179  auto it = pool.available_buffers.lower_bound(&params.search_key);
180  if (it != pool.available_buffers.end()) {
181    BufferBlock* buffer_block = *it;
182
183    // the logic in here is simple: keep reusing existing heaps capacity as long as possible (by splitting
184    // or releasing oversize buffers, if required), and avoid 'new' heap allocations as much as possible.
185    if (buffer_block->size <= params.size() + kLargeHeap) {
186      // return the existing buffer if it already fits the requested size (i.e., not oversize)
187      params.buffer_block = buffer_block;
188    } else {
189      HeapBlock search_key(params.size());
190      // if there's an 'existing' heap with enough capacity, then don't
191      // return the oversize buffer and sub-allocate from that existing heap.
192      if (pool.heaps.lower_bound(&search_key) != pool.heaps.end()) {
193        params.buffer_block = nullptr;
194      } else if (buffer_block->retainCount() <= 1) {
195        // otherwise if buffer is releasable immediately, we make room by releasing the
196        // buffer and reuse the new space within its heap container for the new smaller buffer allocation
197        release_buffer(buffer_block, false);
198        // this will skip unnecessary garbage collection as we'll reuse the newly released space
199        params.has_memory_pressure = false;
200      } else if (params.has_memory_pressure) {
201        // the oversized buffer is busy and not reusable at the moment. So release it (and potentially its heap
202        // container) in allocator, and ARC will later free up its backing memory when the busy command buffer finishes.
203        release_buffer(buffer_block, true);
204      } else {
205        // only if there's no memory pressure, we'll reuse the oversized buffer
206        params.buffer_block = buffer_block;
207      }
208    }
209  }
210
211  if (!params.buffer_block) {
212    return false; // this will make allocator to allocate a new buffer
213  }
214  pool.available_buffers.erase(params.buffer_block);
215  params.buffer_block->requested_size = params.requested_size;
216  params.buffer_block->gc_count = 0;
217  pool.available_size -= params.buffer_block->size;
218
219  if ((m_debug_verbosity & DebugVerbosity::RECYCLES) &&
220      (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
221    std::cerr << "Reusing " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private")
222              << ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #"
223              << params.buffer_block->buf_id << " of size " << format_size(params.buffer_block->size) << " at "
224              << params.buffer_block->buffer << " (requested: " << format_size(params.requested_size)
225              << ", use#: " << params.buffer_block->use_count + 1 << ", retain#: " << params.buffer_block->retainCount()
226              << ")\n";
227  }
228  return true;
229}
230
231BufferBlock* MPSHeapAllocatorImpl::alloc_buffer_block(size_t size, uint32_t usage) {
232  TORCH_CHECK(size < m_max_buffer_size, "Invalid buffer size: ", format_size(size));
233
234  size_t alloc_size = get_allocation_size(size, usage);
235  auto& pool = get_pool(size, alloc_size, usage);
236  AllocParams params(alloc_size, size, &pool);
237  // we care about memory pressure if only we're allocating large buffers when the
238  // low watermark limit has been reached
239  params.has_memory_pressure = !(pool.usage & UsageFlags::SMALL) && getLowWatermarkValue() <= 0;
240  params.has_unified_memory = m_device.hasUnifiedMemory;
241
242  // first, try to get a block from the existing pool.
243  bool block_found = get_free_buffer(params);
244  if (!block_found) {
245    // do garbage collection if memory pressure is high and there's enough memory in pool
246    if (params.has_memory_pressure && alloc_size < pool.available_size) {
247      garbage_collect_cached_buffers(params);
248    }
249
250    block_found =
251        // Attempt allocate
252        alloc_buffer(params) ||
253        // Callbacks might release more memory (eg. by forcing a GC in the host language) thus
254        // we can retry getting a free buffer in the pool, before trying to alloc again.
255        (trigger_memory_callbacks(nullptr, IMpsAllocatorCallback::EventType::ALLOCATION_FAILED) &&
256         get_free_buffer(params)) ||
257        // Free enough available cached blocks to satisfy alloc and retry alloc.
258        (release_available_cached_buffers(params) && alloc_buffer(params)) ||
259        // Free all cached buffers and retry alloc.
260        (release_cached_buffers() && alloc_buffer(params));
261  }
262
263  BufferBlock* buffer_block = params.buffer_block;
264
265  // the OOM could be triggered if:
266  //   1- the High Watermark limit has been reached (if enabled)
267  //   2- ran out of device memory, or the memory fragmentation is so high that a contiguous
268  //      chunk of requested size couldn't be found.
269  if (!block_found || !buffer_block) {
270    if (m_high_watermark_ratio > 0.0) {
271      TORCH_CHECK(
272          false,
273          "MPS backend out of memory (MPS allocated: ",
274          format_size(m_total_allocated_memory),
275          ", other allocations: ",
276          format_size(current_allocated_size() - m_total_allocated_memory),
277          ", max allowed: ",
278          format_size(m_max_total_allowed_size),
279          "). Tried to allocate ",
280          format_size(alloc_size),
281          " on ",
282          ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"),
283          " pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).");
284    } else {
285      TORCH_CHECK(false,
286                  "MPS backend out of memory (MPS allocated: ",
287                  format_size(m_total_allocated_memory),
288                  ", other allocations: ",
289                  format_size(current_allocated_size() - m_total_allocated_memory),
290                  "). Tried to allocate ",
291                  format_size(alloc_size),
292                  " on ",
293                  ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"),
294                  " pool.");
295    }
296  }
297  buffer_block->in_use = true;
298  buffer_block->use_count++;
299  m_current_allocated_memory += buffer_block->size;
300
301  return buffer_block;
302}
303
304void MPSHeapAllocatorImpl::free_buffer(BufferBlock* buffer_block) {
305  TORCH_INTERNAL_ASSERT(buffer_block->in_use);
306
307  BufferPool& pool = *buffer_block->heap->pool;
308  // Makes sure the BufferBlock* isn't already present in the pool we're freeing it back into.
309  TORCH_INTERNAL_ASSERT(pool.available_buffers.insert(buffer_block).second);
310  pool.available_size += buffer_block->size;
311  buffer_block->shape.clear(); // reset shape
312  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_current_allocated_memory >= buffer_block->size);
313  m_current_allocated_memory -= buffer_block->size;
314  if (buffer_block->event) {
315    // returns the MPSEvent back to MPSEventPool
316    buffer_block->event.reset(nullptr);
317  }
318  buffer_block->in_use = false;
319}
320
321BufferBlock* MPSHeapAllocatorImpl::get_allocated_buffer_block(const void* ptr) {
322  auto it = m_allocated_buffers.find(ptr);
323  if (it == m_allocated_buffers.end()) {
324    return nullptr;
325  }
326  return it->second;
327}
328
329bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove_empty_heap) {
330  HeapBlock* heap_block = buffer_block->heap;
331  BufferPool& pool = *heap_block->pool;
332  pool.allocated_size -= buffer_block->size;
333  pool.available_size -= buffer_block->size;
334  m_allocated_buffers.erase(buffer_block->buffer);
335  pool.available_buffers.erase(buffer_block);
336  pool.n_buffers--;
337  // will re-insert later to keep the heaps list sorted based on heap's new available size (if heap not empty)
338  pool.heaps.erase(heap_block);
339  uint32_t retainCount = heap_block->releaseMTLBuffer(buffer_block->buffer);
340
341  if ((m_debug_verbosity & DebugVerbosity::RELEASES) &&
342      (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) {
343    std::cerr << "Released buffer #" << buffer_block->buf_id << " of size " << format_size(buffer_block->size)
344              << " from heap #" << heap_block->heap_id << " (heap size: " << format_size(heap_block->size.available)
345              << ", use#: " << buffer_block->use_count << ", retain#: " << retainCount
346              << ", gc#: " << buffer_block->gc_count << ")\n";
347  }
348  delete buffer_block;
349
350  if (remove_empty_heap && heap_block->n_buffers == 0) {
351    pool.heaps_pending_update.erase(heap_block);
352    m_total_allocated_memory -= heap_block->size.total;
353    retainCount = heap_block->releaseMTLHeap();
354    if (m_debug_verbosity & DebugVerbosity::RELEASES) {
355      std::cerr << "Released heap #" << heap_block->heap_id << " of size " << format_size(heap_block->size.total)
356                << " (current allocated: " << format_size(current_allocated_size()) << ", retain#: " << retainCount
357                << ")\n";
358    }
359    delete heap_block;
360    return true;
361  } else {
362    pool.heaps.insert(heap_block);
363    // if heap wasn't released and its released buffer is still busy in command buffer, the available
364    // size of the heap cannot be updated and we should defer updating until command buffer finishes.
365    if (retainCount > 1) {
366      pool.heaps_pending_update.insert(heap_block);
367      m_mutex.unlock();
368      m_stream->addCompletedHandler(^(id<MTLCommandBuffer>) {
369        std::lock_guard<std::recursive_mutex> lock(m_mutex);
370        // check if the heap block still exists
371        if (pool.heaps_pending_update.find(heap_block) != pool.heaps_pending_update.end()) {
372          pool.heaps_pending_update.erase(heap_block);
373          pool.heaps.erase(heap_block);
374          heap_block->updateAvailableSize();
375          pool.heaps.insert(heap_block);
376        }
377      });
378      m_mutex.lock();
379    }
380  }
381  return false;
382}
383
384void MPSHeapAllocatorImpl::release_buffers(BufferPool& pool) {
385  if (pool.available_buffers.empty()) {
386    return;
387  }
388  if ((m_debug_verbosity & DebugVerbosity::RELEASES)) {
389    std::cerr << "Releasing " << pool.available_buffers.size() << " buffers from "
390              << ((pool.usage & UsageFlags::SMALL) ? "small " : "large ")
391              << ((pool.usage & UsageFlags::SHARED) ? "shared" : "private")
392              << ((pool.usage & UsageFlags::SCALAR) ? " scalar" : "")
393              << " pool (total size: " << format_size(pool.allocated_size) << ", #buffers: " << pool.n_buffers << ")\n";
394  }
395  auto it = pool.available_buffers.begin();
396  while (it != pool.available_buffers.end()) {
397    BufferBlock* buffer_block = *it;
398    ++it;
399    release_buffer(buffer_block);
400  }
401}
402
403bool MPSHeapAllocatorImpl::release_available_cached_buffers(AllocParams& params) {
404  BufferPool& pool = *params.pool;
405
406  if (pool.available_buffers.empty()) {
407    return false;
408  }
409  auto it = pool.available_buffers.lower_bound(&params.search_key);
410  if (it == pool.available_buffers.end()) {
411    size_t totalReleased = 0;
412    --it;
413    while (totalReleased < params.search_key.size) {
414      auto cur = it;
415      totalReleased += (*it)->size;
416      if (it != pool.available_buffers.begin()) {
417        --it;
418        release_buffer(*cur);
419      } else {
420        release_buffer(*cur);
421        break;
422      }
423    }
424    if (totalReleased < params.search_key.size) {
425      return false;
426    }
427  } else {
428    release_buffer(*it);
429  }
430  return true;
431}
432
433bool MPSHeapAllocatorImpl::release_cached_buffers() {
434  if (m_debug_verbosity >= DebugVerbosity::PROFILING) {
435    std::cerr << "Attempting to release cached buffers (MPS allocated: " << format_size(m_total_allocated_memory)
436              << ", other allocations: " << format_size(current_allocated_size() - m_total_allocated_memory) << ")\n";
437  }
438  // before releasing the buffers make sure the command buffer has finished.
439  // we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers.
440  m_mutex.unlock();
441  auto stream = getDefaultMPSStream();
442  dispatch_sync(stream->queue(), ^() {
443    stream->synchronize(SyncType::COMMIT_AND_WAIT);
444  });
445  m_mutex.lock();
446  // Free all cached blocks to system allocator
447  for (const auto& poolIt : m_pools) {
448    BufferPool& pool = *poolIt.second;
449    release_buffers(pool);
450  }
451  return true;
452}
453
454void MPSHeapAllocatorImpl::garbage_collect_cached_buffers(AllocParams& params) {
455  // skip garbage collection if memory pressure has already relieved
456  if (current_allocated_size() < m_low_watermark_limit) {
457    return;
458  }
459  // attempt to collect garbage until we reach below low watermark limit
460  const auto target_size = current_allocated_size() - m_low_watermark_limit;
461  const BufferPool& pool = *params.pool;
462  // calculate the total age of the free-able blocks. We'll use it later to get the average age threshold.
463  double total_age = 0.0;
464  unsigned int freeable_block_count = 0, freed_count = 0;
465  size_t gc_reclaimed = 0;
466
467  for (auto& b : pool.available_buffers) {
468    if (b->retainCount() <= 1) {
469      total_age += b->gc_count;
470      ++freeable_block_count;
471    }
472  }
473  if (freeable_block_count == 0) {
474    return;
475  }
476  // repeat GC until we reach reclaim > target size.
477  bool block_freed = true;
478  while (gc_reclaimed < target_size && block_freed && freeable_block_count > 0) {
479    // free blocks exceeding this age threshold first.
480    double age_threshold = total_age / freeable_block_count;
481    // stop iteration if we can no longer free a block.
482    block_freed = false;
483    // free blocks of > avg age. Stop garbage collection if we reach below the
484    // low watermark limit since re-allocation or fragmentation could be costly.
485    auto it = pool.available_buffers.begin();
486    while (it != pool.available_buffers.end() && gc_reclaimed < target_size) {
487      BufferBlock* buffer_block = *it++;
488      if (buffer_block->gc_count >= age_threshold && buffer_block->retainCount() <= 1) {
489        block_freed = true;
490        gc_reclaimed += buffer_block->size;
491        total_age -= buffer_block->gc_count;
492        freeable_block_count--;
493        freed_count++;
494        release_buffer(buffer_block, !buffer_block->heap->is_split);
495      }
496    }
497  }
498  if (m_debug_verbosity & DebugVerbosity::RELEASES) {
499    std::cerr << "Garbage collected " << freed_count << " buffers from large "
500              << ((pool.usage & UsageFlags::SHARED) ? "shared" : "private")
501              << " pool (total reclaimed: " << format_size(gc_reclaimed)
502              << ", #buffers: " << pool.available_buffers.size() << ")\n";
503  }
504}
505
506// public interface to MPSAllocator
507id<MTLBuffer> MPSHeapAllocatorImpl::malloc(size_t size, uint32_t usage) {
508  std::lock_guard<std::recursive_mutex> lock(m_mutex);
509
510  BufferBlock* buffer_block = alloc_buffer_block(size, usage);
511  return buffer_block ? buffer_block->buffer : nullptr;
512}
513
514bool MPSHeapAllocatorImpl::isSharedBuffer(const void* ptr) {
515  std::lock_guard<std::recursive_mutex> lock(m_mutex);
516
517  BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
518  // it's OK for the buffer_block to not exist yet
519  return buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED);
520}
521
522id<MTLBuffer> MPSHeapAllocatorImpl::allocScalarBufferWithValue(void* value, size_t size) {
523  BufferBlock* buffer_block = nullptr;
524  {
525    std::lock_guard<std::recursive_mutex> lock(m_mutex);
526
527    buffer_block = alloc_buffer_block(size, UsageFlags::SCALAR);
528    if (!buffer_block) {
529      return nullptr;
530    }
531    if (!buffer_block->cpu_ptr) {
532      buffer_block->cpu_ptr = [buffer_block->buffer contents];
533    }
534  }
535  // buffer is out of the pool, so no mutex lock is needed
536  memcpy(buffer_block->cpu_ptr, value, size);
537  return buffer_block->buffer;
538}
539
540std::pair<const void*, uint32_t> MPSHeapAllocatorImpl::getSharedBufferPtr(const void* ptr) {
541  std::lock_guard<std::recursive_mutex> lock(m_mutex);
542
543  BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
544  // return if buffer was not allocated on MPSAllocator or isn't a Shared buffer
545  if (!buffer_block || !(buffer_block->heap->pool->usage & UsageFlags::SHARED)) {
546    return {nullptr, 0};
547  }
548  if (!buffer_block->cpu_ptr) {
549    buffer_block->cpu_ptr = [buffer_block->buffer contents];
550  }
551  return {buffer_block->cpu_ptr, buffer_block->retainCount()};
552}
553
554bool MPSHeapAllocatorImpl::recordEvents(c10::ArrayRef<const void*> buffers) {
555  bool recordedEvent = false;
556  std::lock_guard<std::recursive_mutex> lock(m_mutex);
557
558  for (const auto& buffer : buffers) {
559    BufferBlock* buffer_block = get_allocated_buffer_block(buffer);
560    // return if buffer was not allocated on MPSAllocator or isn't a Shared buffer
561    if (buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED)) {
562      if (!buffer_block->event) {
563        buffer_block->event = m_event_pool->acquireEvent(false, nullptr);
564        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer_block->event);
565      }
566      buffer_block->event->record(/*needsLock*/ false);
567      recordedEvent = true;
568    }
569  }
570  return recordedEvent;
571}
572
573bool MPSHeapAllocatorImpl::waitForEvents(c10::ArrayRef<const void*> buffers) {
574  std::vector<BufferBlock*> buffer_blocks;
575  {
576    std::lock_guard<std::recursive_mutex> lock(m_mutex);
577    for (const auto& buffer : buffers) {
578      BufferBlock* buffer_block = get_allocated_buffer_block(buffer);
579      // wait on event if "shared" buffer was allocated on MPSAllocator and
580      // or actually needs waiting (based on retainCount)
581      if (buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED) && buffer_block->retainCount() > 1 &&
582          buffer_block->event) {
583        buffer_blocks.push_back(buffer_block);
584      }
585    }
586  }
587  bool waitedForEvent = false;
588
589  for (const auto& buffer_block : buffer_blocks) {
590    // check for retain count again as the previous wait might have released the buffer
591    if (buffer_block->retainCount() > 1) {
592      bool waitedOnCPU = buffer_block->event->synchronize();
593      if (waitedOnCPU) {
594        // after waiting, it's a good time to free some pending inactive buffers
595        freeInactiveBuffers();
596        waitedForEvent |= buffer_block->retainCount() <= 1;
597      } else {
598        // even if one of the buffers weren't recorded beforehand, we return
599        // without continuing with other buffers since retainCount > 1
600        waitedForEvent = false;
601        break;
602      }
603    }
604  }
605  return waitedForEvent;
606}
607
608id_t MPSHeapAllocatorImpl::getBufferId(const void* ptr) {
609  std::lock_guard<std::recursive_mutex> lock(m_mutex);
610
611  BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
612  return buffer_block ? buffer_block->buf_id : 0;
613}
614
615ssize_t MPSHeapAllocatorImpl::getUnalignedBufferSize(const void* ptr) {
616  std::lock_guard<std::recursive_mutex> lock(m_mutex);
617
618  BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
619  if (buffer_block) {
620    return (ssize_t)buffer_block->requested_size;
621  }
622  // -1 indicates the passed buffer pointer wasn't found
623  return -1;
624}
625
626void MPSHeapAllocatorImpl::setBufferShape(const void* ptr, const IntArrayRef& shape) {
627  std::lock_guard<std::recursive_mutex> lock(m_mutex);
628
629  BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
630  TORCH_INTERNAL_ASSERT(buffer_block, "failed to find the buffer ", ptr);
631  // note that the IntArrayRef doesn't own the underlying data, and the backing
632  // memory for shape data must persist as long as the buffer is in use.
633  // So we need to copy to vector.
634  buffer_block->shape = shape.vec();
635}
636
637IntArrayRef MPSHeapAllocatorImpl::getBufferShape(const void* ptr) {
638  std::lock_guard<std::recursive_mutex> lock(m_mutex);
639
640  BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
641  if (buffer_block && buffer_block->shape.size() > 0) {
642    return IntArrayRef{buffer_block->shape};
643  }
644  return IntArrayRef();
645}
646
647void MPSHeapAllocatorImpl::free(void* ptr) {
648  BufferBlock* buffer_block = nullptr;
649  {
650    std::lock_guard<std::recursive_mutex> lock(m_mutex);
651
652    buffer_block = get_allocated_buffer_block(ptr);
653    TORCH_INTERNAL_ASSERT(buffer_block);
654    const BufferPool& pool = *buffer_block->heap->pool;
655    if (!(pool.usage & UsageFlags::SCALAR)) {
656      free_buffer(buffer_block);
657      return;
658    }
659  }
660  // we sync the scalar pool manually with completion handler at the time buffer is
661  // freed when the MPSScalar instance goes our of scope
662  m_stream->addCompletedHandler(^(id<MTLCommandBuffer>) {
663    std::lock_guard<std::recursive_mutex> lock(m_mutex);
664    free_buffer(buffer_block);
665  });
666}
667
668void MPSHeapAllocatorImpl::freeInactiveBuffers() {
669  std::lock_guard<std::recursive_mutex> lock(m_mutex);
670
671  for (const auto& poolIt : m_pools) {
672    BufferPool& pool = *poolIt.second;
673    if (!pool.buffers_pending_free.empty()) {
674      for (auto it = pool.buffers_pending_free.begin(), last = pool.buffers_pending_free.end(); it != last;) {
675        BufferBlock* buffer_block = *it;
676        if (buffer_block->retainCount() <= 1) {
677          it = pool.buffers_pending_free.erase(it);
678          free_buffer(buffer_block);
679        } else {
680          ++it;
681        }
682      }
683    }
684  }
685}
686
687void MPSHeapAllocatorImpl::emptyCache() {
688  std::lock_guard<std::recursive_mutex> lock(m_mutex);
689  release_cached_buffers();
690}
691
692ssize_t MPSHeapAllocatorImpl::getLowWatermarkValue() {
693  // check if low watermark limit is disabled
694  if (m_low_watermark_ratio == 0.0) {
695    return std::numeric_limits<ssize_t>::max();
696  }
697  // current_allocated_size could exceed m_low_watermark_limit (e.g., when swapping to disk)
698  return std::max<ssize_t>(0, (ssize_t)(m_low_watermark_limit - current_allocated_size()) / 1048576L);
699}
700
701inline std::string MPSHeapAllocatorImpl::format_size(uint64_t size) const {
702  std::ostringstream os;
703  os.precision(2);
704  os << std::fixed;
705  if (size <= 1024UL) {
706    os << size << " bytes";
707  } else if (size <= 1048576UL) {
708    os << ((float)size / 1024.0) << " KB";
709  } else if (size <= 1073741824UL) {
710    os << ((float)size / 1048576.0) << " MB";
711  } else {
712    os << ((float)size / 1073741824.0) << " GB";
713  }
714  return os.str();
715}
716
717} // namespace HeapAllocator
718
719// Use "at::mps::GetMPSAllocator()" to acquire a handle to MPS Allocator
720namespace {
721HeapAllocator::MPSHeapAllocatorImpl& _getAllocImpl() {
722  static HeapAllocator::MPSHeapAllocatorImpl s_allocatorImpl;
723  return s_allocatorImpl;
724}
725} // namespace
726
727// MPS allocator struct to be registered with Pytorch
728struct TORCH_API MPSAllocator final : public IMPSAllocator {
729 public:
730  explicit MPSAllocator(uint32_t Usage)
731      : m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage) {
732    if (_getAllocImpl().getDebugVerbosity()) {
733      if (!(m_usage & HeapAllocator::UsageFlags::SHARED) || m_has_unified_memory) {
734        std::cerr << "Initializing " << ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private")
735                  << " heap allocator on " << (m_has_unified_memory ? "unified" : "discrete")
736                  << " device memory of size "
737                  << _getAllocImpl().format_size(_getAllocImpl().Device().recommendedMaxWorkingSetSize) << "\n";
738      }
739    }
740  }
741
742  ~MPSAllocator() override {
743    _getAllocImpl().emptyCache();
744  }
745  DeleterFnPtr raw_deleter() const override {
746    return &Delete;
747  }
748
749  DataPtr allocate(const size_t nbytes) override {
750    __block id<MTLBuffer> buf = nbytes > 0 ? _getAllocImpl().malloc(nbytes, m_usage) : nullptr;
751    return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
752  }
753
754  // implementation of IMPSAllocator interface
755  DataPtr allocScalarBufferWithValue(void* value, size_t size) const override {
756    id<MTLBuffer> buf = _getAllocImpl().allocScalarBufferWithValue(value, size);
757    return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)};
758  }
759  std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const override {
760    return _getAllocImpl().getSharedBufferPtr(ptr);
761  }
762  bool isSharedBuffer(const void* ptr) const override {
763    return _getAllocImpl().isSharedBuffer(ptr);
764  }
765  bool isSharedStorageSupported() const override {
766    return m_has_unified_memory;
767  }
768  void emptyCache() const override {
769    _getAllocImpl().emptyCache();
770  }
771  void freeInactiveBuffers() const override {
772    _getAllocImpl().freeInactiveBuffers();
773  }
774  ssize_t getUnalignedBufferSize(const void* ptr) const override {
775    return _getAllocImpl().getUnalignedBufferSize(ptr);
776  }
777  id_t getBufferId(const void* ptr) const override {
778    return _getAllocImpl().getBufferId(ptr);
779  };
780  IntArrayRef getBufferShape(const void* ptr) const override {
781    return _getAllocImpl().getBufferShape(ptr);
782  }
783  void setBufferShape(const void* ptr, const IntArrayRef& shape) const override {
784    _getAllocImpl().setBufferShape(ptr, shape);
785  }
786  size_t getTotalAllocatedMemory() const override {
787    return _getAllocImpl().getTotalAllocatedMemory();
788  }
789  size_t getCurrentAllocatedMemory() const override {
790    return _getAllocImpl().getCurrentAllocatedMemory();
791  }
792  size_t getDriverAllocatedMemory() const override {
793    return _getAllocImpl().getDriverAllocatedMemory();
794  }
795  size_t getRecommendedMaxMemory() const override {
796    return _getAllocImpl().getRecommendedMaxMemory();
797  }
798  ssize_t getLowWatermarkValue() const override {
799    return _getAllocImpl().getLowWatermarkValue();
800  }
801  size_t getLowWatermarkLimit() const override {
802    return _getAllocImpl().getLowWatermarkLimit();
803  }
804  size_t getHighWatermarkLimit() const override {
805    return _getAllocImpl().getHighWatermarkLimit();
806  }
807  void setLowWatermarkRatio(double ratio) const override {
808    _getAllocImpl().setLowWatermarkRatio(ratio);
809  }
810  void setHighWatermarkRatio(double ratio) const override {
811    _getAllocImpl().setHighWatermarkRatio(ratio);
812  }
813  bool recordEvents(c10::ArrayRef<const void*> buffers) const override {
814    return _getAllocImpl().recordEvents(buffers);
815  }
816  bool waitForEvents(c10::ArrayRef<const void*> buffers) const override {
817    return _getAllocImpl().waitForEvents(buffers);
818  }
819  std::string formatSize(size_t size) const override {
820    return _getAllocImpl().format_size(size);
821  }
822
823  void copy_data(void* dest, const void* src, std::size_t count) const final {
824    default_copy_data(dest, src, count);
825  }
826
827 private:
828  bool m_has_unified_memory;
829  uint32_t m_usage;
830
831  static void Delete(void* ptr) {
832    if (ptr) {
833      _getAllocImpl().free(ptr);
834    }
835  }
836};
837
838namespace {
839MPSAllocator& _getSharedAllocator() {
840  static MPSAllocator s_mps_shared_alloc(HeapAllocator::UsageFlags::SHARED);
841  return s_mps_shared_alloc;
842}
843
844MPSAllocator& _getPrivateAllocator() {
845  static MPSAllocator s_mps_private_alloc(HeapAllocator::UsageFlags::PRIVATE);
846  return s_mps_private_alloc;
847}
848} // anonymous namespace
849
850IMPSAllocator* getIMPSAllocator(bool sharedAllocator) {
851  if (!sharedAllocator) {
852    return &_getPrivateAllocator();
853  }
854  auto& sa = _getSharedAllocator();
855  if (sa.isSharedStorageSupported()) {
856    return &sa;
857  }
858  return nullptr;
859}
860
861// torch.is_pinned() implementation
862// Pinned memory will be helpful on Apple Silicon Macs with Unified memory as we
863// will be able to use SharedStorageMode for MTLBuffer allocations. This will
864// avoid extra copies on DataLoading operations.
865bool isMPSPinnedPtr(const void* data) {
866  return at::mps::_getSharedAllocator().isSharedBuffer(data);
867}
868
869} // namespace at::mps
870