xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/ScanUtils.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ceil_div.h>
4 #include <ATen/cuda/DeviceUtils.cuh>
5 #include <ATen/cuda/AsmUtils.cuh>
6 #include <c10/macros/Macros.h>
7 
8 // Collection of in-kernel scan / prefix sum utilities
9 
10 namespace at::cuda {
11 
12 // Inclusive prefix sum for binary vars using intra-warp voting +
13 // shared memory
14 template <typename T, bool KillWARDependency, class BinaryFunction>
inclusiveBinaryPrefixScan(T * smem,bool in,T * out,BinaryFunction binop)15 __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
16   // Within-warp, we use warp voting.
17 #if defined (USE_ROCM)
18   unsigned long long int vote = WARP_BALLOT(in);
19   T index = __popcll(getLaneMaskLe() & vote);
20   T carry = __popcll(vote);
21 #else
22   T vote = WARP_BALLOT(in);
23   T index = __popc(getLaneMaskLe() & vote);
24   T carry = __popc(vote);
25 #endif
26 
27   int warp = threadIdx.x / C10_WARP_SIZE;
28 
29   // Per each warp, write out a value
30   if (getLaneId() == 0) {
31     smem[warp] = carry;
32   }
33 
34   __syncthreads();
35 
36   // Sum across warps in one thread. This appears to be faster than a
37   // warp shuffle scan for CC 3.0+
38   if (threadIdx.x == 0) {
39     int current = 0;
40     for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
41       T v = smem[i];
42       smem[i] = binop(smem[i], current);
43       current = binop(current, v);
44     }
45   }
46 
47   __syncthreads();
48 
49   // load the carry from the preceding warp
50   if (warp >= 1) {
51     index = binop(index, smem[warp - 1]);
52   }
53 
54   *out = index;
55 
56   if (KillWARDependency) {
57     __syncthreads();
58   }
59 }
60 
61 // Exclusive prefix sum for binary vars using intra-warp voting +
62 // shared memory
63 template <typename T, bool KillWARDependency, class BinaryFunction>
exclusiveBinaryPrefixScan(T * smem,bool in,T * out,T * carry,BinaryFunction binop)64 __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
65   inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
66 
67   // Inclusive to exclusive
68   *out -= (T) in;
69 
70   // The outgoing carry for all threads is the last warp's sum
71   *carry = smem[at::ceil_div<int>(blockDim.x, C10_WARP_SIZE) - 1];
72 
73   if (KillWARDependency) {
74     __syncthreads();
75   }
76 }
77 
78 }  // namespace at::cuda
79