1 #pragma once
2
3 #include <ATen/ceil_div.h>
4 #include <ATen/cuda/DeviceUtils.cuh>
5 #include <ATen/cuda/AsmUtils.cuh>
6 #include <c10/macros/Macros.h>
7
8 // Collection of in-kernel scan / prefix sum utilities
9
10 namespace at::cuda {
11
12 // Inclusive prefix sum for binary vars using intra-warp voting +
13 // shared memory
14 template <typename T, bool KillWARDependency, class BinaryFunction>
inclusiveBinaryPrefixScan(T * smem,bool in,T * out,BinaryFunction binop)15 __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
16 // Within-warp, we use warp voting.
17 #if defined (USE_ROCM)
18 unsigned long long int vote = WARP_BALLOT(in);
19 T index = __popcll(getLaneMaskLe() & vote);
20 T carry = __popcll(vote);
21 #else
22 T vote = WARP_BALLOT(in);
23 T index = __popc(getLaneMaskLe() & vote);
24 T carry = __popc(vote);
25 #endif
26
27 int warp = threadIdx.x / C10_WARP_SIZE;
28
29 // Per each warp, write out a value
30 if (getLaneId() == 0) {
31 smem[warp] = carry;
32 }
33
34 __syncthreads();
35
36 // Sum across warps in one thread. This appears to be faster than a
37 // warp shuffle scan for CC 3.0+
38 if (threadIdx.x == 0) {
39 int current = 0;
40 for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
41 T v = smem[i];
42 smem[i] = binop(smem[i], current);
43 current = binop(current, v);
44 }
45 }
46
47 __syncthreads();
48
49 // load the carry from the preceding warp
50 if (warp >= 1) {
51 index = binop(index, smem[warp - 1]);
52 }
53
54 *out = index;
55
56 if (KillWARDependency) {
57 __syncthreads();
58 }
59 }
60
61 // Exclusive prefix sum for binary vars using intra-warp voting +
62 // shared memory
63 template <typename T, bool KillWARDependency, class BinaryFunction>
exclusiveBinaryPrefixScan(T * smem,bool in,T * out,T * carry,BinaryFunction binop)64 __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
65 inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
66
67 // Inclusive to exclusive
68 *out -= (T) in;
69
70 // The outgoing carry for all threads is the last warp's sum
71 *carry = smem[at::ceil_div<int>(blockDim.x, C10_WARP_SIZE) - 1];
72
73 if (KillWARDependency) {
74 __syncthreads();
75 }
76 }
77
78 } // namespace at::cuda
79