xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/detail/OffsetCalculator.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <array>
4 #include <cstdint>
5 #include <type_traits>
6 #include <c10/macros/Macros.h>
7 #include <ATen/core/Array.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/cuda/detail/IntegerDivider.cuh>
10 
11 // If element_sizes is nullptr, then the strides will be in bytes, otherwise
12 // the strides will be in # of elements.
13 // Operands that share the same shape, but may have different strides.
14 // OffsetCalculator iterates the tensor in a column-major order
15 
16 #if defined(USE_ROCM)
17 constexpr int MAX_DIMS = 16;
18 #else
19 constexpr int MAX_DIMS = 25;
20 #endif
21 
22 template <int NARGS, typename index_t = uint32_t, bool signed_strides = false>
23 struct OffsetCalculator {
24   // We allow having negative strides to implement some operations like torch.flip
25   using stride_t = std::conditional_t<signed_strides,
26                                       std::make_signed_t<index_t>,
27                                       index_t>;
28   // The offset for each argument. Wrapper around fixed-size array.
29   // On CUDA, zero sized array is not allowed, so when we are handling nullary
30   // operators, we need to create a size 1 offset to avoid compiler failure.
31   // This size 1 offset is just a placeholder, and we will not use it.
32   using offset_type = at::detail::Array<stride_t, std::max<int>(NARGS, 1)>;
33 
34   // if element_sizes is nullptr, then the strides will be in bytes, otherwise
35   // the strides will be in # of elements.
OffsetCalculatorOffsetCalculator36   OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) {
37     TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
38     for (int i=0; i < dims; i++){
39       sizes_[i] = at::cuda::detail::IntDivider<index_t>(sizes[i]);
40       for (int arg = 0; arg < NARGS; arg++) {
41         int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
42         strides_[i][arg] = strides[arg][i] / element_size;
43       }
44     }
45   }
46 
getOffsetCalculator47   C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
48     offset_type offsets;
49     #pragma unroll
50     for (int arg = 0; arg < NARGS; arg++) {
51       offsets[arg] = 0;
52     }
53 
54     #pragma unroll
55     for (int dim = 0; dim < MAX_DIMS; ++dim) {
56       if (dim == dims) {
57         break;
58       }
59       auto divmod = sizes_[dim].divmod(linear_idx);
60       linear_idx = divmod.div;
61 
62       #pragma unroll
63       for (int arg = 0; arg < NARGS; arg++) {
64         offsets[arg] += divmod.mod * strides_[dim][arg];
65       }
66 
67     }
68     return offsets;
69   }
70 
71   int dims;
72   at::cuda::detail::IntDivider<index_t> sizes_[MAX_DIMS];
73   stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
74 };
75 
76 template <int NARGS, typename index_t = uint32_t>
77 struct TrivialOffsetCalculator {
78   // The offset for each argument. Wrapper around fixed-size array.
79   // The offsets are in # of elements, not in bytes.
80   // On CUDA, zero sized array is not allowed, so when we are handling nullary
81   // operators, we need to create a size 1 offset to avoid compiler failure.
82   // This size 1 offset is just a placeholder, and we will not use it.
83   using offset_type = at::detail::Array<index_t, std::max<int>(NARGS, 1)>;
84 
getTrivialOffsetCalculator85   C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
86     offset_type offsets;
87     #pragma unroll
88     for (int arg = 0; arg < NARGS; arg++) {
89       offsets[arg] = linear_idx;
90     }
91     return offsets;
92   }
93 };
94 
95 // Make an OffsetCalculator with byte offsets
96 template<int N, bool signed_strides = false>
make_offset_calculator(const at::TensorIteratorBase & iter)97 static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(const at::TensorIteratorBase& iter) {
98   TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
99   std::array<const int64_t*, N> strides;
100   for (int i = 0; i < N; i++) {
101     strides[i] = iter.strides(i).data();
102   }
103   return OffsetCalculator<N, uint32_t, signed_strides>(iter.ndim(), iter.shape().data(), strides.data());
104 }
105 
106 // Make an OffsetCalculator with element offsets
107 template<int N, bool signed_strides = false>
make_element_offset_calculator(const at::TensorIteratorBase & iter)108 static OffsetCalculator<N, uint32_t, signed_strides> make_element_offset_calculator(
109     const at::TensorIteratorBase& iter) {
110   TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
111   std::array<const int64_t*, N> strides;
112   std::array<int64_t, N> element_sizes;
113   for (int i = 0; i < N; i++) {
114     strides[i] = iter.strides(i).data();
115     element_sizes[i] = iter.element_size(i);
116   }
117   return OffsetCalculator<N, uint32_t, signed_strides>(
118       iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data());
119 }
120