xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/SortingRadixSelect.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ceil_div.h>
2 #include <ATen/cuda/Atomic.cuh>
3 #include <ATen/cuda/DeviceUtils.cuh>
4 #include <ATen/cuda/AsmUtils.cuh>
5 #include <c10/macros/Macros.h>
6 
7 namespace at {
8 namespace native {
9 
10 template <typename scalar_t>
11 struct TopKTypeConfig {};
12 
13 template <>
14 struct TopKTypeConfig<float> {
15   typedef uint32_t RadixType;
16 
17   // Converts a float to an integer representation with the same
18   // sorting; i.e., for floats f1, f2:
19   // if f1 < f2 then convert(f1) < convert(f2)
20   // We use this to enable radix selection of floating-point values.
21   // This also gives a relative order for NaNs, but that's ok, as they
22   // will all be adjacent
23   // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff..
24   // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00..
25   // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0
26   // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x<ff...
convertat::native::TopKTypeConfig27   static inline __device__ RadixType convert(float v) {
28     RadixType x = __float_as_int(v);
29     RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
30 
31     return (v == v) ? (x ^ mask) : 0xffffffff;
32   }
33 
deconvertat::native::TopKTypeConfig34   static inline __device__ float deconvert(RadixType v) {
35     RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
36 
37     return __int_as_float(v ^ mask);
38   }
39 };
40 
41 template <>
42 struct TopKTypeConfig<uint8_t> {
43   typedef uint32_t RadixType;
44 
convertat::native::TopKTypeConfig45   static inline __device__ RadixType convert(uint8_t v) {
46     return v;
47   }
48 
deconvertat::native::TopKTypeConfig49   static inline __device__ uint8_t deconvert(RadixType v) {
50     return v;
51   }
52 };
53 
54 template <>
55 struct TopKTypeConfig<int8_t> {
56   typedef uint32_t RadixType;
57 
convertat::native::TopKTypeConfig58   static inline __device__ RadixType convert(int8_t v) {
59     return 128u + v;
60   }
61 
deconvertat::native::TopKTypeConfig62   static inline __device__ int8_t deconvert(RadixType v) {
63     return v - 128;
64   }
65 };
66 
67 template <>
68 struct TopKTypeConfig<int16_t> {
69   typedef uint32_t RadixType;
70 
convertat::native::TopKTypeConfig71   static inline __device__ RadixType convert(int16_t v) {
72     static_assert(sizeof(short) == 2, "");
73     return 32768u + v;
74   }
75 
deconvertat::native::TopKTypeConfig76   static inline __device__ int16_t deconvert(RadixType v) {
77     return v - 32768;
78   }
79 };
80 
81 template <>
82 struct TopKTypeConfig<int32_t> {
83   typedef uint32_t RadixType;
84 
convertat::native::TopKTypeConfig85   static inline __device__ RadixType convert(int32_t v) {
86     static_assert(sizeof(int) == 4, "");
87     return 2147483648u + v;
88   }
89 
deconvertat::native::TopKTypeConfig90   static inline __device__ int32_t deconvert(RadixType v) {
91     return v - 2147483648u;
92   }
93 };
94 
95 template <>
96 struct TopKTypeConfig<int64_t> {
97   typedef uint64_t RadixType;
98 
convertat::native::TopKTypeConfig99   static inline __device__ RadixType convert(int64_t v) {
100     static_assert(sizeof(int64_t) == 8, "");
101     return 9223372036854775808ull + v;
102   }
103 
deconvertat::native::TopKTypeConfig104   static inline __device__ int64_t deconvert(RadixType v) {
105     return v - 9223372036854775808ull;
106   }
107 };
108 
109 template <>
110 struct TopKTypeConfig<double> {
111   typedef uint64_t RadixType;
112 
convertat::native::TopKTypeConfig113   static inline __device__ RadixType convert(double v) {
114     RadixType x = __double_as_longlong(v);
115     RadixType mask = -((x >> 63)) | 0x8000000000000000;
116     return (v == v) ? (x ^ mask) : 0xffffffffffffffff;
117   }
118 
deconvertat::native::TopKTypeConfig119   static inline __device__ double deconvert(RadixType v) {
120     RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
121     return __longlong_as_double(v ^ mask);
122   }
123 };
124 
125 template <>
126 struct TopKTypeConfig<at::Half> {
127   typedef uint32_t RadixType;
128 
convertat::native::TopKTypeConfig129   static inline __device__ RadixType convert(at::Half v) {
130 #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
131     RadixType x = __half_as_ushort(v);
132     RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
133     return (v == v) ? (x ^ mask) : 0xffff;
134 #else
135     CUDA_KERNEL_ASSERT(false);
136     return 0u;
137 #endif
138   }
139 
deconvertat::native::TopKTypeConfig140   static inline __device__ at::Half deconvert(RadixType v) {
141 #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
142     RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
143     return __ushort_as_half(v ^ mask);
144 #else
145     CUDA_KERNEL_ASSERT(false);
146     return static_cast<at::Half>(0);
147 #endif
148   }
149 };
150 
151 template <>
152 struct TopKTypeConfig<at::BFloat16> {
153   typedef uint32_t RadixType;
154 
convertat::native::TopKTypeConfig155   static inline __device__ RadixType convert(at::BFloat16 v) {
156     RadixType x = v.x;
157     RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
158     return (v == v) ? (x ^ mask) : 0xffff;
159   }
160 
deconvertat::native::TopKTypeConfig161   static inline __device__ at::BFloat16 deconvert(RadixType v) {
162     RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
163     at::BFloat16 r;
164     r.x = (v ^ mask);
165     return r;
166   }
167 };
168 
169 // This function counts the distribution of all input values in a
170 // slice we are selecting by radix digit at `radixDigitPos`, but only
171 // those that pass the filter `((v & desiredMask) == desired)`.
172 // This produces and broadcasts the seen counts for a single block only.
173 // `smem` must have at least `RadixSize` elements.
174 template <
175     typename scalar_t,
176     typename bitwise_t,
177     typename index_t,
178     typename CountType,
179     int RadixSize,
180     int RadixBits>
countRadixUsingMask(CountType counts[RadixSize],CountType * smem,bitwise_t desired,bitwise_t desiredMask,int radixDigitPos,index_t sliceSize,index_t withinSliceStride,const scalar_t * data)181 __device__ void countRadixUsingMask(
182     CountType counts[RadixSize],
183     CountType* smem,
184     bitwise_t desired,
185     bitwise_t desiredMask,
186     int radixDigitPos,
187     index_t sliceSize,
188     index_t withinSliceStride,
189     const scalar_t* data) {
190   // Clear out per-thread counts from a previous round
191 #pragma unroll
192   for (int i = 0; i < RadixSize; ++i) {
193     counts[i] = 0;
194   }
195 
196   if (threadIdx.x < RadixSize) {
197     smem[threadIdx.x] = 0;
198   }
199   __syncthreads();
200 
201   // Scan over all the data. Upon a read, the warp will accumulate
202   // counts per each digit in the radix using warp voting.
203 #if !defined(USE_ROCM)
204   // Must be called outside of loop to ensure all threads participate
205   unsigned mask = WARP_BALLOT(threadIdx.x < sliceSize);
206 #endif
207   for (index_t i = threadIdx.x; i < sliceSize;) {
208     bitwise_t val =
209         TopKTypeConfig<scalar_t>::convert(doLdg(&data[i * withinSliceStride]));
210 
211     bool hasVal = ((val & desiredMask) == desired);
212     bitwise_t digitInRadix = at::cuda::Bitfield<bitwise_t>::getBitfield(
213         val, radixDigitPos, RadixBits);
214 
215 #pragma unroll
216     for (uint32_t j = 0; j < RadixSize; ++j) {
217       bool vote = hasVal && (digitInRadix == j);
218 #if defined(USE_ROCM)
219       counts[j] += __popcll(WARP_BALLOT(vote));
220 #else
221       counts[j] += __popc(WARP_BALLOT(vote, mask));
222 #endif
223     }
224     i += blockDim.x;
225 #if !defined(USE_ROCM)
226     mask = WARP_BALLOT(i < sliceSize, mask);
227 #endif
228   }
229 
230   // Now, for each warp, sum values
231   if (at::cuda::getLaneId() == 0) {
232 #pragma unroll
233     for (uint32_t i = 0; i < RadixSize; ++i) {
234       gpuAtomicAddNoReturn(&smem[i], counts[i]);
235     }
236   }
237 
238   __syncthreads();
239 
240   // For each thread, read in the total counts
241 #pragma unroll
242   for (uint32_t i = 0; i < RadixSize; ++i) {
243     counts[i] = smem[i];
244   }
245 
246   __syncthreads();
247 }
248 
249 // Over what radix we are selecting values
250 constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS)
251 constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS
252 constexpr int RADIX_MASK = (RADIX_SIZE - 1);
253 
254 // This finds the unique value `v` that matches the pattern
255 // ((v & desired) == desiredMask) in our sorted int format
256 template <typename scalar_t, typename bitwise_t, typename index_t>
findPattern(scalar_t * smem,const scalar_t * data,index_t sliceSize,index_t withinSliceStride,bitwise_t desired,bitwise_t desiredMask)257 __device__ scalar_t findPattern(
258     scalar_t* smem,
259     const scalar_t* data,
260     index_t sliceSize,
261     index_t withinSliceStride,
262     bitwise_t desired,
263     bitwise_t desiredMask) {
264   if (threadIdx.x < 2) {
265     smem[threadIdx.x] = static_cast<scalar_t>(0);
266   }
267   __syncthreads();
268 
269   // All threads participate in the loop, in order to sync on the flag
270   index_t numIterations =
271       round_up(sliceSize, static_cast<index_t>(blockDim.x));
272   for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) {
273     bool inRange = (i < sliceSize);
274     scalar_t v = inRange ? doLdg(&data[i * withinSliceStride])
275                          : static_cast<scalar_t>(0);
276 
277     if (inRange &&
278         ((TopKTypeConfig<scalar_t>::convert(v) & desiredMask) == desired)) {
279       // There should not be conflicts if we are using findPattern,
280       // since the result is unique
281       smem[0] = static_cast<scalar_t>(1);
282       smem[1] = v; // can't use val as the flag, since it could be 0
283     }
284 
285     __syncthreads();
286 
287     scalar_t found = smem[0];
288     scalar_t val = smem[1];
289 
290     __syncthreads();
291 
292     // Check to see if a thread found the value
293     if (found != static_cast<scalar_t>(0)) {
294       // all threads return this value
295       return val;
296     }
297   }
298 
299   // should not get here
300   CUDA_KERNEL_ASSERT(false);
301   return static_cast<scalar_t>(0);
302 }
303 
304 // Returns the top-Kth element found in the data using radix selection
305 template <typename scalar_t, typename bitwise_t, typename index_t>
radixSelect(const scalar_t * data,index_t k,bool largest,index_t sliceSize,index_t withinSliceStride,int * smem,scalar_t * topK)306 __device__ void radixSelect(
307     const scalar_t* data,
308     index_t k,
309     bool largest,
310     index_t sliceSize,
311     index_t withinSliceStride,
312     int* smem,
313     scalar_t* topK) {
314   // Per-thread buckets into which we accumulate digit counts in our
315   // radix
316   int counts[RADIX_SIZE];
317 
318   // We only consider elements x such that (x & desiredMask) == desired
319   // Initially, we consider all elements of the array, so the above
320   // statement is true regardless of input.
321   bitwise_t desired = 0;
322   bitwise_t desiredMask = 0;
323 
324   // We are looking for the top kToFind-th element when iterating over
325   // digits; this count gets reduced by elimination when counting
326   // successive digits
327   int kToFind = k;
328 
329   // We start at the most significant digit in our radix, scanning
330   // through to the least significant digit
331   for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0;
332        digitPos -= RADIX_BITS) {
333     // Count radix distribution for the current position and reduce
334     // across all threads
335     countRadixUsingMask<
336         scalar_t,
337         bitwise_t,
338         index_t,
339         int,
340         RADIX_SIZE,
341         RADIX_BITS>(
342         counts,
343         smem,
344         desired,
345         desiredMask,
346         digitPos,
347         sliceSize,
348         withinSliceStride,
349         data);
350 
351     auto found_unique = [&](int i, int count) -> bool {
352       /* All threads have the same value in counts here, so all */
353       /* threads will return from the function. */
354       if (count == 1 && kToFind == 1) {
355         /* There is a unique answer. */
356         desired = at::cuda::Bitfield<bitwise_t>::setBitfield(
357             desired, i, digitPos, RADIX_BITS);
358         desiredMask = at::cuda::Bitfield<bitwise_t>::setBitfield(
359             desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
360 
361         /* The answer is now the unique element v such that: */
362         /* (v & desiredMask) == desired */
363         /* However, we do not yet know what the actual element is. We */
364         /* need to perform a search through the data to find the */
365         /* element that matches this pattern. */
366         *topK = findPattern<scalar_t, bitwise_t, index_t>(
367             (scalar_t*)smem,
368             data,
369             sliceSize,
370             withinSliceStride,
371             desired,
372             desiredMask);
373         return true;
374       }
375       return false;
376     };
377     auto found_non_unique = [&](int i, int count) -> bool {
378       if (count >= kToFind) {
379         desired =
380             at::cuda::Bitfield<bitwise_t>::setBitfield(
381                 desired, i, digitPos, RADIX_BITS);
382         desiredMask = at::cuda::Bitfield<bitwise_t>::setBitfield(
383             desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
384 
385         /* The top-Kth element v must now be one such that: */
386         /* (v & desiredMask == desired) */
387         /* but we haven't narrowed it down; we must check the next */
388         /* least-significant digit */
389         return true;
390       }
391       kToFind -= count;
392       return false; // continue the loop
393     };
394 
395     // All threads participate in the comparisons below to know the
396     // final result
397     if (largest) {
398       // Process in descending order
399 #pragma unroll
400       for (int i = RADIX_SIZE - 1; i >= 0; --i) {
401         int count = counts[i];
402         if (found_unique(i, count)) {
403           return;
404         }
405         if (found_non_unique(i, count)) {
406           break;
407         }
408       }
409     } else {
410       // Process in ascending order
411 #pragma unroll
412       for (int i = 0; i < RADIX_SIZE; ++i) {
413         int count = counts[i];
414         if (found_unique(i, count)) {
415           return;
416         }
417         if (found_non_unique(i, count)) {
418           break;
419         }
420       }
421     }
422   } // end digitPos for
423 
424   // There is no unique result, but there is a non-unique result
425   // matching `desired` exactly
426   *topK = TopKTypeConfig<scalar_t>::deconvert(desired);
427 }
428 } // namespace native
429 } // namespace at
430