1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/core/TensorBase.h>
3 #include <ATen/Dispatch.h>
4
5 #include <ATen/native/cuda/ScanKernels.h>
6 #include <ATen/native/cuda/ScanUtils.cuh>
7
8 namespace at::native {
9
launch_cumprod_cuda_kernel(const TensorBase & result,const TensorBase & self,int64_t dim)10 void launch_cumprod_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) {
11 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
12 ScalarType::Half, ScalarType::BFloat16, self.scalar_type(), "cumprod_cuda", [&]() {
13 scalar_t init = 1;
14 scan_dim<scalar_t>(
15 self,
16 result,
17 dim,
18 init,
19 std::multiplies<scalar_t>());
20 });
21 }
22
23 } // namespace at::native
24