1 #pragma once 2 3 #include <c10/core/Allocator.h> 4 #include <c10/cuda/CUDAGraphsC10Utils.h> 5 #include <c10/cuda/CUDAMacros.h> 6 #include <c10/cuda/CUDAStream.h> 7 8 #include <c10/cuda/CUDACachingAllocator.h> 9 10 #include <mutex> 11 12 namespace torch::cuda::CUDAPluggableAllocator { 13 14 using MallocFuncType = void*(size_t, int, cudaStream_t); 15 using FreeFuncType = void(void*, size_t, int, cudaStream_t); 16 17 // A CUDAPluggableAllocatorDeleterContext object is used as the `ctx` 18 // argument for DataPtr. We need context because a user can use 19 // multiple allocators in the same PyTorch program, and 20 // the allocators can have different free functions, such as: 21 // free, cudaFree, cudaFreeAsync, ncclMemFree etc. 22 struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext { 23 explicit CUDAPluggableAllocatorDeleterContext( 24 std::function<FreeFuncType> free_fn, 25 void* data, 26 size_t size, 27 int device, 28 cudaStream_t stream); 29 30 void free(); 31 32 private: 33 std::function<FreeFuncType> free_fn_; 34 void* data_; 35 size_t size_; 36 int device_; 37 cudaStream_t stream_; 38 }; 39 40 #if defined(TORCH_HIP_VERSION) 41 using streamType = c10::hip::HIPStream; 42 #else 43 using streamType = c10::cuda::CUDAStream; 44 #endif 45 46 TORCH_CUDA_CPP_API std::shared_ptr< 47 c10::cuda::CUDACachingAllocator::CUDAAllocator> 48 getCurrentAllocator(); 49 TORCH_CUDA_CPP_API std::shared_ptr< 50 c10::cuda::CUDACachingAllocator::CUDAAllocator> 51 createCustomAllocator( 52 std::function<MallocFuncType> alloc_fn, 53 std::function<FreeFuncType> free_fn); 54 TORCH_CUDA_CPP_API void changeCurrentAllocator( 55 const std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>& 56 allocator); 57 58 struct _AllocationMetadata { 59 _AllocationMetadata(); 60 _AllocationMetadata( 61 size_t size, 62 c10::DeviceIndex device_idx, 63 cudaStream_t stream); 64 size_t size; 65 c10::DeviceIndex device_idx; 66 cudaStream_t stream; 67 }; 68 69 struct TORCH_CUDA_CPP_API CUDAPluggableAllocator 70 : public c10::cuda::CUDACachingAllocator::CUDAAllocator { 71 CUDAPluggableAllocator( 72 std::function<MallocFuncType> alloc_fn, 73 std::function<FreeFuncType> free_fn); 74 75 CUDAPluggableAllocator(CUDAPluggableAllocator& other); 76 CUDAPluggableAllocator& operator=(CUDAPluggableAllocator& other) = delete; 77 78 void set_init_fn(std::function<void(int)> init_fn); 79 80 void set_reset_fn(std::function<void()> reset_fn); 81 82 void set_memory_fraction_fn( 83 std::function<void(double, int)> memory_fraction_fn); 84 85 void set_base_alloc_fn(std::function<void*(void*, size_t*)> base_alloc_fn); 86 87 void set_record_stream_fn( 88 std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn); 89 90 void set_begin_allocate_to_pool( 91 std::function< 92 void(int, c10::cuda::MempoolId_t, std::function<bool(cudaStream_t)>)> 93 capture_begin_fn); 94 95 void set_end_allocate_to_pool_fn( 96 std::function<void(int, c10::cuda::MempoolId_t)> capture_about_to_end_fn); 97 98 void set_release_pool( 99 std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn); 100 101 void* malloc(size_t size, c10::DeviceIndex device, cudaStream_t stream); 102 103 c10::DataPtr allocate(size_t size) override; 104 c10::DeleterFnPtr raw_deleter() const override; 105 106 void* raw_alloc(size_t nbytes) override; 107 void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override; 108 void raw_delete(void* ptr) override; 109 void init(int device_count) override; 110 bool initialized() override; 111 void setMemoryFraction(double fraction, c10::DeviceIndex device) override; 112 void emptyCache() override; 113 void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override; 114 void* getBaseAllocation(void* ptr, size_t* size) override; 115 116 void recordStream(const c10::DataPtr&, streamType stream) override; 117 118 c10::CachingDeviceAllocator::DeviceStats getDeviceStats( 119 c10::DeviceIndex device) override; 120 void resetAccumulatedStats(c10::DeviceIndex device) override; 121 void resetPeakStats(c10::DeviceIndex device) override; 122 c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override; 123 void beginAllocateToPool( 124 c10::DeviceIndex device, 125 c10::cuda::MempoolId_t mempool_id, 126 std::function<bool(cudaStream_t)>) override; 127 void endAllocateToPool( 128 c10::DeviceIndex device, 129 c10::cuda::MempoolId_t mempool_id) override; 130 void releasePool(c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id) 131 override; 132 std::shared_ptr<void> getIpcDevPtr(std::string handle) override; 133 c10::cuda::CUDACachingAllocator::ShareableHandle shareIpcHandle( 134 void*) override; 135 void recordHistory( 136 bool enabled, 137 c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder, 138 size_t alloc_trace_max_entries, 139 c10::cuda::CUDACachingAllocator::RecordContext when) override; 140 void attachOutOfMemoryObserver( 141 c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) override; 142 void attachAllocatorTraceTracker( 143 c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker) override; 144 std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> 145 getCheckpointState(c10::DeviceIndex device, at::cuda::MempoolId_t id) 146 override; 147 c10::cuda::CUDACachingAllocator::CheckpointDelta setCheckpointPoolState( 148 c10::DeviceIndex device, 149 std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps) 150 override; 151 void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) 152 override; 153 cudaError_t memcpyAsync( 154 void* dst, 155 int dstDevice, 156 const void* src, 157 int srcDevice, 158 size_t count, 159 cudaStream_t stream, 160 bool p2p_enabled) override; 161 std::string name() override; 162 void copy_data(void* dest, const void* src, std::size_t count) const final; 163 164 protected: 165 std::function<MallocFuncType> alloc_fn_; 166 std::function<FreeFuncType> free_fn_; 167 std::function<void(int)> init_fn_; 168 std::function<void()> reset_fn_; 169 std::function<void(double, int)> memory_fraction_fn_; 170 std::function<void*(void*, size_t*)> base_alloc_fn_; 171 std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn_; 172 std::function< 173 void(int, c10::cuda::MempoolId_t, std::function<bool(cudaStream_t)>)> 174 begin_allocate_to_pool_fn_; 175 std::function<void(int, c10::cuda::MempoolId_t)> end_allocate_to_pool_fn_; 176 std::function<void(int, c10::cuda::MempoolId_t)> relase_pool_fn_; 177 std::mutex allocator_mutex_; 178 // We do the bookeeping here in order to simplify custom allocators 179 std::unordered_map<void*, _AllocationMetadata> allocation_metadata_; 180 181 bool initialized_ = false; 182 }; 183 } // namespace torch::cuda::CUDAPluggableAllocator 184