xref: /aosp_15_r20/external/pytorch/c10/xpu/XPUCachingAllocator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/CachingDeviceAllocator.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUStream.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker namespace c10::xpu::XPUCachingAllocator {
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker C10_XPU_API Allocator* get();
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void init(DeviceIndex device_count);
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void emptyCache();
13*da0073e9SAndroid Build Coastguard Worker 
14*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void resetPeakStats(DeviceIndex device);
15*da0073e9SAndroid Build Coastguard Worker 
16*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void resetAccumulatedStats(DeviceIndex device);
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker C10_XPU_API c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
19*da0073e9SAndroid Build Coastguard Worker     DeviceIndex device);
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void* raw_alloc(size_t size);
22*da0073e9SAndroid Build Coastguard Worker 
23*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void raw_delete(void* ptr);
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker } // namespace c10::xpu::XPUCachingAllocator
28