1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/cuda/cub.cuh>
3 #include <ATen/cuda/CUDAConfig.h>
4
5 namespace at::cuda::cub {
6
7 namespace {
8 template <typename scalar_t>
9 struct SumOp {
operator ()at::cuda::cub::__anonabd027b00111::SumOp10 __device__ scalar_t operator () (scalar_t a, scalar_t b) const {
11 return a + b;
12 }
13 };
14 }
15
16 template <typename input_t, typename output_t>
inclusive_sum_truncating(const input_t * input,output_t * output,int64_t num_items)17 void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t num_items) {
18 using NO_ROCM(at_cuda_detail)::cub::Sum;
19 inclusive_scan(input, output, Sum{}, num_items);
20 }
21
22 template void inclusive_sum_truncating(const int32_t *input, int32_t *output, int64_t num_items);
23 template void inclusive_sum_truncating(const int64_t *input, int64_t *output, int64_t num_items);
24 template void inclusive_sum_truncating(const int32_t *input, int64_t *output, int64_t num_items);
25
26 template <typename input_t, typename output_t>
exclusive_sum_in_common_type(const input_t * input,output_t * output,int64_t num_items)27 void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t num_items) {
28 using scalar_t = std::common_type_t<input_t, output_t>;
29 exclusive_scan(input, output, SumOp<scalar_t>{}, scalar_t(0), num_items);
30 }
31
32 template void exclusive_sum_in_common_type(const int32_t *input, int32_t *output, int64_t num_items);
33 template void exclusive_sum_in_common_type(const int64_t *input, int64_t *output, int64_t num_items);
34
35 namespace {
36 struct CountMaskOp {
operator ()at::cuda::cub::__anonabd027b00211::CountMaskOp37 __device__ int64_t operator() (const uint8_t &x) const {
38 return x != 0;
39 }
40 };
41 }
42
mask_exclusive_sum(const uint8_t * mask,int64_t * output_idx,int64_t n)43 void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n) {
44 CountMaskOp op{};
45 auto iter = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<
46 bool, decltype(op), decltype(mask)>(mask, op);
47 exclusive_scan(iter, output_idx, SumOp<int64_t>{}, int64_t{0}, n);
48 }
49
50 } // namespace at::cuda::cub
51