xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/cub-RadixSortPairs.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/cuda/CUDAConfig.h>
3 #include <ATen/cuda/cub.cuh>
4 
5 namespace at::cuda::cub::detail {
6 
7 template <typename key_t, int value_size>
radix_sort_pairs_impl(const key_t * keys_in,key_t * keys_out,const OpaqueType<value_size> * values_in,OpaqueType<value_size> * values_out,int64_t n,bool descending,int64_t begin_bit,int64_t end_bit)8 void radix_sort_pairs_impl(
9     const key_t* keys_in,
10     key_t* keys_out,
11     const OpaqueType<value_size>* values_in,
12     OpaqueType<value_size>* values_out,
13     int64_t n,
14     bool descending,
15     int64_t begin_bit,
16     int64_t end_bit) {
17   TORCH_CHECK(
18       n <= std::numeric_limits<int>::max(),
19       "cub sort does not support sorting more than INT_MAX elements");
20   using key_t_ = typename detail::cuda_type<key_t>::type;
21 
22   auto allocator = c10::cuda::CUDACachingAllocator::get();
23   c10::DataPtr keys_out_owner;
24 
25   if (keys_out == nullptr) {
26     keys_out_owner = allocator->allocate(n * sizeof(key_t));
27     keys_out = reinterpret_cast<key_t*>(keys_out_owner.get());
28   }
29 
30   const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
31   key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out);
32 
33   if (descending) {
34     CUB_WRAPPER(
35         NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending,
36         keys_in_,
37         keys_out_,
38         values_in,
39         values_out,
40         n,
41         begin_bit,
42         end_bit,
43         c10::cuda::getCurrentCUDAStream());
44   } else {
45     CUB_WRAPPER(
46         NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs,
47         keys_in_,
48         keys_out_,
49         values_in,
50         values_out,
51         n,
52         begin_bit,
53         end_bit,
54         c10::cuda::getCurrentCUDAStream());
55   }
56 }
57 
58 #define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \
59   template void radix_sort_pairs_impl(               \
60       const key_t* keys_in,                          \
61       key_t* keys_out,                               \
62       const OpaqueType<value_size>* values_in,       \
63       OpaqueType<value_size>* values_out,            \
64       int64_t n,                                     \
65       bool descending,                               \
66       int64_t begin_bit,                             \
67       int64_t end_bit);
68 
69 AT_INSTANTIATE_SORT_PAIRS(int32_t, 1)
70 AT_INSTANTIATE_SORT_PAIRS(int32_t, 2)
71 AT_INSTANTIATE_SORT_PAIRS(int32_t, 4)
72 AT_INSTANTIATE_SORT_PAIRS(int64_t, 1)
73 AT_INSTANTIATE_SORT_PAIRS(int64_t, 2)
74 AT_INSTANTIATE_SORT_PAIRS(int64_t, 4)
75 
76 #define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \
77   AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8)
78 
79 AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8)
80 AT_INSTANTIATE_SORT_PAIRS(uint16_t, 8)
81 AT_INSTANTIATE_SORT_PAIRS(uint32_t, 8)
82 AT_INSTANTIATE_SORT_PAIRS(uint64_t, 8)
83 AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8)
84 
85 } // namespace at::cuda::cub::detail
86