1 #include <ATen/xpu/CachingHostAllocator.h>
2
3 namespace at::xpu {
4 namespace {
5
6 constexpr size_t kHostAlignment = 512;
7
8 using Block = HostBlock<XPUStream>;
9
10 struct XPUCachingHostAllocatorImpl
11 : public CachingHostAllocatorImpl<XPUStream, XPUEvent> {
12 /* These following functions are runtime-related. */
allocate_host_memoryat::xpu::__anond8fa822d0111::XPUCachingHostAllocatorImpl13 void allocate_host_memory(size_t size, void** ptr) override {
14 *ptr = sycl::aligned_alloc_host(
15 kHostAlignment, size, c10::xpu::get_device_context());
16 }
17
free_blockat::xpu::__anond8fa822d0111::XPUCachingHostAllocatorImpl18 void free_block(Block* block) override {
19 sycl::free(block->ptr_, c10::xpu::get_device_context());
20 }
21
record_streamat::xpu::__anond8fa822d0111::XPUCachingHostAllocatorImpl22 void record_stream(
23 std::optional<std::vector<XPUEvent>>& events,
24 XPUStream stream) override {
25 XPUEvent event;
26 event.record(stream);
27 events->push_back(std::move(event));
28 }
29
query_eventat::xpu::__anond8fa822d0111::XPUCachingHostAllocatorImpl30 bool query_event(XPUEvent& event) override {
31 return event.query();
32 }
33 };
34
35 void raw_local_deleter(void* ptr);
36
37 struct XPUCachingHostAllocator final
38 : public CachingHostAllocatorInterface<XPUCachingHostAllocatorImpl> {
allocateat::xpu::__anond8fa822d0111::XPUCachingHostAllocator39 at::DataPtr allocate(size_t size) override {
40 auto ptr_and_ctx = impl_->allocate(size);
41 return {
42 ptr_and_ctx.first,
43 ptr_and_ctx.second,
44 &raw_local_deleter,
45 at::DeviceType::CPU};
46 }
47 };
48
49 static XPUCachingHostAllocator caching_host_allocator;
50
getXPUCachingHostAllocator()51 static inline XPUCachingHostAllocator& getXPUCachingHostAllocator() {
52 return caching_host_allocator;
53 }
54
raw_local_deleter(void * ptr)55 void raw_local_deleter(void* ptr) {
56 getXPUCachingHostAllocator().free(ptr);
57 }
58
59 } // anonymous namespace
60
CachingHostAllocator_recordEvent(void * ptr,void * ctx,c10::xpu::XPUStream stream)61 bool CachingHostAllocator_recordEvent(
62 void* ptr,
63 void* ctx,
64 c10::xpu::XPUStream stream) {
65 return getXPUCachingHostAllocator().record_event(ptr, ctx, stream);
66 }
67
CachingHostAllocator_emptyCache()68 void CachingHostAllocator_emptyCache() {
69 getXPUCachingHostAllocator().empty_cache();
70 }
71
getCachingHostAllocator()72 at::Allocator* getCachingHostAllocator() {
73 return &getXPUCachingHostAllocator();
74 }
75
76 } // namespace at::xpu
77