xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/PowKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/cuda/JitLoops.cuh>
6 #include <ATen/native/cuda/Pow.cuh>
7 #include <ATen/native/DispatchStub.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/Pow.h>
10 #include <c10/core/Scalar.h>
11 
12 namespace at::native {
13 
14 // Forward declare some unary kernels
15 void rsqrt_kernel_cuda(TensorIteratorBase& iter);
16 void sqrt_kernel_cuda(TensorIteratorBase& iter);
17 void reciprocal_kernel_cuda(TensorIteratorBase& iter);
18 
19 namespace {
20 
21 void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar);
22 
23 template <typename scalar_t>
pow_scalar_tensor_impl(TensorIteratorBase & iter,scalar_t base)24 void pow_scalar_tensor_impl(TensorIteratorBase& iter, scalar_t base) {
25   gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t exp) -> scalar_t {
26     return pow_(base, exp);
27   });
28 }
29 
30 template <typename value_t>
pow_scalar_tensor_impl(TensorIteratorBase & iter,c10::complex<value_t> base)31 void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex<value_t> base) {
32   // For complex, thrust::pow uses the identity
33   // pow(a, b) = exp(log(a) * b)
34   const auto fct = std::log(base);
35   gpu_kernel(iter, [=]GPU_LAMBDA(c10::complex<value_t> exp) -> c10::complex<value_t> {
36     return std::exp(fct * exp);
37   });
38 }
39 
40 /* complex<Half> support impl */
41 CONSTEXPR_EXCEPT_WIN_CUDA char pow_scalar_base_name[] = "pow_scalar_base_kernel";
42 template <>
pow_scalar_tensor_impl(TensorIteratorBase & iter,c10::complex<at::Half> base)43 void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex<at::Half> base) {
44   using scalar_t = c10::complex<at::Half>;
45   using opmath_t = at::opmath_type<scalar_t>;
46   // For complex, thrust::pow uses the identity
47   // pow(a, b) = exp(log(a) * b)
48   const auto fct = std::log(opmath_t{base});
49 #if AT_USE_JITERATOR()
50   static const auto pow_kernel_string =
51       jiterator_stringify(template <typename T> T pow_scalar_base_kernel(T exp, T fct) {
52         return std::exp(fct * exp);
53       });
54   jitted_gpu_kernel<pow_scalar_base_name, scalar_t, scalar_t, 1>(
55       iter,
56       pow_kernel_string,
57       /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
58       /*scalar_val=*/0,
59       /*extra_args=*/std::make_tuple(fct));
60 #else
61   gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t exp) -> scalar_t {
62     return std::exp(fct * opmath_t{exp});
63   });
64 #endif
65 }
66 
67 namespace {
68 
69 #if AT_USE_JITERATOR()
70 /* complex<Half> support impl */
71 CONSTEXPR_EXCEPT_WIN_CUDA char pow_name[] = "pow_kernel";
72 static const auto pow_kernel_string =
73     jiterator_stringify(template <typename T> T pow_kernel(T base, T exp) {
74       return std::pow(base, exp);
75     });
76 #endif
77 
78 /* complex<Half> support impl */
pow_chalf_tensor_scalar_impl(TensorIteratorBase & iter,const Scalar & exp_scalar)79 void pow_chalf_tensor_scalar_impl(TensorIteratorBase& iter, const Scalar& exp_scalar) {
80   using scalar_t = c10::complex<at::Half>;
81   using opmath_t = at::opmath_type<scalar_t>;
82   auto exp = exp_scalar.to<opmath_t>();
83 #if AT_USE_JITERATOR()
84   jitted_gpu_kernel<pow_name, scalar_t, scalar_t, 1>(
85       iter,
86       pow_kernel_string,
87       /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
88       /*scalar_val=*/0,
89       /*extra_args=*/std::make_tuple(exp));
90 #else
91   gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base) -> scalar_t {
92     return std::pow(opmath_t{base}, exp);
93   });
94 #endif
95 }
96 
97 }  // anonymous namespace
98 
pow_tensor_tensor_kernel(TensorIteratorBase & iter)99 void pow_tensor_tensor_kernel(TensorIteratorBase& iter) {
100   auto common_dtype = iter.common_dtype();
101   if (common_dtype == kComplexHalf) {
102     using scalar_t = c10::complex<at::Half>;
103     if (iter.is_cpu_scalar(1)) {
104       const auto base = iter.scalar_value<scalar_t>(1);
105       iter.remove_operand(1);
106       pow_scalar_tensor_impl(iter, base);
107     } else if (iter.is_cpu_scalar(2)) {
108       const auto exp = iter.scalar_value<scalar_t>(2);
109       iter.remove_operand(2);
110       pow_chalf_tensor_scalar_impl(iter, exp);
111     } else {
112       using opmath_t = at::opmath_type<scalar_t>;
113       TORCH_INTERNAL_ASSERT(!iter.is_cpu_scalar(1) && !iter.is_cpu_scalar(2));
114 #if AT_USE_JITERATOR()
115       jitted_gpu_kernel<pow_name, scalar_t, scalar_t, 2>(
116           iter, pow_kernel_string);
117 #else
118       gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
119             using opmath_t = at::opmath_type<scalar_t>;
120             return pow_(opmath_t{base}, opmath_t{exp});
121           });
122 #endif
123     }
124   } else {
125     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
126         kHalf, kBFloat16, iter.common_dtype(), "pow_cuda", [&] {
127       if (iter.is_cpu_scalar(1)) {
128         const auto base = iter.scalar_value<scalar_t>(1);
129         iter.remove_operand(1);
130         pow_scalar_tensor_impl(iter, base);
131       } else if (iter.is_cpu_scalar(2)) {
132         const auto exp = iter.scalar_value<scalar_t>(2);
133         iter.remove_operand(2);
134         pow_tensor_scalar_kernel(iter, exp);
135       } else {
136         gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
137           return pow_(base, exp);
138         });
139       }
140     });
141   }
142 }
143 
144 
145 template<typename Base_type, typename Exp_type>
pow_tensor_scalar_kernel_impl(TensorIteratorBase & iter,Exp_type exp)146 void pow_tensor_scalar_kernel_impl(TensorIteratorBase& iter,
147                                                  Exp_type exp) {
148   const auto d_exp = static_cast<double>(exp);
149   // .5 (sqrt), -.5 (rsqrt) and -1 (reciprocal) specializations are handled
150   // in pow_tensor_scalar_kernel
151   if (d_exp == 2) {
152     gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type {
153       return base * base;
154     });
155   } else if (d_exp == 3) {
156     gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type {
157       return base * base * base;
158     });
159   } else if (d_exp == -2) {
160     gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type {
161       return 1.0 / (base * base);
162     });
163   } else {
164     gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type {
165       return pow_(base, exp);
166     });
167   }
168 }
169 
pow_tensor_scalar_kernel(TensorIteratorBase & iter,const Scalar & exp_scalar)170 void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar) {
171   // Dispatch to fast specialization for sqrt, rsqrt and reciprocal
172   if (!exp_scalar.isComplex()) {
173     if (exp_scalar.equal(.5)) {
174       return sqrt_kernel_cuda(iter);
175     } else if (exp_scalar.equal(-0.5)) {
176       return rsqrt_kernel_cuda(iter);
177     } else if (exp_scalar.equal(-1.0)) {
178       return reciprocal_kernel_cuda(iter);
179     }
180   }
181   if (isComplexType(iter.common_dtype()) || exp_scalar.isComplex()) {
182     if (iter.common_dtype() == kComplexHalf) {
183       using scalar_t = c10::complex<at::Half>;
184       pow_chalf_tensor_scalar_impl(iter, exp_scalar);
185       return;
186     }
187     AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "pow_cuda", [&]() {
188       const auto exp = exp_scalar.to<scalar_t>();
189       gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t {
190         return pow_(base, exp);
191       });
192     });
193   } else if (isFloatingType(iter.common_dtype()) || exp_scalar.isIntegral(false)) {
194     AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "pow_cuda", [&]() {
195       const auto exp = exp_scalar.to<scalar_t>();
196       pow_tensor_scalar_kernel_impl<scalar_t>(iter, exp);
197     });
198   } else {
199     TORCH_INTERNAL_ASSERT(false, "invalid combination of type in Pow function, common dtype:", iter.common_dtype(),
200                                  "exp is integral?", exp_scalar.isIntegral(false));
201   }
202 }
203 
204 } // anonymous namespace
205 
206 REGISTER_DISPATCH(pow_tensor_tensor_stub, &pow_tensor_tensor_kernel);
207 REGISTER_DISPATCH(pow_tensor_scalar_stub, &pow_tensor_scalar_kernel);
208 
209 } // namespace at::native
210