xref: /aosp_15_r20/external/pytorch/aten/src/ATen/xpu/CachingHostAllocator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/xpu/CachingHostAllocator.h>
2 
3 namespace at::xpu {
4 namespace {
5 
6 constexpr size_t kHostAlignment = 512;
7 
8 using Block = HostBlock<XPUStream>;
9 
10 struct XPUCachingHostAllocatorImpl
11     : public CachingHostAllocatorImpl<XPUStream, XPUEvent> {
12   /* These following functions are runtime-related. */
allocate_host_memoryat::xpu::__anond8fa822d0111::XPUCachingHostAllocatorImpl13   void allocate_host_memory(size_t size, void** ptr) override {
14     *ptr = sycl::aligned_alloc_host(
15         kHostAlignment, size, c10::xpu::get_device_context());
16   }
17 
free_blockat::xpu::__anond8fa822d0111::XPUCachingHostAllocatorImpl18   void free_block(Block* block) override {
19     sycl::free(block->ptr_, c10::xpu::get_device_context());
20   }
21 
record_streamat::xpu::__anond8fa822d0111::XPUCachingHostAllocatorImpl22   void record_stream(
23       std::optional<std::vector<XPUEvent>>& events,
24       XPUStream stream) override {
25     XPUEvent event;
26     event.record(stream);
27     events->push_back(std::move(event));
28   }
29 
query_eventat::xpu::__anond8fa822d0111::XPUCachingHostAllocatorImpl30   bool query_event(XPUEvent& event) override {
31     return event.query();
32   }
33 };
34 
35 void raw_local_deleter(void* ptr);
36 
37 struct XPUCachingHostAllocator final
38     : public CachingHostAllocatorInterface<XPUCachingHostAllocatorImpl> {
allocateat::xpu::__anond8fa822d0111::XPUCachingHostAllocator39   at::DataPtr allocate(size_t size) override {
40     auto ptr_and_ctx = impl_->allocate(size);
41     return {
42         ptr_and_ctx.first,
43         ptr_and_ctx.second,
44         &raw_local_deleter,
45         at::DeviceType::CPU};
46   }
47 };
48 
49 static XPUCachingHostAllocator caching_host_allocator;
50 
getXPUCachingHostAllocator()51 static inline XPUCachingHostAllocator& getXPUCachingHostAllocator() {
52   return caching_host_allocator;
53 }
54 
raw_local_deleter(void * ptr)55 void raw_local_deleter(void* ptr) {
56   getXPUCachingHostAllocator().free(ptr);
57 }
58 
59 } // anonymous namespace
60 
CachingHostAllocator_recordEvent(void * ptr,void * ctx,c10::xpu::XPUStream stream)61 bool CachingHostAllocator_recordEvent(
62     void* ptr,
63     void* ctx,
64     c10::xpu::XPUStream stream) {
65   return getXPUCachingHostAllocator().record_event(ptr, ctx, stream);
66 }
67 
CachingHostAllocator_emptyCache()68 void CachingHostAllocator_emptyCache() {
69   getXPUCachingHostAllocator().empty_cache();
70 }
71 
getCachingHostAllocator()72 at::Allocator* getCachingHostAllocator() {
73   return &getXPUCachingHostAllocator();
74 }
75 
76 } // namespace at::xpu
77