1 #pragma once
2 #include <ATen/core/TensorBase.h>
3 #include <ATen/ceil_div.h>
4 #include <ATen/NumericUtils.h>
5 #include <c10/macros/Macros.h>
6 #include <stdlib.h>
7 #include <ATen/cuda/detail/IndexUtils.cuh>
8 #include <ATen/cuda/detail/TensorInfo.cuh>
9
10 namespace at {
11 namespace native {
12
13 // Is this questionable namespace pollution?
14 #if defined(USE_ROCM)
15 constexpr int MAX_BLOCK_SIZE = 256;
16
17 #else
18 constexpr int MAX_BLOCK_SIZE = 1024;
19 #endif
20
21 // Maximum size per grid dimension that we assume (compute capability >= 2.0)
22 constexpr int64_t MAX_GRID_SIZE = 65535LL;
23
getGridFromTiles(int64_t gridTiles,dim3 & grid)24 inline bool getGridFromTiles(int64_t gridTiles, dim3& grid) {
25 if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) {
26 return false;
27 }
28
29 int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
30 int64_t gridY = 1;
31 int64_t gridZ = 1;
32
33 if (gridTiles > MAX_GRID_SIZE) {
34 gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE);
35 gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
36
37 if (gridTiles > MAX_GRID_SIZE) {
38 gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE);
39 gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
40 }
41 }
42
43 grid = dim3(gridX, gridY, gridZ);
44 return true;
45 }
46
47 template <typename scalar_t, bool handleNaN = false>
48 struct GTOp {
operator ()at::native::GTOp49 __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
50 return (handleNaN && at::_isnan(lhs) && !at::_isnan(rhs)) || (lhs > rhs);
51 }
52 };
53
54 template <typename scalar_t, bool handleNaN = false>
55 struct LTOp {
operator ()at::native::LTOp56 __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
57 return (handleNaN && at::_isnan(rhs) && !at::_isnan(lhs)) || (lhs < rhs);
58 }
59 };
60
61 template <typename index_t>
getLinearBlockId()62 __device__ __forceinline__ index_t getLinearBlockId() {
63 return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x +
64 blockIdx.x;
65 }
66
67 // For slice sorting in Thrust; extracts a slice index from a linear
68 // index and uses that for comparison
69 struct SliceComp {
SliceCompat::native::SliceComp70 SliceComp(int64_t size) : sliceSize(size) {}
71
operator ()at::native::SliceComp72 __device__ bool operator()(const int64_t& a, const int64_t& b) const {
73 // Since the slices are guaranteed to be innermost,
74 // the segment is just via int64_t division
75 int64_t segA = a / sliceSize;
76 int64_t segB = b / sliceSize;
77 return segA < segB;
78 }
79
80 const int64_t sliceSize;
81 };
82
83 // For sorting in Thurst; extracts a within-slice index from a linear index
84 struct GlobalIndexToPerSliceIndex {
GlobalIndexToPerSliceIndexat::native::GlobalIndexToPerSliceIndex85 GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {}
86
operator ()at::native::GlobalIndexToPerSliceIndex87 __device__ inline void operator()(int64_t& v) const {
88 v = v % sliceSize;
89 }
90
91 const int64_t sliceSize;
92 };
93
94 // Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks
nextHighestPowerOf2(uint64_t n)95 inline uint64_t nextHighestPowerOf2(uint64_t n) {
96 n--;
97 n |= n >> 1;
98 n |= n >> 2;
99 n |= n >> 4;
100 n |= n >> 8;
101 n |= n >> 16;
102 #ifndef _MSC_VER
103 n |= n >> 32;
104 #endif
105 n++;
106
107 return n;
108 }
109
110
111 // WARNING: This function assumes input tensors are contiguous
112 template <typename scalar_t, typename index_t, typename Launcher>
run_launcher(const TensorBase & values,const TensorBase & indices,const TensorBase & self,int64_t dim,Launcher l)113 void run_launcher(
114 const TensorBase &values,
115 const TensorBase &indices,
116 const TensorBase &self,
117 int64_t dim,
118 Launcher l) {
119 auto self_info = cuda::detail::getTensorInfo<const scalar_t, index_t>(self);
120 auto values_info = cuda::detail::getTensorInfo<scalar_t, index_t>(values);
121 auto indices_info = cuda::detail::getTensorInfo<int64_t, index_t>(indices);
122
123 int64_t slice_size = self.size(dim);
124 /* We use these structures solely to find the offset to */
125 /* each slice we are operating on */
126 self_info.reduceDim(dim);
127 values_info.reduceDim(dim);
128 indices_info.reduceDim(dim);
129
130 /* Collapse all other dims */
131 int collapse_self_dim = self_info.collapseDims(dim);
132 int collapse_values_dim = values_info.collapseDims(dim);
133 int collapse_indices_dim = indices_info.collapseDims(dim);
134
135 int64_t num_slices = 1;
136 for (int i = 0; i < self_info.dims; ++i) {
137 num_slices *= self_info.sizes[i];
138 }
139
140 /* This is used as a template parameter to calculate indices. */
141 /* We only specialize it if all collapsed dim sizes are the */
142 /* same; otherwise, we use -1 which is the specialization */
143 /* parameter for arbitrary dimensions */
144 int all_dims = self_info.dims;
145 if (values_info.dims != all_dims || indices_info.dims != all_dims) {
146 all_dims = -1;
147 }
148
149 if (all_dims == 1) {
150 l.template launch<scalar_t, index_t, 1>(
151 values_info,
152 collapse_values_dim,
153 indices_info,
154 collapse_indices_dim,
155 self_info,
156 collapse_self_dim,
157 num_slices,
158 slice_size);
159 } else if (all_dims == 2) {
160 l.template launch<scalar_t, index_t, 2>(
161 values_info,
162 collapse_values_dim,
163 indices_info,
164 collapse_indices_dim,
165 self_info,
166 collapse_self_dim,
167 num_slices,
168 slice_size);
169 } else if (all_dims == 3) {
170 l.template launch<scalar_t, index_t, 3>(
171 values_info,
172 collapse_values_dim,
173 indices_info,
174 collapse_indices_dim,
175 self_info,
176 collapse_self_dim,
177 num_slices,
178 slice_size);
179 } else {
180 l.template launch<scalar_t, index_t, -1>(
181 values_info,
182 collapse_values_dim,
183 indices_info,
184 collapse_indices_dim,
185 self_info,
186 collapse_self_dim,
187 num_slices,
188 slice_size);
189 }
190 }
191
192 } // namespace native
193 } // namespace at
194