xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSAllocator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //  Copyright © 2022 Apple Inc.
2 
3 #pragma once
4 
5 #include <ATen/mps/MPSAllocatorInterface.h>
6 #include <ATen/mps/MPSEvent.h>
7 #include <ATen/mps/MPSStream.h>
8 
9 #include <cstdio>
10 #include <mutex>
11 #include <set>
12 #include <unordered_set>
13 #include <mach/vm_page_size.h>
14 #include <c10/util/flat_hash_map.h>
15 
16 // this implementation is based on CUDACachingAllocator.
17 // It utilizes Metal Heaps to improve the performance with buffer allocation.
18 // Do not include this header. Use MPSAllocatorInterface.h instead.
19 // TODO: Unify the logic with CUDACachingAllocator and remove redundant code.
20 namespace at::mps::HeapAllocator {
21 
22 static const size_t kMaxSmallAlloc = MB(1);    // largest "small" allocation is 1 MiB
23 static const size_t kMinLargeAlloc = MB(10);   // allocations between 1 and 10 MiB may use kLargeHeap
24 static const size_t kRoundLarge    = MB(2);    // round up large allocations to 2 MiB
25 static const size_t kSmallHeap     = MB(8);    // "small" allocations are packed in 8 MiB heaps
26 static const size_t kLargeHeap     = MB(32);   // "large" allocations may be packed in 32 MiB heaps
27 static const size_t kXLargeHeapD   = MB(128);  // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
28 static const size_t kXLargeHeapU   = MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
29 static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
30 
31 // buffer pools could be customized with a combination of usage flags
32 enum UsageFlags : uint32_t {
33   PRIVATE = 0,
34   SMALL   = (1 << 0), // small heaps have sizes of kSmallHeap, and large ones kLargeHeap
35   SHARED  = (1 << 1), // shared pools allocated on devices with unified memory; otherwise, private between host/device
36   MANAGED = (1 << 2), // managed storage mode
37   HAZARD  = (1 << 3), // enables Automatic Hazard Tracking for the resources allocated on the pool
38   SCALAR  = (1 << 4), // used to import CPU scalar values to GPU and use them in MPS Stream
39 };
40 // debug verbosity flags
41 enum DebugVerbosity : uint32_t {
42   SILENT      = 0,
43   PROFILING   = (1 << 0), // print generic profiling data for total system memory usage
44   ALLOCATIONS = (1 << 1), // print buffer allocations
45   RECYCLES    = (1 << 2), // print buffer recycling
46   RELEASES    = (1 << 3), // print buffer releases
47   LARGE_ONLY  = (1 << 4), // only log large buffer pool transactions
48 };
49 
50 struct HeapBlock;
51 
52 struct BufferBlock {
53   id<MTLBuffer> buffer;
54   void* cpu_ptr = nullptr; // stores the pointer to CPU mapping of a Shared MTLBuffer
55   size_t size; // size after alignment
56   size_t requested_size; // requested size (before alignment)
57   // buffer shape is used for retrieving base of views in cached graphs
58   std::vector<int64_t> shape;
59   bool in_use = false;
60   HeapBlock* heap;
61   id_t buf_id;
62   // counter to candidate least recently used buffers for garbage collection
63   uint32_t gc_count = 0;
64   uint32_t use_count = 0;
65   // counter to assign unique ids to buffer blocks
66   static uint64_t buffer_counter;
67   // Metal events used to sync GPU/CPU operations on the shared-storage buffers
68   MPSEventPtr event;
69 
70   BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr,
71               HeapBlock* Heap = nullptr) :
bufferBufferBlock72               buffer(Buffer), size(Size), requested_size(RequestedSize),
73               heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) { }
74 
ComparatorBufferBlock75   static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
76     return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
77   }
alignUpBufferBlock78   static size_t alignUp(size_t Size, size_t Alignment) {
79     assert(((Alignment - 1) & Alignment) == 0);
80     return ((Size + Alignment - 1) & ~(Alignment - 1));
81   }
retainCountBufferBlock82   uint32_t retainCount() const { return [buffer retainCount]; }
83 };
84 typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
85 
86 struct BufferPool;
87 struct AllocParams {
AllocParamsAllocParams88   AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) :
89               search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) { }
sizeAllocParams90   size_t size() const { return search_key.size; }
91 
92   BufferBlock search_key;
93   BufferPool* pool;
94   BufferBlock* buffer_block = nullptr;
95   size_t requested_size;
96   // true if we exceed the low watermark limit. In this case
97   // we apply strategies to relieve the pressure before allocation.
98   bool has_memory_pressure = false;
99   // true if we're allocating on a unified memory device
100   bool has_unified_memory = true;
101 };
102 
103 struct HeapBlock {
104   id<MTLHeap> heap;
105   struct { size_t total, available; } size;
106   BufferPool* pool;
107   unsigned int n_buffers = 0;
108   id_t heap_id;
109   // indicates if we split this heap to sub-allocate 'several' buffers (otherwise single buffer)
110   bool is_split;
111   // counter to assign unique ids to heap blocks
112   static uint64_t heap_counter;
113 
114   HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool *Pool = nullptr) :
heapHeapBlock115             heap(Heap), size({.total = Size, .available = Size}), pool(Pool),
116             heap_id(Heap ? ++heap_counter : 0), is_split(true) { }
117 
getOptionsHeapBlock118   static MTLResourceOptions getOptions(uint32_t usage) {
119     // TODO: check the caching performance of write-combined mode
120     MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache;
121 
122     if (usage & UsageFlags::MANAGED)
123       options |= MTLResourceStorageModeManaged;
124     else if (usage & UsageFlags::SHARED)
125       options |= MTLResourceStorageModeShared;
126     else
127       options |= MTLResourceStorageModePrivate;
128 
129     options |= (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
130 
131     return options;
132   }
133 
createHeapBlockHeapBlock134   static HeapBlock* createHeapBlock(AllocParams& params, id<MTLDevice> device, uint32_t usage) {
135     HeapBlock *heapBlock = nullptr;
136     bool is_split = true;
137     const size_t size = params.size();
138     MTLHeapDescriptor *d = [MTLHeapDescriptor new];
139     if (d) {
140       const size_t kXLargeHeap = params.has_unified_memory ? kXLargeHeapU : kXLargeHeapD;
141       if (size <= kMaxSmallAlloc) {
142         d.size = kSmallHeap;
143       } else if (size < kMinLargeAlloc) {
144         d.size = kLargeHeap;
145       } else if (size < kXLargeHeap / 2 && !params.has_memory_pressure) {
146         d.size = kXLargeHeap;
147       } else {
148         d.size = kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
149         is_split = false;
150       }
151       d.storageMode = (usage & UsageFlags::SHARED) ? MTLStorageModeShared : MTLStorageModePrivate;
152       d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
153       // this automatically handles Metal buffer access synchronizations at the
154       // cost of slightly lower performance.
155       d.hazardTrackingMode = (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
156       d.resourceOptions = getOptions(usage);
157       d.type = MTLHeapTypeAutomatic;
158       id<MTLHeap> heap = [device newHeapWithDescriptor: d];
159       if (heap) {
160         [heap setPurgeableState:MTLPurgeableStateNonVolatile];
161         const size_t heap_size = heapAvailableSize(heap);
162         heapBlock = new HeapBlock(heap_size, heap, params.pool);
163         if (heapBlock) {
164           heapBlock->is_split = is_split;
165         }
166       }
167       [d release];
168     }
169     return heapBlock;
170   }
ComparatorHeapBlock171   static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
172     return (a->size.available != b->size.available) ? a->size.available < b->size.available :
173                                                       (uintptr_t)a->heap < (uintptr_t)b->heap;
174   }
175   static NSUInteger heapAvailableSize(id<MTLHeap> heap, size_t Alignment = vm_page_size) {
176     return [heap maxAvailableSizeWithAlignment:Alignment];
177   }
SizeHeapBlock178   NSUInteger Size() {
179     return [heap size];
180   }
newMTLBufferHeapBlock181   id<MTLBuffer> newMTLBuffer(size_t length, uint32_t usage) {
182     id<MTLBuffer> buf = [heap newBufferWithLength:length options:getOptions(usage)];
183     if (buf) {
184       updateAvailableSize();
185       n_buffers++;
186     }
187     return buf;
188   }
189   // returns the retainCount before releasing the buffer
releaseMTLBufferHeapBlock190   uint32_t releaseMTLBuffer(id<MTLBuffer>& buffer) {
191     const uint32_t retainCount = [buffer retainCount];
192     [buffer release];
193     buffer = nil;
194     updateAvailableSize();
195     n_buffers--;
196     return retainCount;
197   }
198   // returns the retainCount before releasing the heap
releaseMTLHeapHeapBlock199   uint32_t releaseMTLHeap() {
200     const uint32_t retainCount = [heap retainCount];
201     TORCH_INTERNAL_ASSERT(!n_buffers); // assert if heap isn't empty
202     [heap setPurgeableState:MTLPurgeableStateEmpty];
203     [heap release];
204     heap = nil;
205     size.available = 0;
206     return retainCount;
207   }
retainCountHeapBlock208   uint32_t retainCount() const { return [heap retainCount]; }
updateAvailableSizeHeapBlock209   void updateAvailableSize() { size.available = heapAvailableSize(heap); }
210 };
211 typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
212 
213 struct BufferPool {
214   enum class Kind {
215     PRIVATE_SMALL,
216     PRIVATE_LARGE,
217     SHARED_SMALL,
218     SHARED_LARGE,
219     SCALAR,
220   };
221 
BufferPoolBufferPool222   BufferPool(const id<MTLDevice> Device, uint32_t Usage) :
223              device(Device), usage(Usage),
224              heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) { }
225 
226   const id<MTLDevice> device;
227   // usage flags to customize the pool for various purposes (see UsageFlags enum)
228   const uint32_t usage;
229   // total number of buffers in the pool
230   uint32_t n_buffers = 0;
231   // total allocations size on this pool
232   size_t allocated_size = 0;
233   // total memory available in the pool
234   size_t available_size = 0;
235   // list of heaps ordered by their "available" (not total) memory size
236   std::set<HeapBlock*, HeapComparison> heaps;
237   // list of only "available" buffers in the pool (i.e., buffers not in-use)
238   std::set<BufferBlock*, BufferComparison> available_buffers;
239   // list of buffers that are in a state of "limbo" where they've already been freed
240   // from PyTorch-side, but were not returned to pool due to still being
241   // in-use by command buffers with retainCount > 1. In this state, the buffer is
242   // neither ready to be recycled, nor could be returned to pool as available.
243   // These buffers will be returned to pool once the command buffer's
244   // completionHandler callbacks are called.
245   std::unordered_set<BufferBlock*> buffers_pending_free;
246   // list of heaps pending size update
247   std::unordered_set<HeapBlock*> heaps_pending_update;
248 };
249 
250 class MPSHeapAllocatorImpl {
251 public:
MPSHeapAllocatorImpl()252   explicit MPSHeapAllocatorImpl() :
253     m_device(at::mps::MPSDevice::getInstance()->device()),
254     m_max_buffer_size([m_device maxBufferLength]),
255     m_stream(getDefaultMPSStream()),
256     m_event_pool(getMPSEventPool()) {
257     init_allocator();
258   }
~MPSHeapAllocatorImpl()259   ~MPSHeapAllocatorImpl() {
260     emptyCache();
261   }
262   // interface exposed to at::Allocator
263   id<MTLBuffer> malloc(size_t size, uint32_t usage);
264   // frees a buffer and returns it into buffer pool
265   void free(void* ptr);
266   // releases all the cached buffers and their associated heaps
267   void emptyCache();
268   // free inactive buffers that are pending to be freed
269   void freeInactiveBuffers();
270   // returns true if buffer was allocated from the shared pool
271   bool isSharedBuffer(const void* ptr);
272   // get the requested unaligned size of an MTLBuffer
273   ssize_t getUnalignedBufferSize(const void* ptr);
274   // set the shape of a base tensor from a view tensor
275   void setBufferShape(const void* ptr, const IntArrayRef& shape);
276   // retrieve the shape of a base tensor from a view tensor
277   IntArrayRef getBufferShape(const void* ptr);
278   // get the unique ID of the buffer
279   id_t getBufferId(const void* ptr);
280   // allocate a buffer from a specialized pool to import CPU scalars into GPU
281   id<MTLBuffer> allocScalarBufferWithValue(void* value, size_t size);
282   // returns a CPU-mapping of the input buffer and its retainCount,
283   // if only it has Shared storage-mode and allocated on MPSAllocator
284   std::pair<const void*, uint32_t> getSharedBufferPtr(const void* buffer);
285   // records events for a list of MTLBuffers (list is used to lock the mutex once)
286   // returns true if records any event (given if passed buffers exist and are shared-storage)
287   bool recordEvents(c10::ArrayRef<const void*> buffers);
288   // waits for the event to signal the completion of GPU execution
289   // on the passed shared buffers (list is used to lock the mutex once)
290   // returns true if actually waited on any event
291   bool waitForEvents(c10::ArrayRef<const void*> buffers);
292   // this indicates how far (in Megabytes) the current total allocations are from the
293   // low watermark limit which is used to detect if we're under memory pressure
294   // This returns zero if we've reached the low watermark limit
295   ssize_t getLowWatermarkValue();
296   // (see m_low_watermark_ratio for description)
297   void setLowWatermarkRatio(double ratio);
298   // (see m_high_watermark_ratio for description)
299   void setHighWatermarkRatio(double ratio);
300   // (see m_low_watermark_limit for description)
getLowWatermarkLimit()301   size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
302   // (see m_max_total_allowed_size for description)
getHighWatermarkLimit()303   size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
304   // (see m_total_allocated_memory for description)
getTotalAllocatedMemory()305   size_t getTotalAllocatedMemory() const { return m_total_allocated_memory; }
306   // (see m_current_allocated_memory for description)
getCurrentAllocatedMemory()307   size_t getCurrentAllocatedMemory() const { return m_current_allocated_memory; }
308   // total GPU memory allocated in the process by Metal driver; including
309   // implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
getDriverAllocatedMemory()310   size_t getDriverAllocatedMemory() const { return current_allocated_size(); }
311   // recommended Max memory for Metal
getRecommendedMaxMemory()312   size_t getRecommendedMaxMemory() const { return max_device_size(); }
313   // (see enum DebugVerbosity for description)
getDebugVerbosity()314   uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
315   // returns the device that we allocate from
Device()316   inline id<MTLDevice> Device() const { return m_device; }
317 
318   // TODO: make a common function to do size unit conversions in PyTorch.
319   inline std::string format_size(uint64_t size) const;
320 
321 private:
322   // (see m_high_watermark_ratio for description)
323   constexpr static double default_high_watermark_ratio = 1.7;
324   // we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize.
325   constexpr static double default_high_watermark_upper_bound = 2.0;
326   // (see m_low_watermark_ratio for description)
327   // on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize
328   constexpr static double default_low_watermark_ratio_unified  = 1.4;
329   constexpr static double default_low_watermark_ratio_discrete = 1.0;
330 
331   const id<MTLDevice> m_device;
332   std::recursive_mutex m_mutex;
333   // allocated buffers by device pointer
334   ska::flat_hash_map<const void*, BufferBlock*> m_allocated_buffers;
335   // using a container for pools to simplify iterating them
336   ska::flat_hash_map<BufferPool::Kind, std::unique_ptr<BufferPool>> m_pools;
337   // total memory allocated by HeapAllocator (including blocks in pools)
338   size_t m_total_allocated_memory = 0;
339   // currently active memory allocations in use (i.e., blocks not in pools)
340   size_t m_current_allocated_memory = 0;
341   // max buffer size allowed by Metal
342   size_t m_max_buffer_size = 0;
343   // maximum total size allowed to be allocated
344   size_t m_max_total_allowed_size = 0;
345   // high watermark ratio is a hard limit for the total allowed allocations
346   // 0. : disables high watermark limit (may cause system failure if system-wide OOM occurs)
347   // 1. : recommended maximum allocation size (i.e., device.recommendedMaxWorkingSetSize)
348   // >1.: allows limits beyond the device.recommendedMaxWorkingSetSize
349   // e.g., value 0.95 means we allocate up to 95% of recommended maximum
350   // allocation size; beyond that, the allocations would fail with OOM error.
351   double m_high_watermark_ratio;
352   // low watermark ratio is a soft limit to attempt limiting memory allocations up to the lower watermark
353   // level by garbage collection or committing command buffers more frequently (a.k.a, adaptive commit).
354   // Value between 0 to m_high_watermark_ratio (setting 0.0 disables adaptive commit and garbage collection)
355   // e.g., value 0.9 means we 'attempt' to limit allocations up to 90% of recommended maximum
356   // allocation size.
357   double m_low_watermark_ratio;
358   // low watermark size limit (in Bytes) at the time we initialize the allocator
359   size_t m_low_watermark_limit;
360   // use "PYTORCH_DEBUG_MPS_ALLOCATOR" env-var to set debug verbosity
361   uint32_t m_debug_verbosity;
362   // default MPS stream
363   MPSStream* m_stream;
364   // we hold a reference to MPSEventPool so it could get destroyed after MPSAllocator
365   std::shared_ptr<MPSEventPool> m_event_pool;
366 
367   void init_allocator();
368   void init_buffer_pools();
369   HeapBlock* get_free_heap(AllocParams& params);
370   bool get_free_buffer(AllocParams& params);
371   BufferBlock* get_allocated_buffer_block(const void* ptr);
372   BufferBlock* alloc_buffer_block(size_t size, uint32_t usage);
373   bool alloc_buffer(AllocParams& params);
374   void free_buffer(BufferBlock* buffer_block);
375   // returns true if the container heap is also released
376   bool release_buffer(BufferBlock* buffer_block, bool remove_empty_heap = true);
377   void release_buffers(BufferPool& pool);
378   bool release_available_cached_buffers(AllocParams& params);
379   bool release_cached_buffers();
380   // free unused cached blocks to reclaim GPU memory if memory pressure is high
381   void garbage_collect_cached_buffers(AllocParams& params);
382   // returns the suitable buffer pool type for the usage or
383   // requested/allocated sizes
384   BufferPool& get_pool(size_t requested_size, size_t aligned_size, uint32_t usage);
385   // returns the aligned allocation size that is optimized
386   // for the buffers to get reused frequently
387   size_t get_allocation_size(size_t size, uint32_t usage) const;
388   // maximum size of device memory available for allocation in current process
389   // Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
max_device_size()390   size_t max_device_size() const { return [m_device recommendedMaxWorkingSetSize]; }
391   // there are implicit allocations from MPS backend, so we need to query the 'device' for
392   // total allocated size instead of manually tracking in MPSAllocator
current_allocated_size()393   size_t current_allocated_size() const { return [m_device currentAllocatedSize]; }
394 
trigger_memory_callbacks(BufferBlock * buffer_block,IMpsAllocatorCallback::EventType event)395   bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
396     for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
397       MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block ? buffer_block->buffer : nullptr, event);
398     }
399     return true;
400   }
401 };
402 
403 } // namespace at::mps::HeapAllocator
404