1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/cuda/JitLoops.cuh>
6 #include <ATen/native/cuda/Pow.cuh>
7 #include <ATen/native/DispatchStub.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/Pow.h>
10 #include <c10/core/Scalar.h>
11
12 namespace at::native {
13
14 // Forward declare some unary kernels
15 void rsqrt_kernel_cuda(TensorIteratorBase& iter);
16 void sqrt_kernel_cuda(TensorIteratorBase& iter);
17 void reciprocal_kernel_cuda(TensorIteratorBase& iter);
18
19 namespace {
20
21 void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar);
22
23 template <typename scalar_t>
pow_scalar_tensor_impl(TensorIteratorBase & iter,scalar_t base)24 void pow_scalar_tensor_impl(TensorIteratorBase& iter, scalar_t base) {
25 gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t exp) -> scalar_t {
26 return pow_(base, exp);
27 });
28 }
29
30 template <typename value_t>
pow_scalar_tensor_impl(TensorIteratorBase & iter,c10::complex<value_t> base)31 void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex<value_t> base) {
32 // For complex, thrust::pow uses the identity
33 // pow(a, b) = exp(log(a) * b)
34 const auto fct = std::log(base);
35 gpu_kernel(iter, [=]GPU_LAMBDA(c10::complex<value_t> exp) -> c10::complex<value_t> {
36 return std::exp(fct * exp);
37 });
38 }
39
40 /* complex<Half> support impl */
41 CONSTEXPR_EXCEPT_WIN_CUDA char pow_scalar_base_name[] = "pow_scalar_base_kernel";
42 template <>
pow_scalar_tensor_impl(TensorIteratorBase & iter,c10::complex<at::Half> base)43 void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex<at::Half> base) {
44 using scalar_t = c10::complex<at::Half>;
45 using opmath_t = at::opmath_type<scalar_t>;
46 // For complex, thrust::pow uses the identity
47 // pow(a, b) = exp(log(a) * b)
48 const auto fct = std::log(opmath_t{base});
49 #if AT_USE_JITERATOR()
50 static const auto pow_kernel_string =
51 jiterator_stringify(template <typename T> T pow_scalar_base_kernel(T exp, T fct) {
52 return std::exp(fct * exp);
53 });
54 jitted_gpu_kernel<pow_scalar_base_name, scalar_t, scalar_t, 1>(
55 iter,
56 pow_kernel_string,
57 /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
58 /*scalar_val=*/0,
59 /*extra_args=*/std::make_tuple(fct));
60 #else
61 gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t exp) -> scalar_t {
62 return std::exp(fct * opmath_t{exp});
63 });
64 #endif
65 }
66
67 namespace {
68
69 #if AT_USE_JITERATOR()
70 /* complex<Half> support impl */
71 CONSTEXPR_EXCEPT_WIN_CUDA char pow_name[] = "pow_kernel";
72 static const auto pow_kernel_string =
73 jiterator_stringify(template <typename T> T pow_kernel(T base, T exp) {
74 return std::pow(base, exp);
75 });
76 #endif
77
78 /* complex<Half> support impl */
pow_chalf_tensor_scalar_impl(TensorIteratorBase & iter,const Scalar & exp_scalar)79 void pow_chalf_tensor_scalar_impl(TensorIteratorBase& iter, const Scalar& exp_scalar) {
80 using scalar_t = c10::complex<at::Half>;
81 using opmath_t = at::opmath_type<scalar_t>;
82 auto exp = exp_scalar.to<opmath_t>();
83 #if AT_USE_JITERATOR()
84 jitted_gpu_kernel<pow_name, scalar_t, scalar_t, 1>(
85 iter,
86 pow_kernel_string,
87 /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
88 /*scalar_val=*/0,
89 /*extra_args=*/std::make_tuple(exp));
90 #else
91 gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base) -> scalar_t {
92 return std::pow(opmath_t{base}, exp);
93 });
94 #endif
95 }
96
97 } // anonymous namespace
98
pow_tensor_tensor_kernel(TensorIteratorBase & iter)99 void pow_tensor_tensor_kernel(TensorIteratorBase& iter) {
100 auto common_dtype = iter.common_dtype();
101 if (common_dtype == kComplexHalf) {
102 using scalar_t = c10::complex<at::Half>;
103 if (iter.is_cpu_scalar(1)) {
104 const auto base = iter.scalar_value<scalar_t>(1);
105 iter.remove_operand(1);
106 pow_scalar_tensor_impl(iter, base);
107 } else if (iter.is_cpu_scalar(2)) {
108 const auto exp = iter.scalar_value<scalar_t>(2);
109 iter.remove_operand(2);
110 pow_chalf_tensor_scalar_impl(iter, exp);
111 } else {
112 using opmath_t = at::opmath_type<scalar_t>;
113 TORCH_INTERNAL_ASSERT(!iter.is_cpu_scalar(1) && !iter.is_cpu_scalar(2));
114 #if AT_USE_JITERATOR()
115 jitted_gpu_kernel<pow_name, scalar_t, scalar_t, 2>(
116 iter, pow_kernel_string);
117 #else
118 gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
119 using opmath_t = at::opmath_type<scalar_t>;
120 return pow_(opmath_t{base}, opmath_t{exp});
121 });
122 #endif
123 }
124 } else {
125 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
126 kHalf, kBFloat16, iter.common_dtype(), "pow_cuda", [&] {
127 if (iter.is_cpu_scalar(1)) {
128 const auto base = iter.scalar_value<scalar_t>(1);
129 iter.remove_operand(1);
130 pow_scalar_tensor_impl(iter, base);
131 } else if (iter.is_cpu_scalar(2)) {
132 const auto exp = iter.scalar_value<scalar_t>(2);
133 iter.remove_operand(2);
134 pow_tensor_scalar_kernel(iter, exp);
135 } else {
136 gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t {
137 return pow_(base, exp);
138 });
139 }
140 });
141 }
142 }
143
144
145 template<typename Base_type, typename Exp_type>
pow_tensor_scalar_kernel_impl(TensorIteratorBase & iter,Exp_type exp)146 void pow_tensor_scalar_kernel_impl(TensorIteratorBase& iter,
147 Exp_type exp) {
148 const auto d_exp = static_cast<double>(exp);
149 // .5 (sqrt), -.5 (rsqrt) and -1 (reciprocal) specializations are handled
150 // in pow_tensor_scalar_kernel
151 if (d_exp == 2) {
152 gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type {
153 return base * base;
154 });
155 } else if (d_exp == 3) {
156 gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type {
157 return base * base * base;
158 });
159 } else if (d_exp == -2) {
160 gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type {
161 return 1.0 / (base * base);
162 });
163 } else {
164 gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type {
165 return pow_(base, exp);
166 });
167 }
168 }
169
pow_tensor_scalar_kernel(TensorIteratorBase & iter,const Scalar & exp_scalar)170 void pow_tensor_scalar_kernel(TensorIteratorBase& iter, const Scalar& exp_scalar) {
171 // Dispatch to fast specialization for sqrt, rsqrt and reciprocal
172 if (!exp_scalar.isComplex()) {
173 if (exp_scalar.equal(.5)) {
174 return sqrt_kernel_cuda(iter);
175 } else if (exp_scalar.equal(-0.5)) {
176 return rsqrt_kernel_cuda(iter);
177 } else if (exp_scalar.equal(-1.0)) {
178 return reciprocal_kernel_cuda(iter);
179 }
180 }
181 if (isComplexType(iter.common_dtype()) || exp_scalar.isComplex()) {
182 if (iter.common_dtype() == kComplexHalf) {
183 using scalar_t = c10::complex<at::Half>;
184 pow_chalf_tensor_scalar_impl(iter, exp_scalar);
185 return;
186 }
187 AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "pow_cuda", [&]() {
188 const auto exp = exp_scalar.to<scalar_t>();
189 gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t base) -> scalar_t {
190 return pow_(base, exp);
191 });
192 });
193 } else if (isFloatingType(iter.common_dtype()) || exp_scalar.isIntegral(false)) {
194 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "pow_cuda", [&]() {
195 const auto exp = exp_scalar.to<scalar_t>();
196 pow_tensor_scalar_kernel_impl<scalar_t>(iter, exp);
197 });
198 } else {
199 TORCH_INTERNAL_ASSERT(false, "invalid combination of type in Pow function, common dtype:", iter.common_dtype(),
200 "exp is integral?", exp_scalar.isIntegral(false));
201 }
202 }
203
204 } // anonymous namespace
205
206 REGISTER_DISPATCH(pow_tensor_tensor_stub, &pow_tensor_tensor_kernel);
207 REGISTER_DISPATCH(pow_tensor_scalar_stub, &pow_tensor_scalar_kernel);
208
209 } // namespace at::native
210