xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/cub.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <cstdint>
3 #include <c10/core/ScalarType.h>
4 #include <ATen/cuda/CUDAConfig.h>
5 
6 // NOTE: These templates are intentionally not defined in this header,
7 // which aviods re-compiling them for each translation unit. If you get
8 // a link error, you need to add an explicit instantiation for your
9 // types in cub.cu
10 
11 namespace at::cuda::cub {
12 
get_num_bits(uint64_t max_key)13 inline int get_num_bits(uint64_t max_key) {
14   int num_bits = 1;
15   while (max_key > 1) {
16     max_key >>= 1;
17     num_bits++;
18   }
19   return num_bits;
20 }
21 
22 namespace detail {
23 
24 // radix_sort_pairs doesn't interact with value_t other than to copy
25 // the data, so we can save template instantiations by reinterpreting
26 // it as an opaque type.
27 template <int N> struct alignas(N) OpaqueType { char data[N]; };
28 
29 template<typename key_t, int value_size>
30 void radix_sort_pairs_impl(
31     const key_t *keys_in, key_t *keys_out,
32     const OpaqueType<value_size> *values_in, OpaqueType<value_size> *values_out,
33     int64_t n, bool descending, int64_t begin_bit, int64_t end_bit);
34 
35 }  // namespace detail
36 
37 template<typename key_t, typename value_t>
38 void radix_sort_pairs(
39     const key_t *keys_in, key_t *keys_out,
40     const value_t *values_in, value_t *values_out,
41     int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8) {
42   static_assert(std::is_trivially_copyable_v<value_t> ||
43                 AT_ROCM_ENABLED(),  // ROCm incorrectly fails this check for vector types
44                 "radix_sort_pairs value type must be trivially copyable");
45   // Make value type opaque, so all inputs of a certain size use the same template instantiation
46   using opaque_t = detail::OpaqueType<sizeof(value_t)>;
47   static_assert(sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0,
48                 "This size of value_t is not instantiated. Please instantiate it in cub.cu"
49                 " and modify this check.");
50   static_assert(sizeof(value_t) == alignof(value_t), "Expected value_t to be size-aligned");
51   detail::radix_sort_pairs_impl(
52       keys_in, keys_out,
53       reinterpret_cast<const opaque_t*>(values_in),
54       reinterpret_cast<opaque_t*>(values_out),
55       n, descending, begin_bit, end_bit);
56 }
57 
58 template<typename key_t>
59 void radix_sort_keys(
60     const key_t *keys_in, key_t *keys_out,
61     int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8);
62 
63 // NOTE: Intermediate sums will be truncated to input_t precision
64 template <typename input_t, typename output_t>
65 void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t n);
66 
67 template <typename scalar_t>
inclusive_sum(const scalar_t * input,scalar_t * output,int64_t n)68 void inclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) {
69   return inclusive_sum_truncating(input, output, n);
70 }
71 
72 // NOTE: Sums are done is common_type<input_t, output_t>
73 template <typename input_t, typename output_t>
74 void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t n);
75 
76 template <typename scalar_t>
exclusive_sum(const scalar_t * input,scalar_t * output,int64_t n)77 void exclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) {
78   return exclusive_sum_in_common_type(input, output, n);
79 }
80 
81 void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n);
mask_exclusive_sum(const bool * mask,int64_t * output_idx,int64_t n)82 inline void mask_exclusive_sum(const bool *mask, int64_t *output_idx, int64_t n) {
83   return mask_exclusive_sum(
84       reinterpret_cast<const uint8_t*>(mask), output_idx, n);
85 }
86 
87 }  // namespace at::cuda::cub
88