xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/CumsumKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/core/TensorBase.h>
3 #include <ATen/Dispatch.h>
4 
5 #include <ATen/native/cuda/ScanKernels.h>
6 #include <ATen/native/cuda/ScanUtils.cuh>
7 
8 namespace at::native {
9 
launch_cumsum_cuda_kernel(const TensorBase & result,const TensorBase & self,int64_t dim)10 void launch_cumsum_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) {
11   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
12       ScalarType::Half, ScalarType::BFloat16,
13       self.scalar_type(), "cumsum_cuda",
14       [&]() {
15         scalar_t init = 0;
16         scan_dim<scalar_t>(
17             self,
18             result,
19             dim,
20             init,
21             std::plus<scalar_t>());
22       });
23 }
24 
25 } // namespace at::native
26