xref: /aosp_15_r20/external/pytorch/c10/xpu/XPUCachingAllocator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <c10/util/flat_hash_map.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUCachingAllocator.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <deque>
6*da0073e9SAndroid Build Coastguard Worker #include <mutex>
7*da0073e9SAndroid Build Coastguard Worker #include <set>
8*da0073e9SAndroid Build Coastguard Worker #include <vector>
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker namespace c10::xpu::XPUCachingAllocator {
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker using namespace c10::CachingDeviceAllocator;
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker // newly allocated memory with 512-byte alignment.
15*da0073e9SAndroid Build Coastguard Worker constexpr size_t kDeviceAlignment = 512;
16*da0073e9SAndroid Build Coastguard Worker // all sizes are rounded to at least 512 bytes
17*da0073e9SAndroid Build Coastguard Worker constexpr size_t kMinBlockSize = 512;
18*da0073e9SAndroid Build Coastguard Worker // largest "small" allocation is 1 MiB
19*da0073e9SAndroid Build Coastguard Worker constexpr size_t kSmallSize = 1048576;
20*da0073e9SAndroid Build Coastguard Worker // "small" allocations are packed in 2 MiB blocks
21*da0073e9SAndroid Build Coastguard Worker constexpr size_t kSmallBuffer = 2097152;
22*da0073e9SAndroid Build Coastguard Worker // "large" allocations may be packed in 20 MiB blocks
23*da0073e9SAndroid Build Coastguard Worker constexpr size_t kLargeBuffer = 20971520;
24*da0073e9SAndroid Build Coastguard Worker // allocations between 1 and 10 MiB may use kLargeBuffer
25*da0073e9SAndroid Build Coastguard Worker constexpr size_t kMinLargeAlloc = 10485760;
26*da0073e9SAndroid Build Coastguard Worker // round up large allocations to 2 MiB
27*da0073e9SAndroid Build Coastguard Worker constexpr size_t kRoundLarge = 2097152;
28*da0073e9SAndroid Build Coastguard Worker 
29*da0073e9SAndroid Build Coastguard Worker namespace {
30*da0073e9SAndroid Build Coastguard Worker using stream_set = ska::flat_hash_set<xpu::XPUStream>;
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker struct Block;
33*da0073e9SAndroid Build Coastguard Worker typedef bool (*Comparison)(const Block*, const Block*);
34*da0073e9SAndroid Build Coastguard Worker bool BlockComparatorSize(const Block* a, const Block* b);
35*da0073e9SAndroid Build Coastguard Worker 
36*da0073e9SAndroid Build Coastguard Worker struct BlockPool {
BlockPoolc10::xpu::XPUCachingAllocator::__anon8131195a0111::BlockPool37*da0073e9SAndroid Build Coastguard Worker   BlockPool(bool small) : blocks(BlockComparatorSize), is_small(small) {}
38*da0073e9SAndroid Build Coastguard Worker   std::set<Block*, Comparison> blocks;
39*da0073e9SAndroid Build Coastguard Worker   const bool is_small;
40*da0073e9SAndroid Build Coastguard Worker };
41*da0073e9SAndroid Build Coastguard Worker 
42*da0073e9SAndroid Build Coastguard Worker struct Block {
43*da0073e9SAndroid Build Coastguard Worker   DeviceIndex device;
44*da0073e9SAndroid Build Coastguard Worker   sycl::queue* queue{nullptr}; // underlying queue of the allocation stream
45*da0073e9SAndroid Build Coastguard Worker   stream_set stream_uses; // streams on which the block was used
46*da0073e9SAndroid Build Coastguard Worker   size_t size; // block size in bytes
47*da0073e9SAndroid Build Coastguard Worker   size_t requested_size; // memory originally requested
48*da0073e9SAndroid Build Coastguard Worker   BlockPool* pool{nullptr}; // owning memory pool
49*da0073e9SAndroid Build Coastguard Worker   void* ptr{nullptr}; // memory address
50*da0073e9SAndroid Build Coastguard Worker   bool allocated{false}; // in-use flag
51*da0073e9SAndroid Build Coastguard Worker   Block* prev{nullptr}; // prev block if split from a larger allocation
52*da0073e9SAndroid Build Coastguard Worker   Block* next{nullptr}; // next block if split from a larger allocation
53*da0073e9SAndroid Build Coastguard Worker   int event_count{0}; // number of outstanding XPU events
54*da0073e9SAndroid Build Coastguard Worker 
Blockc10::xpu::XPUCachingAllocator::__anon8131195a0111::Block55*da0073e9SAndroid Build Coastguard Worker   Block(
56*da0073e9SAndroid Build Coastguard Worker       DeviceIndex device,
57*da0073e9SAndroid Build Coastguard Worker       sycl::queue* queue,
58*da0073e9SAndroid Build Coastguard Worker       size_t size,
59*da0073e9SAndroid Build Coastguard Worker       BlockPool* pool,
60*da0073e9SAndroid Build Coastguard Worker       void* ptr)
61*da0073e9SAndroid Build Coastguard Worker       : device(device),
62*da0073e9SAndroid Build Coastguard Worker         queue(queue),
63*da0073e9SAndroid Build Coastguard Worker         stream_uses(),
64*da0073e9SAndroid Build Coastguard Worker         size(size),
65*da0073e9SAndroid Build Coastguard Worker         requested_size(0),
66*da0073e9SAndroid Build Coastguard Worker         pool(pool),
67*da0073e9SAndroid Build Coastguard Worker         ptr(ptr) {}
68*da0073e9SAndroid Build Coastguard Worker 
69*da0073e9SAndroid Build Coastguard Worker   // constructor for search key
Blockc10::xpu::XPUCachingAllocator::__anon8131195a0111::Block70*da0073e9SAndroid Build Coastguard Worker   Block(DeviceIndex device, sycl::queue* queue, size_t size)
71*da0073e9SAndroid Build Coastguard Worker       : device(device),
72*da0073e9SAndroid Build Coastguard Worker         queue(queue),
73*da0073e9SAndroid Build Coastguard Worker         stream_uses(),
74*da0073e9SAndroid Build Coastguard Worker         size(size),
75*da0073e9SAndroid Build Coastguard Worker         requested_size(0) {}
76*da0073e9SAndroid Build Coastguard Worker 
is_splitc10::xpu::XPUCachingAllocator::__anon8131195a0111::Block77*da0073e9SAndroid Build Coastguard Worker   bool is_split() const {
78*da0073e9SAndroid Build Coastguard Worker     return (prev != nullptr) || (next != nullptr);
79*da0073e9SAndroid Build Coastguard Worker   }
80*da0073e9SAndroid Build Coastguard Worker };
81*da0073e9SAndroid Build Coastguard Worker 
BlockComparatorSize(const Block * a,const Block * b)82*da0073e9SAndroid Build Coastguard Worker bool BlockComparatorSize(const Block* a, const Block* b) {
83*da0073e9SAndroid Build Coastguard Worker   if (a->queue != b->queue) {
84*da0073e9SAndroid Build Coastguard Worker     return reinterpret_cast<uintptr_t>(a->queue) <
85*da0073e9SAndroid Build Coastguard Worker         reinterpret_cast<uintptr_t>(b->queue);
86*da0073e9SAndroid Build Coastguard Worker   }
87*da0073e9SAndroid Build Coastguard Worker   if (a->size != b->size) {
88*da0073e9SAndroid Build Coastguard Worker     return a->size < b->size;
89*da0073e9SAndroid Build Coastguard Worker   }
90*da0073e9SAndroid Build Coastguard Worker   return reinterpret_cast<uintptr_t>(a->ptr) <
91*da0073e9SAndroid Build Coastguard Worker       reinterpret_cast<uintptr_t>(b->ptr);
92*da0073e9SAndroid Build Coastguard Worker }
93*da0073e9SAndroid Build Coastguard Worker 
94*da0073e9SAndroid Build Coastguard Worker struct AllocParams {
AllocParamsc10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams95*da0073e9SAndroid Build Coastguard Worker   AllocParams(
96*da0073e9SAndroid Build Coastguard Worker       DeviceIndex device,
97*da0073e9SAndroid Build Coastguard Worker       size_t size,
98*da0073e9SAndroid Build Coastguard Worker       sycl::queue* queue,
99*da0073e9SAndroid Build Coastguard Worker       BlockPool* pool,
100*da0073e9SAndroid Build Coastguard Worker       size_t alloc_size)
101*da0073e9SAndroid Build Coastguard Worker       : search_key(device, queue, size),
102*da0073e9SAndroid Build Coastguard Worker         pool(pool),
103*da0073e9SAndroid Build Coastguard Worker         alloc_size(alloc_size),
104*da0073e9SAndroid Build Coastguard Worker         block(nullptr) {}
105*da0073e9SAndroid Build Coastguard Worker 
devicec10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams106*da0073e9SAndroid Build Coastguard Worker   DeviceIndex device() const {
107*da0073e9SAndroid Build Coastguard Worker     return search_key.device;
108*da0073e9SAndroid Build Coastguard Worker   }
109*da0073e9SAndroid Build Coastguard Worker 
queuec10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams110*da0073e9SAndroid Build Coastguard Worker   sycl::queue* queue() const {
111*da0073e9SAndroid Build Coastguard Worker     return search_key.queue;
112*da0073e9SAndroid Build Coastguard Worker   }
113*da0073e9SAndroid Build Coastguard Worker 
sizec10::xpu::XPUCachingAllocator::__anon8131195a0111::AllocParams114*da0073e9SAndroid Build Coastguard Worker   size_t size() const {
115*da0073e9SAndroid Build Coastguard Worker     return search_key.size;
116*da0073e9SAndroid Build Coastguard Worker   }
117*da0073e9SAndroid Build Coastguard Worker 
118*da0073e9SAndroid Build Coastguard Worker   Block search_key;
119*da0073e9SAndroid Build Coastguard Worker   BlockPool* pool;
120*da0073e9SAndroid Build Coastguard Worker   size_t alloc_size;
121*da0073e9SAndroid Build Coastguard Worker   Block* block;
122*da0073e9SAndroid Build Coastguard Worker   StatTypes stat_types = {};
123*da0073e9SAndroid Build Coastguard Worker };
124*da0073e9SAndroid Build Coastguard Worker 
125*da0073e9SAndroid Build Coastguard Worker } // anonymous namespace
126*da0073e9SAndroid Build Coastguard Worker 
127*da0073e9SAndroid Build Coastguard Worker class DeviceCachingAllocator {
128*da0073e9SAndroid Build Coastguard Worker  private:
129*da0073e9SAndroid Build Coastguard Worker   mutable std::recursive_mutex mutex;
130*da0073e9SAndroid Build Coastguard Worker   DeviceStats stats;
131*da0073e9SAndroid Build Coastguard Worker   BlockPool large_blocks; // unallocated cached blocks larger than 1 MB
132*da0073e9SAndroid Build Coastguard Worker   BlockPool small_blocks; // unallocated cached blocks 1 MB or smaller
133*da0073e9SAndroid Build Coastguard Worker   ska::flat_hash_set<Block*> active_blocks; // allocated or in use by a stream
134*da0073e9SAndroid Build Coastguard Worker   ska::flat_hash_map<xpu::XPUStream, std::deque<std::pair<sycl::event, Block*>>>
135*da0073e9SAndroid Build Coastguard Worker       xpu_events;
136*da0073e9SAndroid Build Coastguard Worker   DeviceIndex device_index;
137*da0073e9SAndroid Build Coastguard Worker 
try_merge_blocks(Block * dst,Block * src,BlockPool & pool)138*da0073e9SAndroid Build Coastguard Worker   size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
139*da0073e9SAndroid Build Coastguard Worker     if (!src || src->allocated || src->event_count > 0 ||
140*da0073e9SAndroid Build Coastguard Worker         !src->stream_uses.empty()) {
141*da0073e9SAndroid Build Coastguard Worker       return 0;
142*da0073e9SAndroid Build Coastguard Worker     }
143*da0073e9SAndroid Build Coastguard Worker 
144*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(dst->is_split() && src->is_split());
145*da0073e9SAndroid Build Coastguard Worker     if (dst->prev == src) { // [src dst]
146*da0073e9SAndroid Build Coastguard Worker       dst->ptr = src->ptr;
147*da0073e9SAndroid Build Coastguard Worker       dst->prev = src->prev;
148*da0073e9SAndroid Build Coastguard Worker       if (dst->prev) {
149*da0073e9SAndroid Build Coastguard Worker         dst->prev->next = dst;
150*da0073e9SAndroid Build Coastguard Worker       }
151*da0073e9SAndroid Build Coastguard Worker     } else { // [dst src]
152*da0073e9SAndroid Build Coastguard Worker       dst->next = src->next;
153*da0073e9SAndroid Build Coastguard Worker       if (dst->next) {
154*da0073e9SAndroid Build Coastguard Worker         dst->next->prev = dst;
155*da0073e9SAndroid Build Coastguard Worker       }
156*da0073e9SAndroid Build Coastguard Worker     }
157*da0073e9SAndroid Build Coastguard Worker     const size_t subsumed_size = src->size;
158*da0073e9SAndroid Build Coastguard Worker     dst->size += subsumed_size;
159*da0073e9SAndroid Build Coastguard Worker     auto erased = pool.blocks.erase(src);
160*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
161*da0073e9SAndroid Build Coastguard Worker     delete src;
162*da0073e9SAndroid Build Coastguard Worker 
163*da0073e9SAndroid Build Coastguard Worker     return subsumed_size;
164*da0073e9SAndroid Build Coastguard Worker   }
165*da0073e9SAndroid Build Coastguard Worker 
free_block(Block * block)166*da0073e9SAndroid Build Coastguard Worker   void free_block(Block* block) {
167*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(
168*da0073e9SAndroid Build Coastguard Worker         !block->allocated && block->event_count == 0 &&
169*da0073e9SAndroid Build Coastguard Worker         block->stream_uses.empty());
170*da0073e9SAndroid Build Coastguard Worker 
171*da0073e9SAndroid Build Coastguard Worker     size_t original_block_size = block->size;
172*da0073e9SAndroid Build Coastguard Worker     size_t requested_size = block->requested_size;
173*da0073e9SAndroid Build Coastguard Worker     auto& pool = *block->pool;
174*da0073e9SAndroid Build Coastguard Worker     const std::array<Block*, 2> merge_candidates = {block->prev, block->next};
175*da0073e9SAndroid Build Coastguard Worker     for (Block* merge_candidate : merge_candidates) {
176*da0073e9SAndroid Build Coastguard Worker       try_merge_blocks(block, merge_candidate, pool);
177*da0073e9SAndroid Build Coastguard Worker     }
178*da0073e9SAndroid Build Coastguard Worker 
179*da0073e9SAndroid Build Coastguard Worker     active_blocks.erase(block);
180*da0073e9SAndroid Build Coastguard Worker     bool inserted = pool.blocks.insert(block).second;
181*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
182*da0073e9SAndroid Build Coastguard Worker 
183*da0073e9SAndroid Build Coastguard Worker     StatTypes stat_types = get_stat_types_for_pool(pool);
184*da0073e9SAndroid Build Coastguard Worker     for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
185*da0073e9SAndroid Build Coastguard Worker       stats.active_bytes[stat_type].decrease(original_block_size);
186*da0073e9SAndroid Build Coastguard Worker       stats.requested_bytes[stat_type].decrease(requested_size);
187*da0073e9SAndroid Build Coastguard Worker     });
188*da0073e9SAndroid Build Coastguard Worker   }
189*da0073e9SAndroid Build Coastguard Worker 
process_events()190*da0073e9SAndroid Build Coastguard Worker   void process_events() {
191*da0073e9SAndroid Build Coastguard Worker     using namespace sycl::info;
192*da0073e9SAndroid Build Coastguard Worker     for (auto it = xpu_events.begin(); it != xpu_events.end();) {
193*da0073e9SAndroid Build Coastguard Worker       while (!it->second.empty()) {
194*da0073e9SAndroid Build Coastguard Worker         auto& e = it->second.front();
195*da0073e9SAndroid Build Coastguard Worker         auto event = e.first;
196*da0073e9SAndroid Build Coastguard Worker         auto* block = e.second;
197*da0073e9SAndroid Build Coastguard Worker         if (event.get_info<event::command_execution_status>() !=
198*da0073e9SAndroid Build Coastguard Worker             event_command_status::complete) {
199*da0073e9SAndroid Build Coastguard Worker           break;
200*da0073e9SAndroid Build Coastguard Worker         }
201*da0073e9SAndroid Build Coastguard Worker         block->event_count--;
202*da0073e9SAndroid Build Coastguard Worker         if (block->event_count == 0) {
203*da0073e9SAndroid Build Coastguard Worker           free_block(block);
204*da0073e9SAndroid Build Coastguard Worker         }
205*da0073e9SAndroid Build Coastguard Worker         it->second.pop_front();
206*da0073e9SAndroid Build Coastguard Worker       }
207*da0073e9SAndroid Build Coastguard Worker 
208*da0073e9SAndroid Build Coastguard Worker       if (it->second.empty()) {
209*da0073e9SAndroid Build Coastguard Worker         it = xpu_events.erase(it);
210*da0073e9SAndroid Build Coastguard Worker       } else {
211*da0073e9SAndroid Build Coastguard Worker         it++;
212*da0073e9SAndroid Build Coastguard Worker       }
213*da0073e9SAndroid Build Coastguard Worker     }
214*da0073e9SAndroid Build Coastguard Worker   }
215*da0073e9SAndroid Build Coastguard Worker 
round_size(size_t size)216*da0073e9SAndroid Build Coastguard Worker   static size_t round_size(size_t size) {
217*da0073e9SAndroid Build Coastguard Worker     if (size < kMinBlockSize) {
218*da0073e9SAndroid Build Coastguard Worker       return kMinBlockSize;
219*da0073e9SAndroid Build Coastguard Worker     } else {
220*da0073e9SAndroid Build Coastguard Worker       return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
221*da0073e9SAndroid Build Coastguard Worker     }
222*da0073e9SAndroid Build Coastguard Worker   }
223*da0073e9SAndroid Build Coastguard Worker 
get_allocation_size(size_t size)224*da0073e9SAndroid Build Coastguard Worker   static size_t get_allocation_size(size_t size) {
225*da0073e9SAndroid Build Coastguard Worker     if (size <= kSmallSize) {
226*da0073e9SAndroid Build Coastguard Worker       return kSmallBuffer;
227*da0073e9SAndroid Build Coastguard Worker     } else if (size < kMinLargeAlloc) {
228*da0073e9SAndroid Build Coastguard Worker       return kLargeBuffer;
229*da0073e9SAndroid Build Coastguard Worker     } else {
230*da0073e9SAndroid Build Coastguard Worker       return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
231*da0073e9SAndroid Build Coastguard Worker     }
232*da0073e9SAndroid Build Coastguard Worker   }
233*da0073e9SAndroid Build Coastguard Worker 
get_pool(size_t size)234*da0073e9SAndroid Build Coastguard Worker   BlockPool& get_pool(size_t size) {
235*da0073e9SAndroid Build Coastguard Worker     if (size < kSmallSize) {
236*da0073e9SAndroid Build Coastguard Worker       return small_blocks;
237*da0073e9SAndroid Build Coastguard Worker     } else {
238*da0073e9SAndroid Build Coastguard Worker       return large_blocks;
239*da0073e9SAndroid Build Coastguard Worker     }
240*da0073e9SAndroid Build Coastguard Worker   }
241*da0073e9SAndroid Build Coastguard Worker 
get_free_block(AllocParams & p)242*da0073e9SAndroid Build Coastguard Worker   bool get_free_block(AllocParams& p) {
243*da0073e9SAndroid Build Coastguard Worker     BlockPool& pool = *p.pool;
244*da0073e9SAndroid Build Coastguard Worker     auto it = pool.blocks.lower_bound(&p.search_key);
245*da0073e9SAndroid Build Coastguard Worker     if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
246*da0073e9SAndroid Build Coastguard Worker       return false;
247*da0073e9SAndroid Build Coastguard Worker     }
248*da0073e9SAndroid Build Coastguard Worker     p.block = *it;
249*da0073e9SAndroid Build Coastguard Worker     pool.blocks.erase(it);
250*da0073e9SAndroid Build Coastguard Worker     return true;
251*da0073e9SAndroid Build Coastguard Worker   }
252*da0073e9SAndroid Build Coastguard Worker 
alloc_block(AllocParams & p)253*da0073e9SAndroid Build Coastguard Worker   bool alloc_block(AllocParams& p) {
254*da0073e9SAndroid Build Coastguard Worker     auto size = p.alloc_size;
255*da0073e9SAndroid Build Coastguard Worker     auto device = p.device();
256*da0073e9SAndroid Build Coastguard Worker     void* ptr = sycl::aligned_alloc_device(
257*da0073e9SAndroid Build Coastguard Worker         kDeviceAlignment,
258*da0073e9SAndroid Build Coastguard Worker         size,
259*da0073e9SAndroid Build Coastguard Worker         xpu::get_raw_device(device),
260*da0073e9SAndroid Build Coastguard Worker         xpu::get_device_context());
261*da0073e9SAndroid Build Coastguard Worker     if (!ptr) {
262*da0073e9SAndroid Build Coastguard Worker       return false;
263*da0073e9SAndroid Build Coastguard Worker     }
264*da0073e9SAndroid Build Coastguard Worker     p.block = new Block(device, p.queue(), size, p.pool, ptr);
265*da0073e9SAndroid Build Coastguard Worker     for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
266*da0073e9SAndroid Build Coastguard Worker       stats.reserved_bytes[stat_type].increase(size);
267*da0073e9SAndroid Build Coastguard Worker     });
268*da0073e9SAndroid Build Coastguard Worker     return true;
269*da0073e9SAndroid Build Coastguard Worker   }
270*da0073e9SAndroid Build Coastguard Worker 
synchronize_and_free_events()271*da0073e9SAndroid Build Coastguard Worker   void synchronize_and_free_events() {
272*da0073e9SAndroid Build Coastguard Worker     for (auto& xe : xpu_events) {
273*da0073e9SAndroid Build Coastguard Worker       for (auto& e : xe.second) {
274*da0073e9SAndroid Build Coastguard Worker         auto event = e.first;
275*da0073e9SAndroid Build Coastguard Worker         auto* block = e.second;
276*da0073e9SAndroid Build Coastguard Worker         event.wait();
277*da0073e9SAndroid Build Coastguard Worker         block->event_count--;
278*da0073e9SAndroid Build Coastguard Worker         if (block->event_count == 0) {
279*da0073e9SAndroid Build Coastguard Worker           free_block(block);
280*da0073e9SAndroid Build Coastguard Worker         }
281*da0073e9SAndroid Build Coastguard Worker       }
282*da0073e9SAndroid Build Coastguard Worker     }
283*da0073e9SAndroid Build Coastguard Worker     xpu_events.clear();
284*da0073e9SAndroid Build Coastguard Worker   }
285*da0073e9SAndroid Build Coastguard Worker 
release_block(Block * block)286*da0073e9SAndroid Build Coastguard Worker   void release_block(Block* block) {
287*da0073e9SAndroid Build Coastguard Worker     /*
288*da0073e9SAndroid Build Coastguard Worker      * Note [Safe to Free Blocks on BlockPool]
289*da0073e9SAndroid Build Coastguard Worker      *
290*da0073e9SAndroid Build Coastguard Worker      * Callers must ensure that all accesses to the block, whose raw pointer is
291*da0073e9SAndroid Build Coastguard Worker      * allocated by SYCL APIs, have been completed before invoking sycl::free.
292*da0073e9SAndroid Build Coastguard Worker      *
293*da0073e9SAndroid Build Coastguard Worker      * We have to do a device-level synchronization before free these blocks to
294*da0073e9SAndroid Build Coastguard Worker      * guarantee that all kernels can access to the blocks have finished.
295*da0073e9SAndroid Build Coastguard Worker      */
296*da0073e9SAndroid Build Coastguard Worker     sycl::free(block->ptr, xpu::get_device_context());
297*da0073e9SAndroid Build Coastguard Worker     auto* pool = block->pool;
298*da0073e9SAndroid Build Coastguard Worker     pool->blocks.erase(block);
299*da0073e9SAndroid Build Coastguard Worker 
300*da0073e9SAndroid Build Coastguard Worker     StatTypes stat_types = get_stat_types_for_pool(*pool);
301*da0073e9SAndroid Build Coastguard Worker     for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
302*da0073e9SAndroid Build Coastguard Worker       stats.reserved_bytes[stat_type].decrease(block->size);
303*da0073e9SAndroid Build Coastguard Worker     });
304*da0073e9SAndroid Build Coastguard Worker 
305*da0073e9SAndroid Build Coastguard Worker     delete block;
306*da0073e9SAndroid Build Coastguard Worker   }
307*da0073e9SAndroid Build Coastguard Worker 
release_blocks(BlockPool & pool)308*da0073e9SAndroid Build Coastguard Worker   void release_blocks(BlockPool& pool) {
309*da0073e9SAndroid Build Coastguard Worker     auto it = pool.blocks.begin();
310*da0073e9SAndroid Build Coastguard Worker     while (it != pool.blocks.end()) {
311*da0073e9SAndroid Build Coastguard Worker       Block* block = *it;
312*da0073e9SAndroid Build Coastguard Worker       ++it;
313*da0073e9SAndroid Build Coastguard Worker       if (!block->prev && !block->next) {
314*da0073e9SAndroid Build Coastguard Worker         release_block(block);
315*da0073e9SAndroid Build Coastguard Worker       }
316*da0073e9SAndroid Build Coastguard Worker     }
317*da0073e9SAndroid Build Coastguard Worker   }
318*da0073e9SAndroid Build Coastguard Worker 
release_cached_blocks()319*da0073e9SAndroid Build Coastguard Worker   bool release_cached_blocks() {
320*da0073e9SAndroid Build Coastguard Worker     synchronize_and_free_events();
321*da0073e9SAndroid Build Coastguard Worker     // See Note [Safe to Free Blocks on BlockPool]
322*da0073e9SAndroid Build Coastguard Worker     c10::xpu::syncStreamsOnDevice(device_index);
323*da0073e9SAndroid Build Coastguard Worker 
324*da0073e9SAndroid Build Coastguard Worker     release_blocks(large_blocks);
325*da0073e9SAndroid Build Coastguard Worker     release_blocks(small_blocks);
326*da0073e9SAndroid Build Coastguard Worker     return true;
327*da0073e9SAndroid Build Coastguard Worker   }
328*da0073e9SAndroid Build Coastguard Worker 
should_split(const Block * block,size_t size)329*da0073e9SAndroid Build Coastguard Worker   bool should_split(const Block* block, size_t size) {
330*da0073e9SAndroid Build Coastguard Worker     size_t remaining = block->size - size;
331*da0073e9SAndroid Build Coastguard Worker     if (block->pool->is_small) {
332*da0073e9SAndroid Build Coastguard Worker       return remaining >= kMinBlockSize;
333*da0073e9SAndroid Build Coastguard Worker     } else {
334*da0073e9SAndroid Build Coastguard Worker       return remaining > kSmallSize;
335*da0073e9SAndroid Build Coastguard Worker     }
336*da0073e9SAndroid Build Coastguard Worker   }
337*da0073e9SAndroid Build Coastguard Worker 
get_stat_types_for_pool(const BlockPool & pool)338*da0073e9SAndroid Build Coastguard Worker   StatTypes get_stat_types_for_pool(const BlockPool& pool) {
339*da0073e9SAndroid Build Coastguard Worker     StatTypes stat_types = {};
340*da0073e9SAndroid Build Coastguard Worker     stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
341*da0073e9SAndroid Build Coastguard Worker     stat_types[static_cast<size_t>(
342*da0073e9SAndroid Build Coastguard Worker         pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true;
343*da0073e9SAndroid Build Coastguard Worker     return stat_types;
344*da0073e9SAndroid Build Coastguard Worker   }
345*da0073e9SAndroid Build Coastguard Worker 
alloc_found_block(AllocParams params,size_t orig_size,bool split_remainder)346*da0073e9SAndroid Build Coastguard Worker   Block* alloc_found_block(
347*da0073e9SAndroid Build Coastguard Worker       AllocParams params,
348*da0073e9SAndroid Build Coastguard Worker       size_t orig_size,
349*da0073e9SAndroid Build Coastguard Worker       bool split_remainder) {
350*da0073e9SAndroid Build Coastguard Worker     auto size = params.size();
351*da0073e9SAndroid Build Coastguard Worker     auto device = params.device();
352*da0073e9SAndroid Build Coastguard Worker     BlockPool* pool = params.pool;
353*da0073e9SAndroid Build Coastguard Worker     sycl::queue* queue = params.queue();
354*da0073e9SAndroid Build Coastguard Worker 
355*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(
356*da0073e9SAndroid Build Coastguard Worker         params.block != nullptr && params.block->ptr != nullptr);
357*da0073e9SAndroid Build Coastguard Worker     Block* block = params.block;
358*da0073e9SAndroid Build Coastguard Worker     Block* remaining = nullptr;
359*da0073e9SAndroid Build Coastguard Worker 
360*da0073e9SAndroid Build Coastguard Worker     if (split_remainder) {
361*da0073e9SAndroid Build Coastguard Worker       remaining = block;
362*da0073e9SAndroid Build Coastguard Worker 
363*da0073e9SAndroid Build Coastguard Worker       block = new Block(device, queue, size, pool, block->ptr);
364*da0073e9SAndroid Build Coastguard Worker       block->prev = remaining->prev;
365*da0073e9SAndroid Build Coastguard Worker       if (block->prev) {
366*da0073e9SAndroid Build Coastguard Worker         block->prev->next = block;
367*da0073e9SAndroid Build Coastguard Worker       }
368*da0073e9SAndroid Build Coastguard Worker       block->next = remaining;
369*da0073e9SAndroid Build Coastguard Worker 
370*da0073e9SAndroid Build Coastguard Worker       remaining->prev = block;
371*da0073e9SAndroid Build Coastguard Worker       remaining->ptr = static_cast<char*>(remaining->ptr) + size;
372*da0073e9SAndroid Build Coastguard Worker       remaining->size -= size;
373*da0073e9SAndroid Build Coastguard Worker       bool inserted = pool->blocks.insert(remaining).second;
374*da0073e9SAndroid Build Coastguard Worker       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
375*da0073e9SAndroid Build Coastguard Worker     }
376*da0073e9SAndroid Build Coastguard Worker 
377*da0073e9SAndroid Build Coastguard Worker     block->allocated = true;
378*da0073e9SAndroid Build Coastguard Worker     block->requested_size = orig_size;
379*da0073e9SAndroid Build Coastguard Worker     bool inserted = active_blocks.insert(block).second;
380*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted)
381*da0073e9SAndroid Build Coastguard Worker 
382*da0073e9SAndroid Build Coastguard Worker     for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
383*da0073e9SAndroid Build Coastguard Worker       stats.allocated_bytes[stat_type].increase(block->size);
384*da0073e9SAndroid Build Coastguard Worker       stats.active_bytes[stat_type].increase(block->size);
385*da0073e9SAndroid Build Coastguard Worker       stats.requested_bytes[stat_type].increase(block->requested_size);
386*da0073e9SAndroid Build Coastguard Worker     });
387*da0073e9SAndroid Build Coastguard Worker 
388*da0073e9SAndroid Build Coastguard Worker     return block;
389*da0073e9SAndroid Build Coastguard Worker   }
390*da0073e9SAndroid Build Coastguard Worker 
insert_events(Block * block)391*da0073e9SAndroid Build Coastguard Worker   void insert_events(Block* block) {
392*da0073e9SAndroid Build Coastguard Worker     stream_set streams(std::move(block->stream_uses));
393*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(block->stream_uses.empty());
394*da0073e9SAndroid Build Coastguard Worker     for (auto& stream : streams) {
395*da0073e9SAndroid Build Coastguard Worker       block->event_count++;
396*da0073e9SAndroid Build Coastguard Worker       xpu_events[stream].emplace_back(
397*da0073e9SAndroid Build Coastguard Worker           stream.queue().ext_oneapi_submit_barrier(), block);
398*da0073e9SAndroid Build Coastguard Worker     }
399*da0073e9SAndroid Build Coastguard Worker   }
400*da0073e9SAndroid Build Coastguard Worker 
401*da0073e9SAndroid Build Coastguard Worker  public:
DeviceCachingAllocator(DeviceIndex device_index)402*da0073e9SAndroid Build Coastguard Worker   DeviceCachingAllocator(DeviceIndex device_index)
403*da0073e9SAndroid Build Coastguard Worker       : large_blocks(/* small */ false),
404*da0073e9SAndroid Build Coastguard Worker         small_blocks(/* small */ true),
405*da0073e9SAndroid Build Coastguard Worker         device_index(device_index) {}
406*da0073e9SAndroid Build Coastguard Worker 
malloc(DeviceIndex device,size_t orig_size,sycl::queue & queue)407*da0073e9SAndroid Build Coastguard Worker   Block* malloc(DeviceIndex device, size_t orig_size, sycl::queue& queue) {
408*da0073e9SAndroid Build Coastguard Worker     std::scoped_lock<std::recursive_mutex> lock(mutex);
409*da0073e9SAndroid Build Coastguard Worker     process_events();
410*da0073e9SAndroid Build Coastguard Worker     size_t size = round_size(orig_size);
411*da0073e9SAndroid Build Coastguard Worker     auto& pool = get_pool(size);
412*da0073e9SAndroid Build Coastguard Worker     const size_t alloc_size = get_allocation_size(size);
413*da0073e9SAndroid Build Coastguard Worker     AllocParams params(device, size, &queue, &pool, alloc_size);
414*da0073e9SAndroid Build Coastguard Worker     params.stat_types = get_stat_types_for_pool(pool);
415*da0073e9SAndroid Build Coastguard Worker 
416*da0073e9SAndroid Build Coastguard Worker     // First, try to get a block from the existing pool.
417*da0073e9SAndroid Build Coastguard Worker     bool block_found = get_free_block(params);
418*da0073e9SAndroid Build Coastguard Worker     // Can't reuse an existing block, try to get a new one.
419*da0073e9SAndroid Build Coastguard Worker     if (!block_found) {
420*da0073e9SAndroid Build Coastguard Worker       block_found = alloc_block(params) ||
421*da0073e9SAndroid Build Coastguard Worker           (release_cached_blocks() && alloc_block(params));
422*da0073e9SAndroid Build Coastguard Worker     }
423*da0073e9SAndroid Build Coastguard Worker     if (!block_found) {
424*da0073e9SAndroid Build Coastguard Worker       c10::xpu::DeviceProp device_prop;
425*da0073e9SAndroid Build Coastguard Worker       c10::xpu::get_device_properties(&device_prop, device);
426*da0073e9SAndroid Build Coastguard Worker       auto device_total = device_prop.global_mem_size;
427*da0073e9SAndroid Build Coastguard Worker       auto allocated_bytes =
428*da0073e9SAndroid Build Coastguard Worker           stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
429*da0073e9SAndroid Build Coastguard Worker               .current;
430*da0073e9SAndroid Build Coastguard Worker       auto reserved_bytes =
431*da0073e9SAndroid Build Coastguard Worker           stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
432*da0073e9SAndroid Build Coastguard Worker               .current;
433*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK_WITH(
434*da0073e9SAndroid Build Coastguard Worker           OutOfMemoryError,
435*da0073e9SAndroid Build Coastguard Worker           false,
436*da0073e9SAndroid Build Coastguard Worker           "XPU out of memory. Tried to allocate ",
437*da0073e9SAndroid Build Coastguard Worker           format_size(alloc_size),
438*da0073e9SAndroid Build Coastguard Worker           ". GPU ",
439*da0073e9SAndroid Build Coastguard Worker           static_cast<int>(device),
440*da0073e9SAndroid Build Coastguard Worker           " has a total capacity of ",
441*da0073e9SAndroid Build Coastguard Worker           format_size(device_total),
442*da0073e9SAndroid Build Coastguard Worker           ". Of the allocated memory ",
443*da0073e9SAndroid Build Coastguard Worker           format_size(allocated_bytes),
444*da0073e9SAndroid Build Coastguard Worker           " is allocated by PyTorch, and ",
445*da0073e9SAndroid Build Coastguard Worker           format_size(reserved_bytes - allocated_bytes),
446*da0073e9SAndroid Build Coastguard Worker           " is reserved by PyTorch but unallocated.",
447*da0073e9SAndroid Build Coastguard Worker           " Please use `empty_cache` to release all unoccupied cached memory.");
448*da0073e9SAndroid Build Coastguard Worker     }
449*da0073e9SAndroid Build Coastguard Worker     bool split_remainder = should_split(params.block, params.size());
450*da0073e9SAndroid Build Coastguard Worker     return alloc_found_block(std::move(params), orig_size, split_remainder);
451*da0073e9SAndroid Build Coastguard Worker   }
452*da0073e9SAndroid Build Coastguard Worker 
free(Block * block)453*da0073e9SAndroid Build Coastguard Worker   void free(Block* block) {
454*da0073e9SAndroid Build Coastguard Worker     std::scoped_lock<std::recursive_mutex> lock(mutex);
455*da0073e9SAndroid Build Coastguard Worker     block->allocated = false;
456*da0073e9SAndroid Build Coastguard Worker 
457*da0073e9SAndroid Build Coastguard Worker     StatTypes stat_types = get_stat_types_for_pool(*block->pool);
458*da0073e9SAndroid Build Coastguard Worker     for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
459*da0073e9SAndroid Build Coastguard Worker       stats.allocated_bytes[stat_type].decrease(block->size);
460*da0073e9SAndroid Build Coastguard Worker     });
461*da0073e9SAndroid Build Coastguard Worker 
462*da0073e9SAndroid Build Coastguard Worker     if (!block->stream_uses.empty()) {
463*da0073e9SAndroid Build Coastguard Worker       insert_events(block);
464*da0073e9SAndroid Build Coastguard Worker     } else {
465*da0073e9SAndroid Build Coastguard Worker       free_block(block);
466*da0073e9SAndroid Build Coastguard Worker     }
467*da0073e9SAndroid Build Coastguard Worker   }
468*da0073e9SAndroid Build Coastguard Worker 
recordStream(Block * block,xpu::XPUStream stream)469*da0073e9SAndroid Build Coastguard Worker   void recordStream(Block* block, xpu::XPUStream stream) {
470*da0073e9SAndroid Build Coastguard Worker     std::scoped_lock<std::recursive_mutex> lock(mutex);
471*da0073e9SAndroid Build Coastguard Worker     if (stream.queue() == *block->queue) {
472*da0073e9SAndroid Build Coastguard Worker       return;
473*da0073e9SAndroid Build Coastguard Worker     }
474*da0073e9SAndroid Build Coastguard Worker     block->stream_uses.insert(stream);
475*da0073e9SAndroid Build Coastguard Worker   }
476*da0073e9SAndroid Build Coastguard Worker 
emptyCache()477*da0073e9SAndroid Build Coastguard Worker   void emptyCache() {
478*da0073e9SAndroid Build Coastguard Worker     std::scoped_lock<std::recursive_mutex> lock(mutex);
479*da0073e9SAndroid Build Coastguard Worker     release_cached_blocks();
480*da0073e9SAndroid Build Coastguard Worker   }
481*da0073e9SAndroid Build Coastguard Worker 
getStats()482*da0073e9SAndroid Build Coastguard Worker   DeviceStats getStats() {
483*da0073e9SAndroid Build Coastguard Worker     std::scoped_lock<std::recursive_mutex> lock(mutex);
484*da0073e9SAndroid Build Coastguard Worker     return stats;
485*da0073e9SAndroid Build Coastguard Worker   }
486*da0073e9SAndroid Build Coastguard Worker 
resetAccumulatedStats()487*da0073e9SAndroid Build Coastguard Worker   void resetAccumulatedStats() {
488*da0073e9SAndroid Build Coastguard Worker     std::scoped_lock<std::recursive_mutex> lock(mutex);
489*da0073e9SAndroid Build Coastguard Worker 
490*da0073e9SAndroid Build Coastguard Worker     for (const auto statType :
491*da0073e9SAndroid Build Coastguard Worker          c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
492*da0073e9SAndroid Build Coastguard Worker       stats.allocated_bytes[statType].reset_accumulated();
493*da0073e9SAndroid Build Coastguard Worker       stats.reserved_bytes[statType].reset_accumulated();
494*da0073e9SAndroid Build Coastguard Worker       stats.active_bytes[statType].reset_accumulated();
495*da0073e9SAndroid Build Coastguard Worker       stats.requested_bytes[statType].reset_accumulated();
496*da0073e9SAndroid Build Coastguard Worker     }
497*da0073e9SAndroid Build Coastguard Worker   }
498*da0073e9SAndroid Build Coastguard Worker 
resetPeakStats()499*da0073e9SAndroid Build Coastguard Worker   void resetPeakStats() {
500*da0073e9SAndroid Build Coastguard Worker     std::scoped_lock<std::recursive_mutex> lock(mutex);
501*da0073e9SAndroid Build Coastguard Worker 
502*da0073e9SAndroid Build Coastguard Worker     for (const auto statType :
503*da0073e9SAndroid Build Coastguard Worker          c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
504*da0073e9SAndroid Build Coastguard Worker       stats.allocated_bytes[statType].reset_peak();
505*da0073e9SAndroid Build Coastguard Worker       stats.reserved_bytes[statType].reset_peak();
506*da0073e9SAndroid Build Coastguard Worker       stats.active_bytes[statType].reset_peak();
507*da0073e9SAndroid Build Coastguard Worker       stats.requested_bytes[statType].reset_peak();
508*da0073e9SAndroid Build Coastguard Worker     }
509*da0073e9SAndroid Build Coastguard Worker   }
510*da0073e9SAndroid Build Coastguard Worker };
511*da0073e9SAndroid Build Coastguard Worker 
512*da0073e9SAndroid Build Coastguard Worker void local_raw_delete(void* ptr);
513*da0073e9SAndroid Build Coastguard Worker 
514*da0073e9SAndroid Build Coastguard Worker class XPUAllocator : public Allocator {
515*da0073e9SAndroid Build Coastguard Worker  private:
516*da0073e9SAndroid Build Coastguard Worker   std::mutex mutex;
517*da0073e9SAndroid Build Coastguard Worker   ska::flat_hash_map<void*, Block*> allocated_blocks;
518*da0073e9SAndroid Build Coastguard Worker 
add_allocated_block(Block * block)519*da0073e9SAndroid Build Coastguard Worker   void add_allocated_block(Block* block) {
520*da0073e9SAndroid Build Coastguard Worker     std::lock_guard<std::mutex> lock(mutex);
521*da0073e9SAndroid Build Coastguard Worker     allocated_blocks[block->ptr] = block;
522*da0073e9SAndroid Build Coastguard Worker   }
523*da0073e9SAndroid Build Coastguard Worker 
get_allocated_block(void * ptr,bool remove=false)524*da0073e9SAndroid Build Coastguard Worker   Block* get_allocated_block(void* ptr, bool remove = false) {
525*da0073e9SAndroid Build Coastguard Worker     std::scoped_lock<std::mutex> lock(mutex);
526*da0073e9SAndroid Build Coastguard Worker     auto it = allocated_blocks.find(ptr);
527*da0073e9SAndroid Build Coastguard Worker     if (it == allocated_blocks.end()) {
528*da0073e9SAndroid Build Coastguard Worker       return nullptr;
529*da0073e9SAndroid Build Coastguard Worker     }
530*da0073e9SAndroid Build Coastguard Worker     Block* block = it->second;
531*da0073e9SAndroid Build Coastguard Worker     if (remove) {
532*da0073e9SAndroid Build Coastguard Worker       allocated_blocks.erase(it);
533*da0073e9SAndroid Build Coastguard Worker     }
534*da0073e9SAndroid Build Coastguard Worker     return block;
535*da0073e9SAndroid Build Coastguard Worker   }
536*da0073e9SAndroid Build Coastguard Worker 
537*da0073e9SAndroid Build Coastguard Worker  public:
538*da0073e9SAndroid Build Coastguard Worker   std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocators;
539*da0073e9SAndroid Build Coastguard Worker 
init(DeviceIndex device_count)540*da0073e9SAndroid Build Coastguard Worker   void init(DeviceIndex device_count) {
541*da0073e9SAndroid Build Coastguard Worker     const auto size = static_cast<DeviceIndex>(device_allocators.size());
542*da0073e9SAndroid Build Coastguard Worker     if (size < device_count) {
543*da0073e9SAndroid Build Coastguard Worker       device_allocators.resize(device_count);
544*da0073e9SAndroid Build Coastguard Worker       for (const auto i : c10::irange(size, device_count)) {
545*da0073e9SAndroid Build Coastguard Worker         device_allocators[i] = std::make_unique<DeviceCachingAllocator>(i);
546*da0073e9SAndroid Build Coastguard Worker       }
547*da0073e9SAndroid Build Coastguard Worker     }
548*da0073e9SAndroid Build Coastguard Worker   }
549*da0073e9SAndroid Build Coastguard Worker 
malloc(void ** devPtr,DeviceIndex device,size_t size,sycl::queue & queue)550*da0073e9SAndroid Build Coastguard Worker   void malloc(
551*da0073e9SAndroid Build Coastguard Worker       void** devPtr,
552*da0073e9SAndroid Build Coastguard Worker       DeviceIndex device,
553*da0073e9SAndroid Build Coastguard Worker       size_t size,
554*da0073e9SAndroid Build Coastguard Worker       sycl::queue& queue) {
555*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(
556*da0073e9SAndroid Build Coastguard Worker         0 <= device && static_cast<size_t>(device) < device_allocators.size(),
557*da0073e9SAndroid Build Coastguard Worker         "Allocator not initialized for device ",
558*da0073e9SAndroid Build Coastguard Worker         static_cast<int16_t>(device),
559*da0073e9SAndroid Build Coastguard Worker         ": did you call init?");
560*da0073e9SAndroid Build Coastguard Worker     Block* block = device_allocators[device]->malloc(device, size, queue);
561*da0073e9SAndroid Build Coastguard Worker     add_allocated_block(block);
562*da0073e9SAndroid Build Coastguard Worker     *devPtr = block->ptr;
563*da0073e9SAndroid Build Coastguard Worker     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
564*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(interp)) {
565*da0073e9SAndroid Build Coastguard Worker       (*interp)->trace_gpu_memory_allocation(
566*da0073e9SAndroid Build Coastguard Worker           c10::kXPU, reinterpret_cast<uintptr_t>(*devPtr));
567*da0073e9SAndroid Build Coastguard Worker     }
568*da0073e9SAndroid Build Coastguard Worker   }
569*da0073e9SAndroid Build Coastguard Worker 
free(void * ptr)570*da0073e9SAndroid Build Coastguard Worker   void free(void* ptr) {
571*da0073e9SAndroid Build Coastguard Worker     if (!ptr) {
572*da0073e9SAndroid Build Coastguard Worker       return;
573*da0073e9SAndroid Build Coastguard Worker     }
574*da0073e9SAndroid Build Coastguard Worker     Block* block = get_allocated_block(ptr, /* remove */ true);
575*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(block, "invalid device pointer: ", ptr);
576*da0073e9SAndroid Build Coastguard Worker     device_allocators[block->device]->free(block);
577*da0073e9SAndroid Build Coastguard Worker     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
578*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(interp)) {
579*da0073e9SAndroid Build Coastguard Worker       (*interp)->trace_gpu_memory_deallocation(
580*da0073e9SAndroid Build Coastguard Worker           c10::kXPU, reinterpret_cast<uintptr_t>(block->ptr));
581*da0073e9SAndroid Build Coastguard Worker     }
582*da0073e9SAndroid Build Coastguard Worker   }
583*da0073e9SAndroid Build Coastguard Worker 
emptyCache()584*da0073e9SAndroid Build Coastguard Worker   void emptyCache() {
585*da0073e9SAndroid Build Coastguard Worker     for (auto& da : device_allocators) {
586*da0073e9SAndroid Build Coastguard Worker       da->emptyCache();
587*da0073e9SAndroid Build Coastguard Worker     }
588*da0073e9SAndroid Build Coastguard Worker   }
589*da0073e9SAndroid Build Coastguard Worker 
recordStream(const DataPtr & ptr,XPUStream stream)590*da0073e9SAndroid Build Coastguard Worker   void recordStream(const DataPtr& ptr, XPUStream stream) {
591*da0073e9SAndroid Build Coastguard Worker     if (!ptr.get()) {
592*da0073e9SAndroid Build Coastguard Worker       return;
593*da0073e9SAndroid Build Coastguard Worker     }
594*da0073e9SAndroid Build Coastguard Worker     if (ptr.get_deleter() != &local_raw_delete) {
595*da0073e9SAndroid Build Coastguard Worker       return;
596*da0073e9SAndroid Build Coastguard Worker     }
597*da0073e9SAndroid Build Coastguard Worker 
598*da0073e9SAndroid Build Coastguard Worker     Block* block = get_allocated_block(ptr.get());
599*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(block, "No allocated block can be found.");
600*da0073e9SAndroid Build Coastguard Worker     device_allocators[block->device]->recordStream(block, stream);
601*da0073e9SAndroid Build Coastguard Worker   }
602*da0073e9SAndroid Build Coastguard Worker 
allocate(size_t size)603*da0073e9SAndroid Build Coastguard Worker   DataPtr allocate(size_t size) override {
604*da0073e9SAndroid Build Coastguard Worker     auto device = c10::xpu::current_device();
605*da0073e9SAndroid Build Coastguard Worker     void* r = nullptr;
606*da0073e9SAndroid Build Coastguard Worker     if (size != 0) {
607*da0073e9SAndroid Build Coastguard Worker       this->malloc(&r, device, size, xpu::getCurrentXPUStream(device));
608*da0073e9SAndroid Build Coastguard Worker     }
609*da0073e9SAndroid Build Coastguard Worker     return {r, r, &local_raw_delete, Device(DeviceType::XPU, device)};
610*da0073e9SAndroid Build Coastguard Worker   }
611*da0073e9SAndroid Build Coastguard Worker 
raw_deleter() const612*da0073e9SAndroid Build Coastguard Worker   DeleterFnPtr raw_deleter() const override {
613*da0073e9SAndroid Build Coastguard Worker     return &local_raw_delete;
614*da0073e9SAndroid Build Coastguard Worker   }
615*da0073e9SAndroid Build Coastguard Worker 
raw_alloc(size_t size)616*da0073e9SAndroid Build Coastguard Worker   void* raw_alloc(size_t size) {
617*da0073e9SAndroid Build Coastguard Worker     if (size == 0) {
618*da0073e9SAndroid Build Coastguard Worker       return nullptr;
619*da0073e9SAndroid Build Coastguard Worker     }
620*da0073e9SAndroid Build Coastguard Worker     auto device = c10::xpu::current_device();
621*da0073e9SAndroid Build Coastguard Worker     void* r = nullptr;
622*da0073e9SAndroid Build Coastguard Worker     malloc(&r, device, size, xpu::getCurrentXPUStream(device));
623*da0073e9SAndroid Build Coastguard Worker     return r;
624*da0073e9SAndroid Build Coastguard Worker   }
625*da0073e9SAndroid Build Coastguard Worker 
raw_alloc_with_stream(size_t size,XPUStream stream)626*da0073e9SAndroid Build Coastguard Worker   void* raw_alloc_with_stream(size_t size, XPUStream stream) {
627*da0073e9SAndroid Build Coastguard Worker     if (size == 0) {
628*da0073e9SAndroid Build Coastguard Worker       return nullptr;
629*da0073e9SAndroid Build Coastguard Worker     }
630*da0073e9SAndroid Build Coastguard Worker     auto device = c10::xpu::current_device();
631*da0073e9SAndroid Build Coastguard Worker     void* r = nullptr;
632*da0073e9SAndroid Build Coastguard Worker     malloc(&r, device, size, stream);
633*da0073e9SAndroid Build Coastguard Worker     return r;
634*da0073e9SAndroid Build Coastguard Worker   }
635*da0073e9SAndroid Build Coastguard Worker 
raw_delete(void * ptr)636*da0073e9SAndroid Build Coastguard Worker   void raw_delete(void* ptr) {
637*da0073e9SAndroid Build Coastguard Worker     this->free(ptr);
638*da0073e9SAndroid Build Coastguard Worker   }
639*da0073e9SAndroid Build Coastguard Worker 
copy_data(void * dest,const void * src,std::size_t count) const640*da0073e9SAndroid Build Coastguard Worker   void copy_data(void* dest, const void* src, std::size_t count) const final {
641*da0073e9SAndroid Build Coastguard Worker     xpu::getCurrentXPUStream().queue().memcpy(dest, src, count);
642*da0073e9SAndroid Build Coastguard Worker   }
643*da0073e9SAndroid Build Coastguard Worker 
assertValidDevice(DeviceIndex device)644*da0073e9SAndroid Build Coastguard Worker   void assertValidDevice(DeviceIndex device) {
645*da0073e9SAndroid Build Coastguard Worker     const auto device_num = device_allocators.size();
646*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(
647*da0073e9SAndroid Build Coastguard Worker         0 <= device && device < static_cast<int64_t>(device_num),
648*da0073e9SAndroid Build Coastguard Worker         "Invalid device argument ",
649*da0073e9SAndroid Build Coastguard Worker         device,
650*da0073e9SAndroid Build Coastguard Worker         ": did you call init?");
651*da0073e9SAndroid Build Coastguard Worker   }
652*da0073e9SAndroid Build Coastguard Worker 
getDeviceStats(DeviceIndex device)653*da0073e9SAndroid Build Coastguard Worker   DeviceStats getDeviceStats(DeviceIndex device) {
654*da0073e9SAndroid Build Coastguard Worker     assertValidDevice(device);
655*da0073e9SAndroid Build Coastguard Worker     return device_allocators[device]->getStats();
656*da0073e9SAndroid Build Coastguard Worker   }
657*da0073e9SAndroid Build Coastguard Worker 
resetPeakStats(DeviceIndex device)658*da0073e9SAndroid Build Coastguard Worker   void resetPeakStats(DeviceIndex device) {
659*da0073e9SAndroid Build Coastguard Worker     assertValidDevice(device);
660*da0073e9SAndroid Build Coastguard Worker     device_allocators[device]->resetPeakStats();
661*da0073e9SAndroid Build Coastguard Worker   }
662*da0073e9SAndroid Build Coastguard Worker 
resetAccumulatedStats(DeviceIndex device)663*da0073e9SAndroid Build Coastguard Worker   void resetAccumulatedStats(DeviceIndex device) {
664*da0073e9SAndroid Build Coastguard Worker     assertValidDevice(device);
665*da0073e9SAndroid Build Coastguard Worker     device_allocators[device]->resetAccumulatedStats();
666*da0073e9SAndroid Build Coastguard Worker   }
667*da0073e9SAndroid Build Coastguard Worker };
668*da0073e9SAndroid Build Coastguard Worker 
669*da0073e9SAndroid Build Coastguard Worker static XPUAllocator allocator;
670*da0073e9SAndroid Build Coastguard Worker 
local_raw_delete(void * ptr)671*da0073e9SAndroid Build Coastguard Worker void local_raw_delete(void* ptr) {
672*da0073e9SAndroid Build Coastguard Worker   allocator.free(ptr);
673*da0073e9SAndroid Build Coastguard Worker }
674*da0073e9SAndroid Build Coastguard Worker 
get()675*da0073e9SAndroid Build Coastguard Worker Allocator* get() {
676*da0073e9SAndroid Build Coastguard Worker   return &allocator;
677*da0073e9SAndroid Build Coastguard Worker }
678*da0073e9SAndroid Build Coastguard Worker 
init(DeviceIndex device_count)679*da0073e9SAndroid Build Coastguard Worker void init(DeviceIndex device_count) {
680*da0073e9SAndroid Build Coastguard Worker   return allocator.init(device_count);
681*da0073e9SAndroid Build Coastguard Worker }
682*da0073e9SAndroid Build Coastguard Worker 
emptyCache()683*da0073e9SAndroid Build Coastguard Worker void emptyCache() {
684*da0073e9SAndroid Build Coastguard Worker   return allocator.emptyCache();
685*da0073e9SAndroid Build Coastguard Worker }
686*da0073e9SAndroid Build Coastguard Worker 
resetPeakStats(DeviceIndex device)687*da0073e9SAndroid Build Coastguard Worker void resetPeakStats(DeviceIndex device) {
688*da0073e9SAndroid Build Coastguard Worker   return allocator.resetPeakStats(device);
689*da0073e9SAndroid Build Coastguard Worker }
690*da0073e9SAndroid Build Coastguard Worker 
resetAccumulatedStats(DeviceIndex device)691*da0073e9SAndroid Build Coastguard Worker void resetAccumulatedStats(DeviceIndex device) {
692*da0073e9SAndroid Build Coastguard Worker   return allocator.resetAccumulatedStats(device);
693*da0073e9SAndroid Build Coastguard Worker }
694*da0073e9SAndroid Build Coastguard Worker 
getDeviceStats(DeviceIndex device)695*da0073e9SAndroid Build Coastguard Worker DeviceStats getDeviceStats(DeviceIndex device) {
696*da0073e9SAndroid Build Coastguard Worker   return allocator.getDeviceStats(device);
697*da0073e9SAndroid Build Coastguard Worker }
698*da0073e9SAndroid Build Coastguard Worker 
raw_alloc(size_t size)699*da0073e9SAndroid Build Coastguard Worker void* raw_alloc(size_t size) {
700*da0073e9SAndroid Build Coastguard Worker   return allocator.raw_alloc(size);
701*da0073e9SAndroid Build Coastguard Worker }
702*da0073e9SAndroid Build Coastguard Worker 
raw_delete(void * ptr)703*da0073e9SAndroid Build Coastguard Worker void raw_delete(void* ptr) {
704*da0073e9SAndroid Build Coastguard Worker   return allocator.raw_delete(ptr);
705*da0073e9SAndroid Build Coastguard Worker }
706*da0073e9SAndroid Build Coastguard Worker 
recordStream(const DataPtr & dataPtr,XPUStream stream)707*da0073e9SAndroid Build Coastguard Worker void recordStream(const DataPtr& dataPtr, XPUStream stream) {
708*da0073e9SAndroid Build Coastguard Worker   return allocator.recordStream(dataPtr, stream);
709*da0073e9SAndroid Build Coastguard Worker }
710*da0073e9SAndroid Build Coastguard Worker 
711*da0073e9SAndroid Build Coastguard Worker REGISTER_ALLOCATOR(kXPU, &allocator)
712*da0073e9SAndroid Build Coastguard Worker 
713*da0073e9SAndroid Build Coastguard Worker } // namespace c10::xpu::XPUCachingAllocator
714