xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/cub.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/cuda/cub.cuh>
3 #include <ATen/cuda/CUDAConfig.h>
4 
5 namespace at::cuda::cub {
6 
7 namespace {
8 template <typename scalar_t>
9 struct SumOp {
operator ()at::cuda::cub::__anonabd027b00111::SumOp10   __device__ scalar_t operator () (scalar_t a, scalar_t b) const {
11     return a + b;
12   }
13 };
14 }
15 
16 template <typename input_t, typename output_t>
inclusive_sum_truncating(const input_t * input,output_t * output,int64_t num_items)17 void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t num_items) {
18   using NO_ROCM(at_cuda_detail)::cub::Sum;
19   inclusive_scan(input, output, Sum{}, num_items);
20 }
21 
22 template void inclusive_sum_truncating(const int32_t *input, int32_t *output, int64_t num_items);
23 template void inclusive_sum_truncating(const int64_t *input, int64_t *output, int64_t num_items);
24 template void inclusive_sum_truncating(const int32_t *input, int64_t *output, int64_t num_items);
25 
26 template <typename input_t, typename output_t>
exclusive_sum_in_common_type(const input_t * input,output_t * output,int64_t num_items)27 void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t num_items) {
28   using scalar_t = std::common_type_t<input_t, output_t>;
29   exclusive_scan(input, output, SumOp<scalar_t>{}, scalar_t(0), num_items);
30 }
31 
32 template void exclusive_sum_in_common_type(const int32_t *input, int32_t *output, int64_t num_items);
33 template void exclusive_sum_in_common_type(const int64_t *input, int64_t *output, int64_t num_items);
34 
35 namespace {
36 struct CountMaskOp {
operator ()at::cuda::cub::__anonabd027b00211::CountMaskOp37   __device__ int64_t operator() (const uint8_t &x) const {
38     return x != 0;
39   }
40 };
41 }
42 
mask_exclusive_sum(const uint8_t * mask,int64_t * output_idx,int64_t n)43 void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n) {
44   CountMaskOp op{};
45   auto iter = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<
46       bool, decltype(op), decltype(mask)>(mask, op);
47   exclusive_scan(iter, output_idx, SumOp<int64_t>{}, int64_t{0}, n);
48 }
49 
50 }  // namespace at::cuda::cub
51