1 #include <c10/cuda/CUDACachingAllocator.h>
2 #include <c10/cuda/CUDAException.h>
3 #include <c10/cuda/CUDAFunctions.h>
4 #include <c10/cuda/CUDAGuard.h>
5 #include <c10/util/UniqueVoidPtr.h>
6 #include <c10/util/flat_hash_map.h>
7 #include <c10/util/irange.h>
8
9 #include <unordered_set>
10 #include <vector>
11
12 namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync {
13
14 using namespace c10::CachingDeviceAllocator;
15
16 #if CUDA_VERSION >= 11040
17 // CUDA device allocator that uses cudaMallocAsync to implement
18 // the same interface as CUDACachingAllocator.cpp.
19
20 // Designed to be safe for CUDA graph capture.
21 // Interactions with CUDA graph capture are mediated by
22 // notifyCaptureBegin
23 // notifyCaptureAboutToEnd
24 // notifyCaptureEnded
25 // notifyCaptureDestroy
26
27 // Implementation details, not declared in CUDACachingAllocator.h
28 namespace {
29
30 // General helpers
31
32 struct UsageStream {
33 cudaStream_t stream;
34 c10::DeviceIndex device;
35 UsageStream() = default;
UsageStreamc10::cuda::CUDACachingAllocator::CudaMallocAsync::__anon5214cc830111::UsageStream36 UsageStream(cudaStream_t s, c10::DeviceIndex d) : stream(s), device(d) {}
37 UsageStream(const UsageStream& us) = default;
38 UsageStream(UsageStream&& us) noexcept = default;
39 UsageStream& operator=(const UsageStream& other) = default;
40 UsageStream& operator=(UsageStream&& other) noexcept = default;
41 };
42
operator ==(const UsageStream & lhs,const UsageStream & rhs)43 bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
44 return (lhs.stream == rhs.stream) && (lhs.device == rhs.device);
45 }
46
47 struct UsageStreamHash {
operator ()c10::cuda::CUDACachingAllocator::CudaMallocAsync::__anon5214cc830111::UsageStreamHash48 size_t operator()(const UsageStream& us) const noexcept {
49 return std::hash<void*>{}(us.stream) + size_t(us.device);
50 }
51 };
52
53 struct PtrUsage {
54 // recorded_streams holds side usage streams added by record_stream calls.
55 // In other words, it does NOT include the original creation stream.
56 ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams;
57 UsageStream creation_stream{};
58 uint64_t size;
59 bool captured;
PtrUsagec10::cuda::CUDACachingAllocator::CudaMallocAsync::__anon5214cc830111::PtrUsage60 PtrUsage(uint64_t s, bool c) : size(s), captured(c) {}
61 };
62
63 int device_count = 0;
64 // these don't need to be c10::once_flags as in CUDAGeneratorImpl.cpp
65 // because they'll only be flipped by functions that have locked the mutex.
66 std::vector<bool> devs_initialized_flags;
67 std::vector<UsageStream> dummy_unifying_free_streams;
68
69 // Possible micro-optimization:
70 // Some accesses to ptr_info are read-only.
71 // We could let those be concurrent with a shared_mutex and
72 // have concurrent calls take a shared_lock.
73 // Keeping it simple with an ordinary mutex for now.
74 std::mutex general_mutex;
75
76 /**
77 * Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
78 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79 * During CUDA graph capture, it's illegal to call cudaFreeAsync
80 * on a pointer that came from a non-captured cudaMallocAsync.
81 * Unfortunately, Python being what it is, it's impossible to be
82 * sure no uncaptured tensor will ever have its destructor called
83 * in a capturing region.
84 * We avoid errors by
85 * 1. remembering if allocated pointers were captured or uncaptured
86 * 2. during capture, if we detect an attempt to free an uncaptured
87 * allocation on a capturing stream, don't free it immediately,
88 * just remember it and defer its cudaFreeAsync call to after
89 * the end of capture (specifically, to notifyCaptureEnded).
90 */
91
92 using PtrInfo = ska::flat_hash_map<void*, PtrUsage>;
93 PtrInfo ptr_info;
94 std::vector<void*> ungraphed_ptrs_defer_free_until_no_capture;
95
96 // These two help setMemoryFraction limit the amount of memory
97 // used by PyTorch in particular (as opposed to other libraries
98 // in the same process that might be sharing the same cudaMemPool_t).
99 std::vector<size_t> pytorch_used_bytes;
100 std::vector<size_t> pytorch_memory_limits;
101
102 // Graph-specific helpers
103
104 /**
105 * Note [Avoid dangling free streams during CUDA graph capture]
106 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
107 * During capture, all stream dependencies must branch out from
108 * the stream on which capture began and rejoin this initial stream
109 * before capture ends.
110 * The user rigs desired forking and joining with event waits.
111 * But it's hard to be sure when tensor destructors get called relative
112 * to the final joins.
113 * For example, suppose a user
114 * forks work stream B from initial capture stream A
115 * creates a tensor T in B
116 * joins by syncing A with B
117 * ends capture.
118 * All well and good, right? Maybe not: maybe T went out of scope
119 * and its destructor got called AFTER the rejoin, leaving the graph with
120 * "unjoined work": a dangling cudaFreeAsync node in stream B.
121 * Ensuring that all tensor destructors for all side stream tensors
122 * are called before side streams rejoin the main stream is
123 * difficult. The user might have to add a bunch of explicit
124 * "del"s at the right spots in code that was fine for ordinary
125 * eager execution.
126 * Fortunately, we can spare the user this burden:
127 * during capture, we remember _all_ free streams,
128 * and manually rejoin them with the capture stream during
129 * notifyCaptureAboutToEnd.
130 * This approach is heavy-handed, but hopefully capture only needs to
131 * happen once, so we don't mind being heavy-handed.
132 *
133 * TODO: If, someday, we augment the graph bindings to support recapture
134 * https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#whole-graph-update
135 * (eg, as a way to accommodate dynamic params) we should think more
136 * carefully about the CPU overhead of remembering and rejoining
137 * all free streams during capture. Maybe it's not a big deal.
138 */
139 std::unordered_set<UsageStream, UsageStreamHash> capture_free_streams;
140 bool capture_underway = false;
141
142 // Implementation functions
143
144 // Assumes the caller holds general_mutex
lazy_init_device(c10::DeviceIndex device)145 inline void lazy_init_device(c10::DeviceIndex device) {
146 if (!devs_initialized_flags[device]) {
147 CUDAGuard g(device);
148
149 // See "Retaining memory in the pool" here:
150 // https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
151 cudaMemPool_t mempool = nullptr;
152 C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
153 uint64_t threshold = UINT64_MAX;
154 C10_CUDA_CHECK(cudaMemPoolSetAttribute(
155 mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
156
157 // I think all these are on by default, but I want to enable them
158 // explicitly to ensure awareness.
159 int enable = 1;
160 C10_CUDA_CHECK(cudaMemPoolSetAttribute(
161 mempool, cudaMemPoolReuseFollowEventDependencies, &enable));
162 C10_CUDA_CHECK(cudaMemPoolSetAttribute(
163 mempool, cudaMemPoolReuseAllowOpportunistic, &enable));
164 C10_CUDA_CHECK(cudaMemPoolSetAttribute(
165 mempool, cudaMemPoolReuseAllowInternalDependencies, &enable));
166
167 // Grabs a stream from the current device to use as the "unifier" free
168 // stream for allocations that end up used on multiple streams.
169 const auto dufs = getStreamFromPool();
170 dummy_unifying_free_streams[device] =
171 UsageStream(dufs.stream(), dufs.device_index());
172
173 pytorch_used_bytes[device] = 0;
174 pytorch_memory_limits[device] = UINT64_MAX;
175
176 devs_initialized_flags[device] = true;
177 }
178 }
179
sync_raw(cudaStream_t dependency,cudaStream_t dependent)180 inline void sync_raw(cudaStream_t dependency, cudaStream_t dependent) {
181 // CUDACachingAllocator.cpp uses raw cuda events, as do we.
182 cudaEvent_t event = nullptr;
183 C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
184 C10_CUDA_CHECK(cudaEventRecord(event, dependency));
185 C10_CUDA_CHECK(cudaStreamWaitEvent(dependent, event));
186 C10_CUDA_CHECK(cudaEventDestroy(event));
187 }
188
189 // Assumes the caller holds general_mutex
free_impl(PtrInfo::iterator & it)190 inline void free_impl(PtrInfo::iterator& it) {
191 // Possible micro-optimization: If we did a value-copy here, we could move
192 // ptr_info.erase(it) up here and drop the lock immediately.
193 const auto& recorded_streams = it->second.recorded_streams;
194 const auto& creation_stream = it->second.creation_stream;
195
196 // If the usage stream is a null (default) stream,
197 // cudaFreeAsync infers the device from the ambient context,
198 // so we need to set the right ambient context.
199 CUDAGuard g(creation_stream.device);
200
201 if (recorded_streams.empty()) {
202 // ptr was only used on one stream, which must have been
203 // the original allocation stream.
204 // Frees ptr in the original allocation stream.
205
206 C10_CUDA_CHECK(cudaFreeAsync(it->first, creation_stream.stream));
207
208 if (C10_UNLIKELY(capture_underway)) {
209 // See Note [Avoid dangling free streams during CUDA graph capture]
210 capture_free_streams.insert(creation_stream);
211 }
212 } else {
213 // ptr was used on many streams. We don't know which was the most recent.
214 // There could even have been multiple most recent usage streams acting
215 // on different regions of the memory.
216 // But cudaFreeAsync only accepts a single most recent usage stream.
217 // We can still safely free ptr with a trick:
218 // Use a dummy "unifying stream", sync the unifying stream with all of
219 // ptr's usage streams, and pass the dummy stream to cudaFreeAsync.
220
221 // Retrieves the dummy "unifier" stream from the device
222 // on which the pointer was originally allocated.
223 auto dummy_unifying_free_stream =
224 dummy_unifying_free_streams[creation_stream.device];
225 TORCH_INTERNAL_ASSERT(
226 dummy_unifying_free_stream.device == creation_stream.device);
227
228 // we're already on creation_stream.device, no need to re-guard
229 sync_raw(creation_stream.stream, dummy_unifying_free_stream.stream);
230
231 // The number of usage streams is typically small (low single digits)
232 for (const auto& recorded_stream : recorded_streams) {
233 // Logic here accommodates the chance some of the usage streams were on
234 // other devices, which is possible if some usage kernels accessed the
235 // memory via p2p.
236
237 // cudaEventRecord requires that the input event and stream are on the
238 // same device.
239 CUDAGuard g_usage(recorded_stream.device);
240
241 sync_raw(recorded_stream.stream, dummy_unifying_free_stream.stream);
242 }
243
244 // Frees ptr in the dummy "unifier" stream.
245 C10_CUDA_CHECK(cudaFreeAsync(it->first, dummy_unifying_free_stream.stream));
246 // At this point, unless dummy_unifying_free_stream happens to alias some
247 // future user stream, the allocation is only available for "opportunistic"
248 // reuse, ie, if the CPU sees dummy_unifying_free_stream has reached the
249 // point that all events recorded on all usage streams have resolved from
250 // the CPU's perspective. In theory, we could remove the need for the driver
251 // to do this tracking by e.g. replacing
252 // cudaStreamWaitEvent(dummy_unifying_free_stream.stream, event);
253 // with
254 // cudaStreamWaitEvent(creation_stream.stream, event);
255 // then cudaFreeAsyncing straight back into creation_stream.stream,
256 // but this forces a potentially false dependency of creation_stream.stream
257 // on all the recorded_streams.
258
259 if (C10_UNLIKELY(capture_underway)) {
260 // See Note [Avoid dangling free streams during CUDA graph capture]
261 capture_free_streams.emplace(
262 dummy_unifying_free_stream.stream, dummy_unifying_free_stream.device);
263 }
264 }
265
266 pytorch_used_bytes[creation_stream.device] -= it->second.size;
267
268 ptr_info.erase(it);
269 }
270
freeAsync(void * ptr)271 void freeAsync(void* ptr) {
272 std::lock_guard<std::mutex> lk(general_mutex);
273
274 auto err = cudaGetLastError();
275 C10_CUDA_CHECK(err);
276 auto it = ptr_info.find(ptr);
277 TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
278
279 if (C10_UNLIKELY(capture_underway)) {
280 if (!it->second.captured) {
281 TORCH_WARN_ONCE(
282 "freeAsync() was called on an uncaptured allocation during graph capture "
283 "(address = ",
284 ptr,
285 "). This may be benign, for example, a Python tensor in the capture "
286 "might happen to shadow (use the same name as) an unrelated temporary "
287 "tensor from somewhere before capture, pushing the earlier tensor "
288 "out of scope. "
289 "However, if the tensor we're freeing here IS used by the capture, "
290 "freeing it is an error, and may cause illegal memory accesses or "
291 "memory corruption during graph replay.");
292 // See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
293 // Remembers the raw pointer, not the iterator.
294 // This forces notifyCaptureEnded to do another lookup,
295 // but avoids the risk the iterator might be invalidated
296 // between now and then.
297 ungraphed_ptrs_defer_free_until_no_capture.push_back(ptr);
298 return;
299 }
300 } else if (C10_UNLIKELY(it->second.captured)) {
301 TORCH_WARN(
302 "Attempting uncaptured free of a captured allocation with address ",
303 ptr,
304 "\nThis is technically allowed, but may indicate you are losing "
305 "the last user-visible tensor through which the allocation can "
306 "be accessed, so you'll have no way to view the data after "
307 "future replays of the owning graph.");
308 }
309
310 free_impl(it);
311 }
312
313 // Symmetric with NativeCachingAllocator::malloc for now,
314 // although I don't think we absolutely need the symmetry.
mallocAsync(void ** devPtr,c10::DeviceIndex device,size_t size,cudaStream_t stream)315 void mallocAsync(
316 void** devPtr,
317 c10::DeviceIndex device,
318 size_t size,
319 cudaStream_t stream) {
320 TORCH_INTERNAL_ASSERT(
321 0 <= device && device < device_count,
322 "Invalid device index ",
323 device,
324 ": did you call init?");
325
326 // If stream is a null (default) stream,
327 // cudaMallocAsync infers the device from the ambient context,
328 // so we need to set the right ambient context.
329 CUDAGuard g(device);
330
331 std::lock_guard<std::mutex> lk(general_mutex);
332
333 if (!capture_underway &&
334 !ungraphed_ptrs_defer_free_until_no_capture.empty()) {
335 // See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
336 for (const auto ptr : ungraphed_ptrs_defer_free_until_no_capture) {
337 auto it = ptr_info.find(ptr);
338 TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
339 free_impl(it);
340 }
341
342 ungraphed_ptrs_defer_free_until_no_capture.clear();
343 }
344
345 lazy_init_device(device);
346
347 // Defensively checks for preexisting CUDA error state.
348 auto err = cudaGetLastError();
349 C10_CUDA_CHECK(err);
350
351 // TODO: Could we avoid calling cudaMallocAsync while holding general_mutex,
352 // perhaps by letting lazy_init_device use separate once_flags or an internal
353 // static initializer?
354 if (pytorch_used_bytes[device] + size > pytorch_memory_limits[device]) {
355 err = cudaErrorMemoryAllocation;
356 } else {
357 err = cudaMallocAsync(devPtr, size, stream);
358 }
359
360 if (err == cudaErrorMemoryAllocation) {
361 // Clears CUDA's internal error state so the user, if desired, can catch the
362 // OOM exception, free some stuff on the script side, and retry the
363 // allocation. This aligns with the behavior of alloc_block in
364 // CUDACachingAllocator.cpp.
365 (void)cudaGetLastError(); // clear CUDA error
366 size_t device_free = 0;
367 size_t device_total = 0;
368 C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
369 TORCH_CHECK_WITH(
370 OutOfMemoryError,
371 false,
372 "Allocation on device ",
373 device,
374 " would exceed allowed memory. (out of memory)",
375 "\nCurrently allocated : ",
376 format_size(pytorch_used_bytes[device]),
377 "\nRequested : ",
378 format_size(size),
379 "\nDevice limit : ",
380 format_size(device_total),
381 "\nFree (according to CUDA): ",
382 format_size(device_free),
383 "\nPyTorch limit (set by user-supplied memory fraction)"
384 "\n : ",
385 format_size(pytorch_memory_limits[device]));
386 } else {
387 C10_CUDA_CHECK(err);
388 }
389
390 auto inserted = ptr_info.emplace(*devPtr, PtrUsage(size, capture_underway));
391 TORCH_INTERNAL_ASSERT(
392 inserted.second,
393 "address returned by cudaMallocAsync already exists "
394 "in ptr_info");
395
396 inserted.first->second.creation_stream = {stream, device};
397
398 pytorch_used_bytes[device] += size;
399 }
400
401 } // anonymous namespace
402
403 void local_raw_delete(void* ptr);
404
405 // Same pattern as CUDACachingAllocator.cpp.
406 struct CudaMallocAsyncAllocator : public CUDAAllocator {
allocatec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator407 DataPtr allocate(size_t size) override {
408 constexpr size_t one_exa_bytes = 1152921504606846976ULL;
409 TORCH_CHECK_WITH(
410 OutOfMemoryError,
411 size < one_exa_bytes,
412 "CUDA out of memory. Tried to allocate more than 1EB memory.");
413 c10::DeviceIndex device = 0;
414 C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
415 void* r = nullptr;
416 if (size != 0) {
417 mallocAsync(&r, device, size, cuda::getCurrentCUDAStream(device));
418 }
419 return {r, r, &local_raw_delete, Device(DeviceType::CUDA, device)};
420 }
raw_deleterc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator421 DeleterFnPtr raw_deleter() const override {
422 return &local_raw_delete;
423 }
424
425 // This function should not issue any context-creating calls,
426 // just set up for later calls to init per-device pools based
427 // on the current device each later call sees.
initc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator428 void init(int dev_count) override {
429 static bool called = [](int dev_count) {
430 ;
431 // Are there external guarantees init will be called before
432 // any of the allocator's other functions?
433 // std::lock_guard<std::mutex> lk(general_mutex);
434 device_count = dev_count;
435 devs_initialized_flags.resize(dev_count, false);
436 dummy_unifying_free_streams.resize(dev_count);
437 pytorch_used_bytes.resize(dev_count);
438 pytorch_memory_limits.resize(dev_count);
439 return true;
440 }(dev_count);
441 (void)called;
442 }
443
initializedc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator444 bool initialized() override {
445 return !devs_initialized_flags.empty();
446 }
447
assertValidDevicec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator448 static inline void assertValidDevice(c10::DeviceIndex device) {
449 TORCH_CHECK(
450 0 <= device && device < device_count, "Invalid device argument.");
451 }
452
setMemoryFractionc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator453 void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
454 TORCH_INTERNAL_ASSERT(
455 0 <= fraction && fraction <= 1,
456 "invalid fraction:",
457 fraction,
458 ". Please set within (0, 1).");
459
460 std::lock_guard<std::mutex> lk(general_mutex);
461 assertValidDevice(device);
462 CUDAGuard g(device);
463 // Should setMemoryFraction be allowed to trigger a full device context and
464 // pool-creating lazy_init_device, or should we simply assert this device is
465 // already initialized, ie
466 // TORCH_CHECK(devs_initialized_flags[device], ...)?
467 lazy_init_device(device);
468
469 size_t device_free = 0;
470 size_t device_total = 0;
471 C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
472 pytorch_memory_limits[device] =
473 static_cast<uint64_t>(fraction * static_cast<double>(device_total));
474
475 // Alternative: Instead of a manual hard limit, we could use
476 // cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold,
477 // &threshold); This is a soft hint: The driver allows the pool's reserved
478 // memory to spike above threshold in regions of high cudaMallocAsync
479 // demand, but opportunistically trims reserved memory back to threshold
480 // when the memory in use is < threshold. I don't like this because it
481 // introduces performance nondeterminism.
482 }
483
emptyCachec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator484 void emptyCache() override {
485 std::lock_guard<std::mutex> lk(general_mutex);
486
487 for (int dev = 0; dev < device_count; dev++) {
488 if (devs_initialized_flags[dev]) {
489 CUDAGuard g(static_cast<c10::DeviceIndex>(dev));
490
491 cudaMemPool_t mempool = nullptr;
492 cudaDeviceGetDefaultMemPool(&mempool, dev);
493 cudaDeviceSynchronize();
494 cudaMemPoolTrimTo(mempool, 0);
495 }
496 }
497 }
498
cacheInfoc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator499 void cacheInfo(c10::DeviceIndex device, size_t* maxWorkspaceGuess) override {
500 // The only consumer of cacheInfo is getMaxWorkspaceSize in Conv_v7.cpp.
501 // Afaict, the role of cacheInfo is to give getMaxWorkspaceSize a reasonable
502 // maximum workspace size to use for an upcoming cudnnFind call.
503 //
504 // The native allocator's cacheInfo chooses to return the size of its
505 // largest unused block (which is the largest allocation the native
506 // allocator can service immediately and asynchronously without a
507 // cudaMalloc.
508 //
509 // Here, we use a different heuristic: figure out the max usable workspace
510 // size with a bit of educated trial and error. It's ok to be
511 // perf-inefficient because cacheInfo is a prelude to cudnnFind.
512 //
513 // The algo cache then stores the best-performing algo with workspace <=
514 // maxWorkspaceGuess. Later calls with the same param set hit in cache and
515 // try to allocate the same workspace. If, in one of those future calls,
516 // workspace allocation fails (ie because less ambient memory is available),
517 // the bindings rerun cudnnFind, including calling cacheInfo again
518 // beforehand to estimate a new (smaller) largest-available workspace. Over
519 // a few such calls, the cache should settle to the algo with a workspace
520 // size that's small enough to succeed every time (for that param set).
521 //
522 // So the strategy here is to return a rough, largeish guess and let the
523 // bindings retry to trim as needed over time.
524 //
525 // The only caveat is, even if a workspace is allocated without OOM errors
526 // now and in future calls, it's hard to be sure those later error-free
527 // cudaMallocAsyncs are fast and come straight from the pool (ie,
528 // cudaMallocAsync didn't need to reserve more memory from the system).
529 // Hopefully, after repeated workspace requests, the pool's reserved memory
530 // also stabilizes to a point where they all come straight from the pool.
531 std::lock_guard<std::mutex> lk(general_mutex);
532 assertValidDevice(device);
533 CUDAGuard g(device);
534 lazy_init_device(device);
535
536 size_t free_upper_bound = 0;
537 size_t device_total = 0;
538 C10_CUDA_CHECK(cudaMemGetInfo(&free_upper_bound, &device_total));
539 TORCH_INTERNAL_ASSERT(
540 free_upper_bound + pytorch_used_bytes[device] <= device_total);
541 size_t guess = std::min(
542 free_upper_bound,
543 pytorch_memory_limits[device] - pytorch_used_bytes[device]);
544 auto stream = c10::cuda::getCurrentCUDAStream();
545 void* dummy = nullptr;
546
547 // Defensively checks for preexisting CUDA error state.
548 auto err = cudaGetLastError();
549 C10_CUDA_CHECK(err);
550
551 while (true) {
552 // Duplicates some logic from mallocAsync to work with the error state
553 // directly instead of repeatedly catching an exception thrown by
554 // mallocAsync.
555 if (pytorch_used_bytes[device] + guess > pytorch_memory_limits[device]) {
556 err = cudaErrorMemoryAllocation;
557 } else {
558 err = cudaMallocAsync(&dummy, guess, stream);
559 }
560
561 if (err == cudaSuccess) {
562 cudaFreeAsync(dummy, stream);
563 *maxWorkspaceGuess = guess;
564 return;
565 } else if (err == cudaErrorMemoryAllocation) {
566 (void)cudaGetLastError(); // clear CUDA error
567 guess >>= 1; // quick and dirty: try half the size next iteration
568 } else {
569 C10_CUDA_CHECK(err);
570 }
571 }
572 }
573
getBaseAllocationc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator574 void* getBaseAllocation(void* ptr, size_t* size) override {
575 std::lock_guard<std::mutex> lk(general_mutex);
576
577 auto it = ptr_info.find(ptr);
578 TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
579
580 if (size) {
581 *size = it->second.size;
582 }
583
584 return ptr;
585 }
586
recordStreamc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator587 void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) override {
588 std::lock_guard<std::mutex> lk(general_mutex);
589 auto ptr_val = ptr.get();
590 // Empty tensor's storage().data() might be a null ptr. As there is no
591 // blocks associated with those tensors, it is fine to do nothing here.
592 if (!ptr_val) {
593 return;
594 }
595
596 // The pointer should exist in the map already.
597 auto it = ptr_info.find(ptr_val);
598 TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
599
600 UsageStream to_record{stream.stream(), stream.device_index()};
601 if (to_record == it->second.creation_stream) {
602 TORCH_WARN_ONCE(
603 "Called record_stream on tensor whose original creation stream "
604 "matches the recorded stream. This is unnecessary and has no effect.");
605 } else {
606 it->second.recorded_streams.insert(to_record);
607 }
608 }
609
shareIpcHandlec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator610 ShareableHandle shareIpcHandle(void* handle) override {
611 TORCH_CHECK(
612 false,
613 "cudaMallocAsync does not yet support shareIpcHandle. "
614 "If you need it, please file an issue describing your use case.");
615 }
616
getIpcDevPtrc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator617 std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
618 TORCH_CHECK(
619 false,
620 "cudaMallocAsync does not yet support getIpcDevPtr. "
621 "If you need it, please file an issue describing your use case.");
622 }
623
recordHistoryc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator624 void recordHistory(
625 bool enabled,
626 CreateContextFn context_recorder,
627 size_t alloc_trace_max_entries,
628 RecordContext when) override {
629 TORCH_CHECK(
630 false,
631 "cudaMallocAsync does not yet support recordHistory. "
632 "If you need it, please file an issue describing your use case.");
633 }
634
attachOutOfMemoryObserverc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator635 void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override {
636 TORCH_CHECK(
637 false,
638 "cudaMallocAsync does not yet support attachOutOfMemoryObserver. "
639 "If you need it, please file an issue describing your use case.");
640 }
641
attachAllocatorTraceTrackerc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator642 void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) override {
643 TORCH_CHECK(
644 false,
645 "cudaMallocAsync does not yet support attachAllocatorTraceTracker. "
646 "If you need it, please file an issue describing your use case.");
647 }
648
getCheckpointStatec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator649 std::shared_ptr<AllocatorState> getCheckpointState(
650 c10::DeviceIndex device,
651 MempoolId_t id) override {
652 TORCH_CHECK(
653 false,
654 "cudaMallocAsync does not yet support getCheckpointState. "
655 "If you need it, please file an issue describing your use case.");
656 }
657
setCheckpointPoolStatec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator658 CheckpointDelta setCheckpointPoolState(
659 c10::DeviceIndex device,
660 std::shared_ptr<AllocatorState> pps) override {
661 TORCH_CHECK(
662 false,
663 "cudaMallocAsync does not yet support setCheckpointPoolState. "
664 "If you need it, please file an issue describing your use case.");
665 }
666
667 // Collects stats for device.
668 // If device hasn't been used yet, returns 0s without creating a context.
getDeviceStatsc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator669 DeviceStats getDeviceStats(c10::DeviceIndex device) override {
670 assertValidDevice(device);
671
672 // Memory currently reserved by the mempool
673 uint64_t reserved_mem_current = 0;
674 // High-water mark of memory reserved by the mempool since last reset
675 uint64_t reserved_mem_peak = 0;
676 // Memory currently in use by the mempool
677 uint64_t used_mem_current = 0;
678 // High-water mark of memory
679 uint64_t used_mem_peak = 0;
680
681 std::lock_guard<std::mutex> lk(general_mutex);
682
683 if (devs_initialized_flags[device]) {
684 CUDAGuard g(device);
685
686 cudaMemPool_t mempool = nullptr;
687 C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
688 C10_CUDA_CHECK(cudaMemPoolGetAttribute(
689 mempool, cudaMemPoolAttrReservedMemCurrent, &reserved_mem_current));
690
691 C10_CUDA_CHECK(cudaMemPoolGetAttribute(
692 mempool, cudaMemPoolAttrReservedMemHigh, &reserved_mem_peak));
693
694 C10_CUDA_CHECK(cudaMemPoolGetAttribute(
695 mempool, cudaMemPoolAttrUsedMemCurrent, &used_mem_current));
696
697 C10_CUDA_CHECK(cudaMemPoolGetAttribute(
698 mempool, cudaMemPoolAttrUsedMemHigh, &used_mem_peak));
699 }
700
701 // Many stat types are specific to the native allocator. We leave these
702 // untouched. Their "struct Stat"s will contain zeroed values.
703 DeviceStats stats;
704
705 // In the native allocator:
706 // allocated_bytes is the total bytes of blocks that have been malloc()ed
707 // and not yet free()d.
708 // active_bytes is the total bytes of blocks that have been malloc()ed but
709 // not yet released back into a free pool. In other words, it includes all
710 // allocated_bytes, as well as the bytes of "limbo state" blocks had have
711 // already been free()ed but not yet free_block()ed back into a pool due to
712 // outstanding stream_uses.
713 //
714 // Here, in the cudaMallocAsync allocator:
715 // We simply ask the driver's opinion about active memory.
716 // We don't bother distinguishing between allocated_bytes and active_bytes.
717 stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
718 static_cast<int64_t>(used_mem_current);
719 stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
720 static_cast<int64_t>(used_mem_peak);
721 stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
722 static_cast<int64_t>(used_mem_current);
723 stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
724 static_cast<int64_t>(used_mem_peak);
725 stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
726 static_cast<int64_t>(reserved_mem_current);
727 stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
728 static_cast<int64_t>(reserved_mem_peak);
729
730 return stats;
731 }
732
resetAccumulatedStatsc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator733 void resetAccumulatedStats(c10::DeviceIndex device) override {
734 assertValidDevice(device);
735 TORCH_WARN_ONCE(
736 "For backend:cudaMallocAsync, resetAccumulatedStats has no effect.");
737 }
738
resetPeakStatsc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator739 void resetPeakStats(c10::DeviceIndex device) override {
740 assertValidDevice(device);
741
742 CUDAGuard g(device);
743 cudaMemPool_t mempool = nullptr;
744 C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
745 // Using zero as the reset value is the method recommended by Cuda driver
746 // team. Vivek Kini says:
747 // "Resetting to zero (which is the only valid value when setting
748 // ReservedMemHigh) resets it to ReservedMemCurrent inside the driver
749 // (same goes for UsedMemHigh/UsedMemCurrent)"
750 uint64_t zero = 0;
751 C10_CUDA_CHECK(cudaMemPoolSetAttribute(
752 mempool, cudaMemPoolAttrReservedMemHigh, &zero));
753 C10_CUDA_CHECK(
754 cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero));
755 }
756
snapshotc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator757 SnapshotInfo snapshot() override {
758 TORCH_CHECK(
759 false,
760 "Calling snapshot with backend:cudaMallocAsync is not meaningful. "
761 "(For backend:native, snapshot returns a detailed summary of all "
762 "blocks tracked by the allocator, but the cudaMallocAsync backend "
763 "does not track individual blocks.)");
764 // Alternative: TORCH_WARN
765 return {};
766 }
767
768 // CUDAGraph interactions
beginAllocateToPoolc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator769 void beginAllocateToPool(
770 c10::DeviceIndex device,
771 MempoolId_t mempool_id,
772 std::function<bool(cudaStream_t)>) override {
773 std::lock_guard<std::mutex> lk(general_mutex);
774
775 TORCH_INTERNAL_ASSERT(capture_free_streams.empty());
776 TORCH_CHECK(
777 !capture_underway,
778 "Only one capture at a time is allowed in a process.")
779 capture_underway = true;
780 }
781
endAllocateToPoolc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator782 void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id)
783 override {
784 assertValidDevice(device);
785
786 std::lock_guard<std::mutex> lk(general_mutex);
787
788 TORCH_CHECK(
789 capture_underway,
790 "CudaMallocAsync::notifyCaptureAboutToEnd called, "
791 "but CudaMallocAsync::capture_underway is false.");
792
793 auto capture_stream = cuda::getCurrentCUDAStream(device);
794
795 // See Note [Avoid dangling free streams during CUDA graph capture]
796 for (const auto& free_stream : capture_free_streams) {
797 // cudaEventRecord requires that the input event and stream are on the
798 // same device.
799 CUDAGuard g(free_stream.device);
800
801 // CUDACachingAllocator.cpp uses raw cuda events, as do we.
802 cudaEvent_t event = nullptr;
803 C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
804 C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream));
805 C10_CUDA_CHECK(cudaStreamWaitEvent(capture_stream.stream(), event));
806 C10_CUDA_CHECK(cudaEventDestroy(event));
807 }
808
809 capture_free_streams.clear();
810 TORCH_CHECK(
811 capture_underway,
812 "CudaMallocAsync::notifyCaptureEnded called, "
813 "but CudaMallocAsync::capture_underway is false.");
814 capture_underway = false;
815 }
816
releasePoolc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator817 void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
818 // Q: Do we need to do anything special here, like clear long-lived
819 // pointers created during the original capture (for example,
820 // tensors intended as the graph's I/O surface) that might still
821 // be resident in ptr_info?
822 // A: I don't think so.
823 // Those allocations survived capture because the user held
824 // explicit tensor references to them,
825 // Those tensors' destructors will call freeAsync() on each pointer
826 // when the user is done with them.
827 // The freeAsync()s will probably incur
828 // TORCH_WARN("Attempting uncaptured free of a captured allocation..."
829 // but stale ptrs will not permanently leak into ptr_info.
830 }
831
raw_allocc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator832 void* raw_alloc(size_t nbytes) override {
833 if (nbytes == 0) {
834 return nullptr;
835 }
836 c10::DeviceIndex device = 0;
837 C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
838 void* r = nullptr;
839 mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
840 return r;
841 }
842
raw_alloc_with_streamc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator843 void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override {
844 if (nbytes == 0) {
845 return nullptr;
846 }
847 c10::DeviceIndex device = 0;
848 C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
849 void* r = nullptr;
850 mallocAsync(&r, device, nbytes, stream);
851 return r;
852 }
raw_deletec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator853 void raw_delete(void* ptr) override {
854 freeAsync(ptr);
855 }
enablePeerAccessc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator856 void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access)
857 override {
858 // Double-checks allocator backend hasn't changed, which would definitely be
859 // an error. cudaMallocAsync pools are unaffected by
860 // cudaDeviceEnablePeerAccess. We need pool-specific enablement. See
861 // https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-2/
862 c10::cuda::CUDAGuard device_guard(dev);
863 cudaMemPool_t mempool = nullptr;
864 C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, dev_to_access));
865 cudaMemAccessDesc desc = {};
866 desc.location.type = cudaMemLocationTypeDevice;
867 // NOLINTNEXTLINE(bugprone-signed-char-misuse)
868 desc.location.id = dev;
869 desc.flags = cudaMemAccessFlagsProtReadWrite;
870 C10_CUDA_CHECK(cudaMemPoolSetAccess(mempool, &desc, 1 /* numDescs */));
871 }
memcpyAsyncc10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator872 cudaError_t memcpyAsync(
873 void* dst,
874 int dstDevice,
875 const void* src,
876 int srcDevice,
877 size_t count,
878 cudaStream_t stream,
879 bool p2p_enabled) override {
880 if (p2p_enabled || dstDevice == srcDevice) {
881 return cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, stream);
882 } else {
883 return cudaMemcpyPeerAsync(dst, dstDevice, src, srcDevice, count, stream);
884 }
885 }
namec10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator886 std::string name() override {
887 return "cudaMallocAsync";
888 }
copy_datac10::cuda::CUDACachingAllocator::CudaMallocAsync::CudaMallocAsyncAllocator889 void copy_data(void* dest, const void* src, std::size_t count) const final {
890 C10_CUDA_CHECK(
891 cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
892 }
893 };
894
895 CudaMallocAsyncAllocator device_allocator;
896
local_raw_delete(void * ptr)897 void local_raw_delete(void* ptr) {
898 freeAsync(ptr);
899 }
allocator()900 CUDAAllocator* allocator() {
901 return &device_allocator;
902 }
903
904 #else
allocator()905 CUDAAllocator* allocator() {
906 TORCH_CHECK(false, "Cannot use cudaMallocAsyncAllocator with cuda < 11.4.");
907 return nullptr;
908 }
909
910 #endif
911
912 } // namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync
913