1 #pragma once 2 #include <cutlass/cutlass.h> 3 4 /** 5 * A Functor class to create a sort for fixed sized arrays/containers with a 6 * compile time generated Bose-Nelson sorting network. 7 * \tparam NumElements The number of elements in the array or container to 8 * sort. \tparam T The element type. \tparam Compare A 9 * comparator functor class that returns true if lhs < rhs. 10 */ 11 template <unsigned NumElements> 12 class StaticSort { 13 template <class A> 14 struct Swap { 15 template <class T> sSwap16 CUTLASS_HOST_DEVICE void s(T& v0, T& v1) { 17 // Explicitly code out the Min and Max to nudge the compiler 18 // to generate branchless code. 19 T t = v0 < v1 ? v0 : v1; // Min 20 v1 = v0 < v1 ? v1 : v0; // Max 21 v0 = t; 22 } 23 SwapSwap24 CUTLASS_HOST_DEVICE Swap(A& a, const int& i0, const int& i1) { 25 s(a[i0], a[i1]); 26 } 27 }; 28 29 template <class A, int I, int J, int X, int Y> 30 struct PB { PBPB31 CUTLASS_HOST_DEVICE PB(A& a) { 32 enum { 33 L = X >> 1, 34 M = (X & 1 ? Y : Y + 1) >> 1, 35 IAddL = I + L, 36 XSubL = X - L 37 }; 38 PB<A, I, J, L, M> p0(a); 39 PB<A, IAddL, J + M, XSubL, Y - M> p1(a); 40 PB<A, IAddL, J, XSubL, M> p2(a); 41 } 42 }; 43 44 template <class A, int I, int J> 45 struct PB<A, I, J, 1, 1> { 46 CUTLASS_HOST_DEVICE PB(A& a) { 47 Swap<A> s(a, I - 1, J - 1); 48 } 49 }; 50 51 template <class A, int I, int J> 52 struct PB<A, I, J, 1, 2> { 53 CUTLASS_HOST_DEVICE PB(A& a) { 54 Swap<A> s0(a, I - 1, J); 55 Swap<A> s1(a, I - 1, J - 1); 56 } 57 }; 58 59 template <class A, int I, int J> 60 struct PB<A, I, J, 2, 1> { 61 CUTLASS_HOST_DEVICE PB(A& a) { 62 Swap<A> s0(a, I - 1, J - 1); 63 Swap<A> s1(a, I, J - 1); 64 } 65 }; 66 67 template <class A, int I, int M, bool Stop = false> 68 struct PS { 69 CUTLASS_HOST_DEVICE PS(A& a) { 70 enum { L = M >> 1, IAddL = I + L, MSubL = M - L }; 71 PS<A, I, L, (L <= 1)> ps0(a); 72 PS<A, IAddL, MSubL, (MSubL <= 1)> ps1(a); 73 PB<A, I, IAddL, L, MSubL> pb(a); 74 } 75 }; 76 77 template <class A, int I, int M> 78 struct PS<A, I, M, true> { 79 CUTLASS_HOST_DEVICE PS(A& a) {} 80 }; 81 82 public: 83 /** 84 * Sorts the array/container arr. 85 * \param arr The array/container to be sorted. 86 */ 87 template <class Container> 88 CUTLASS_HOST_DEVICE void operator()(Container& arr) const { 89 PS<Container, 1, NumElements, (NumElements <= 1)> ps(arr); 90 }; 91 92 /** 93 * Sorts the array arr. 94 * \param arr The array to be sorted. 95 */ 96 template <class T> 97 CUTLASS_HOST_DEVICE void operator()(T* arr) const { 98 PS<T*, 1, NumElements, (NumElements <= 1)> ps(arr); 99 }; 100 }; 101