1 #pragma once
2 #include <c10/macros/Macros.h>
3
4 #include <ATen/cuda/cub.cuh>
5 #include <ATen/cuda/detail/TensorInfo.cuh>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/DeviceUtils.cuh>
8 #include <ATen/native/cuda/SortingCommon.cuh>
9 #include <ATen/native/cuda/Sort.h>
10 #include <ATen/native/StridedRandomAccessor.h>
11
12 #define HAS_WARP_MERGE_SORT() (CUDA_VERSION >= 110600)
13
14
15 namespace at { namespace native {
16
17 template <typename T>
swapVars(T & t1,T & t2)18 __device__ inline void swapVars(T& t1, T& t2) {
19 T tmp = t1;
20 t1 = t2;
21 t2 = tmp;
22 }
23
24 template <typename Comparator, typename K, typename V>
bitonicSwap(K & kA,V & vA,bool & validA,K & kB,V & vB,bool & validB,bool dir,const Comparator & comp)25 __device__ inline void bitonicSwap(K& kA, V& vA, bool& validA,
26 K& kB, V& vB, bool& validB,
27 bool dir,
28 const Comparator& comp) {
29 // Invalid entries always sort to the end
30 bool swap = (comp(kA, kB) && validA) || !validB;
31 if (swap == dir) {
32 swapVars(kA, kB);
33 swapVars(vA, vB);
34 swapVars(validA, validB);
35 }
36 };
37
38 template <int Power2SortSize, typename IndexType, typename Comparator,
39 typename K, typename V>
bitonicSort(K * keys,V * values,bool * valid,const Comparator & comp)40 __device__ inline void bitonicSort(K *keys,
41 V *values,
42 bool *valid,
43 const Comparator& comp) {
44 #if !defined(USE_ROCM)
45 #pragma unroll
46 #endif
47 for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
48 bool flag = ((threadIdx.x & (size / 2)) != 0);
49
50 #if !defined(USE_ROCM)
51 #pragma unroll
52 #endif
53 for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
54
55 __syncthreads();
56
57 unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
58 bitonicSwap<Comparator, K, V>(
59 keys[pos], values[pos], valid[pos],
60 keys[pos + stride], values[pos + stride], valid[pos + stride],
61 flag, comp);
62 }
63 }
64
65 #if !defined(USE_ROCM)
66 #pragma unroll
67 #endif
68 for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
69
70 __syncthreads();
71
72 unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
73 bitonicSwap<Comparator, K, V>(
74 keys[pos], values[pos], valid[pos],
75 keys[pos + stride], values[pos + stride], valid[pos + stride],
76 false, comp);
77 }
78
79 __syncthreads();
80
81 }
82
83 // at::cuda::detail::TensorInfo version
84 // Sorts (key, value) pairs (in different tensors) in-place; i.e.,
85 // modifies the input `keys` and `values`
86 template <int KeyDims, int ValueDims, int block_dim_x, int max_block_dim_y,
87 typename K, typename V, typename Comparator, typename IndexType>
C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y)88 C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y)
89 __global__ void
90 bitonicSortKVInPlace(at::cuda::detail::TensorInfo<K, IndexType> keys,
91 IndexType keySlices,
92 IndexType keySliceSize,
93 IndexType keySliceStride,
94 at::cuda::detail::TensorInfo<V, IndexType> values,
95 IndexType valueSliceStride,
96 Comparator comp) {
97 // Find the slice of the tensor that we are sorting
98 // NOTE: blockDim.y may be less max_block_dim_y
99 const IndexType blockIndex = getLinearBlockId<IndexType>();
100 const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y;
101
102 // If the entire block is out of bounds exit early
103 if (blockIndex * blockDim.y >= keySlices) {
104 return;
105 }
106 // It's also possible for some rows of a block to be out of bounds
107 // but all thread need to run for __syncthreads to work.
108 const bool row_valid = linearIndex < keySlices;
109
110 constexpr int items_per_thread = 2;
111 constexpr int Power2SortSize = block_dim_x * items_per_thread;
112
113 // Storage for max_block_dim_y sorts performed in parallel
114 __shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize];
115 __shared__ V blockSharedValues[max_block_dim_y][Power2SortSize];
116 __shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize];
117
118 auto sharedKeys = blockSharedKeys[threadIdx.y];
119 auto sharedValues = blockSharedValues[threadIdx.y];
120 auto sharedValid = blockSharedValid[threadIdx.y];
121
122 const IndexType keyStartOffset =
123 at::cuda::detail::IndexToOffset<K, IndexType, KeyDims>::get(linearIndex, keys);
124 const IndexType valueStartOffset =
125 at::cuda::detail::IndexToOffset<V, IndexType, ValueDims>::get(linearIndex, values);
126
127 // Load 2 values per thread into the shared workspace
128 #pragma unroll
129 for (int k = 0; k < items_per_thread; ++k) {
130 auto idx = threadIdx.x + k * blockDim.x;
131 bool valid = row_valid && idx < keySliceSize;
132
133 sharedKeys[idx] = valid ?
134 keys.data[idx * keySliceStride + keyStartOffset] : K{};
135 sharedValues[idx] = valid ?
136 values.data[idx * valueSliceStride + valueStartOffset] : V{};
137 sharedValid[idx] = valid;
138 }
139
140 // Sort!
141 bitonicSort<Power2SortSize, IndexType>(
142 sharedKeys, sharedValues, sharedValid, comp);
143
144 if (!row_valid) {
145 return;
146 }
147
148 // Store outputs
149 #pragma unroll
150 for (int k = 0; k < items_per_thread; ++k) {
151 auto idx = threadIdx.x + k * blockDim.x;
152 if (idx < keySliceSize) {
153 keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx];
154 values.data[idx * valueSliceStride + valueStartOffset] = sharedValues[idx];
155 }
156 }
157 }
158
159 #if HAS_WARP_MERGE_SORT()
160
161 template <int KeyDims, int ValueDims, int sort_size, int max_block_dim_y,
162 typename K, typename V, typename Comparator, typename IndexType>
C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y)163 C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y)
164 __global__ void
165 warpMergeSortKVInPlace(
166 at::cuda::detail::TensorInfo<K, IndexType> keys,
167 IndexType keySlices,
168 IndexType keySliceSize,
169 IndexType keySliceStride,
170 at::cuda::detail::TensorInfo<V, IndexType> values,
171 IndexType valueSliceStride,
172 Comparator comp,
173 K invalid_key) {
174 // Find the slice of the tensor that we are sorting
175 // NOTE: blockDim.y may be less max_block_dim_y
176 const IndexType blockIndex = getLinearBlockId<IndexType>();
177 const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y;
178
179 // If this row is out of bounds exit early
180 if (linearIndex >= keySlices) {
181 return;
182 }
183
184 const IndexType keyStartOffset =
185 at::cuda::detail::IndexToOffset<K, IndexType, KeyDims>::get(linearIndex, keys);
186 const IndexType valueStartOffset =
187 at::cuda::detail::IndexToOffset<V, IndexType, ValueDims>::get(linearIndex, values);
188
189 K *keys_slice = &keys.data[keyStartOffset];
190 V *values_slice = &values.data[valueStartOffset];
191
192 StridedRandomAccessor<K, IndexType> keys_iter(keys_slice, keySliceStride);
193 StridedRandomAccessor<V, IndexType> values_iter(values_slice, valueSliceStride);
194
195 namespace cub = ROCM_HIPCUB(at_cuda_detail::cub);
196
197 CUDA_KERNEL_ASSERT(blockDim.x == C10_WARP_SIZE);
198 CUDA_KERNEL_ASSERT(blockDim.y <= max_block_dim_y);
199 constexpr int items_per_thread = sort_size / C10_WARP_SIZE;
200 static_assert(
201 items_per_thread * C10_WARP_SIZE == sort_size,
202 "sort_size must be a multiple of C10_WARP_SIZE");
203
204
205 using LoadKeys = cub::WarpLoad<K, items_per_thread, cub::WARP_LOAD_TRANSPOSE>;
206 using LoadValues = cub::WarpLoad<V, items_per_thread, cub::WARP_LOAD_TRANSPOSE>;
207 using Sort = cub::WarpMergeSort<K, items_per_thread, C10_WARP_SIZE, V>;
208 using StoreKeys = cub::WarpStore<K, items_per_thread, cub::WARP_STORE_TRANSPOSE>;
209 using StoreValues = cub::WarpStore<V, items_per_thread, cub::WARP_STORE_TRANSPOSE>;
210
211 __shared__ union {
212 typename LoadKeys::TempStorage load_keys;
213 typename LoadValues::TempStorage load_values;
214 typename Sort::TempStorage sort;
215 typename StoreKeys::TempStorage store_keys;
216 typename StoreValues::TempStorage store_values;
217 } tmp_storage[max_block_dim_y];
218
219 auto& warp_storage = tmp_storage[threadIdx.y];
220
221 // Load inputs
222 K local_keys[items_per_thread];
223 V local_values[items_per_thread];
224
225 const auto invalid_value = V{};
226 LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key);
227 WARP_SYNC();
228 LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value);
229 WARP_SYNC();
230
231 // Sort! We use stable sort to ensure that invalid values are never
232 // sorted before valid values. In testing it performed the same as
233 // .Sort, so there is no down-side.
234 Sort(warp_storage.sort).StableSort(
235 local_keys, local_values, comp, keySliceSize, invalid_key);
236 WARP_SYNC();
237
238 // Store outputs
239 StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize);
240 WARP_SYNC();
241 StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize);
242 }
243
244 #endif // HAS_WARP_MERGE_SORT()
245
246 template <int KeyDims, int ValueDims,
247 int block_size, int items_per_thread,
248 typename K, typename V, typename IndexType>
249 C10_LAUNCH_BOUNDS_1(block_size)
250 __global__ void
251 radixSortKVInPlace(at::cuda::detail::TensorInfo<K, IndexType> keys,
252 IndexType keySlices,
253 IndexType keySliceSize,
254 IndexType keySliceStride,
255 at::cuda::detail::TensorInfo<V, IndexType> values,
256 IndexType valueSliceStride,
257 bool descending) {
258 static_assert(block_size > 0, "");
259
260 // Find the slice of the tensor that we are sorting
261 const IndexType linearIndex = getLinearBlockId<IndexType>();
262 // Tiling the slices could have us be out of bounds, if there are a
263 // lot of slices to sort
264 if (linearIndex >= keySlices) {
265 return;
266 }
267
268 const IndexType keyStartOffset =
269 at::cuda::detail::IndexToOffset<K, IndexType, KeyDims>::get(linearIndex, keys);
270 const IndexType valueStartOffset =
271 at::cuda::detail::IndexToOffset<V, IndexType, ValueDims>::get(linearIndex, values);
272
273 K *keys_slice = &keys.data[keyStartOffset];
274 V *values_slice = &values.data[valueStartOffset];
275
276 StridedRandomAccessor<K, IndexType> keys_iter(keys_slice, keySliceStride);
277 StridedRandomAccessor<V, IndexType> values_iter(values_slice, valueSliceStride);
278
279 namespace cub = ROCM_HIPCUB(at_cuda_detail::cub);
280
281 using key_t = typename at::cuda::cub::detail::cuda_type<K>::type;
282 using LoadKeys = cub::BlockLoad<K, block_size, items_per_thread,
283 cub::BlockLoadAlgorithm::BLOCK_LOAD_TRANSPOSE>;
284 using LoadValues = cub::BlockLoad<V, block_size, items_per_thread,
285 cub::BlockLoadAlgorithm::BLOCK_LOAD_TRANSPOSE>;
286 using Sort = cub::BlockRadixSort<key_t, block_size, items_per_thread, V>;
287 using StoreKeys = cub::BlockStore<K, block_size, items_per_thread,
288 cub::BLOCK_STORE_TRANSPOSE>;
289 using StoreValues = cub::BlockStore<V, block_size, items_per_thread,
290 cub::BLOCK_STORE_TRANSPOSE>;
291
292 __shared__ union {
293 typename LoadKeys::TempStorage load_keys;
294 typename LoadValues::TempStorage load_values;
295 typename Sort::TempStorage sort;
296 typename StoreKeys::TempStorage store_keys;
297 typename StoreValues::TempStorage store_values;
298 } tmp_storage;
299
300 // cub's Block operations operate on a fixed number of items, but the
301 // actual slice we are sorting might be smaller. So, we need to make
302 // up the difference with keys that will always sort higher.
303 const K invalid_key = [descending] {
304 using radix_t = typename cub::Traits<key_t>::UnsignedBits;
305 union {
306 K key;
307 radix_t radix;
308 } tmp;
309 tmp.radix = descending ?
310 cub::Traits<key_t>::LOWEST_KEY :
311 cub::Traits<key_t>::MAX_KEY;
312 return tmp.key;
313 }();
314 const V invalid_value = static_cast<V>(0);
315
316 // Load inputs
317 K local_keys[items_per_thread];
318 V local_values[items_per_thread];
319
320 LoadKeys(tmp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key);
321 __syncthreads();
322 LoadValues(tmp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value);
323 __syncthreads();
324
325 // Sort!
326 if (descending) {
327 Sort(tmp_storage.sort).SortDescending(
328 reinterpret_cast<key_t (&)[items_per_thread]>(local_keys),
329 local_values);
330 } else {
331 Sort(tmp_storage.sort).Sort(
332 reinterpret_cast<key_t (&)[items_per_thread]>(local_keys),
333 local_values);
334 }
335 __syncthreads();
336
337 // Store outputs
338 StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize);
339 __syncthreads();
340 StoreValues(tmp_storage.store_values).Store(values_iter, local_values, keySliceSize);
341 }
342
343 }} // at::native
344