xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/StaticSort.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <cutlass/cutlass.h>
3 
4 /**
5  * A Functor class to create a sort for fixed sized arrays/containers with a
6  * compile time generated Bose-Nelson sorting network.
7  * \tparam NumElements  The number of elements in the array or container to
8  * sort. \tparam T            The element type. \tparam Compare      A
9  * comparator functor class that returns true if lhs < rhs.
10  */
11 template <unsigned NumElements>
12 class StaticSort {
13   template <class A>
14   struct Swap {
15     template <class T>
sSwap16     CUTLASS_HOST_DEVICE void s(T& v0, T& v1) {
17       // Explicitly code out the Min and Max to nudge the compiler
18       // to generate branchless code.
19       T t = v0 < v1 ? v0 : v1; // Min
20       v1 = v0 < v1 ? v1 : v0; // Max
21       v0 = t;
22     }
23 
SwapSwap24     CUTLASS_HOST_DEVICE Swap(A& a, const int& i0, const int& i1) {
25       s(a[i0], a[i1]);
26     }
27   };
28 
29   template <class A, int I, int J, int X, int Y>
30   struct PB {
PBPB31     CUTLASS_HOST_DEVICE PB(A& a) {
32       enum {
33         L = X >> 1,
34         M = (X & 1 ? Y : Y + 1) >> 1,
35         IAddL = I + L,
36         XSubL = X - L
37       };
38       PB<A, I, J, L, M> p0(a);
39       PB<A, IAddL, J + M, XSubL, Y - M> p1(a);
40       PB<A, IAddL, J, XSubL, M> p2(a);
41     }
42   };
43 
44   template <class A, int I, int J>
45   struct PB<A, I, J, 1, 1> {
46     CUTLASS_HOST_DEVICE PB(A& a) {
47       Swap<A> s(a, I - 1, J - 1);
48     }
49   };
50 
51   template <class A, int I, int J>
52   struct PB<A, I, J, 1, 2> {
53     CUTLASS_HOST_DEVICE PB(A& a) {
54       Swap<A> s0(a, I - 1, J);
55       Swap<A> s1(a, I - 1, J - 1);
56     }
57   };
58 
59   template <class A, int I, int J>
60   struct PB<A, I, J, 2, 1> {
61     CUTLASS_HOST_DEVICE PB(A& a) {
62       Swap<A> s0(a, I - 1, J - 1);
63       Swap<A> s1(a, I, J - 1);
64     }
65   };
66 
67   template <class A, int I, int M, bool Stop = false>
68   struct PS {
69     CUTLASS_HOST_DEVICE PS(A& a) {
70       enum { L = M >> 1, IAddL = I + L, MSubL = M - L };
71       PS<A, I, L, (L <= 1)> ps0(a);
72       PS<A, IAddL, MSubL, (MSubL <= 1)> ps1(a);
73       PB<A, I, IAddL, L, MSubL> pb(a);
74     }
75   };
76 
77   template <class A, int I, int M>
78   struct PS<A, I, M, true> {
79     CUTLASS_HOST_DEVICE PS(A& a) {}
80   };
81 
82  public:
83   /**
84    * Sorts the array/container arr.
85    * \param  arr  The array/container to be sorted.
86    */
87   template <class Container>
88   CUTLASS_HOST_DEVICE void operator()(Container& arr) const {
89     PS<Container, 1, NumElements, (NumElements <= 1)> ps(arr);
90   };
91 
92   /**
93    * Sorts the array arr.
94    * \param  arr  The array to be sorted.
95    */
96   template <class T>
97   CUTLASS_HOST_DEVICE void operator()(T* arr) const {
98     PS<T*, 1, NumElements, (NumElements <= 1)> ps(arr);
99   };
100 };
101