xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/TensorIterator.h>
3 #include <ATen/native/cuda/Reduce.cuh>
4 #include <ATen/native/DispatchStub.h>
5 #include <ATen/native/SharedReduceOps.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/native/ReduceOps.h>
8 #include <ATen/jit_macros.h>
9 #include <ATen/OpMathType.h>
10 
11 namespace at::native {
12 
13 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
14 struct sum_functor {
operator ()at::native::sum_functor15   void operator()(TensorIterator& iter) {
16     gpu_reduce_kernel<scalar_t, out_t>(
17         iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
18           return a + b;
19         }));
20   }
21 };
22 
23 // jiterated specialization for `complex<Half>`
24 CONSTEXPR_EXCEPT_WIN_CUDA char sum_name[] = "sum";
25 template <>
26 struct sum_functor<c10::complex<at::Half>> {
27 // jiterator reduction fails on windows
28 // Ref: https://github.com/pytorch/pytorch/issues/77305
29 #if AT_USE_JITERATOR() && !defined(_MSC_VER)
operator ()at::native::sum_functor30   void operator()(TensorIterator& iter) {
31     using scalar_t = c10::complex<at::Half>;
32     std::string func = jiterator_stringify(
33     arg_t combine(arg_t a, arg_t b) {
34       return a + b;
35     }
36     );
37     jitted_gpu_reduce_kernel<sum_name, scalar_t, scalar_t>(
38         iter, func, 0.);
39   }
40 #else
41   void operator()(TensorIterator& iter) {
42     using scalar_t = c10::complex<at::Half>;
43     using acc_t = at::opmath_type<scalar_t>;
44     gpu_reduce_kernel<scalar_t, scalar_t>(
45         iter, func_wrapper<scalar_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
46           return a + b;
47         }), acc_t{0.});
48   }
49 #endif
50 };
51 
52 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
53 struct nansum_functor {
operator ()at::native::nansum_functor54   void operator()(TensorIterator& iter) {
55     gpu_reduce_kernel<scalar_t, out_t>(
56         iter, NanSumOps<acc_t, out_t>{});
57   }
58 };
59 
60 CONSTEXPR_EXCEPT_WIN_CUDA char nansum_name[] = "nansum";
61 template <typename scalar_t>
62 struct nansum_functor_complex {
63 #if AT_USE_JITERATOR()
operator ()at::native::nansum_functor_complex64   void operator()(TensorIterator& iter) {
65     std::string func = jiterator_stringify(
66         arg_t combine(arg_t a, scalar_t b) {
67           return a + (std::isnan(b) ? arg_t{0.} : arg_t{b});
68         }
69     );
70     jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>(
71         iter, func, 0.);
72   }
73 #else
74   void operator()(TensorIterator& iter) {
75     using acc_t = at::opmath_type<scalar_t>;
76     gpu_reduce_kernel<scalar_t, acc_t>(
77         iter, NanSumOps<acc_t, acc_t>{});
78   }
79 #endif
80 };
81 
82 CONSTEXPR_EXCEPT_WIN_CUDA char prod_name[] = "prod";
83 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
84 struct prod_functor {
85   // jiterator reduction fails on windows
86   // Ref: https://github.com/pytorch/pytorch/issues/77305
87   #if AT_USE_JITERATOR() && !defined(_MSC_VER)
operator ()at::native::prod_functor88   void operator()(TensorIterator& iter) {
89     std::string func = jiterator_stringify(
90     arg_t combine(arg_t a, arg_t b) {
91       return a * b;
92     }
93     );
94     jitted_gpu_reduce_kernel<prod_name, scalar_t, out_t>(
95         iter, func, 1.);
96   }
97   #else
98   void operator()(TensorIterator& iter) {
99     gpu_reduce_kernel<scalar_t, out_t>(
100         iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
101           return a * b;
102         }), 1.);
103   }
104   #endif
105 };
106 
107 // Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context]
108 template <>
109 struct prod_functor<bool> {
operator ()at::native::prod_functor110   void operator()(TensorIterator& iter) {
111     gpu_reduce_kernel<bool, bool>(
112         iter, func_wrapper<bool>([] GPU_LAMBDA(bool a, bool b) -> bool {
113           return a && b;
114         }), 1);
115   }
116 };
117 
118 // jiterated specialization for `complex<Half>`
119 template <>
120 struct prod_functor<c10::complex<at::Half>> {
121 // jiterator reduction fails on windows
122 // Ref: https://github.com/pytorch/pytorch/issues/77305
123 #if AT_USE_JITERATOR() && !defined(_MSC_VER)
operator ()at::native::prod_functor124   void operator()(TensorIterator& iter) {
125     using scalar_t = c10::complex<at::Half>;
126     std::string func =
127         jiterator_stringify(arg_t combine(arg_t a, arg_t b) { return a * b; });
128     jitted_gpu_reduce_kernel<prod_name, scalar_t, scalar_t>(iter, func, 1.);
129   }
130 #else
131   void operator()(TensorIterator& iter) {
132     using scalar_t = c10::complex<at::Half>;
133     using acc_t = at::opmath_type<scalar_t>;
134     gpu_reduce_kernel<scalar_t, scalar_t>(
135         iter,
136         func_wrapper<scalar_t>(
137             [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { return a * b; }),
138         acc_t{1.});
139   }
140 #endif
141 };
142 
143 // The function `reduce_dispatch` below dispatches to the kernel based
144 // on the type of `iter`. It takes care of the common logic
145 // for handling Half-Precision floating types.
146 // Otherwise the functor `op` is called to dispatch to the kernel
147 // of relevant type.
148 //
149 // Note: Functor `op` should take care of all the types to be supported
150 //       except for `at::Half` and `at::BFloat16`.
151 template <
152     template <
153         typename scalar_t,
154         typename acc_t = scalar_t,
155         typename out_t = scalar_t>
156     typename OpFunctor,
157     typename GeneralDispatcher>
reduce_dispatch(TensorIterator & iter,GeneralDispatcher op)158 static void reduce_dispatch(TensorIterator& iter, GeneralDispatcher op) {
159   if (iter.dtype() == kHalf) {
160     return OpFunctor<at::Half, float>{}(iter);
161   } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
162     // type promotion that does cast and reduction in a single kernel
163     return OpFunctor<at::Half, float, float>{}(iter);
164   } else if (iter.dtype() == kBFloat16) {
165     return OpFunctor<at::BFloat16, float>{}(iter);
166   } else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) {
167     // type promotion that does cast and reduction in a single kernel
168     return OpFunctor<at::BFloat16, float, float>{}(iter);
169   }
170   op(iter);
171 }
172 
sum_kernel_cuda(TensorIterator & iter)173 static void sum_kernel_cuda(TensorIterator& iter){
174   auto general_dispatcher = [](TensorIterator& iter) {
175     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
176         kBool, kComplexHalf, iter.dtype(), "sum_cuda", [&]() {
177           sum_functor<scalar_t>{}(iter);
178         });
179   };
180 
181   reduce_dispatch<sum_functor>(iter, general_dispatcher);
182 }
183 
nansum_kernel_cuda(TensorIterator & iter)184 static void nansum_kernel_cuda(TensorIterator& iter) {
185   auto general_dispatcher = [](TensorIterator& iter) {
186     auto dtype = iter.dtype();
187     if (at::isComplexType(dtype)) {
188         AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "nansum_cuda", [&]() {
189           nansum_functor_complex<scalar_t>{}(iter);
190         });
191     } else {
192         AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "nansum_cuda", [&]() {
193           nansum_functor<scalar_t>{}(iter);
194         });
195     }
196   };
197 
198   reduce_dispatch<nansum_functor>(iter, general_dispatcher);
199 }
200 
prod_kernel_cuda(TensorIterator & iter)201 static void prod_kernel_cuda(TensorIterator& iter) {
202   auto general_dispatcher = [](TensorIterator& iter) {
203     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kComplexHalf, kBool, iter.dtype(), "prod_cuda", [&]() {
204       prod_functor<scalar_t>{}(iter);
205     });
206   };
207 
208   reduce_dispatch<prod_functor>(iter, general_dispatcher);
209 }
210 
211 REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda);
212 REGISTER_DISPATCH(nansum_stub, &nansum_kernel_cuda);
213 REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda);
214 
215 } // namespace at::native
216