xref: /aosp_15_r20/external/pytorch/c10/core/DeviceArray.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/Allocator.h>
2 #include <c10/util/Exception.h>
3 #include <cstddef>
4 #include <cstdint>
5 #include <type_traits>
6 
7 namespace c10 {
8 
9 template <typename T>
10 class DeviceArray {
11  public:
DeviceArray(c10::Allocator & allocator,size_t size)12   DeviceArray(c10::Allocator& allocator, size_t size)
13       : data_ptr_(allocator.allocate(size * sizeof(T))) {
14     static_assert(std::is_trivial<T>::value, "T must be a trivial type");
15     TORCH_INTERNAL_ASSERT(
16         0 == (reinterpret_cast<intptr_t>(data_ptr_.get()) % alignof(T)),
17         "c10::DeviceArray: Allocated memory is not aligned for this data type");
18   }
19 
get()20   T* get() {
21     return static_cast<T*>(data_ptr_.get());
22   }
23 
24  private:
25   c10::DataPtr data_ptr_;
26 };
27 
28 } // namespace c10
29