1 #pragma once 2 3 #include <cstddef> 4 #include <c10/cuda/CUDACachingAllocator.h> 5 6 namespace at::cuda { 7 8 /// Allocator for Thrust to re-route its internal device allocations 9 /// to the THC allocator 10 class ThrustAllocator { 11 public: 12 typedef char value_type; 13 allocate(std::ptrdiff_t size)14 char* allocate(std::ptrdiff_t size) { 15 return static_cast<char*>(c10::cuda::CUDACachingAllocator::raw_alloc(size)); 16 } 17 deallocate(char * p,size_t size)18 void deallocate(char* p, size_t size) { 19 c10::cuda::CUDACachingAllocator::raw_delete(p); 20 } 21 }; 22 23 } // namespace at::cuda 24