xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ReduceMomentKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/native/TensorIterator.h>
4 #include <ATen/native/cuda/Reduce.cuh>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/SharedReduceOps.h>
7 #include <ATen/AccumulateType.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/native/ReduceOps.h>
10 
11 namespace at::native {
12 
13 template <typename scalar_t, typename out_t=scalar_t>
std_var_kernel_impl(TensorIterator & iter,double correction,bool take_sqrt)14 void std_var_kernel_impl(TensorIterator& iter, double correction, bool take_sqrt) {
15   // reducing unrolling factor to 2 for welford kernel
16   // This is necessary to lower register usage that leads to register spills.
17   using accscalar_t = at::acc_type<scalar_t, true>;
18   using ops_t = WelfordOps<scalar_t, accscalar_t, int32_t, thrust::pair<out_t, out_t>>;
19   ops_t ops(static_cast<accscalar_t>(correction), take_sqrt);
20   gpu_reduce_kernel<scalar_t, out_t, 2>(iter, ops, typename ops_t::acc_t{});
21 }
22 
std_var_kernel_cuda(TensorIterator & iter,double correction,bool take_sqrt)23 static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool take_sqrt) {
24   const auto input_dtype = iter.input_dtype();
25   if (input_dtype == kHalf && iter.dtype() == kFloat) {
26     // type promotion that does cast and reduction in a single kernel
27     std_var_kernel_impl<at::Half, float>(iter, correction, take_sqrt);
28   } else if (input_dtype == kBFloat16 && iter.dtype() == kFloat) {
29     // type promotion that does cast and reduction in a single kernel
30     std_var_kernel_impl<at::BFloat16, float>(iter, correction, take_sqrt);
31   } else {
32     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
33                                     iter.dtype(), "std_cuda", [&]() {
34       std_var_kernel_impl<scalar_t>(iter, correction, take_sqrt);
35     });
36   }
37 }
38 
39 template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
mean_kernel_impl(TensorIterator & iter)40 void mean_kernel_impl(TensorIterator& iter) {
41   //  returns acc_t for all non-complex dtypes and returns T for c10::complex<T>
42   using factor_t = typename c10::scalar_value_type<acc_t>::type;
43   factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel();
44   gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
45 }
46 
mean_kernel_cuda(TensorIterator & iter)47 static void mean_kernel_cuda(TensorIterator& iter) {
48   if (iter.dtype() == kHalf) {
49     mean_kernel_impl<at::Half, float>(iter);
50   } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
51     // type promotion that does cast and reduction in a single kernel
52     mean_kernel_impl<at::Half, float, float>(iter);
53   } else if(iter.dtype() == kBFloat16) {
54     mean_kernel_impl<at::BFloat16, float>(iter);
55   } else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) {
56     // type promotion that does cast and reduction in a single kernel
57     mean_kernel_impl<at::BFloat16, float, float>(iter);
58   } else {
59     AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "mean_cuda", [&]() {
60       mean_kernel_impl<scalar_t>(iter);
61     });
62   }
63 }
64 
65 REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda);
66 REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda);
67 
68 } // namespace at::native
69