xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UnaryOpsKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/UnaryOps.h>
3 
4 #include <limits>
5 
6 #include <ATen/AccumulateType.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <ATen/native/Math.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/native/cuda/jit_utils.h>
12 #include <ATen/native/cuda/JitLoops.cuh>
13 #include <ATen/native/cuda/Loops.cuh>
14 #include <ATen/native/cuda/Math.cuh>
15 #include <ATen/NumericUtils.h>
16 #include <ATen/OpMathType.h>
17 #include <c10/cuda/CUDAMathCompat.h>
18 #include <c10/core/Scalar.h>
19 #include <c10/util/complex.h>
20 
21 namespace at::native {
22 
bitwise_not_kernel_cuda(TensorIteratorBase & iter)23 void bitwise_not_kernel_cuda(TensorIteratorBase& iter) {
24   if (iter.dtype() == ScalarType::Bool) {
25     gpu_kernel(iter, []GPU_LAMBDA(bool a) {
26       return !a;
27     });
28   } else {
29     AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_not_cuda", [&]() {
30       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
31         return ~a;
32       });
33     });
34   }
35 }
36 
37 CONSTEXPR_EXCEPT_WIN_CUDA char exp_name[] = "exp_kernel";
exp_kernel_cuda(TensorIteratorBase & iter)38 void exp_kernel_cuda(TensorIteratorBase& iter) {
39   auto common_dtype = iter.common_dtype();
40   if (at::isComplexType(common_dtype)) {
41     #if AT_USE_JITERATOR()
42       static const auto exp_string = jiterator_stringify(
43           template <typename T>
44           T exp_kernel(T x) {
45             return std::exp(x);
46       }); // exp_string
47       AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "exp_cuda", [&]() {
48           jitted_gpu_kernel<
49               /*name=*/exp_name,
50               /*return_dtype=*/scalar_t,
51               /*common_dtype=*/scalar_t,
52               /*arity=*/1>(iter, exp_string);
53       });
54     #else
55       AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "exp_cuda", [&]() {
56         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
57           using opmath_t = at::opmath_type<scalar_t>;
58           return std::exp(static_cast<opmath_t>(a));
59         });
60       });
61     #endif
62   } else {
63     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, common_dtype, "exp_cuda", [&]() {
64       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
65         return std::exp(a);
66       });
67     });
68   }
69 }
70 
expm1_kernel_cuda(TensorIteratorBase & iter)71 void expm1_kernel_cuda(TensorIteratorBase& iter) {
72   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
73       ScalarType::BFloat16, ScalarType::Half,
74       iter.common_dtype(), "expm1_cuda",
75       [&]() {
76         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
77           return ::expm1(a);
78         });
79       });
80 }
81 
82 // We manually overload rsqrt because std::rsqrt does not work with complex types.
83 template<typename scalar_t>
rsqrt_wrapper(scalar_t v)84 C10_HOST_DEVICE static inline scalar_t rsqrt_wrapper(scalar_t v) {
85   return ::rsqrt(v);
86 }
87 
88 template<typename T>
rsqrt_wrapper(c10::complex<T> v)89 C10_HOST_DEVICE static inline c10::complex<T> rsqrt_wrapper(c10::complex<T> v) {
90   const c10::complex<T> one = c10::complex<T>(1.0, 0);
91   // std::sqrt for c10::complex is overloaded in c10/util/complex_math.h
92   return one / ::sqrt(v);
93 }
94 
95 CONSTEXPR_EXCEPT_WIN_CUDA char rsqrt_name[] = "rsqrt_kernel";
rsqrt_kernel_cuda(TensorIteratorBase & iter)96 void rsqrt_kernel_cuda(TensorIteratorBase& iter) {
97   auto common_dtype = iter.common_dtype();
98   if (at::isComplexType(common_dtype)) {
99     #if AT_USE_JITERATOR()
100       static const auto rsqrt_string = jiterator_stringify(
101           template <typename T>
102           T rsqrt_kernel(T x) {
103             const T one = T{1};
104             return one / std::sqrt(x);
105       }); // rsqrt_string
106       AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "rsqrt_cuda", [&]() {
107           jitted_gpu_kernel<
108               /*name=*/rsqrt_name,
109               /*return_dtype=*/scalar_t,
110               /*common_dtype=*/scalar_t,
111               /*arity=*/1>(iter, rsqrt_string);
112       });
113     #else
114       AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "rsqrt_cuda", [&]() {
115         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
116           using opmath_t = at::opmath_type<scalar_t>;
117           return rsqrt_wrapper(static_cast<opmath_t>(a));
118         });
119       });
120     #endif
121   } else {
122     AT_DISPATCH_FLOATING_TYPES_AND2(
123       ScalarType::BFloat16, ScalarType::Half,
124       iter.common_dtype(), "rsqrt_cuda",
125       [&]() {
126         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
127           // In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float.
128           return rsqrt_wrapper(a);
129         });
130       });
131   }
132 }
133 
134 CONSTEXPR_EXCEPT_WIN_CUDA char sqrt_name[] = "sqrt_kernel";
sqrt_kernel_cuda(TensorIteratorBase & iter)135 void sqrt_kernel_cuda(TensorIteratorBase& iter) {
136   auto common_dtype = iter.common_dtype();
137   if (at::isComplexType(common_dtype)) {
138     #if AT_USE_JITERATOR()
139       static const auto sqrt_string = jiterator_stringify(
140           template <typename T>
141           T sqrt_kernel(T x) {
142             return std::sqrt(x);
143       }); // sqrt_string
144       AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sqrt_cuda", [&]() {
145           jitted_gpu_kernel<
146               /*name=*/sqrt_name,
147               /*return_dtype=*/scalar_t,
148               /*common_dtype=*/scalar_t,
149               /*arity=*/1>(iter, sqrt_string);
150       });
151     #else
152       AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sqrt_cuda", [&]() {
153         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
154           using opmath_t = at::opmath_type<scalar_t>;
155           return ::sqrt(static_cast<opmath_t>(a));
156         });
157       });
158     #endif
159   } else {
160     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, common_dtype, "sqrt_cuda", [&]() {
161       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
162         return std::sqrt(a);
163       });
164     });
165   }
166 }
167 
clamp_kernel_cuda(TensorIteratorBase & iter,const Scalar & min_value,const Scalar & max_value)168 void clamp_kernel_cuda(TensorIteratorBase& iter, const Scalar& min_value, const Scalar& max_value) {
169   AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_cuda", [&]() {
170     auto lower = min_value.to<scalar_t>();
171     auto upper = max_value.to<scalar_t>();
172     gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t {
173       // Propagate nan, which doesn't propagate automatically for ROCm
174       if (_isnan(v)) {
175         return v;
176       } else {
177         return ::min(::max(v, lower), upper);
178       }
179     });
180   });
181 }
182 
clamp_min_kernel_cuda(TensorIteratorBase & iter,const Scalar & min_value)183 void clamp_min_kernel_cuda(TensorIteratorBase& iter, const Scalar& min_value) {
184   AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_min_cuda", [&]() {
185     auto lower = min_value.to<scalar_t>();
186     gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t {
187       // Propagate nan, which doesn't propagate automatically for ROCm
188       if (_isnan(v)) {
189         return v;
190       } else {
191         return ::max(v, lower);
192       }
193     });
194   });
195 }
196 
clamp_max_kernel_cuda(TensorIteratorBase & iter,const Scalar & max_value)197 void clamp_max_kernel_cuda(TensorIteratorBase& iter, const Scalar& max_value) {
198   AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_max_cuda", [&]() {
199     auto upper = max_value.to<scalar_t>();
200     gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t {
201       // Propagate nan, which doesn't propagate automatically for ROCm
202       if (_isnan(v)) {
203         return v;
204       } else {
205         return ::min(v, upper);
206       }
207     });
208   });
209 }
210 
211 template<typename scalar_t>
_nan_to_num_replace(scalar_t a,scalar_t nan_replacement,scalar_t pos_inf_replacement,scalar_t neg_inf_replacement)212 C10_HOST_DEVICE static inline scalar_t _nan_to_num_replace(scalar_t a, scalar_t nan_replacement, scalar_t pos_inf_replacement, scalar_t neg_inf_replacement) {
213   return at::_isnan(a)
214     ? nan_replacement
215     : (a == std::numeric_limits<scalar_t>::infinity()
216       ? pos_inf_replacement
217       : (a == -std::numeric_limits<scalar_t>::infinity()
218         ? neg_inf_replacement
219         : a));
220 }
221 
nan_to_num_kernel_cuda(TensorIteratorBase & iter,std::optional<double> nan,std::optional<double> pos_inf,std::optional<double> neg_inf)222 void nan_to_num_kernel_cuda(
223     TensorIteratorBase& iter,
224     std::optional<double> nan,
225     std::optional<double> pos_inf,
226     std::optional<double> neg_inf) {
227   if (isComplexType(iter.dtype())) {
228     AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "nan_to_num", [&]() {
229       using value_t = scalar_t::value_type;
230       value_t nan_replacement = static_cast<value_t>(nan.value_or(0.));
231       value_t pos_inf_replacement = pos_inf.has_value()
232           ? static_cast<value_t>(pos_inf.value())
233           : std::numeric_limits<value_t>::max();
234       value_t neg_inf_replacement = neg_inf.has_value()
235           ? static_cast<value_t>(neg_inf.value())
236           : std::numeric_limits<value_t>::lowest();
237 
238       gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t {
239         value_t res_real = _nan_to_num_replace(
240           a.real(), nan_replacement, pos_inf_replacement, neg_inf_replacement);
241         value_t res_imag = _nan_to_num_replace(
242           a.imag(), nan_replacement, pos_inf_replacement, neg_inf_replacement);
243         return scalar_t(res_real, res_imag);
244       });
245     });
246   } else {
247     AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "nan_to_num_cuda", [&]() {
248       scalar_t nan_replacement = static_cast<scalar_t>(nan.value_or(0.));
249       scalar_t pos_inf_replacement = pos_inf.has_value()
250           ? static_cast<scalar_t>(pos_inf.value())
251           : std::numeric_limits<scalar_t>::max();
252       scalar_t neg_inf_replacement = neg_inf.has_value()
253           ? static_cast<scalar_t>(neg_inf.value())
254           : std::numeric_limits<scalar_t>::lowest();
255 
256       gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t {
257           return _nan_to_num_replace(
258             a, nan_replacement, pos_inf_replacement, neg_inf_replacement);
259       });
260     });
261   }
262 }
263 
frexp_kernel_cuda(TensorIteratorBase & iter)264 void frexp_kernel_cuda(TensorIteratorBase& iter) {
265   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
266     // The iter.dtype() here is the dtype of mantissa output.
267     // It's a floating point type and must be the same as the input's dtype.
268     iter.dtype(),
269     "frexp_cuda", [&]() {
270       gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t a) -> thrust::tuple<scalar_t, int32_t> {
271         int32_t exponent;
272         scalar_t mantissa = std::frexp(a, &exponent);
273         return {mantissa, exponent};
274       });
275   });
276 }
277 
278 REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel_cuda);
279 REGISTER_DISPATCH(exp_stub, &exp_kernel_cuda);
280 REGISTER_DISPATCH(expm1_stub, &expm1_kernel_cuda);
281 REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel_cuda);
282 REGISTER_DISPATCH(sqrt_stub, &sqrt_kernel_cuda);
283 REGISTER_DISPATCH(nan_to_num_stub, &nan_to_num_kernel_cuda);
284 REGISTER_DISPATCH(frexp_stub, &frexp_kernel_cuda);
285 
286 } // namespace at::native
287