xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/ThrustAllocator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstddef>
4 #include <c10/cuda/CUDACachingAllocator.h>
5 
6 namespace at::cuda {
7 
8 /// Allocator for Thrust to re-route its internal device allocations
9 /// to the THC allocator
10 class ThrustAllocator {
11 public:
12   typedef char value_type;
13 
allocate(std::ptrdiff_t size)14   char* allocate(std::ptrdiff_t size) {
15     return static_cast<char*>(c10::cuda::CUDACachingAllocator::raw_alloc(size));
16   }
17 
deallocate(char * p,size_t size)18   void deallocate(char* p, size_t size) {
19     c10::cuda::CUDACachingAllocator::raw_delete(p);
20   }
21 };
22 
23 } // namespace at::cuda
24