1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/CachingDeviceAllocator.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAGraphsC10Utils.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMacros.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAStream.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ApproximateClock.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Registry.h>
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker #include <array>
12*da0073e9SAndroid Build Coastguard Worker #include <atomic>
13*da0073e9SAndroid Build Coastguard Worker #include <cstddef>
14*da0073e9SAndroid Build Coastguard Worker #include <cstdint>
15*da0073e9SAndroid Build Coastguard Worker #include <functional>
16*da0073e9SAndroid Build Coastguard Worker #include <memory>
17*da0073e9SAndroid Build Coastguard Worker #include <string>
18*da0073e9SAndroid Build Coastguard Worker #include <unordered_set>
19*da0073e9SAndroid Build Coastguard Worker #include <utility>
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker namespace c10 {
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker // Caching allocator will execute every registered callback if it unable to find
24*da0073e9SAndroid Build Coastguard Worker // block inside of already allocated area.
25*da0073e9SAndroid Build Coastguard Worker class C10_CUDA_API FreeMemoryCallback {
26*da0073e9SAndroid Build Coastguard Worker public:
27*da0073e9SAndroid Build Coastguard Worker virtual ~FreeMemoryCallback() = default;
28*da0073e9SAndroid Build Coastguard Worker virtual bool Execute() = 0;
29*da0073e9SAndroid Build Coastguard Worker };
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
32*da0073e9SAndroid Build Coastguard Worker #define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \
33*da0073e9SAndroid Build Coastguard Worker C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__);
34*da0073e9SAndroid Build Coastguard Worker } // namespace c10
35*da0073e9SAndroid Build Coastguard Worker //
36*da0073e9SAndroid Build Coastguard Worker // TODO: Turn this into an honest to goodness class. I briefly attempted to do
37*da0073e9SAndroid Build Coastguard Worker // this, but it was a bit irritating to figure out how to also correctly
38*da0073e9SAndroid Build Coastguard Worker // apply pimpl pattern so I didn't have to leak any internal implementation
39*da0073e9SAndroid Build Coastguard Worker // details in the header (CUDACachingAllocator could be made a pimpl, but
40*da0073e9SAndroid Build Coastguard Worker // you also need to appropriately define a class which is a subclass
41*da0073e9SAndroid Build Coastguard Worker // of Allocator. Not impossible, but required a bit more surgery than
42*da0073e9SAndroid Build Coastguard Worker // I wanted to do at the time.)
43*da0073e9SAndroid Build Coastguard Worker //
44*da0073e9SAndroid Build Coastguard Worker // Why is this using a namespace rather than old-style THCCachingAllocator_
45*da0073e9SAndroid Build Coastguard Worker // prefix? Mostly because it made the HIPify rules easier to write; _ is
46*da0073e9SAndroid Build Coastguard Worker // not counted as a word boundary, so you would otherwise have to list each
47*da0073e9SAndroid Build Coastguard Worker // of these functions.
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda::CUDACachingAllocator {
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker // Preserved only for BC reasons
52*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(misc-unused-using-decls)
53*da0073e9SAndroid Build Coastguard Worker using c10::CachingDeviceAllocator::DeviceStats;
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker extern const size_t kLargeBuffer;
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker typedef std::shared_ptr<GatheredContext> (*CreateContextFn)();
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker // Struct containing info of an allocation block (i.e. a fractional part of a
60*da0073e9SAndroid Build Coastguard Worker // cudaMalloc)..
61*da0073e9SAndroid Build Coastguard Worker struct BlockInfo {
62*da0073e9SAndroid Build Coastguard Worker size_t size = 0;
63*da0073e9SAndroid Build Coastguard Worker size_t requested_size = 0;
64*da0073e9SAndroid Build Coastguard Worker int32_t gc_counter = 0;
65*da0073e9SAndroid Build Coastguard Worker bool allocated = false;
66*da0073e9SAndroid Build Coastguard Worker bool active = false;
67*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<GatheredContext>
68*da0073e9SAndroid Build Coastguard Worker context_when_allocated; // per-watcher context
69*da0073e9SAndroid Build Coastguard Worker };
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker // Struct containing info of a memory segment (i.e. one contiguous cudaMalloc).
72*da0073e9SAndroid Build Coastguard Worker struct SegmentInfo {
73*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device = 0;
74*da0073e9SAndroid Build Coastguard Worker size_t address = 0;
75*da0073e9SAndroid Build Coastguard Worker size_t total_size = 0;
76*da0073e9SAndroid Build Coastguard Worker size_t requested_size = 0; // unrounded, actually requested size
77*da0073e9SAndroid Build Coastguard Worker size_t allocated_size = 0;
78*da0073e9SAndroid Build Coastguard Worker size_t active_size = 0;
79*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream = nullptr;
80*da0073e9SAndroid Build Coastguard Worker bool is_large = false;
81*da0073e9SAndroid Build Coastguard Worker bool is_expandable = false;
82*da0073e9SAndroid Build Coastguard Worker MempoolId_t owner_private_pool_id = {0, 0};
83*da0073e9SAndroid Build Coastguard Worker std::vector<BlockInfo> blocks;
84*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<GatheredContext> context_when_allocated;
85*da0073e9SAndroid Build Coastguard Worker };
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker struct AllocatorState {
88*da0073e9SAndroid Build Coastguard Worker virtual ~AllocatorState() = default;
89*da0073e9SAndroid Build Coastguard Worker };
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker union trace_time_ {
92*da0073e9SAndroid Build Coastguard Worker time_t t_;
93*da0073e9SAndroid Build Coastguard Worker approx_time_t approx_t_;
94*da0073e9SAndroid Build Coastguard Worker };
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker struct TraceEntry {
97*da0073e9SAndroid Build Coastguard Worker enum Action {
98*da0073e9SAndroid Build Coastguard Worker ALLOC, // API made to the caching allocator for new memory
99*da0073e9SAndroid Build Coastguard Worker FREE_REQUESTED, // API call made to the caching allocator to free memory
100*da0073e9SAndroid Build Coastguard Worker FREE_COMPLETED, // The allocator might have to delay a free because
101*da0073e9SAndroid Build Coastguard Worker // it is still in use on another stream via record_stream
102*da0073e9SAndroid Build Coastguard Worker // This event is generated when a free actually completes.
103*da0073e9SAndroid Build Coastguard Worker SEGMENT_ALLOC, // a call to cudaMalloc to get more memory from the OS
104*da0073e9SAndroid Build Coastguard Worker SEGMENT_FREE, // a call to cudaFree to return memory to the OS (e.g. to
105*da0073e9SAndroid Build Coastguard Worker // defragment or empty_caches)
106*da0073e9SAndroid Build Coastguard Worker SEGMENT_MAP, // a call to cuMemMap (used with expandable_segments)
107*da0073e9SAndroid Build Coastguard Worker SEGMENT_UNMAP, // unmap part of a segment (used with expandable segments)
108*da0073e9SAndroid Build Coastguard Worker SNAPSHOT, // a call to snapshot, used to correlate memory snapshots to trace
109*da0073e9SAndroid Build Coastguard Worker // events
110*da0073e9SAndroid Build Coastguard Worker OOM // the allocator threw an OutOfMemoryError (addr_ is the amount of free
111*da0073e9SAndroid Build Coastguard Worker // bytes reported by cuda)
112*da0073e9SAndroid Build Coastguard Worker };
113*da0073e9SAndroid Build Coastguard Worker TraceEntry(
114*da0073e9SAndroid Build Coastguard Worker Action action,
115*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
116*da0073e9SAndroid Build Coastguard Worker size_t addr,
117*da0073e9SAndroid Build Coastguard Worker size_t size,
118*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream,
119*da0073e9SAndroid Build Coastguard Worker approx_time_t time,
120*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<GatheredContext> context = nullptr)
action_TraceEntry121*da0073e9SAndroid Build Coastguard Worker : action_(action),
122*da0073e9SAndroid Build Coastguard Worker device_(device),
123*da0073e9SAndroid Build Coastguard Worker addr_(addr),
124*da0073e9SAndroid Build Coastguard Worker context_(std::move(context)),
125*da0073e9SAndroid Build Coastguard Worker stream_(stream),
126*da0073e9SAndroid Build Coastguard Worker size_(size) {
127*da0073e9SAndroid Build Coastguard Worker time_.approx_t_ = time;
128*da0073e9SAndroid Build Coastguard Worker }
129*da0073e9SAndroid Build Coastguard Worker Action action_;
130*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device_;
131*da0073e9SAndroid Build Coastguard Worker size_t addr_; // for OOM, this is the amount of free bytes reported by cuda
132*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<GatheredContext> context_;
133*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream_{};
134*da0073e9SAndroid Build Coastguard Worker size_t size_;
135*da0073e9SAndroid Build Coastguard Worker trace_time_ time_{};
136*da0073e9SAndroid Build Coastguard Worker };
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker // Calls made by record_function will save annotations
139*da0073e9SAndroid Build Coastguard Worker struct AnnotationEntry {
AnnotationEntryAnnotationEntry140*da0073e9SAndroid Build Coastguard Worker AnnotationEntry(c10::DeviceIndex device, approx_time_t time)
141*da0073e9SAndroid Build Coastguard Worker : device_(device) {
142*da0073e9SAndroid Build Coastguard Worker time_.approx_t_ = time;
143*da0073e9SAndroid Build Coastguard Worker }
144*da0073e9SAndroid Build Coastguard Worker
recordUserMetadataAnnotationEntry145*da0073e9SAndroid Build Coastguard Worker void recordUserMetadata(const std::string& name, std::string value) {
146*da0073e9SAndroid Build Coastguard Worker metadata_[name] = std::move(value);
147*da0073e9SAndroid Build Coastguard Worker }
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device_;
150*da0073e9SAndroid Build Coastguard Worker trace_time_ time_{};
151*da0073e9SAndroid Build Coastguard Worker std::unordered_map<std::string, std::string> metadata_;
152*da0073e9SAndroid Build Coastguard Worker };
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker struct AllocatorConfigInfo {
155*da0073e9SAndroid Build Coastguard Worker double garbage_collection_threshold;
156*da0073e9SAndroid Build Coastguard Worker size_t max_split_size;
157*da0073e9SAndroid Build Coastguard Worker size_t pinned_num_register_threads;
158*da0073e9SAndroid Build Coastguard Worker bool expandable_segments;
159*da0073e9SAndroid Build Coastguard Worker bool release_lock_on_malloc;
160*da0073e9SAndroid Build Coastguard Worker bool pinned_use_host_register;
161*da0073e9SAndroid Build Coastguard Worker std::string last_allocator_settings;
162*da0073e9SAndroid Build Coastguard Worker std::vector<size_t> roundup_power2_divisions;
163*da0073e9SAndroid Build Coastguard Worker };
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker struct SnapshotInfo {
166*da0073e9SAndroid Build Coastguard Worker std::vector<SegmentInfo> segments;
167*da0073e9SAndroid Build Coastguard Worker std::vector<std::vector<TraceEntry>> device_traces;
168*da0073e9SAndroid Build Coastguard Worker std::vector<AnnotationEntry> external_annotations;
169*da0073e9SAndroid Build Coastguard Worker AllocatorConfigInfo config_metadata;
170*da0073e9SAndroid Build Coastguard Worker };
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker // returns the pointers freed in the pool
173*da0073e9SAndroid Build Coastguard Worker // and the pointers allocated. Note: a pointer
174*da0073e9SAndroid Build Coastguard Worker // may appear in both freed and allocated
175*da0073e9SAndroid Build Coastguard Worker struct CheckpointDelta {
176*da0073e9SAndroid Build Coastguard Worker std::vector<void*> ptrs_freed;
177*da0073e9SAndroid Build Coastguard Worker std::vector<at::DataPtr> dataptrs_allocd;
178*da0073e9SAndroid Build Coastguard Worker };
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker enum struct RecordContext {
181*da0073e9SAndroid Build Coastguard Worker NEVER = 0,
182*da0073e9SAndroid Build Coastguard Worker STATE = 1, // only keep stacks for active allocations
183*da0073e9SAndroid Build Coastguard Worker ALLOC = 2, // additionally keep stacks for allocations in the trace history
184*da0073e9SAndroid Build Coastguard Worker ALL = 3, // additionally record stacks for when something is freed
185*da0073e9SAndroid Build Coastguard Worker };
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker using OutOfMemoryObserver = std::function<void(
188*da0073e9SAndroid Build Coastguard Worker int64_t device,
189*da0073e9SAndroid Build Coastguard Worker size_t allocated,
190*da0073e9SAndroid Build Coastguard Worker size_t device_total,
191*da0073e9SAndroid Build Coastguard Worker size_t device_free)>;
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker using AllocatorTraceTracker = std::function<void(const TraceEntry&)>;
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker struct ShareableHandle {
196*da0073e9SAndroid Build Coastguard Worker ptrdiff_t offset;
197*da0073e9SAndroid Build Coastguard Worker std::string handle;
198*da0073e9SAndroid Build Coastguard Worker };
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker class CUDAAllocator : public Allocator {
201*da0073e9SAndroid Build Coastguard Worker public:
202*da0073e9SAndroid Build Coastguard Worker virtual void* raw_alloc(size_t nbytes) = 0;
203*da0073e9SAndroid Build Coastguard Worker virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0;
204*da0073e9SAndroid Build Coastguard Worker virtual void raw_delete(void* ptr) = 0;
205*da0073e9SAndroid Build Coastguard Worker virtual void init(int device_count) = 0;
206*da0073e9SAndroid Build Coastguard Worker virtual bool initialized() = 0;
207*da0073e9SAndroid Build Coastguard Worker virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
208*da0073e9SAndroid Build Coastguard Worker virtual void emptyCache() = 0;
209*da0073e9SAndroid Build Coastguard Worker virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
210*da0073e9SAndroid Build Coastguard Worker virtual void* getBaseAllocation(void* ptr, size_t* size) = 0;
211*da0073e9SAndroid Build Coastguard Worker virtual void recordStream(const DataPtr&, CUDAStream stream) = 0;
212*da0073e9SAndroid Build Coastguard Worker virtual c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
213*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device) = 0;
214*da0073e9SAndroid Build Coastguard Worker virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
215*da0073e9SAndroid Build Coastguard Worker virtual void resetPeakStats(c10::DeviceIndex device) = 0;
216*da0073e9SAndroid Build Coastguard Worker virtual SnapshotInfo snapshot() = 0;
217*da0073e9SAndroid Build Coastguard Worker virtual void beginAllocateToPool(
218*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
219*da0073e9SAndroid Build Coastguard Worker MempoolId_t mempool_id,
220*da0073e9SAndroid Build Coastguard Worker std::function<bool(cudaStream_t)> filter) = 0;
221*da0073e9SAndroid Build Coastguard Worker virtual void endAllocateToPool(
222*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
223*da0073e9SAndroid Build Coastguard Worker MempoolId_t mempool_id) = 0;
224*da0073e9SAndroid Build Coastguard Worker virtual void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) = 0;
225*da0073e9SAndroid Build Coastguard Worker // returns true if the allocated blocks are equal to expected live allocations
checkPoolLiveAllocations(c10::DeviceIndex device,MempoolId_t mempool_id,const std::unordered_set<void * > & expected_live_allocations)226*da0073e9SAndroid Build Coastguard Worker virtual bool checkPoolLiveAllocations(
227*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
228*da0073e9SAndroid Build Coastguard Worker MempoolId_t mempool_id,
229*da0073e9SAndroid Build Coastguard Worker const std::unordered_set<void*>& expected_live_allocations) {
230*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
231*da0073e9SAndroid Build Coastguard Worker false,
232*da0073e9SAndroid Build Coastguard Worker name(),
233*da0073e9SAndroid Build Coastguard Worker " does not yet support checkPoolLiveAllocations. "
234*da0073e9SAndroid Build Coastguard Worker "If you need it, please file an issue describing your use case.");
235*da0073e9SAndroid Build Coastguard Worker }
236*da0073e9SAndroid Build Coastguard Worker virtual ShareableHandle shareIpcHandle(void* ptr) = 0;
237*da0073e9SAndroid Build Coastguard Worker virtual std::shared_ptr<void> getIpcDevPtr(std::string handle) = 0;
isHistoryEnabled()238*da0073e9SAndroid Build Coastguard Worker virtual bool isHistoryEnabled() {
239*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(
240*da0073e9SAndroid Build Coastguard Worker false,
241*da0073e9SAndroid Build Coastguard Worker name(),
242*da0073e9SAndroid Build Coastguard Worker " does not yet support recordHistory. "
243*da0073e9SAndroid Build Coastguard Worker "If you need it, please file an issue describing your use case.");
244*da0073e9SAndroid Build Coastguard Worker }
245*da0073e9SAndroid Build Coastguard Worker virtual void recordHistory(
246*da0073e9SAndroid Build Coastguard Worker bool enabled,
247*da0073e9SAndroid Build Coastguard Worker CreateContextFn context_recorder,
248*da0073e9SAndroid Build Coastguard Worker size_t alloc_trace_max_entries,
249*da0073e9SAndroid Build Coastguard Worker RecordContext when) = 0;
recordAnnotation(const std::vector<std::pair<std::string,std::string>> & md)250*da0073e9SAndroid Build Coastguard Worker virtual void recordAnnotation(
251*da0073e9SAndroid Build Coastguard Worker const std::vector<std::pair<std::string, std::string>>& md){};
252*da0073e9SAndroid Build Coastguard Worker virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker // Attached AllocatorTraceTracker callbacks will be called while the
255*da0073e9SAndroid Build Coastguard Worker // per-device allocator lock is held. Any additional locks taken from within
256*da0073e9SAndroid Build Coastguard Worker // the callback must be proven to always have the lock order that never
257*da0073e9SAndroid Build Coastguard Worker // triggers a deadlock. In particular, Python's GIL may be held when
258*da0073e9SAndroid Build Coastguard Worker // calling the allocator so it is unsafe to try to acquire the GIL in this
259*da0073e9SAndroid Build Coastguard Worker // callback.
260*da0073e9SAndroid Build Coastguard Worker virtual void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) = 0;
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker virtual void enablePeerAccess(
263*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex dev,
264*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex dev_to_access) = 0;
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker // memory not allocated from cudaMalloc cannot be copied
267*da0073e9SAndroid Build Coastguard Worker // across devices using cudaMemcpyAsync if peer to peer access is disabled.
268*da0073e9SAndroid Build Coastguard Worker // instead it requires cudaMemcpyAsyncPeer
269*da0073e9SAndroid Build Coastguard Worker // with P2P Enabled, all combinations work
270*da0073e9SAndroid Build Coastguard Worker // with P2P Disabled:
271*da0073e9SAndroid Build Coastguard Worker // cudaMalloc cudaMallocAsync/cuMemMap
272*da0073e9SAndroid Build Coastguard Worker // cudaMemcpyAsyncPeer works works
273*da0073e9SAndroid Build Coastguard Worker // cudaMemcpyAsync works error
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker // This function performs chooses to use the Peer version of
276*da0073e9SAndroid Build Coastguard Worker // memcpy if required based on where the allocated put dst/src.
277*da0073e9SAndroid Build Coastguard Worker virtual cudaError_t memcpyAsync(
278*da0073e9SAndroid Build Coastguard Worker void* dst,
279*da0073e9SAndroid Build Coastguard Worker int dstDevice,
280*da0073e9SAndroid Build Coastguard Worker const void* src,
281*da0073e9SAndroid Build Coastguard Worker int srcDevice,
282*da0073e9SAndroid Build Coastguard Worker size_t count,
283*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream,
284*da0073e9SAndroid Build Coastguard Worker bool p2p_enabled) = 0;
285*da0073e9SAndroid Build Coastguard Worker virtual std::shared_ptr<AllocatorState> getCheckpointState(
286*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
287*da0073e9SAndroid Build Coastguard Worker MempoolId_t id) = 0;
288*da0073e9SAndroid Build Coastguard Worker virtual CheckpointDelta setCheckpointPoolState(
289*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
290*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<AllocatorState> pps) = 0;
291*da0073e9SAndroid Build Coastguard Worker virtual std::string name() = 0;
292*da0073e9SAndroid Build Coastguard Worker };
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker // Allocator object, statically initialized
295*da0073e9SAndroid Build Coastguard Worker // See BackendInitializer in CUDACachingAllocator.cpp.
296*da0073e9SAndroid Build Coastguard Worker // Atomic loads on x86 are just normal loads,
297*da0073e9SAndroid Build Coastguard Worker // (atomic stores are different), so reading this value
298*da0073e9SAndroid Build Coastguard Worker // is no different than loading a pointer.
299*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API extern std::atomic<CUDAAllocator*> allocator;
300*da0073e9SAndroid Build Coastguard Worker
get()301*da0073e9SAndroid Build Coastguard Worker inline CUDAAllocator* get() {
302*da0073e9SAndroid Build Coastguard Worker return allocator.load();
303*da0073e9SAndroid Build Coastguard Worker }
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker // Called directly by clients.
raw_alloc(size_t nbytes)306*da0073e9SAndroid Build Coastguard Worker inline void* raw_alloc(size_t nbytes) {
307*da0073e9SAndroid Build Coastguard Worker return get()->raw_alloc(nbytes);
308*da0073e9SAndroid Build Coastguard Worker }
309*da0073e9SAndroid Build Coastguard Worker
raw_alloc_with_stream(size_t nbytes,cudaStream_t stream)310*da0073e9SAndroid Build Coastguard Worker inline void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
311*da0073e9SAndroid Build Coastguard Worker return get()->raw_alloc_with_stream(nbytes, stream);
312*da0073e9SAndroid Build Coastguard Worker }
313*da0073e9SAndroid Build Coastguard Worker
raw_delete(void * ptr)314*da0073e9SAndroid Build Coastguard Worker inline void raw_delete(void* ptr) {
315*da0073e9SAndroid Build Coastguard Worker return get()->raw_delete(ptr);
316*da0073e9SAndroid Build Coastguard Worker }
317*da0073e9SAndroid Build Coastguard Worker
init(int device_count)318*da0073e9SAndroid Build Coastguard Worker inline void init(int device_count) {
319*da0073e9SAndroid Build Coastguard Worker return get()->init(device_count);
320*da0073e9SAndroid Build Coastguard Worker }
321*da0073e9SAndroid Build Coastguard Worker
setMemoryFraction(double fraction,c10::DeviceIndex device)322*da0073e9SAndroid Build Coastguard Worker inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
323*da0073e9SAndroid Build Coastguard Worker return get()->setMemoryFraction(fraction, device);
324*da0073e9SAndroid Build Coastguard Worker }
325*da0073e9SAndroid Build Coastguard Worker
emptyCache()326*da0073e9SAndroid Build Coastguard Worker inline void emptyCache() {
327*da0073e9SAndroid Build Coastguard Worker return get()->emptyCache();
328*da0073e9SAndroid Build Coastguard Worker }
329*da0073e9SAndroid Build Coastguard Worker
cacheInfo(c10::DeviceIndex device,size_t * largestBlock)330*da0073e9SAndroid Build Coastguard Worker inline void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) {
331*da0073e9SAndroid Build Coastguard Worker return get()->cacheInfo(device, largestBlock);
332*da0073e9SAndroid Build Coastguard Worker }
333*da0073e9SAndroid Build Coastguard Worker
getBaseAllocation(void * ptr,size_t * size)334*da0073e9SAndroid Build Coastguard Worker inline void* getBaseAllocation(void* ptr, size_t* size) {
335*da0073e9SAndroid Build Coastguard Worker return get()->getBaseAllocation(ptr, size);
336*da0073e9SAndroid Build Coastguard Worker }
337*da0073e9SAndroid Build Coastguard Worker
recordStream(const DataPtr & dataPtr,CUDAStream stream)338*da0073e9SAndroid Build Coastguard Worker inline void recordStream(const DataPtr& dataPtr, CUDAStream stream) {
339*da0073e9SAndroid Build Coastguard Worker return get()->recordStream(dataPtr, stream);
340*da0073e9SAndroid Build Coastguard Worker }
341*da0073e9SAndroid Build Coastguard Worker
getDeviceStats(c10::DeviceIndex device)342*da0073e9SAndroid Build Coastguard Worker inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
343*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device) {
344*da0073e9SAndroid Build Coastguard Worker return get()->getDeviceStats(device);
345*da0073e9SAndroid Build Coastguard Worker }
346*da0073e9SAndroid Build Coastguard Worker
resetAccumulatedStats(c10::DeviceIndex device)347*da0073e9SAndroid Build Coastguard Worker inline void resetAccumulatedStats(c10::DeviceIndex device) {
348*da0073e9SAndroid Build Coastguard Worker return get()->resetAccumulatedStats(device);
349*da0073e9SAndroid Build Coastguard Worker }
350*da0073e9SAndroid Build Coastguard Worker
resetPeakStats(c10::DeviceIndex device)351*da0073e9SAndroid Build Coastguard Worker inline void resetPeakStats(c10::DeviceIndex device) {
352*da0073e9SAndroid Build Coastguard Worker return get()->resetPeakStats(device);
353*da0073e9SAndroid Build Coastguard Worker }
354*da0073e9SAndroid Build Coastguard Worker
snapshot()355*da0073e9SAndroid Build Coastguard Worker inline SnapshotInfo snapshot() {
356*da0073e9SAndroid Build Coastguard Worker return get()->snapshot();
357*da0073e9SAndroid Build Coastguard Worker }
358*da0073e9SAndroid Build Coastguard Worker
getCheckpointState(c10::DeviceIndex device,MempoolId_t id)359*da0073e9SAndroid Build Coastguard Worker inline std::shared_ptr<AllocatorState> getCheckpointState(
360*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
361*da0073e9SAndroid Build Coastguard Worker MempoolId_t id) {
362*da0073e9SAndroid Build Coastguard Worker return get()->getCheckpointState(device, id);
363*da0073e9SAndroid Build Coastguard Worker }
364*da0073e9SAndroid Build Coastguard Worker
setCheckpointPoolState(c10::DeviceIndex device,std::shared_ptr<AllocatorState> pps)365*da0073e9SAndroid Build Coastguard Worker inline CheckpointDelta setCheckpointPoolState(
366*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
367*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<AllocatorState> pps) {
368*da0073e9SAndroid Build Coastguard Worker return get()->setCheckpointPoolState(device, std::move(pps));
369*da0073e9SAndroid Build Coastguard Worker }
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker // CUDAGraph interactions
beginAllocateToPool(c10::DeviceIndex device,MempoolId_t mempool_id,std::function<bool (cudaStream_t)> filter)372*da0073e9SAndroid Build Coastguard Worker inline void beginAllocateToPool(
373*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
374*da0073e9SAndroid Build Coastguard Worker MempoolId_t mempool_id,
375*da0073e9SAndroid Build Coastguard Worker std::function<bool(cudaStream_t)> filter) {
376*da0073e9SAndroid Build Coastguard Worker get()->beginAllocateToPool(device, mempool_id, std::move(filter));
377*da0073e9SAndroid Build Coastguard Worker }
378*da0073e9SAndroid Build Coastguard Worker
endAllocateToPool(c10::DeviceIndex device,MempoolId_t mempool_id)379*da0073e9SAndroid Build Coastguard Worker inline void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) {
380*da0073e9SAndroid Build Coastguard Worker get()->endAllocateToPool(device, mempool_id);
381*da0073e9SAndroid Build Coastguard Worker }
382*da0073e9SAndroid Build Coastguard Worker
recordHistory(bool enabled,CreateContextFn context_recorder,size_t alloc_trace_max_entries,RecordContext when)383*da0073e9SAndroid Build Coastguard Worker inline void recordHistory(
384*da0073e9SAndroid Build Coastguard Worker bool enabled,
385*da0073e9SAndroid Build Coastguard Worker CreateContextFn context_recorder,
386*da0073e9SAndroid Build Coastguard Worker size_t alloc_trace_max_entries,
387*da0073e9SAndroid Build Coastguard Worker RecordContext when) {
388*da0073e9SAndroid Build Coastguard Worker return get()->recordHistory(
389*da0073e9SAndroid Build Coastguard Worker enabled, context_recorder, alloc_trace_max_entries, when);
390*da0073e9SAndroid Build Coastguard Worker }
391*da0073e9SAndroid Build Coastguard Worker
recordAnnotation(const std::vector<std::pair<std::string,std::string>> & md)392*da0073e9SAndroid Build Coastguard Worker inline void recordAnnotation(
393*da0073e9SAndroid Build Coastguard Worker const std::vector<std::pair<std::string, std::string>>& md) {
394*da0073e9SAndroid Build Coastguard Worker return get()->recordAnnotation(md);
395*da0073e9SAndroid Build Coastguard Worker }
396*da0073e9SAndroid Build Coastguard Worker
isHistoryEnabled()397*da0073e9SAndroid Build Coastguard Worker inline bool isHistoryEnabled() {
398*da0073e9SAndroid Build Coastguard Worker return get()->isHistoryEnabled();
399*da0073e9SAndroid Build Coastguard Worker }
400*da0073e9SAndroid Build Coastguard Worker
checkPoolLiveAllocations(c10::DeviceIndex device,MempoolId_t mempool_id,const std::unordered_set<void * > & expected_live_allocations)401*da0073e9SAndroid Build Coastguard Worker inline bool checkPoolLiveAllocations(
402*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex device,
403*da0073e9SAndroid Build Coastguard Worker MempoolId_t mempool_id,
404*da0073e9SAndroid Build Coastguard Worker const std::unordered_set<void*>& expected_live_allocations) {
405*da0073e9SAndroid Build Coastguard Worker return get()->checkPoolLiveAllocations(
406*da0073e9SAndroid Build Coastguard Worker device, mempool_id, expected_live_allocations);
407*da0073e9SAndroid Build Coastguard Worker }
408*da0073e9SAndroid Build Coastguard Worker
attachOutOfMemoryObserver(OutOfMemoryObserver observer)409*da0073e9SAndroid Build Coastguard Worker inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
410*da0073e9SAndroid Build Coastguard Worker return get()->attachOutOfMemoryObserver(std::move(observer));
411*da0073e9SAndroid Build Coastguard Worker }
412*da0073e9SAndroid Build Coastguard Worker
attachAllocatorTraceTracker(AllocatorTraceTracker tracker)413*da0073e9SAndroid Build Coastguard Worker inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
414*da0073e9SAndroid Build Coastguard Worker return get()->attachAllocatorTraceTracker(std::move(tracker));
415*da0073e9SAndroid Build Coastguard Worker }
416*da0073e9SAndroid Build Coastguard Worker
releasePool(c10::DeviceIndex device,MempoolId_t mempool_id)417*da0073e9SAndroid Build Coastguard Worker inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
418*da0073e9SAndroid Build Coastguard Worker return get()->releasePool(device, mempool_id);
419*da0073e9SAndroid Build Coastguard Worker }
420*da0073e9SAndroid Build Coastguard Worker // Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE
getIpcDevPtr(std::string handle)421*da0073e9SAndroid Build Coastguard Worker inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
422*da0073e9SAndroid Build Coastguard Worker return get()->getIpcDevPtr(std::move(handle));
423*da0073e9SAndroid Build Coastguard Worker }
424*da0073e9SAndroid Build Coastguard Worker
shareIpcHandle(void * ptr)425*da0073e9SAndroid Build Coastguard Worker inline ShareableHandle shareIpcHandle(void* ptr) {
426*da0073e9SAndroid Build Coastguard Worker return get()->shareIpcHandle(ptr);
427*da0073e9SAndroid Build Coastguard Worker }
428*da0073e9SAndroid Build Coastguard Worker
name()429*da0073e9SAndroid Build Coastguard Worker inline std::string name() {
430*da0073e9SAndroid Build Coastguard Worker return get()->name();
431*da0073e9SAndroid Build Coastguard Worker }
432*da0073e9SAndroid Build Coastguard Worker
memcpyAsync(void * dst,int dstDevice,const void * src,int srcDevice,size_t count,cudaStream_t stream,bool p2p_enabled)433*da0073e9SAndroid Build Coastguard Worker inline cudaError_t memcpyAsync(
434*da0073e9SAndroid Build Coastguard Worker void* dst,
435*da0073e9SAndroid Build Coastguard Worker int dstDevice,
436*da0073e9SAndroid Build Coastguard Worker const void* src,
437*da0073e9SAndroid Build Coastguard Worker int srcDevice,
438*da0073e9SAndroid Build Coastguard Worker size_t count,
439*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream,
440*da0073e9SAndroid Build Coastguard Worker bool p2p_enabled) {
441*da0073e9SAndroid Build Coastguard Worker return get()->memcpyAsync(
442*da0073e9SAndroid Build Coastguard Worker dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
443*da0073e9SAndroid Build Coastguard Worker }
444*da0073e9SAndroid Build Coastguard Worker
enablePeerAccess(c10::DeviceIndex dev,c10::DeviceIndex dev_to_access)445*da0073e9SAndroid Build Coastguard Worker inline void enablePeerAccess(
446*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex dev,
447*da0073e9SAndroid Build Coastguard Worker c10::DeviceIndex dev_to_access) {
448*da0073e9SAndroid Build Coastguard Worker return get()->enablePeerAccess(dev, dev_to_access);
449*da0073e9SAndroid Build Coastguard Worker }
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda::CUDACachingAllocator
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker // MemPool represents a pool of memory in a caching allocator. Currently,
456*da0073e9SAndroid Build Coastguard Worker // it's just the ID of the pool object maintained in the CUDACachingAllocator.
457*da0073e9SAndroid Build Coastguard Worker //
458*da0073e9SAndroid Build Coastguard Worker // An allocator pointer can be passed to the MemPool to define how the
459*da0073e9SAndroid Build Coastguard Worker // allocations should be done in the pool. For example: using a different
460*da0073e9SAndroid Build Coastguard Worker // system allocator such as ncclMemAlloc.
461*da0073e9SAndroid Build Coastguard Worker struct C10_CUDA_API MemPool {
462*da0073e9SAndroid Build Coastguard Worker MemPool(
463*da0073e9SAndroid Build Coastguard Worker CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
464*da0073e9SAndroid Build Coastguard Worker bool is_user_created = true);
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker MempoolId_t id();
467*da0073e9SAndroid Build Coastguard Worker CUDACachingAllocator::CUDAAllocator* allocator();
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker private:
470*da0073e9SAndroid Build Coastguard Worker static std::atomic<CaptureId_t> uid_;
471*da0073e9SAndroid Build Coastguard Worker static std::atomic<CaptureId_t> uuid_;
472*da0073e9SAndroid Build Coastguard Worker CUDACachingAllocator::CUDAAllocator* allocator_;
473*da0073e9SAndroid Build Coastguard Worker bool is_user_created_;
474*da0073e9SAndroid Build Coastguard Worker MempoolId_t id_;
475*da0073e9SAndroid Build Coastguard Worker };
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker // MemPoolContext holds the currently active pool and stashes the previous
478*da0073e9SAndroid Build Coastguard Worker // pool. On deletion it makes the previous pool active.
479*da0073e9SAndroid Build Coastguard Worker struct C10_CUDA_API MemPoolContext {
480*da0073e9SAndroid Build Coastguard Worker MemPoolContext(MemPool* mempool);
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker ~MemPoolContext();
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker // getActiveMemPool() can be used to get the currently active pool.
485*da0073e9SAndroid Build Coastguard Worker // For instance: in CUDACachingAllocator, we can route allocations
486*da0073e9SAndroid Build Coastguard Worker // to a user provided allocator, by doing:
487*da0073e9SAndroid Build Coastguard Worker //
488*da0073e9SAndroid Build Coastguard Worker // auto active_pool = MemPoolContext::getActiveMemPool();
489*da0073e9SAndroid Build Coastguard Worker // if (active_pool && active_pool->allocator()) {
490*da0073e9SAndroid Build Coastguard Worker // ptr = active_pool->allocator()->raw_alloc(size);
491*da0073e9SAndroid Build Coastguard Worker // }
492*da0073e9SAndroid Build Coastguard Worker //
493*da0073e9SAndroid Build Coastguard Worker static MemPool* getActiveMemPool();
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker private:
496*da0073e9SAndroid Build Coastguard Worker MemPool* prev_mempool_;
497*da0073e9SAndroid Build Coastguard Worker };
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
500