xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/detail/TensorInfo.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/CollapseDims.h>
4 
5 namespace at::cuda::detail {
6 
7 #define MAX_TENSORINFO_DIMS 25
8 
9 // CUDA kernel argument that defines tensor layout
10 template <typename T, typename IndexType>
11 struct TensorInfo {
12   TensorInfo();
13   TensorInfo(T* p,
14              int dim,
15              IndexType sz[MAX_TENSORINFO_DIMS],
16              IndexType st[MAX_TENSORINFO_DIMS]);
17 
18   // Set the size of the given dimension to 1, as if it were a
19   // reduction dim (allows you to calculate offsets of the reduction
20   // slice)
21   void reduceDim(int dim);
22 
23   // See note on [collapse dims].
24   int collapseDims(const int excludeDim = -1);
25 
26   // Contiguous tensors of more than one dimension are collapsed down
27   // to one tensor
isContiguousat::cuda::detail::TensorInfo28   __host__ __device__ inline bool isContiguous() const {
29     return (dims == 1 && strides[0] == 1);
30   }
31 
32   T* data;
33   IndexType sizes[MAX_TENSORINFO_DIMS];
34   IndexType strides[MAX_TENSORINFO_DIMS];
35   int dims;
36 };
37 
38 template <typename T, typename IndexType>
TensorInfo()39 TensorInfo<T, IndexType>::TensorInfo() {
40   data = nullptr;
41   dims = 0;
42 }
43 
44 template <typename T, typename IndexType>
TensorInfo(T * p,int dim,IndexType sz[MAX_TENSORINFO_DIMS],IndexType st[MAX_TENSORINFO_DIMS])45 TensorInfo<T, IndexType>::TensorInfo(T* p,
46                                      int dim,
47                                      IndexType sz[MAX_TENSORINFO_DIMS],
48                                      IndexType st[MAX_TENSORINFO_DIMS]) {
49   data = p;
50   dims = dim;
51   TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions");
52 
53   for (int i = 0; i < dim; ++i) {
54     sizes[i] = sz[i];
55     strides[i] = st[i];
56   }
57 }
58 
59 template <typename T, typename IndexType>
60 void
reduceDim(int dim)61 TensorInfo<T, IndexType>::reduceDim(int dim) {
62   TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
63   sizes[dim] = 1;
64 }
65 
66 template <typename T, typename IndexType>
67 int
collapseDims(const int excludeDim)68 TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {
69   auto result = at::collapse_dims(sizes, strides, dims, excludeDim);
70   dims = std::get<1>(result);
71   return std::get<0>(result);
72 }
73 
74 // Translate a linear index for the apply to a T* offset;
75 // specialized on `Dims` to reduce nvcc compilation time
76 template <typename T, typename IndexType, int Dims>
77 struct IndexToOffset {
getat::cuda::detail::IndexToOffset78   static __host__ __device__ IndexType get(
79     IndexType linearId,
80     const TensorInfo<T, IndexType>& info) {
81 
82     IndexType offset = 0;
83 
84     // Uses static dims
85     for (int i = Dims - 1; i > 0; --i) {
86       IndexType curDimIndex = linearId % info.sizes[i];
87       IndexType curDimOffset = curDimIndex * info.strides[i];
88       offset += curDimOffset;
89       linearId /= info.sizes[i];
90     }
91 
92     return offset + linearId * info.strides[0];
93   }
94 };
95 
96 // Uses dynamic (runtime) instead of static (compiletime) dims
97 template <typename T, typename IndexType>
98 struct IndexToOffset<T, IndexType, -1> {
getat::cuda::detail::IndexToOffset99   static inline __host__ __device__ IndexType get(
100     IndexType linearId,
101     const TensorInfo<T, IndexType>& info) {
102 
103       IndexType offset = 0;
104 
105       for (int i = info.dims - 1; i > 0; --i) {
106         IndexType curDimIndex = linearId % info.sizes[i];
107         IndexType curDimOffset = curDimIndex * info.strides[i];
108         offset += curDimOffset;
109         linearId /= info.sizes[i];
110       }
111 
112       return offset + linearId * info.strides[0];
113   }
114 };
115 
116 } // namespace at::cuda::detail
117