1 #include <ATen/ceil_div.h>
2 #include <ATen/cuda/Atomic.cuh>
3 #include <ATen/cuda/DeviceUtils.cuh>
4 #include <ATen/cuda/AsmUtils.cuh>
5 #include <c10/macros/Macros.h>
6
7 namespace at {
8 namespace native {
9
10 template <typename scalar_t>
11 struct TopKTypeConfig {};
12
13 template <>
14 struct TopKTypeConfig<float> {
15 typedef uint32_t RadixType;
16
17 // Converts a float to an integer representation with the same
18 // sorting; i.e., for floats f1, f2:
19 // if f1 < f2 then convert(f1) < convert(f2)
20 // We use this to enable radix selection of floating-point values.
21 // This also gives a relative order for NaNs, but that's ok, as they
22 // will all be adjacent
23 // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff..
24 // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00..
25 // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0
26 // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x<ff...
convertat::native::TopKTypeConfig27 static inline __device__ RadixType convert(float v) {
28 RadixType x = __float_as_int(v);
29 RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
30
31 return (v == v) ? (x ^ mask) : 0xffffffff;
32 }
33
deconvertat::native::TopKTypeConfig34 static inline __device__ float deconvert(RadixType v) {
35 RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
36
37 return __int_as_float(v ^ mask);
38 }
39 };
40
41 template <>
42 struct TopKTypeConfig<uint8_t> {
43 typedef uint32_t RadixType;
44
convertat::native::TopKTypeConfig45 static inline __device__ RadixType convert(uint8_t v) {
46 return v;
47 }
48
deconvertat::native::TopKTypeConfig49 static inline __device__ uint8_t deconvert(RadixType v) {
50 return v;
51 }
52 };
53
54 template <>
55 struct TopKTypeConfig<int8_t> {
56 typedef uint32_t RadixType;
57
convertat::native::TopKTypeConfig58 static inline __device__ RadixType convert(int8_t v) {
59 return 128u + v;
60 }
61
deconvertat::native::TopKTypeConfig62 static inline __device__ int8_t deconvert(RadixType v) {
63 return v - 128;
64 }
65 };
66
67 template <>
68 struct TopKTypeConfig<int16_t> {
69 typedef uint32_t RadixType;
70
convertat::native::TopKTypeConfig71 static inline __device__ RadixType convert(int16_t v) {
72 static_assert(sizeof(short) == 2, "");
73 return 32768u + v;
74 }
75
deconvertat::native::TopKTypeConfig76 static inline __device__ int16_t deconvert(RadixType v) {
77 return v - 32768;
78 }
79 };
80
81 template <>
82 struct TopKTypeConfig<int32_t> {
83 typedef uint32_t RadixType;
84
convertat::native::TopKTypeConfig85 static inline __device__ RadixType convert(int32_t v) {
86 static_assert(sizeof(int) == 4, "");
87 return 2147483648u + v;
88 }
89
deconvertat::native::TopKTypeConfig90 static inline __device__ int32_t deconvert(RadixType v) {
91 return v - 2147483648u;
92 }
93 };
94
95 template <>
96 struct TopKTypeConfig<int64_t> {
97 typedef uint64_t RadixType;
98
convertat::native::TopKTypeConfig99 static inline __device__ RadixType convert(int64_t v) {
100 static_assert(sizeof(int64_t) == 8, "");
101 return 9223372036854775808ull + v;
102 }
103
deconvertat::native::TopKTypeConfig104 static inline __device__ int64_t deconvert(RadixType v) {
105 return v - 9223372036854775808ull;
106 }
107 };
108
109 template <>
110 struct TopKTypeConfig<double> {
111 typedef uint64_t RadixType;
112
convertat::native::TopKTypeConfig113 static inline __device__ RadixType convert(double v) {
114 RadixType x = __double_as_longlong(v);
115 RadixType mask = -((x >> 63)) | 0x8000000000000000;
116 return (v == v) ? (x ^ mask) : 0xffffffffffffffff;
117 }
118
deconvertat::native::TopKTypeConfig119 static inline __device__ double deconvert(RadixType v) {
120 RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
121 return __longlong_as_double(v ^ mask);
122 }
123 };
124
125 template <>
126 struct TopKTypeConfig<at::Half> {
127 typedef uint32_t RadixType;
128
convertat::native::TopKTypeConfig129 static inline __device__ RadixType convert(at::Half v) {
130 #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
131 RadixType x = __half_as_ushort(v);
132 RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
133 return (v == v) ? (x ^ mask) : 0xffff;
134 #else
135 CUDA_KERNEL_ASSERT(false);
136 return 0u;
137 #endif
138 }
139
deconvertat::native::TopKTypeConfig140 static inline __device__ at::Half deconvert(RadixType v) {
141 #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
142 RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
143 return __ushort_as_half(v ^ mask);
144 #else
145 CUDA_KERNEL_ASSERT(false);
146 return static_cast<at::Half>(0);
147 #endif
148 }
149 };
150
151 template <>
152 struct TopKTypeConfig<at::BFloat16> {
153 typedef uint32_t RadixType;
154
convertat::native::TopKTypeConfig155 static inline __device__ RadixType convert(at::BFloat16 v) {
156 RadixType x = v.x;
157 RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
158 return (v == v) ? (x ^ mask) : 0xffff;
159 }
160
deconvertat::native::TopKTypeConfig161 static inline __device__ at::BFloat16 deconvert(RadixType v) {
162 RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
163 at::BFloat16 r;
164 r.x = (v ^ mask);
165 return r;
166 }
167 };
168
169 // This function counts the distribution of all input values in a
170 // slice we are selecting by radix digit at `radixDigitPos`, but only
171 // those that pass the filter `((v & desiredMask) == desired)`.
172 // This produces and broadcasts the seen counts for a single block only.
173 // `smem` must have at least `RadixSize` elements.
174 template <
175 typename scalar_t,
176 typename bitwise_t,
177 typename index_t,
178 typename CountType,
179 int RadixSize,
180 int RadixBits>
countRadixUsingMask(CountType counts[RadixSize],CountType * smem,bitwise_t desired,bitwise_t desiredMask,int radixDigitPos,index_t sliceSize,index_t withinSliceStride,const scalar_t * data)181 __device__ void countRadixUsingMask(
182 CountType counts[RadixSize],
183 CountType* smem,
184 bitwise_t desired,
185 bitwise_t desiredMask,
186 int radixDigitPos,
187 index_t sliceSize,
188 index_t withinSliceStride,
189 const scalar_t* data) {
190 // Clear out per-thread counts from a previous round
191 #pragma unroll
192 for (int i = 0; i < RadixSize; ++i) {
193 counts[i] = 0;
194 }
195
196 if (threadIdx.x < RadixSize) {
197 smem[threadIdx.x] = 0;
198 }
199 __syncthreads();
200
201 // Scan over all the data. Upon a read, the warp will accumulate
202 // counts per each digit in the radix using warp voting.
203 #if !defined(USE_ROCM)
204 // Must be called outside of loop to ensure all threads participate
205 unsigned mask = WARP_BALLOT(threadIdx.x < sliceSize);
206 #endif
207 for (index_t i = threadIdx.x; i < sliceSize;) {
208 bitwise_t val =
209 TopKTypeConfig<scalar_t>::convert(doLdg(&data[i * withinSliceStride]));
210
211 bool hasVal = ((val & desiredMask) == desired);
212 bitwise_t digitInRadix = at::cuda::Bitfield<bitwise_t>::getBitfield(
213 val, radixDigitPos, RadixBits);
214
215 #pragma unroll
216 for (uint32_t j = 0; j < RadixSize; ++j) {
217 bool vote = hasVal && (digitInRadix == j);
218 #if defined(USE_ROCM)
219 counts[j] += __popcll(WARP_BALLOT(vote));
220 #else
221 counts[j] += __popc(WARP_BALLOT(vote, mask));
222 #endif
223 }
224 i += blockDim.x;
225 #if !defined(USE_ROCM)
226 mask = WARP_BALLOT(i < sliceSize, mask);
227 #endif
228 }
229
230 // Now, for each warp, sum values
231 if (at::cuda::getLaneId() == 0) {
232 #pragma unroll
233 for (uint32_t i = 0; i < RadixSize; ++i) {
234 gpuAtomicAddNoReturn(&smem[i], counts[i]);
235 }
236 }
237
238 __syncthreads();
239
240 // For each thread, read in the total counts
241 #pragma unroll
242 for (uint32_t i = 0; i < RadixSize; ++i) {
243 counts[i] = smem[i];
244 }
245
246 __syncthreads();
247 }
248
249 // Over what radix we are selecting values
250 constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS)
251 constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS
252 constexpr int RADIX_MASK = (RADIX_SIZE - 1);
253
254 // This finds the unique value `v` that matches the pattern
255 // ((v & desired) == desiredMask) in our sorted int format
256 template <typename scalar_t, typename bitwise_t, typename index_t>
findPattern(scalar_t * smem,const scalar_t * data,index_t sliceSize,index_t withinSliceStride,bitwise_t desired,bitwise_t desiredMask)257 __device__ scalar_t findPattern(
258 scalar_t* smem,
259 const scalar_t* data,
260 index_t sliceSize,
261 index_t withinSliceStride,
262 bitwise_t desired,
263 bitwise_t desiredMask) {
264 if (threadIdx.x < 2) {
265 smem[threadIdx.x] = static_cast<scalar_t>(0);
266 }
267 __syncthreads();
268
269 // All threads participate in the loop, in order to sync on the flag
270 index_t numIterations =
271 round_up(sliceSize, static_cast<index_t>(blockDim.x));
272 for (index_t i = threadIdx.x; i < numIterations; i += blockDim.x) {
273 bool inRange = (i < sliceSize);
274 scalar_t v = inRange ? doLdg(&data[i * withinSliceStride])
275 : static_cast<scalar_t>(0);
276
277 if (inRange &&
278 ((TopKTypeConfig<scalar_t>::convert(v) & desiredMask) == desired)) {
279 // There should not be conflicts if we are using findPattern,
280 // since the result is unique
281 smem[0] = static_cast<scalar_t>(1);
282 smem[1] = v; // can't use val as the flag, since it could be 0
283 }
284
285 __syncthreads();
286
287 scalar_t found = smem[0];
288 scalar_t val = smem[1];
289
290 __syncthreads();
291
292 // Check to see if a thread found the value
293 if (found != static_cast<scalar_t>(0)) {
294 // all threads return this value
295 return val;
296 }
297 }
298
299 // should not get here
300 CUDA_KERNEL_ASSERT(false);
301 return static_cast<scalar_t>(0);
302 }
303
304 // Returns the top-Kth element found in the data using radix selection
305 template <typename scalar_t, typename bitwise_t, typename index_t>
radixSelect(const scalar_t * data,index_t k,bool largest,index_t sliceSize,index_t withinSliceStride,int * smem,scalar_t * topK)306 __device__ void radixSelect(
307 const scalar_t* data,
308 index_t k,
309 bool largest,
310 index_t sliceSize,
311 index_t withinSliceStride,
312 int* smem,
313 scalar_t* topK) {
314 // Per-thread buckets into which we accumulate digit counts in our
315 // radix
316 int counts[RADIX_SIZE];
317
318 // We only consider elements x such that (x & desiredMask) == desired
319 // Initially, we consider all elements of the array, so the above
320 // statement is true regardless of input.
321 bitwise_t desired = 0;
322 bitwise_t desiredMask = 0;
323
324 // We are looking for the top kToFind-th element when iterating over
325 // digits; this count gets reduced by elimination when counting
326 // successive digits
327 int kToFind = k;
328
329 // We start at the most significant digit in our radix, scanning
330 // through to the least significant digit
331 for (int digitPos = sizeof(scalar_t) * 8 - RADIX_BITS; digitPos >= 0;
332 digitPos -= RADIX_BITS) {
333 // Count radix distribution for the current position and reduce
334 // across all threads
335 countRadixUsingMask<
336 scalar_t,
337 bitwise_t,
338 index_t,
339 int,
340 RADIX_SIZE,
341 RADIX_BITS>(
342 counts,
343 smem,
344 desired,
345 desiredMask,
346 digitPos,
347 sliceSize,
348 withinSliceStride,
349 data);
350
351 auto found_unique = [&](int i, int count) -> bool {
352 /* All threads have the same value in counts here, so all */
353 /* threads will return from the function. */
354 if (count == 1 && kToFind == 1) {
355 /* There is a unique answer. */
356 desired = at::cuda::Bitfield<bitwise_t>::setBitfield(
357 desired, i, digitPos, RADIX_BITS);
358 desiredMask = at::cuda::Bitfield<bitwise_t>::setBitfield(
359 desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
360
361 /* The answer is now the unique element v such that: */
362 /* (v & desiredMask) == desired */
363 /* However, we do not yet know what the actual element is. We */
364 /* need to perform a search through the data to find the */
365 /* element that matches this pattern. */
366 *topK = findPattern<scalar_t, bitwise_t, index_t>(
367 (scalar_t*)smem,
368 data,
369 sliceSize,
370 withinSliceStride,
371 desired,
372 desiredMask);
373 return true;
374 }
375 return false;
376 };
377 auto found_non_unique = [&](int i, int count) -> bool {
378 if (count >= kToFind) {
379 desired =
380 at::cuda::Bitfield<bitwise_t>::setBitfield(
381 desired, i, digitPos, RADIX_BITS);
382 desiredMask = at::cuda::Bitfield<bitwise_t>::setBitfield(
383 desiredMask, RADIX_MASK, digitPos, RADIX_BITS);
384
385 /* The top-Kth element v must now be one such that: */
386 /* (v & desiredMask == desired) */
387 /* but we haven't narrowed it down; we must check the next */
388 /* least-significant digit */
389 return true;
390 }
391 kToFind -= count;
392 return false; // continue the loop
393 };
394
395 // All threads participate in the comparisons below to know the
396 // final result
397 if (largest) {
398 // Process in descending order
399 #pragma unroll
400 for (int i = RADIX_SIZE - 1; i >= 0; --i) {
401 int count = counts[i];
402 if (found_unique(i, count)) {
403 return;
404 }
405 if (found_non_unique(i, count)) {
406 break;
407 }
408 }
409 } else {
410 // Process in ascending order
411 #pragma unroll
412 for (int i = 0; i < RADIX_SIZE; ++i) {
413 int count = counts[i];
414 if (found_unique(i, count)) {
415 return;
416 }
417 if (found_non_unique(i, count)) {
418 break;
419 }
420 }
421 }
422 } // end digitPos for
423
424 // There is no unique result, but there is a non-unique result
425 // matching `desired` exactly
426 *topK = TopKTypeConfig<scalar_t>::deconvert(desired);
427 }
428 } // namespace native
429 } // namespace at
430