xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/detail/IndexUtils.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/TensorBase.h>
4 #include <ATen/cuda/detail/TensorInfo.cuh>
5 #include <ATen/native/CanUse32BitIndexMath.h>
6 
7 namespace at::cuda::detail {
8 
9 TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t);
10 using at::native::canUse32BitIndexMath;
11 
12 template <typename scalar, typename IndexType>
13 TensorInfo<scalar, IndexType>
getTensorInfo(const at::TensorBase & t)14 getTensorInfo(const at::TensorBase &t) {
15   IndexType sz[MAX_TENSORINFO_DIMS];
16   IndexType st[MAX_TENSORINFO_DIMS];
17 
18   int dims = t.dim();
19   for (int i = 0; i < dims; ++i) {
20     sz[i] = t.size(i);
21     st[i] = t.stride(i);
22   }
23 
24   scalar* data_ptr = nullptr;
25 
26   if constexpr (std::is_const<scalar>::value) {
27     data_ptr = t.const_data_ptr<scalar>();
28   } else {
29     data_ptr = t.mutable_data_ptr<scalar>();
30   }
31 
32   return TensorInfo<scalar, IndexType>(
33     data_ptr, dims, sz, st);
34 }
35 
36 } // namespace at::cuda::detail
37