1 #pragma once 2 3 #include <ATen/core/TensorBase.h> 4 #include <ATen/cuda/detail/TensorInfo.cuh> 5 #include <ATen/native/CanUse32BitIndexMath.h> 6 7 namespace at::cuda::detail { 8 9 TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t); 10 using at::native::canUse32BitIndexMath; 11 12 template <typename scalar, typename IndexType> 13 TensorInfo<scalar, IndexType> getTensorInfo(const at::TensorBase & t)14getTensorInfo(const at::TensorBase &t) { 15 IndexType sz[MAX_TENSORINFO_DIMS]; 16 IndexType st[MAX_TENSORINFO_DIMS]; 17 18 int dims = t.dim(); 19 for (int i = 0; i < dims; ++i) { 20 sz[i] = t.size(i); 21 st[i] = t.stride(i); 22 } 23 24 scalar* data_ptr = nullptr; 25 26 if constexpr (std::is_const<scalar>::value) { 27 data_ptr = t.const_data_ptr<scalar>(); 28 } else { 29 data_ptr = t.mutable_data_ptr<scalar>(); 30 } 31 32 return TensorInfo<scalar, IndexType>( 33 data_ptr, dims, sz, st); 34 } 35 36 } // namespace at::cuda::detail 37