1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/TensorIterator.h>
3 #include <ATen/native/cuda/Reduce.cuh>
4 #include <ATen/native/DispatchStub.h>
5 #include <ATen/native/SharedReduceOps.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/native/ReduceOps.h>
8 #include <ATen/jit_macros.h>
9 #include <ATen/OpMathType.h>
10
11 namespace at::native {
12
13 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
14 struct sum_functor {
operator ()at::native::sum_functor15 void operator()(TensorIterator& iter) {
16 gpu_reduce_kernel<scalar_t, out_t>(
17 iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
18 return a + b;
19 }));
20 }
21 };
22
23 // jiterated specialization for `complex<Half>`
24 CONSTEXPR_EXCEPT_WIN_CUDA char sum_name[] = "sum";
25 template <>
26 struct sum_functor<c10::complex<at::Half>> {
27 // jiterator reduction fails on windows
28 // Ref: https://github.com/pytorch/pytorch/issues/77305
29 #if AT_USE_JITERATOR() && !defined(_MSC_VER)
operator ()at::native::sum_functor30 void operator()(TensorIterator& iter) {
31 using scalar_t = c10::complex<at::Half>;
32 std::string func = jiterator_stringify(
33 arg_t combine(arg_t a, arg_t b) {
34 return a + b;
35 }
36 );
37 jitted_gpu_reduce_kernel<sum_name, scalar_t, scalar_t>(
38 iter, func, 0.);
39 }
40 #else
41 void operator()(TensorIterator& iter) {
42 using scalar_t = c10::complex<at::Half>;
43 using acc_t = at::opmath_type<scalar_t>;
44 gpu_reduce_kernel<scalar_t, scalar_t>(
45 iter, func_wrapper<scalar_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
46 return a + b;
47 }), acc_t{0.});
48 }
49 #endif
50 };
51
52 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
53 struct nansum_functor {
operator ()at::native::nansum_functor54 void operator()(TensorIterator& iter) {
55 gpu_reduce_kernel<scalar_t, out_t>(
56 iter, NanSumOps<acc_t, out_t>{});
57 }
58 };
59
60 CONSTEXPR_EXCEPT_WIN_CUDA char nansum_name[] = "nansum";
61 template <typename scalar_t>
62 struct nansum_functor_complex {
63 #if AT_USE_JITERATOR()
operator ()at::native::nansum_functor_complex64 void operator()(TensorIterator& iter) {
65 std::string func = jiterator_stringify(
66 arg_t combine(arg_t a, scalar_t b) {
67 return a + (std::isnan(b) ? arg_t{0.} : arg_t{b});
68 }
69 );
70 jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>(
71 iter, func, 0.);
72 }
73 #else
74 void operator()(TensorIterator& iter) {
75 using acc_t = at::opmath_type<scalar_t>;
76 gpu_reduce_kernel<scalar_t, acc_t>(
77 iter, NanSumOps<acc_t, acc_t>{});
78 }
79 #endif
80 };
81
82 CONSTEXPR_EXCEPT_WIN_CUDA char prod_name[] = "prod";
83 template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
84 struct prod_functor {
85 // jiterator reduction fails on windows
86 // Ref: https://github.com/pytorch/pytorch/issues/77305
87 #if AT_USE_JITERATOR() && !defined(_MSC_VER)
operator ()at::native::prod_functor88 void operator()(TensorIterator& iter) {
89 std::string func = jiterator_stringify(
90 arg_t combine(arg_t a, arg_t b) {
91 return a * b;
92 }
93 );
94 jitted_gpu_reduce_kernel<prod_name, scalar_t, out_t>(
95 iter, func, 1.);
96 }
97 #else
98 void operator()(TensorIterator& iter) {
99 gpu_reduce_kernel<scalar_t, out_t>(
100 iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
101 return a * b;
102 }), 1.);
103 }
104 #endif
105 };
106
107 // Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context]
108 template <>
109 struct prod_functor<bool> {
operator ()at::native::prod_functor110 void operator()(TensorIterator& iter) {
111 gpu_reduce_kernel<bool, bool>(
112 iter, func_wrapper<bool>([] GPU_LAMBDA(bool a, bool b) -> bool {
113 return a && b;
114 }), 1);
115 }
116 };
117
118 // jiterated specialization for `complex<Half>`
119 template <>
120 struct prod_functor<c10::complex<at::Half>> {
121 // jiterator reduction fails on windows
122 // Ref: https://github.com/pytorch/pytorch/issues/77305
123 #if AT_USE_JITERATOR() && !defined(_MSC_VER)
operator ()at::native::prod_functor124 void operator()(TensorIterator& iter) {
125 using scalar_t = c10::complex<at::Half>;
126 std::string func =
127 jiterator_stringify(arg_t combine(arg_t a, arg_t b) { return a * b; });
128 jitted_gpu_reduce_kernel<prod_name, scalar_t, scalar_t>(iter, func, 1.);
129 }
130 #else
131 void operator()(TensorIterator& iter) {
132 using scalar_t = c10::complex<at::Half>;
133 using acc_t = at::opmath_type<scalar_t>;
134 gpu_reduce_kernel<scalar_t, scalar_t>(
135 iter,
136 func_wrapper<scalar_t>(
137 [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { return a * b; }),
138 acc_t{1.});
139 }
140 #endif
141 };
142
143 // The function `reduce_dispatch` below dispatches to the kernel based
144 // on the type of `iter`. It takes care of the common logic
145 // for handling Half-Precision floating types.
146 // Otherwise the functor `op` is called to dispatch to the kernel
147 // of relevant type.
148 //
149 // Note: Functor `op` should take care of all the types to be supported
150 // except for `at::Half` and `at::BFloat16`.
151 template <
152 template <
153 typename scalar_t,
154 typename acc_t = scalar_t,
155 typename out_t = scalar_t>
156 typename OpFunctor,
157 typename GeneralDispatcher>
reduce_dispatch(TensorIterator & iter,GeneralDispatcher op)158 static void reduce_dispatch(TensorIterator& iter, GeneralDispatcher op) {
159 if (iter.dtype() == kHalf) {
160 return OpFunctor<at::Half, float>{}(iter);
161 } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
162 // type promotion that does cast and reduction in a single kernel
163 return OpFunctor<at::Half, float, float>{}(iter);
164 } else if (iter.dtype() == kBFloat16) {
165 return OpFunctor<at::BFloat16, float>{}(iter);
166 } else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) {
167 // type promotion that does cast and reduction in a single kernel
168 return OpFunctor<at::BFloat16, float, float>{}(iter);
169 }
170 op(iter);
171 }
172
sum_kernel_cuda(TensorIterator & iter)173 static void sum_kernel_cuda(TensorIterator& iter){
174 auto general_dispatcher = [](TensorIterator& iter) {
175 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
176 kBool, kComplexHalf, iter.dtype(), "sum_cuda", [&]() {
177 sum_functor<scalar_t>{}(iter);
178 });
179 };
180
181 reduce_dispatch<sum_functor>(iter, general_dispatcher);
182 }
183
nansum_kernel_cuda(TensorIterator & iter)184 static void nansum_kernel_cuda(TensorIterator& iter) {
185 auto general_dispatcher = [](TensorIterator& iter) {
186 auto dtype = iter.dtype();
187 if (at::isComplexType(dtype)) {
188 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "nansum_cuda", [&]() {
189 nansum_functor_complex<scalar_t>{}(iter);
190 });
191 } else {
192 AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "nansum_cuda", [&]() {
193 nansum_functor<scalar_t>{}(iter);
194 });
195 }
196 };
197
198 reduce_dispatch<nansum_functor>(iter, general_dispatcher);
199 }
200
prod_kernel_cuda(TensorIterator & iter)201 static void prod_kernel_cuda(TensorIterator& iter) {
202 auto general_dispatcher = [](TensorIterator& iter) {
203 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kComplexHalf, kBool, iter.dtype(), "prod_cuda", [&]() {
204 prod_functor<scalar_t>{}(iter);
205 });
206 };
207
208 reduce_dispatch<prod_functor>(iter, general_dispatcher);
209 }
210
211 REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda);
212 REGISTER_DISPATCH(nansum_stub, &nansum_kernel_cuda);
213 REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda);
214
215 } // namespace at::native
216