1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/native/TensorIterator.h>
4 #include <ATen/native/cuda/Reduce.cuh>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/SharedReduceOps.h>
7 #include <ATen/AccumulateType.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/native/ReduceOps.h>
10
11 namespace at::native {
12
13 template <typename scalar_t, typename out_t=scalar_t>
std_var_kernel_impl(TensorIterator & iter,double correction,bool take_sqrt)14 void std_var_kernel_impl(TensorIterator& iter, double correction, bool take_sqrt) {
15 // reducing unrolling factor to 2 for welford kernel
16 // This is necessary to lower register usage that leads to register spills.
17 using accscalar_t = at::acc_type<scalar_t, true>;
18 using ops_t = WelfordOps<scalar_t, accscalar_t, int32_t, thrust::pair<out_t, out_t>>;
19 ops_t ops(static_cast<accscalar_t>(correction), take_sqrt);
20 gpu_reduce_kernel<scalar_t, out_t, 2>(iter, ops, typename ops_t::acc_t{});
21 }
22
std_var_kernel_cuda(TensorIterator & iter,double correction,bool take_sqrt)23 static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool take_sqrt) {
24 const auto input_dtype = iter.input_dtype();
25 if (input_dtype == kHalf && iter.dtype() == kFloat) {
26 // type promotion that does cast and reduction in a single kernel
27 std_var_kernel_impl<at::Half, float>(iter, correction, take_sqrt);
28 } else if (input_dtype == kBFloat16 && iter.dtype() == kFloat) {
29 // type promotion that does cast and reduction in a single kernel
30 std_var_kernel_impl<at::BFloat16, float>(iter, correction, take_sqrt);
31 } else {
32 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
33 iter.dtype(), "std_cuda", [&]() {
34 std_var_kernel_impl<scalar_t>(iter, correction, take_sqrt);
35 });
36 }
37 }
38
39 template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
mean_kernel_impl(TensorIterator & iter)40 void mean_kernel_impl(TensorIterator& iter) {
41 // returns acc_t for all non-complex dtypes and returns T for c10::complex<T>
42 using factor_t = typename c10::scalar_value_type<acc_t>::type;
43 factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel();
44 gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
45 }
46
mean_kernel_cuda(TensorIterator & iter)47 static void mean_kernel_cuda(TensorIterator& iter) {
48 if (iter.dtype() == kHalf) {
49 mean_kernel_impl<at::Half, float>(iter);
50 } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
51 // type promotion that does cast and reduction in a single kernel
52 mean_kernel_impl<at::Half, float, float>(iter);
53 } else if(iter.dtype() == kBFloat16) {
54 mean_kernel_impl<at::BFloat16, float>(iter);
55 } else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) {
56 // type promotion that does cast and reduction in a single kernel
57 mean_kernel_impl<at::BFloat16, float, float>(iter);
58 } else {
59 AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "mean_cuda", [&]() {
60 mean_kernel_impl<scalar_t>(iter);
61 });
62 }
63 }
64
65 REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda);
66 REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda);
67
68 } // namespace at::native
69