xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Pow.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/native/Pow.h>
3 #include <c10/core/Scalar.h>
4 
5 namespace at { namespace native {
6 
7 namespace {
8 
9 
10 // SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt.
11 // So we need to define the functions with the explicit function signatures.
12 // As for pow, the following signatures are defined as the device function:
13 //   pow(float, int)
14 //   pow(double, int)
15 //   pow(float, float)
16 //   pow(double, double)
17 #ifdef _MSC_VER
18 // Functions for pow
19 // pow for at::Half
pow_(at::Half base,at::Half exp)20 static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) {
21   return static_cast<at::Half>(std::pow(static_cast<float>(base), static_cast<float>(exp)));
22 }
23 // pow for at::BFloat16
pow_(at::BFloat16 base,at::BFloat16 exp)24 static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) {
25   return static_cast<at::BFloat16>(std::pow(static_cast<float>(base), static_cast<float>(exp)));
26 }
27 // pow (floating, floating/int)
28 template <typename Base_type, typename Exp_type>
29 static inline __host__ __device__ typename std::enable_if<std::is_floating_point<Base_type>::value && (std::is_same<Base_type, Exp_type>::value || std::is_same<Exp_type, int>::value), Base_type>::type
pow_(Base_type base,Exp_type exp)30   pow_(Base_type base, Exp_type exp) {
31   return std::pow(base, exp);
32 }
33 // pow (Otherwise)
34 template <typename Base_type, typename Exp_type>
35 static inline __host__ __device__ typename std::enable_if<!std::is_same<Base_type, Exp_type>::value && !std::is_same<Exp_type, int>::value, Base_type>::type
pow_(Base_type base,Exp_type exp)36   pow_(Base_type base, Exp_type exp) {
37   return static_cast<Base_type>(std::pow(static_cast<double>(base), static_cast<double>(exp)));
38 }
39 #else
40 template <typename Base_type, typename Exp_type>
41 static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) {
42   return ::pow(base, exp);
43 }
44 #endif
45 
46 template <typename T>
pow_(T base,T exp)47 static inline __host__ __device__ std::enable_if_t<std::is_integral<T>::value, T> pow_(
48     T base, T exp) {
49   return at::native::powi(base, exp);
50 }
51 
52 template <typename T>
pow_(c10::complex<T> base,c10::complex<T> exp)53 static inline __host__ __device__ c10::complex<T> pow_(c10::complex<T> base, c10::complex<T> exp) {
54   return c10_complex_math::pow(base, exp);
55 }
56 
57 } // namespace
58 }} // namespace at::native
59