1 #include <c10/core/Allocator.h> 2 #include <c10/util/Exception.h> 3 #include <cstddef> 4 #include <cstdint> 5 #include <type_traits> 6 7 namespace c10 { 8 9 template <typename T> 10 class DeviceArray { 11 public: DeviceArray(c10::Allocator & allocator,size_t size)12 DeviceArray(c10::Allocator& allocator, size_t size) 13 : data_ptr_(allocator.allocate(size * sizeof(T))) { 14 static_assert(std::is_trivial<T>::value, "T must be a trivial type"); 15 TORCH_INTERNAL_ASSERT( 16 0 == (reinterpret_cast<intptr_t>(data_ptr_.get()) % alignof(T)), 17 "c10::DeviceArray: Allocated memory is not aligned for this data type"); 18 } 19 get()20 T* get() { 21 return static_cast<T*>(data_ptr_.get()); 22 } 23 24 private: 25 c10::DataPtr data_ptr_; 26 }; 27 28 } // namespace c10 29