xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/SortUtils.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/macros/Macros.h>
3 
4 #include <ATen/cuda/cub.cuh>
5 #include <ATen/cuda/detail/TensorInfo.cuh>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <ATen/cuda/DeviceUtils.cuh>
8 #include <ATen/native/cuda/SortingCommon.cuh>
9 #include <ATen/native/cuda/Sort.h>
10 #include <ATen/native/StridedRandomAccessor.h>
11 
12 #define HAS_WARP_MERGE_SORT() (CUDA_VERSION >= 110600)
13 
14 
15 namespace at { namespace native {
16 
17 template <typename T>
swapVars(T & t1,T & t2)18 __device__ inline void swapVars(T& t1, T& t2) {
19   T tmp = t1;
20   t1 = t2;
21   t2 = tmp;
22 }
23 
24 template <typename Comparator, typename K, typename V>
bitonicSwap(K & kA,V & vA,bool & validA,K & kB,V & vB,bool & validB,bool dir,const Comparator & comp)25 __device__ inline void bitonicSwap(K& kA, V& vA, bool& validA,
26                                    K& kB, V& vB, bool& validB,
27                                    bool dir,
28                                    const Comparator& comp) {
29   // Invalid entries always sort to the end
30   bool swap = (comp(kA, kB) && validA) || !validB;
31   if (swap == dir) {
32     swapVars(kA, kB);
33     swapVars(vA, vB);
34     swapVars(validA, validB);
35   }
36 };
37 
38 template <int Power2SortSize, typename IndexType, typename Comparator,
39           typename K, typename V>
bitonicSort(K * keys,V * values,bool * valid,const Comparator & comp)40 __device__ inline void bitonicSort(K *keys,
41                                    V *values,
42                                    bool *valid,
43                                    const Comparator& comp) {
44 #if !defined(USE_ROCM)
45 #pragma unroll
46 #endif
47   for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
48     bool flag = ((threadIdx.x & (size / 2)) != 0);
49 
50 #if !defined(USE_ROCM)
51 #pragma unroll
52 #endif
53     for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
54 
55       __syncthreads();
56 
57       unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
58       bitonicSwap<Comparator, K, V>(
59         keys[pos], values[pos], valid[pos],
60         keys[pos + stride], values[pos + stride], valid[pos + stride],
61         flag, comp);
62     }
63   }
64 
65 #if !defined(USE_ROCM)
66 #pragma unroll
67 #endif
68   for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
69 
70     __syncthreads();
71 
72     unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
73     bitonicSwap<Comparator, K, V>(
74       keys[pos], values[pos], valid[pos],
75       keys[pos + stride], values[pos + stride], valid[pos + stride],
76       false, comp);
77   }
78 
79   __syncthreads();
80 
81 }
82 
83 // at::cuda::detail::TensorInfo version
84 // Sorts (key, value) pairs (in different tensors) in-place; i.e.,
85 // modifies the input `keys` and `values`
86 template <int KeyDims, int ValueDims, int block_dim_x, int max_block_dim_y,
87           typename K, typename V, typename Comparator, typename IndexType>
C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y)88 C10_LAUNCH_BOUNDS_1(block_dim_x * max_block_dim_y)
89 __global__ void
90 bitonicSortKVInPlace(at::cuda::detail::TensorInfo<K, IndexType> keys,
91                      IndexType keySlices,
92                      IndexType keySliceSize,
93                      IndexType keySliceStride,
94                      at::cuda::detail::TensorInfo<V, IndexType> values,
95                      IndexType valueSliceStride,
96                      Comparator comp) {
97   // Find the slice of the tensor that we are sorting
98   // NOTE: blockDim.y may be less max_block_dim_y
99   const IndexType blockIndex = getLinearBlockId<IndexType>();
100   const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y;
101 
102   // If the entire block is out of bounds exit early
103   if (blockIndex * blockDim.y >= keySlices) {
104     return;
105   }
106   // It's also possible for some rows of a block to be out of bounds
107   // but all thread need to run for __syncthreads to work.
108   const bool row_valid = linearIndex < keySlices;
109 
110   constexpr int items_per_thread = 2;
111   constexpr int Power2SortSize = block_dim_x * items_per_thread;
112 
113   // Storage for max_block_dim_y sorts performed in parallel
114   __shared__ K blockSharedKeys[max_block_dim_y][Power2SortSize];
115   __shared__ V blockSharedValues[max_block_dim_y][Power2SortSize];
116   __shared__ bool blockSharedValid[max_block_dim_y][Power2SortSize];
117 
118   auto sharedKeys = blockSharedKeys[threadIdx.y];
119   auto sharedValues = blockSharedValues[threadIdx.y];
120   auto sharedValid = blockSharedValid[threadIdx.y];
121 
122   const IndexType keyStartOffset =
123     at::cuda::detail::IndexToOffset<K, IndexType, KeyDims>::get(linearIndex, keys);
124   const IndexType valueStartOffset =
125     at::cuda::detail::IndexToOffset<V, IndexType, ValueDims>::get(linearIndex, values);
126 
127   // Load 2 values per thread into the shared workspace
128   #pragma unroll
129   for (int k = 0; k < items_per_thread; ++k) {
130     auto idx = threadIdx.x + k * blockDim.x;
131     bool valid = row_valid && idx < keySliceSize;
132 
133     sharedKeys[idx] = valid ?
134         keys.data[idx * keySliceStride + keyStartOffset] : K{};
135     sharedValues[idx] = valid ?
136         values.data[idx * valueSliceStride + valueStartOffset] : V{};
137     sharedValid[idx] = valid;
138   }
139 
140   // Sort!
141   bitonicSort<Power2SortSize, IndexType>(
142       sharedKeys, sharedValues, sharedValid, comp);
143 
144   if (!row_valid) {
145     return;
146   }
147 
148   // Store outputs
149   #pragma unroll
150   for (int k = 0; k < items_per_thread; ++k) {
151     auto idx = threadIdx.x + k * blockDim.x;
152     if (idx < keySliceSize) {
153       keys.data[idx * keySliceStride + keyStartOffset] = sharedKeys[idx];
154       values.data[idx * valueSliceStride + valueStartOffset] = sharedValues[idx];
155     }
156   }
157 }
158 
159 #if HAS_WARP_MERGE_SORT()
160 
161 template <int KeyDims, int ValueDims, int sort_size, int max_block_dim_y,
162           typename K, typename V, typename Comparator, typename IndexType>
C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y)163 C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y)
164 __global__ void
165 warpMergeSortKVInPlace(
166     at::cuda::detail::TensorInfo<K, IndexType> keys,
167     IndexType keySlices,
168     IndexType keySliceSize,
169     IndexType keySliceStride,
170     at::cuda::detail::TensorInfo<V, IndexType> values,
171     IndexType valueSliceStride,
172     Comparator comp,
173     K invalid_key) {
174   // Find the slice of the tensor that we are sorting
175   // NOTE: blockDim.y may be less max_block_dim_y
176   const IndexType blockIndex = getLinearBlockId<IndexType>();
177   const IndexType linearIndex = blockIndex * blockDim.y + threadIdx.y;
178 
179   // If this row is out of bounds exit early
180   if (linearIndex >= keySlices) {
181     return;
182   }
183 
184   const IndexType keyStartOffset =
185     at::cuda::detail::IndexToOffset<K, IndexType, KeyDims>::get(linearIndex, keys);
186   const IndexType valueStartOffset =
187     at::cuda::detail::IndexToOffset<V, IndexType, ValueDims>::get(linearIndex, values);
188 
189   K *keys_slice = &keys.data[keyStartOffset];
190   V *values_slice = &values.data[valueStartOffset];
191 
192   StridedRandomAccessor<K, IndexType> keys_iter(keys_slice, keySliceStride);
193   StridedRandomAccessor<V, IndexType> values_iter(values_slice, valueSliceStride);
194 
195   namespace cub = ROCM_HIPCUB(at_cuda_detail::cub);
196 
197   CUDA_KERNEL_ASSERT(blockDim.x == C10_WARP_SIZE);
198   CUDA_KERNEL_ASSERT(blockDim.y <= max_block_dim_y);
199   constexpr int items_per_thread = sort_size / C10_WARP_SIZE;
200   static_assert(
201       items_per_thread * C10_WARP_SIZE == sort_size,
202       "sort_size must be a multiple of C10_WARP_SIZE");
203 
204 
205   using LoadKeys = cub::WarpLoad<K, items_per_thread, cub::WARP_LOAD_TRANSPOSE>;
206   using LoadValues = cub::WarpLoad<V, items_per_thread, cub::WARP_LOAD_TRANSPOSE>;
207   using Sort = cub::WarpMergeSort<K, items_per_thread, C10_WARP_SIZE, V>;
208   using StoreKeys = cub::WarpStore<K, items_per_thread, cub::WARP_STORE_TRANSPOSE>;
209   using StoreValues = cub::WarpStore<V, items_per_thread, cub::WARP_STORE_TRANSPOSE>;
210 
211   __shared__ union {
212     typename LoadKeys::TempStorage load_keys;
213     typename LoadValues::TempStorage load_values;
214     typename Sort::TempStorage sort;
215     typename StoreKeys::TempStorage store_keys;
216     typename StoreValues::TempStorage store_values;
217   } tmp_storage[max_block_dim_y];
218 
219   auto& warp_storage = tmp_storage[threadIdx.y];
220 
221   // Load inputs
222   K local_keys[items_per_thread];
223   V local_values[items_per_thread];
224 
225   const auto invalid_value = V{};
226   LoadKeys(warp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key);
227   WARP_SYNC();
228   LoadValues(warp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value);
229   WARP_SYNC();
230 
231   // Sort! We use stable sort to ensure that invalid values are never
232   // sorted before valid values. In testing it performed the same as
233   // .Sort, so there is no down-side.
234   Sort(warp_storage.sort).StableSort(
235       local_keys, local_values, comp, keySliceSize, invalid_key);
236   WARP_SYNC();
237 
238   // Store outputs
239   StoreKeys(warp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize);
240   WARP_SYNC();
241   StoreValues(warp_storage.store_values).Store(values_iter, local_values, keySliceSize);
242 }
243 
244 #endif // HAS_WARP_MERGE_SORT()
245 
246 template <int KeyDims, int ValueDims,
247           int block_size, int items_per_thread,
248           typename K, typename V, typename IndexType>
249 C10_LAUNCH_BOUNDS_1(block_size)
250 __global__ void
251 radixSortKVInPlace(at::cuda::detail::TensorInfo<K, IndexType> keys,
252                    IndexType keySlices,
253                    IndexType keySliceSize,
254                    IndexType keySliceStride,
255                    at::cuda::detail::TensorInfo<V, IndexType> values,
256                    IndexType valueSliceStride,
257                    bool descending) {
258   static_assert(block_size > 0, "");
259 
260   // Find the slice of the tensor that we are sorting
261   const IndexType linearIndex = getLinearBlockId<IndexType>();
262   // Tiling the slices could have us be out of bounds, if there are a
263   // lot of slices to sort
264   if (linearIndex >= keySlices) {
265     return;
266   }
267 
268   const IndexType keyStartOffset =
269     at::cuda::detail::IndexToOffset<K, IndexType, KeyDims>::get(linearIndex, keys);
270   const IndexType valueStartOffset =
271     at::cuda::detail::IndexToOffset<V, IndexType, ValueDims>::get(linearIndex, values);
272 
273   K *keys_slice = &keys.data[keyStartOffset];
274   V *values_slice = &values.data[valueStartOffset];
275 
276   StridedRandomAccessor<K, IndexType> keys_iter(keys_slice, keySliceStride);
277   StridedRandomAccessor<V, IndexType> values_iter(values_slice, valueSliceStride);
278 
279   namespace cub = ROCM_HIPCUB(at_cuda_detail::cub);
280 
281   using key_t = typename at::cuda::cub::detail::cuda_type<K>::type;
282   using LoadKeys = cub::BlockLoad<K, block_size, items_per_thread,
283                                   cub::BlockLoadAlgorithm::BLOCK_LOAD_TRANSPOSE>;
284   using LoadValues = cub::BlockLoad<V, block_size, items_per_thread,
285                                     cub::BlockLoadAlgorithm::BLOCK_LOAD_TRANSPOSE>;
286   using Sort = cub::BlockRadixSort<key_t, block_size, items_per_thread, V>;
287   using StoreKeys = cub::BlockStore<K, block_size, items_per_thread,
288                                     cub::BLOCK_STORE_TRANSPOSE>;
289   using StoreValues = cub::BlockStore<V, block_size, items_per_thread,
290                                       cub::BLOCK_STORE_TRANSPOSE>;
291 
292   __shared__ union {
293     typename LoadKeys::TempStorage load_keys;
294     typename LoadValues::TempStorage load_values;
295     typename Sort::TempStorage sort;
296     typename StoreKeys::TempStorage store_keys;
297     typename StoreValues::TempStorage store_values;
298   } tmp_storage;
299 
300   // cub's Block operations operate on a fixed number of items, but the
301   // actual slice we are sorting might be smaller. So, we need to make
302   // up the difference with keys that will always sort higher.
303   const K invalid_key = [descending] {
304     using radix_t = typename cub::Traits<key_t>::UnsignedBits;
305     union {
306       K key;
307       radix_t radix;
308     } tmp;
309     tmp.radix = descending ?
310         cub::Traits<key_t>::LOWEST_KEY :
311         cub::Traits<key_t>::MAX_KEY;
312     return tmp.key;
313   }();
314   const V invalid_value = static_cast<V>(0);
315 
316   // Load inputs
317   K local_keys[items_per_thread];
318   V local_values[items_per_thread];
319 
320   LoadKeys(tmp_storage.load_keys).Load(keys_iter, local_keys, keySliceSize, invalid_key);
321   __syncthreads();
322   LoadValues(tmp_storage.load_values).Load(values_iter, local_values, keySliceSize, invalid_value);
323   __syncthreads();
324 
325   // Sort!
326   if (descending) {
327     Sort(tmp_storage.sort).SortDescending(
328         reinterpret_cast<key_t (&)[items_per_thread]>(local_keys),
329         local_values);
330   } else {
331     Sort(tmp_storage.sort).Sort(
332         reinterpret_cast<key_t (&)[items_per_thread]>(local_keys),
333         local_values);
334   }
335   __syncthreads();
336 
337   // Store outputs
338   StoreKeys(tmp_storage.store_keys).Store(keys_iter, local_keys, keySliceSize);
339   __syncthreads();
340   StoreValues(tmp_storage.store_values).Store(values_iter, local_values, keySliceSize);
341 }
342 
343 }} // at::native
344