1 #include <c10/cuda/CUDACachingAllocator.h>
2 #include <c10/cuda/CUDAGuard.h>
3 #include <mutex>
4 #include <unordered_map>
5 #include <utility>
6
7 #include <torch/csrc/cuda/CUDAPluggableAllocator.h>
8
9 namespace torch::cuda::CUDAPluggableAllocator {
10
CUDAPluggableAllocatorDeleterContext(std::function<FreeFuncType> free_fn,void * data,size_t size,int device,cudaStream_t stream)11 CUDAPluggableAllocatorDeleterContext::CUDAPluggableAllocatorDeleterContext(
12 std::function<FreeFuncType> free_fn,
13 void* data,
14 size_t size,
15 int device,
16 cudaStream_t stream)
17 : free_fn_(free_fn),
18 data_(data),
19 size_(size),
20 device_(device),
21 stream_(stream) {}
22
free()23 void CUDAPluggableAllocatorDeleterContext::free() {
24 free_fn_(data_, size_, device_, stream_);
25 delete this;
26 }
27
28 int device_count = 0;
29
30 void custom_raw_deleter(void* ptr);
31
_AllocationMetadata()32 _AllocationMetadata::_AllocationMetadata()
33 : size(0), device_idx(-1), stream{} {}
34
_AllocationMetadata(size_t size,c10::DeviceIndex device_idx,cudaStream_t stream)35 _AllocationMetadata::_AllocationMetadata(
36 size_t size,
37 c10::DeviceIndex device_idx,
38 cudaStream_t stream)
39 : size(size), device_idx(device_idx), stream(stream) {}
40
41 // This is a fast API to just register allocators
42 // based on function pointers (ie. external .so libraries)
43 // This avoids having to link against libtorch for C++ based custom allocators
44 // And also use this from python
CUDAPluggableAllocator(std::function<MallocFuncType> alloc_fn,std::function<FreeFuncType> free_fn)45 CUDAPluggableAllocator::CUDAPluggableAllocator(
46 std::function<MallocFuncType> alloc_fn,
47 std::function<FreeFuncType> free_fn)
48 : alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {}
49
CUDAPluggableAllocator(CUDAPluggableAllocator & other)50 CUDAPluggableAllocator::CUDAPluggableAllocator(CUDAPluggableAllocator& other)
51 : alloc_fn_(other.alloc_fn_),
52 free_fn_(other.free_fn_),
53 init_fn_(other.init_fn_),
54 reset_fn_(other.reset_fn_),
55 memory_fraction_fn_(other.memory_fraction_fn_),
56 base_alloc_fn_(other.base_alloc_fn_),
57 record_stream_fn_(other.record_stream_fn_),
58 begin_allocate_to_pool_fn_(other.begin_allocate_to_pool_fn_),
59 end_allocate_to_pool_fn_(other.end_allocate_to_pool_fn_),
60 relase_pool_fn_(other.relase_pool_fn_) {}
61
set_init_fn(std::function<void (int)> init_fn)62 void CUDAPluggableAllocator::set_init_fn(std::function<void(int)> init_fn) {
63 init_fn_ = std::move(init_fn);
64 }
65
set_reset_fn(std::function<void ()> reset_fn)66 void CUDAPluggableAllocator::set_reset_fn(std::function<void()> reset_fn) {
67 reset_fn_ = std::move(reset_fn);
68 }
69
set_memory_fraction_fn(std::function<void (double,int)> memory_fraction_fn)70 void CUDAPluggableAllocator::set_memory_fraction_fn(
71 std::function<void(double, int)> memory_fraction_fn) {
72 memory_fraction_fn_ = std::move(memory_fraction_fn);
73 }
74
set_base_alloc_fn(std::function<void * (void *,size_t *)> base_alloc_fn)75 void CUDAPluggableAllocator::set_base_alloc_fn(
76 std::function<void*(void*, size_t*)> base_alloc_fn) {
77 base_alloc_fn_ = std::move(base_alloc_fn);
78 }
79
set_record_stream_fn(std::function<void (void * ptr,cudaStream_t stream)> record_stream_fn)80 void CUDAPluggableAllocator::set_record_stream_fn(
81 std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn) {
82 record_stream_fn_ = std::move(record_stream_fn);
83 }
84
set_begin_allocate_to_pool(std::function<void (int,c10::cuda::MempoolId_t,std::function<bool (cudaStream_t)>)> capture_begin_fn)85 void CUDAPluggableAllocator::set_begin_allocate_to_pool(
86 std::function<
87 void(int, c10::cuda::MempoolId_t, std::function<bool(cudaStream_t)>)>
88 capture_begin_fn) {
89 begin_allocate_to_pool_fn_ = std::move(capture_begin_fn);
90 }
91
set_end_allocate_to_pool_fn(std::function<void (int,c10::cuda::MempoolId_t)> capture_about_to_end_fn)92 void CUDAPluggableAllocator::set_end_allocate_to_pool_fn(
93 std::function<void(int, c10::cuda::MempoolId_t)> capture_about_to_end_fn) {
94 end_allocate_to_pool_fn_ = std::move(capture_about_to_end_fn);
95 }
96
set_release_pool(std::function<void (int,c10::cuda::MempoolId_t)> capture_destroy_fn)97 void CUDAPluggableAllocator::set_release_pool(
98 std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn) {
99 relase_pool_fn_ = std::move(capture_destroy_fn);
100 }
101
malloc(size_t size,c10::DeviceIndex device,cudaStream_t stream)102 void* CUDAPluggableAllocator::malloc(
103 size_t size,
104 c10::DeviceIndex device,
105 cudaStream_t stream) {
106 void* r = alloc_fn_(size, device, stream);
107 {
108 const std::lock_guard<std::mutex> lock(allocator_mutex_);
109 allocation_metadata_.emplace(r, _AllocationMetadata(size, device, stream));
110 }
111 return r;
112 }
113
allocate(size_t size)114 c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) {
115 c10::DeviceIndex device = -1;
116 C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
117 cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device);
118 void* r = this->malloc(size, device, stream);
119 auto* ctx = new CUDAPluggableAllocatorDeleterContext(
120 free_fn_, r, size, device, stream);
121 c10::DataPtr data_ptr = {
122 r, ctx, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)};
123 return data_ptr;
124 }
125
raw_deleter() const126 c10::DeleterFnPtr CUDAPluggableAllocator::raw_deleter() const {
127 return &custom_raw_deleter;
128 }
129
raw_alloc(size_t nbytes)130 void* CUDAPluggableAllocator::raw_alloc(size_t nbytes) {
131 c10::DeviceIndex device = -1;
132 C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
133 cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device);
134 return malloc(nbytes, device, stream);
135 }
136
raw_alloc_with_stream(size_t nbytes,cudaStream_t stream)137 void* CUDAPluggableAllocator::raw_alloc_with_stream(
138 size_t nbytes,
139 cudaStream_t stream) {
140 c10::DeviceIndex device = -1;
141 C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
142 return malloc(nbytes, device, stream);
143 }
144
raw_delete(void * ptr)145 void CUDAPluggableAllocator::raw_delete(void* ptr) {
146 cudaStream_t stream{};
147 c10::DeviceIndex device_idx = -1;
148 size_t size = 0;
149 {
150 const std::lock_guard<std::mutex> lock(allocator_mutex_);
151 TORCH_CHECK(
152 allocation_metadata_.count(ptr),
153 "Trying to free a pointer not allocated here");
154 _AllocationMetadata& metadata = allocation_metadata_[ptr];
155 size = metadata.size;
156 device_idx = metadata.device_idx;
157 stream = metadata.stream;
158 allocation_metadata_.erase(ptr);
159 }
160 free_fn_(ptr, size, device_idx, stream);
161 }
162
init(int device_count)163 void CUDAPluggableAllocator::init(int device_count) {
164 if (init_fn_) {
165 init_fn_(device_count);
166 }
167 initialized_ = true;
168 }
169
initialized()170 bool CUDAPluggableAllocator::initialized() {
171 return initialized_;
172 }
173
setMemoryFraction(double fraction,c10::DeviceIndex device)174 void CUDAPluggableAllocator::setMemoryFraction(
175 double fraction,
176 c10::DeviceIndex device) {
177 if (memory_fraction_fn_) {
178 memory_fraction_fn_(fraction, device);
179 }
180 }
181
emptyCache()182 void CUDAPluggableAllocator::emptyCache() {
183 if (reset_fn_) {
184 return reset_fn_();
185 }
186 }
187
cacheInfo(c10::DeviceIndex device,size_t * largestBlock)188 void CUDAPluggableAllocator::cacheInfo(
189 c10::DeviceIndex device,
190 size_t* largestBlock) {
191 TORCH_CHECK(
192 false,
193 "CUDAPluggableAllocator does not yet support cacheInfo. "
194 "If you need it, please file an issue describing your use case.");
195 }
196
getBaseAllocation(void * ptr,size_t * size)197 void* CUDAPluggableAllocator::getBaseAllocation(void* ptr, size_t* size) {
198 if (base_alloc_fn_) {
199 return base_alloc_fn_(ptr, size);
200 } else {
201 return ptr;
202 }
203 }
204
recordStream(const c10::DataPtr & ptr,streamType stream)205 void CUDAPluggableAllocator::recordStream(
206 const c10::DataPtr& ptr,
207 streamType stream) {
208 if (record_stream_fn_) {
209 record_stream_fn_(ptr.get(), stream);
210 }
211 }
212
getDeviceStats(c10::DeviceIndex device)213 c10::CachingDeviceAllocator::DeviceStats CUDAPluggableAllocator::getDeviceStats(
214 c10::DeviceIndex device) {
215 TORCH_CHECK(
216 false,
217 "CUDAPluggableAllocator does not yet support getDeviceStats. "
218 "If you need it, please file an issue describing your use case.");
219 }
220
resetAccumulatedStats(c10::DeviceIndex device)221 void CUDAPluggableAllocator::resetAccumulatedStats(c10::DeviceIndex device) {
222 TORCH_CHECK(
223 false,
224 "CUDAPluggableAllocator does not yet support resetAccumulatedStats. "
225 "If you need it, please file an issue describing your use case.");
226 }
227
resetPeakStats(c10::DeviceIndex device)228 void CUDAPluggableAllocator::resetPeakStats(c10::DeviceIndex device) {
229 TORCH_CHECK(
230 false,
231 "CUDAPluggableAllocator does not yet support resetPeakStats. "
232 "If you need it, please file an issue describing your use case.");
233 }
234
235 c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator::
snapshot()236 snapshot() {
237 TORCH_CHECK(
238 false,
239 "CUDAPluggableAllocator does not yet support snapshot. "
240 "If you need it, please file an issue describing your use case.");
241 }
242
243 c10::cuda::CUDACachingAllocator::ShareableHandle CUDAPluggableAllocator::
shareIpcHandle(void * ptr)244 shareIpcHandle(void* ptr) {
245 TORCH_CHECK(
246 false,
247 "CUDAPluggableAllocator does not yet support shareIPcHandle. "
248 "If you need it, please file an issue describing your use case.");
249 }
250
getIpcDevPtr(std::string handle)251 std::shared_ptr<void> CUDAPluggableAllocator::getIpcDevPtr(std::string handle) {
252 TORCH_CHECK(
253 false,
254 "CUDAPluggableAllocator does not yet support getIpcDevPtr. "
255 "If you need it, please file an issue describing your use case.");
256 }
257
258 // CUDAGraph interactions
beginAllocateToPool(c10::DeviceIndex device,c10::cuda::MempoolId_t mempool_id,std::function<bool (cudaStream_t)> filter)259 void CUDAPluggableAllocator::beginAllocateToPool(
260 c10::DeviceIndex device,
261 c10::cuda::MempoolId_t mempool_id,
262 std::function<bool(cudaStream_t)> filter) {
263 if (begin_allocate_to_pool_fn_) {
264 begin_allocate_to_pool_fn_(device, mempool_id, std::move(filter));
265 }
266 }
267
endAllocateToPool(c10::DeviceIndex device,c10::cuda::MempoolId_t mempool_id)268 void CUDAPluggableAllocator::endAllocateToPool(
269 c10::DeviceIndex device,
270 c10::cuda::MempoolId_t mempool_id) {
271 if (end_allocate_to_pool_fn_) {
272 end_allocate_to_pool_fn_(device, mempool_id);
273 }
274 }
275
releasePool(c10::DeviceIndex device,c10::cuda::MempoolId_t mempool_id)276 void CUDAPluggableAllocator::releasePool(
277 c10::DeviceIndex device,
278 c10::cuda::MempoolId_t mempool_id) {
279 if (relase_pool_fn_) {
280 relase_pool_fn_(device, mempool_id);
281 }
282 }
283
recordHistory(bool enabled,c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder,size_t alloc_trace_max_entries,c10::cuda::CUDACachingAllocator::RecordContext when)284 void CUDAPluggableAllocator::recordHistory(
285 bool enabled,
286 c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder,
287 size_t alloc_trace_max_entries,
288 c10::cuda::CUDACachingAllocator::RecordContext when) {
289 TORCH_CHECK(
290 false,
291 "CUDAPluggableAllocator does not yet support recordHistory. "
292 "If you need it, please file an issue describing your use case.");
293 }
294
attachOutOfMemoryObserver(c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer)295 void CUDAPluggableAllocator::attachOutOfMemoryObserver(
296 c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) {
297 TORCH_CHECK(
298 false,
299 "CUDAPluggableAllocator does not yet support attachOutOfMemoryObserver. "
300 "If you need it, please file an issue describing your use case.");
301 }
302
attachAllocatorTraceTracker(c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker)303 void CUDAPluggableAllocator::attachAllocatorTraceTracker(
304 c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker) {
305 TORCH_CHECK(
306 false,
307 "CUDAPluggableAllocator does not support attachAllocatorTraceTracker. "
308 "attachAllocatorTraceTracker is only used inside Pytorch.");
309 }
310
311 std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState>
getCheckpointState(c10::DeviceIndex device,at::cuda::MempoolId_t id)312 CUDAPluggableAllocator::getCheckpointState(
313 c10::DeviceIndex device,
314 at::cuda::MempoolId_t id) {
315 TORCH_CHECK(
316 false,
317 "CUDAPluggableAllocator does not yet support getCheckpointState. "
318 "If you need it, please file an issue describing your use case.");
319 }
320
321 c10::cuda::CUDACachingAllocator::CheckpointDelta CUDAPluggableAllocator::
setCheckpointPoolState(c10::DeviceIndex device,std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps)322 setCheckpointPoolState(
323 c10::DeviceIndex device,
324 std::shared_ptr<c10::cuda::CUDACachingAllocator::AllocatorState> pps) {
325 TORCH_CHECK(
326 false,
327 "CUDAPluggableAllocator does not yet support setCheckpointPoolState. "
328 "If you need it, please file an issue describing your use case.");
329 }
330
enablePeerAccess(c10::DeviceIndex dev,c10::DeviceIndex dev_to_access)331 void CUDAPluggableAllocator::enablePeerAccess(
332 c10::DeviceIndex dev,
333 c10::DeviceIndex dev_to_access) {
334 c10::cuda::CUDAGuard device_guard(dev);
335 cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0);
336 if (err == cudaErrorPeerAccessAlreadyEnabled) {
337 // ignore and clear the error if access was already enabled
338 (void)cudaGetLastError();
339 } else {
340 C10_CUDA_CHECK(err);
341 }
342 }
343
memcpyAsync(void * dst,int dstDevice,const void * src,int srcDevice,size_t count,cudaStream_t stream,bool p2p_enabled)344 cudaError_t CUDAPluggableAllocator::memcpyAsync(
345 void* dst,
346 int dstDevice,
347 const void* src,
348 int srcDevice,
349 size_t count,
350 cudaStream_t stream,
351 bool p2p_enabled) {
352 return cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, stream);
353 }
354
name()355 std::string CUDAPluggableAllocator::name() {
356 return "pluggable";
357 }
358
copy_data(void * dest,const void * src,std::size_t count) const359 void CUDAPluggableAllocator::copy_data(
360 void* dest,
361 const void* src,
362 std::size_t count) const {
363 C10_CUDA_CHECK(
364 cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
365 }
366
367 std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
368 current_custom_allocator;
369
370 std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
getCurrentAllocator()371 getCurrentAllocator() {
372 return current_custom_allocator;
373 }
374
375 // TODO: add more functions in the argument
376 std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
createCustomAllocator(std::function<MallocFuncType> alloc_fn,std::function<FreeFuncType> free_fn)377 createCustomAllocator(
378 std::function<MallocFuncType> alloc_fn,
379 std::function<FreeFuncType> free_fn) {
380 std::shared_ptr<CUDAPluggableAllocator> allocator(
381 new CUDAPluggableAllocator(std::move(alloc_fn), std::move(free_fn)));
382 allocator->init(device_count);
383 return allocator;
384 }
385
changeCurrentAllocator(const std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> & allocator)386 void changeCurrentAllocator(
387 const std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>&
388 allocator) {
389 TORCH_CHECK(
390 !c10::cuda::CUDACachingAllocator::allocator.load()->initialized(),
391 "Can't swap an already initialized allocator");
392 c10::cuda::CUDACachingAllocator::allocator.store(allocator.get());
393 current_custom_allocator = allocator;
394 }
395
custom_raw_deleter(void * ctx)396 void custom_raw_deleter(void* ctx) {
397 reinterpret_cast<CUDAPluggableAllocatorDeleterContext*>(ctx)->free();
398 }
399
400 } // namespace torch::cuda::CUDAPluggableAllocator
401