1 #pragma once
2
3 #include <ATen/CollapseDims.h>
4
5 namespace at::cuda::detail {
6
7 #define MAX_TENSORINFO_DIMS 25
8
9 // CUDA kernel argument that defines tensor layout
10 template <typename T, typename IndexType>
11 struct TensorInfo {
12 TensorInfo();
13 TensorInfo(T* p,
14 int dim,
15 IndexType sz[MAX_TENSORINFO_DIMS],
16 IndexType st[MAX_TENSORINFO_DIMS]);
17
18 // Set the size of the given dimension to 1, as if it were a
19 // reduction dim (allows you to calculate offsets of the reduction
20 // slice)
21 void reduceDim(int dim);
22
23 // See note on [collapse dims].
24 int collapseDims(const int excludeDim = -1);
25
26 // Contiguous tensors of more than one dimension are collapsed down
27 // to one tensor
isContiguousat::cuda::detail::TensorInfo28 __host__ __device__ inline bool isContiguous() const {
29 return (dims == 1 && strides[0] == 1);
30 }
31
32 T* data;
33 IndexType sizes[MAX_TENSORINFO_DIMS];
34 IndexType strides[MAX_TENSORINFO_DIMS];
35 int dims;
36 };
37
38 template <typename T, typename IndexType>
TensorInfo()39 TensorInfo<T, IndexType>::TensorInfo() {
40 data = nullptr;
41 dims = 0;
42 }
43
44 template <typename T, typename IndexType>
TensorInfo(T * p,int dim,IndexType sz[MAX_TENSORINFO_DIMS],IndexType st[MAX_TENSORINFO_DIMS])45 TensorInfo<T, IndexType>::TensorInfo(T* p,
46 int dim,
47 IndexType sz[MAX_TENSORINFO_DIMS],
48 IndexType st[MAX_TENSORINFO_DIMS]) {
49 data = p;
50 dims = dim;
51 TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions");
52
53 for (int i = 0; i < dim; ++i) {
54 sizes[i] = sz[i];
55 strides[i] = st[i];
56 }
57 }
58
59 template <typename T, typename IndexType>
60 void
reduceDim(int dim)61 TensorInfo<T, IndexType>::reduceDim(int dim) {
62 TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
63 sizes[dim] = 1;
64 }
65
66 template <typename T, typename IndexType>
67 int
collapseDims(const int excludeDim)68 TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {
69 auto result = at::collapse_dims(sizes, strides, dims, excludeDim);
70 dims = std::get<1>(result);
71 return std::get<0>(result);
72 }
73
74 // Translate a linear index for the apply to a T* offset;
75 // specialized on `Dims` to reduce nvcc compilation time
76 template <typename T, typename IndexType, int Dims>
77 struct IndexToOffset {
getat::cuda::detail::IndexToOffset78 static __host__ __device__ IndexType get(
79 IndexType linearId,
80 const TensorInfo<T, IndexType>& info) {
81
82 IndexType offset = 0;
83
84 // Uses static dims
85 for (int i = Dims - 1; i > 0; --i) {
86 IndexType curDimIndex = linearId % info.sizes[i];
87 IndexType curDimOffset = curDimIndex * info.strides[i];
88 offset += curDimOffset;
89 linearId /= info.sizes[i];
90 }
91
92 return offset + linearId * info.strides[0];
93 }
94 };
95
96 // Uses dynamic (runtime) instead of static (compiletime) dims
97 template <typename T, typename IndexType>
98 struct IndexToOffset<T, IndexType, -1> {
getat::cuda::detail::IndexToOffset99 static inline __host__ __device__ IndexType get(
100 IndexType linearId,
101 const TensorInfo<T, IndexType>& info) {
102
103 IndexType offset = 0;
104
105 for (int i = info.dims - 1; i > 0; --i) {
106 IndexType curDimIndex = linearId % info.sizes[i];
107 IndexType curDimOffset = curDimIndex * info.strides[i];
108 offset += curDimOffset;
109 linearId /= info.sizes[i];
110 }
111
112 return offset + linearId * info.strides[0];
113 }
114 };
115
116 } // namespace at::cuda::detail
117