xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/CUDAPluggableAllocator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Allocator.h>
4 #include <c10/cuda/CUDAGraphsC10Utils.h>
5 #include <c10/cuda/CUDAMacros.h>
6 #include <c10/cuda/CUDAStream.h>
7 
8 #include <c10/cuda/CUDACachingAllocator.h>
9 
10 #include <mutex>
11 
12 namespace torch::cuda::CUDAPluggableAllocator {
13 
14 using MallocFuncType = void*(size_t, int, cudaStream_t);
15 using FreeFuncType = void(void*, size_t, int, cudaStream_t);
16 
17 // A CUDAPluggableAllocatorDeleterContext object is used as the `ctx`
18 // argument for DataPtr. We need context because a user can use
19 // multiple allocators in the same PyTorch program, and
20 // the allocators can have different free functions, such as:
21 // free, cudaFree, cudaFreeAsync, ncclMemFree etc.
22 struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext {
23   explicit CUDAPluggableAllocatorDeleterContext(
24       std::function<FreeFuncType> free_fn,
25       void* data,
26       size_t size,
27       int device,
28       cudaStream_t stream);
29 
30   void free();
31 
32  private:
33   std::function<FreeFuncType> free_fn_;
34   void* data_;
35   size_t size_;
36   int device_;
37   cudaStream_t stream_;
38 };
39 
40 #if defined(TORCH_HIP_VERSION)
41 using streamType = c10::hip::HIPStream;
42 #else
43 using streamType = c10::cuda::CUDAStream;
44 #endif
45 
46 TORCH_CUDA_CPP_API std::shared_ptr<
47     c10::cuda::CUDACachingAllocator::CUDAAllocator>
48 getCurrentAllocator();
49 TORCH_CUDA_CPP_API std::shared_ptr<
50     c10::cuda::CUDACachingAllocator::CUDAAllocator>
51 createCustomAllocator(
52     std::function<MallocFuncType> alloc_fn,
53     std::function<FreeFuncType> free_fn);
54 TORCH_CUDA_CPP_API void changeCurrentAllocator(
55     const std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>&
56         allocator);
57 
58 struct _AllocationMetadata {
59   _AllocationMetadata();
60   _AllocationMetadata(
61       size_t size,
62       c10::DeviceIndex device_idx,
63       cudaStream_t stream);
64   size_t size;
65   c10::DeviceIndex device_idx;
66   cudaStream_t stream;
67 };
68 
69 struct TORCH_CUDA_CPP_API CUDAPluggableAllocator
70     : public c10::cuda::CUDACachingAllocator::CUDAAllocator {
71   CUDAPluggableAllocator(
72       std::function<MallocFuncType> alloc_fn,
73       std::function<FreeFuncType> free_fn);
74 
75   CUDAPluggableAllocator(CUDAPluggableAllocator& other);
76   CUDAPluggableAllocator& operator=(CUDAPluggableAllocator& other) = delete;
77 
78   void set_init_fn(std::function<void(int)> init_fn);
79 
80   void set_reset_fn(std::function<void()> reset_fn);
81 
82   void set_memory_fraction_fn(
83       std::function<void(double, int)> memory_fraction_fn);
84 
85   void set_base_alloc_fn(std::function<void*(void*, size_t*)> base_alloc_fn);
86 
87   void set_record_stream_fn(
88       std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn);
89 
90   void set_begin_allocate_to_pool(
91       std::function<
92           void(int, c10::cuda::MempoolId_t, std::function<bool(cudaStream_t)>)>
93           capture_begin_fn);
94 
95   void set_end_allocate_to_pool_fn(
96       std::function<void(int, c10::cuda::MempoolId_t)> capture_about_to_end_fn);
97 
98   void set_release_pool(
99       std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn);
100 
101   void* malloc(size_t size, c10::DeviceIndex device, cudaStream_t stream);
102 
103   c10::DataPtr allocate(size_t size) override;
104   c10::DeleterFnPtr raw_deleter() const override;
105 
106   void* raw_alloc(size_t nbytes) override;
107   void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override;
108   void raw_delete(void* ptr) override;
109   void init(int device_count) override;
110   bool initialized() override;
111   void setMemoryFraction(double fraction, c10::DeviceIndex device) override;
112   void emptyCache() override;
113   void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override;
114   void* getBaseAllocation(void* ptr, size_t* size) override;
115 
116   void recordStream(const c10::DataPtr&, streamType stream) override;
117 
118   c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
119       c10::DeviceIndex device) override;
120   void resetAccumulatedStats(c10::DeviceIndex device) override;
121   void resetPeakStats(c10::DeviceIndex device) override;
122   c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override;
123   void beginAllocateToPool(
124       c10::DeviceIndex device,
125       c10::cuda::MempoolId_t mempool_id,
126       std::function<bool(cudaStream_t)>) override;
127   void endAllocateToPool(
128       c10::DeviceIndex device,
129       c10::cuda::MempoolId_t mempool_id) override;
130   void releasePool(c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id)
131       override;
132   std::shared_ptr<void> getIpcDevPtr(std::string handle) override;
133   c10::cuda::CUDACachingAllocator::ShareableHandle shareIpcHandle(
134       void*) override;
135   void recordHistory(
136       bool enabled,
137       c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder,
138       size_t alloc_trace_max_entries,
139       c10::cuda::CUDACachingAllocator::RecordContext when) override;
140   void attachOutOfMemoryObserver(
141       c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) override;
142   void attachAllocatorTraceTracker(
143       c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker) override;
144   std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState>
145   getCheckpointState(c10::DeviceIndex device, at::cuda::MempoolId_t id)
146       override;
147   c10::cuda::CUDACachingAllocator::CheckpointDelta setCheckpointPoolState(
148       c10::DeviceIndex device,
149       std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps)
150       override;
151   void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access)
152       override;
153   cudaError_t memcpyAsync(
154       void* dst,
155       int dstDevice,
156       const void* src,
157       int srcDevice,
158       size_t count,
159       cudaStream_t stream,
160       bool p2p_enabled) override;
161   std::string name() override;
162   void copy_data(void* dest, const void* src, std::size_t count) const final;
163 
164  protected:
165   std::function<MallocFuncType> alloc_fn_;
166   std::function<FreeFuncType> free_fn_;
167   std::function<void(int)> init_fn_;
168   std::function<void()> reset_fn_;
169   std::function<void(double, int)> memory_fraction_fn_;
170   std::function<void*(void*, size_t*)> base_alloc_fn_;
171   std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn_;
172   std::function<
173       void(int, c10::cuda::MempoolId_t, std::function<bool(cudaStream_t)>)>
174       begin_allocate_to_pool_fn_;
175   std::function<void(int, c10::cuda::MempoolId_t)> end_allocate_to_pool_fn_;
176   std::function<void(int, c10::cuda::MempoolId_t)> relase_pool_fn_;
177   std::mutex allocator_mutex_;
178   // We do the bookeeping here in order to simplify custom allocators
179   std::unordered_map<void*, _AllocationMetadata> allocation_metadata_;
180 
181   bool initialized_ = false;
182 };
183 } // namespace torch::cuda::CUDAPluggableAllocator
184