1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/cuda/CUDAConfig.h>
3 #include <ATen/cuda/cub.cuh>
4
5 namespace at::cuda::cub::detail {
6
7 template <typename key_t, int value_size>
radix_sort_pairs_impl(const key_t * keys_in,key_t * keys_out,const OpaqueType<value_size> * values_in,OpaqueType<value_size> * values_out,int64_t n,bool descending,int64_t begin_bit,int64_t end_bit)8 void radix_sort_pairs_impl(
9 const key_t* keys_in,
10 key_t* keys_out,
11 const OpaqueType<value_size>* values_in,
12 OpaqueType<value_size>* values_out,
13 int64_t n,
14 bool descending,
15 int64_t begin_bit,
16 int64_t end_bit) {
17 TORCH_CHECK(
18 n <= std::numeric_limits<int>::max(),
19 "cub sort does not support sorting more than INT_MAX elements");
20 using key_t_ = typename detail::cuda_type<key_t>::type;
21
22 auto allocator = c10::cuda::CUDACachingAllocator::get();
23 c10::DataPtr keys_out_owner;
24
25 if (keys_out == nullptr) {
26 keys_out_owner = allocator->allocate(n * sizeof(key_t));
27 keys_out = reinterpret_cast<key_t*>(keys_out_owner.get());
28 }
29
30 const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
31 key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out);
32
33 if (descending) {
34 CUB_WRAPPER(
35 NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending,
36 keys_in_,
37 keys_out_,
38 values_in,
39 values_out,
40 n,
41 begin_bit,
42 end_bit,
43 c10::cuda::getCurrentCUDAStream());
44 } else {
45 CUB_WRAPPER(
46 NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs,
47 keys_in_,
48 keys_out_,
49 values_in,
50 values_out,
51 n,
52 begin_bit,
53 end_bit,
54 c10::cuda::getCurrentCUDAStream());
55 }
56 }
57
58 #define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \
59 template void radix_sort_pairs_impl( \
60 const key_t* keys_in, \
61 key_t* keys_out, \
62 const OpaqueType<value_size>* values_in, \
63 OpaqueType<value_size>* values_out, \
64 int64_t n, \
65 bool descending, \
66 int64_t begin_bit, \
67 int64_t end_bit);
68
69 AT_INSTANTIATE_SORT_PAIRS(int32_t, 1)
70 AT_INSTANTIATE_SORT_PAIRS(int32_t, 2)
71 AT_INSTANTIATE_SORT_PAIRS(int32_t, 4)
72 AT_INSTANTIATE_SORT_PAIRS(int64_t, 1)
73 AT_INSTANTIATE_SORT_PAIRS(int64_t, 2)
74 AT_INSTANTIATE_SORT_PAIRS(int64_t, 4)
75
76 #define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \
77 AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8)
78
79 AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8)
80 AT_INSTANTIATE_SORT_PAIRS(uint16_t, 8)
81 AT_INSTANTIATE_SORT_PAIRS(uint32_t, 8)
82 AT_INSTANTIATE_SORT_PAIRS(uint64_t, 8)
83 AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8)
84
85 } // namespace at::cuda::cub::detail
86