#include #include #include #include #include #include #include namespace c10::xpu::XPUCachingAllocator { using namespace c10::CachingDeviceAllocator; // newly allocated memory with 512-byte alignment. constexpr size_t kDeviceAlignment = 512; // all sizes are rounded to at least 512 bytes constexpr size_t kMinBlockSize = 512; // largest "small" allocation is 1 MiB constexpr size_t kSmallSize = 1048576; // "small" allocations are packed in 2 MiB blocks constexpr size_t kSmallBuffer = 2097152; // "large" allocations may be packed in 20 MiB blocks constexpr size_t kLargeBuffer = 20971520; // allocations between 1 and 10 MiB may use kLargeBuffer constexpr size_t kMinLargeAlloc = 10485760; // round up large allocations to 2 MiB constexpr size_t kRoundLarge = 2097152; namespace { using stream_set = ska::flat_hash_set; struct Block; typedef bool (*Comparison)(const Block*, const Block*); bool BlockComparatorSize(const Block* a, const Block* b); struct BlockPool { BlockPool(bool small) : blocks(BlockComparatorSize), is_small(small) {} std::set blocks; const bool is_small; }; struct Block { DeviceIndex device; sycl::queue* queue{nullptr}; // underlying queue of the allocation stream stream_set stream_uses; // streams on which the block was used size_t size; // block size in bytes size_t requested_size; // memory originally requested BlockPool* pool{nullptr}; // owning memory pool void* ptr{nullptr}; // memory address bool allocated{false}; // in-use flag Block* prev{nullptr}; // prev block if split from a larger allocation Block* next{nullptr}; // next block if split from a larger allocation int event_count{0}; // number of outstanding XPU events Block( DeviceIndex device, sycl::queue* queue, size_t size, BlockPool* pool, void* ptr) : device(device), queue(queue), stream_uses(), size(size), requested_size(0), pool(pool), ptr(ptr) {} // constructor for search key Block(DeviceIndex device, sycl::queue* queue, size_t size) : device(device), queue(queue), stream_uses(), size(size), requested_size(0) {} bool is_split() const { return (prev != nullptr) || (next != nullptr); } }; bool BlockComparatorSize(const Block* a, const Block* b) { if (a->queue != b->queue) { return reinterpret_cast(a->queue) < reinterpret_cast(b->queue); } if (a->size != b->size) { return a->size < b->size; } return reinterpret_cast(a->ptr) < reinterpret_cast(b->ptr); } struct AllocParams { AllocParams( DeviceIndex device, size_t size, sycl::queue* queue, BlockPool* pool, size_t alloc_size) : search_key(device, queue, size), pool(pool), alloc_size(alloc_size), block(nullptr) {} DeviceIndex device() const { return search_key.device; } sycl::queue* queue() const { return search_key.queue; } size_t size() const { return search_key.size; } Block search_key; BlockPool* pool; size_t alloc_size; Block* block; StatTypes stat_types = {}; }; } // anonymous namespace class DeviceCachingAllocator { private: mutable std::recursive_mutex mutex; DeviceStats stats; BlockPool large_blocks; // unallocated cached blocks larger than 1 MB BlockPool small_blocks; // unallocated cached blocks 1 MB or smaller ska::flat_hash_set active_blocks; // allocated or in use by a stream ska::flat_hash_map>> xpu_events; DeviceIndex device_index; size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) { if (!src || src->allocated || src->event_count > 0 || !src->stream_uses.empty()) { return 0; } TORCH_INTERNAL_ASSERT(dst->is_split() && src->is_split()); if (dst->prev == src) { // [src dst] dst->ptr = src->ptr; dst->prev = src->prev; if (dst->prev) { dst->prev->next = dst; } } else { // [dst src] dst->next = src->next; if (dst->next) { dst->next->prev = dst; } } const size_t subsumed_size = src->size; dst->size += subsumed_size; auto erased = pool.blocks.erase(src); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1); delete src; return subsumed_size; } void free_block(Block* block) { TORCH_INTERNAL_ASSERT( !block->allocated && block->event_count == 0 && block->stream_uses.empty()); size_t original_block_size = block->size; size_t requested_size = block->requested_size; auto& pool = *block->pool; const std::array merge_candidates = {block->prev, block->next}; for (Block* merge_candidate : merge_candidates) { try_merge_blocks(block, merge_candidate, pool); } active_blocks.erase(block); bool inserted = pool.blocks.insert(block).second; TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); StatTypes stat_types = get_stat_types_for_pool(pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { stats.active_bytes[stat_type].decrease(original_block_size); stats.requested_bytes[stat_type].decrease(requested_size); }); } void process_events() { using namespace sycl::info; for (auto it = xpu_events.begin(); it != xpu_events.end();) { while (!it->second.empty()) { auto& e = it->second.front(); auto event = e.first; auto* block = e.second; if (event.get_info() != event_command_status::complete) { break; } block->event_count--; if (block->event_count == 0) { free_block(block); } it->second.pop_front(); } if (it->second.empty()) { it = xpu_events.erase(it); } else { it++; } } } static size_t round_size(size_t size) { if (size < kMinBlockSize) { return kMinBlockSize; } else { return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize); } } static size_t get_allocation_size(size_t size) { if (size <= kSmallSize) { return kSmallBuffer; } else if (size < kMinLargeAlloc) { return kLargeBuffer; } else { return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge); } } BlockPool& get_pool(size_t size) { if (size < kSmallSize) { return small_blocks; } else { return large_blocks; } } bool get_free_block(AllocParams& p) { BlockPool& pool = *p.pool; auto it = pool.blocks.lower_bound(&p.search_key); if (it == pool.blocks.end() || (*it)->queue != p.queue()) { return false; } p.block = *it; pool.blocks.erase(it); return true; } bool alloc_block(AllocParams& p) { auto size = p.alloc_size; auto device = p.device(); void* ptr = sycl::aligned_alloc_device( kDeviceAlignment, size, xpu::get_raw_device(device), xpu::get_device_context()); if (!ptr) { return false; } p.block = new Block(device, p.queue(), size, p.pool, ptr); for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { stats.reserved_bytes[stat_type].increase(size); }); return true; } void synchronize_and_free_events() { for (auto& xe : xpu_events) { for (auto& e : xe.second) { auto event = e.first; auto* block = e.second; event.wait(); block->event_count--; if (block->event_count == 0) { free_block(block); } } } xpu_events.clear(); } void release_block(Block* block) { /* * Note [Safe to Free Blocks on BlockPool] * * Callers must ensure that all accesses to the block, whose raw pointer is * allocated by SYCL APIs, have been completed before invoking sycl::free. * * We have to do a device-level synchronization before free these blocks to * guarantee that all kernels can access to the blocks have finished. */ sycl::free(block->ptr, xpu::get_device_context()); auto* pool = block->pool; pool->blocks.erase(block); StatTypes stat_types = get_stat_types_for_pool(*pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { stats.reserved_bytes[stat_type].decrease(block->size); }); delete block; } void release_blocks(BlockPool& pool) { auto it = pool.blocks.begin(); while (it != pool.blocks.end()) { Block* block = *it; ++it; if (!block->prev && !block->next) { release_block(block); } } } bool release_cached_blocks() { synchronize_and_free_events(); // See Note [Safe to Free Blocks on BlockPool] c10::xpu::syncStreamsOnDevice(device_index); release_blocks(large_blocks); release_blocks(small_blocks); return true; } bool should_split(const Block* block, size_t size) { size_t remaining = block->size - size; if (block->pool->is_small) { return remaining >= kMinBlockSize; } else { return remaining > kSmallSize; } } StatTypes get_stat_types_for_pool(const BlockPool& pool) { StatTypes stat_types = {}; stat_types[static_cast(StatType::AGGREGATE)] = true; stat_types[static_cast( pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true; return stat_types; } Block* alloc_found_block( AllocParams params, size_t orig_size, bool split_remainder) { auto size = params.size(); auto device = params.device(); BlockPool* pool = params.pool; sycl::queue* queue = params.queue(); TORCH_INTERNAL_ASSERT( params.block != nullptr && params.block->ptr != nullptr); Block* block = params.block; Block* remaining = nullptr; if (split_remainder) { remaining = block; block = new Block(device, queue, size, pool, block->ptr); block->prev = remaining->prev; if (block->prev) { block->prev->next = block; } block->next = remaining; remaining->prev = block; remaining->ptr = static_cast(remaining->ptr) + size; remaining->size -= size; bool inserted = pool->blocks.insert(remaining).second; TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); } block->allocated = true; block->requested_size = orig_size; bool inserted = active_blocks.insert(block).second; TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted) for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { stats.allocated_bytes[stat_type].increase(block->size); stats.active_bytes[stat_type].increase(block->size); stats.requested_bytes[stat_type].increase(block->requested_size); }); return block; } void insert_events(Block* block) { stream_set streams(std::move(block->stream_uses)); TORCH_INTERNAL_ASSERT(block->stream_uses.empty()); for (auto& stream : streams) { block->event_count++; xpu_events[stream].emplace_back( stream.queue().ext_oneapi_submit_barrier(), block); } } public: DeviceCachingAllocator(DeviceIndex device_index) : large_blocks(/* small */ false), small_blocks(/* small */ true), device_index(device_index) {} Block* malloc(DeviceIndex device, size_t orig_size, sycl::queue& queue) { std::scoped_lock lock(mutex); process_events(); size_t size = round_size(orig_size); auto& pool = get_pool(size); const size_t alloc_size = get_allocation_size(size); AllocParams params(device, size, &queue, &pool, alloc_size); params.stat_types = get_stat_types_for_pool(pool); // First, try to get a block from the existing pool. bool block_found = get_free_block(params); // Can't reuse an existing block, try to get a new one. if (!block_found) { block_found = alloc_block(params) || (release_cached_blocks() && alloc_block(params)); } if (!block_found) { c10::xpu::DeviceProp device_prop; c10::xpu::get_device_properties(&device_prop, device); auto device_total = device_prop.global_mem_size; auto allocated_bytes = stats.allocated_bytes[static_cast(StatType::AGGREGATE)] .current; auto reserved_bytes = stats.reserved_bytes[static_cast(StatType::AGGREGATE)] .current; TORCH_CHECK_WITH( OutOfMemoryError, false, "XPU out of memory. Tried to allocate ", format_size(alloc_size), ". GPU ", static_cast(device), " has a total capacity of ", format_size(device_total), ". Of the allocated memory ", format_size(allocated_bytes), " is allocated by PyTorch, and ", format_size(reserved_bytes - allocated_bytes), " is reserved by PyTorch but unallocated.", " Please use `empty_cache` to release all unoccupied cached memory."); } bool split_remainder = should_split(params.block, params.size()); return alloc_found_block(std::move(params), orig_size, split_remainder); } void free(Block* block) { std::scoped_lock lock(mutex); block->allocated = false; StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { stats.allocated_bytes[stat_type].decrease(block->size); }); if (!block->stream_uses.empty()) { insert_events(block); } else { free_block(block); } } void recordStream(Block* block, xpu::XPUStream stream) { std::scoped_lock lock(mutex); if (stream.queue() == *block->queue) { return; } block->stream_uses.insert(stream); } void emptyCache() { std::scoped_lock lock(mutex); release_cached_blocks(); } DeviceStats getStats() { std::scoped_lock lock(mutex); return stats; } void resetAccumulatedStats() { std::scoped_lock lock(mutex); for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { stats.allocated_bytes[statType].reset_accumulated(); stats.reserved_bytes[statType].reset_accumulated(); stats.active_bytes[statType].reset_accumulated(); stats.requested_bytes[statType].reset_accumulated(); } } void resetPeakStats() { std::scoped_lock lock(mutex); for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { stats.allocated_bytes[statType].reset_peak(); stats.reserved_bytes[statType].reset_peak(); stats.active_bytes[statType].reset_peak(); stats.requested_bytes[statType].reset_peak(); } } }; void local_raw_delete(void* ptr); class XPUAllocator : public Allocator { private: std::mutex mutex; ska::flat_hash_map allocated_blocks; void add_allocated_block(Block* block) { std::lock_guard lock(mutex); allocated_blocks[block->ptr] = block; } Block* get_allocated_block(void* ptr, bool remove = false) { std::scoped_lock lock(mutex); auto it = allocated_blocks.find(ptr); if (it == allocated_blocks.end()) { return nullptr; } Block* block = it->second; if (remove) { allocated_blocks.erase(it); } return block; } public: std::vector> device_allocators; void init(DeviceIndex device_count) { const auto size = static_cast(device_allocators.size()); if (size < device_count) { device_allocators.resize(device_count); for (const auto i : c10::irange(size, device_count)) { device_allocators[i] = std::make_unique(i); } } } void malloc( void** devPtr, DeviceIndex device, size_t size, sycl::queue& queue) { TORCH_INTERNAL_ASSERT( 0 <= device && static_cast(device) < device_allocators.size(), "Allocator not initialized for device ", static_cast(device), ": did you call init?"); Block* block = device_allocators[device]->malloc(device, size, queue); add_allocated_block(block); *devPtr = block->ptr; const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); if (C10_UNLIKELY(interp)) { (*interp)->trace_gpu_memory_allocation( c10::kXPU, reinterpret_cast(*devPtr)); } } void free(void* ptr) { if (!ptr) { return; } Block* block = get_allocated_block(ptr, /* remove */ true); TORCH_CHECK(block, "invalid device pointer: ", ptr); device_allocators[block->device]->free(block); const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); if (C10_UNLIKELY(interp)) { (*interp)->trace_gpu_memory_deallocation( c10::kXPU, reinterpret_cast(block->ptr)); } } void emptyCache() { for (auto& da : device_allocators) { da->emptyCache(); } } void recordStream(const DataPtr& ptr, XPUStream stream) { if (!ptr.get()) { return; } if (ptr.get_deleter() != &local_raw_delete) { return; } Block* block = get_allocated_block(ptr.get()); TORCH_CHECK(block, "No allocated block can be found."); device_allocators[block->device]->recordStream(block, stream); } DataPtr allocate(size_t size) override { auto device = c10::xpu::current_device(); void* r = nullptr; if (size != 0) { this->malloc(&r, device, size, xpu::getCurrentXPUStream(device)); } return {r, r, &local_raw_delete, Device(DeviceType::XPU, device)}; } DeleterFnPtr raw_deleter() const override { return &local_raw_delete; } void* raw_alloc(size_t size) { if (size == 0) { return nullptr; } auto device = c10::xpu::current_device(); void* r = nullptr; malloc(&r, device, size, xpu::getCurrentXPUStream(device)); return r; } void* raw_alloc_with_stream(size_t size, XPUStream stream) { if (size == 0) { return nullptr; } auto device = c10::xpu::current_device(); void* r = nullptr; malloc(&r, device, size, stream); return r; } void raw_delete(void* ptr) { this->free(ptr); } void copy_data(void* dest, const void* src, std::size_t count) const final { xpu::getCurrentXPUStream().queue().memcpy(dest, src, count); } void assertValidDevice(DeviceIndex device) { const auto device_num = device_allocators.size(); TORCH_CHECK( 0 <= device && device < static_cast(device_num), "Invalid device argument ", device, ": did you call init?"); } DeviceStats getDeviceStats(DeviceIndex device) { assertValidDevice(device); return device_allocators[device]->getStats(); } void resetPeakStats(DeviceIndex device) { assertValidDevice(device); device_allocators[device]->resetPeakStats(); } void resetAccumulatedStats(DeviceIndex device) { assertValidDevice(device); device_allocators[device]->resetAccumulatedStats(); } }; static XPUAllocator allocator; void local_raw_delete(void* ptr) { allocator.free(ptr); } Allocator* get() { return &allocator; } void init(DeviceIndex device_count) { return allocator.init(device_count); } void emptyCache() { return allocator.emptyCache(); } void resetPeakStats(DeviceIndex device) { return allocator.resetPeakStats(device); } void resetAccumulatedStats(DeviceIndex device) { return allocator.resetAccumulatedStats(device); } DeviceStats getDeviceStats(DeviceIndex device) { return allocator.getDeviceStats(device); } void* raw_alloc(size_t size) { return allocator.raw_alloc(size); } void raw_delete(void* ptr) { return allocator.raw_delete(ptr); } void recordStream(const DataPtr& dataPtr, XPUStream stream) { return allocator.recordStream(dataPtr, stream); } REGISTER_ALLOCATOR(kXPU, &allocator) } // namespace c10::xpu::XPUCachingAllocator