1 #pragma once
2
3 #include <ATen/native/DispatchStub.h>
4
5 namespace c10 {
6 class Scalar;
7 }
8
9 namespace at {
10
11 struct TensorIterator;
12 struct TensorIteratorBase;
13
14 namespace native {
15
16 #if defined(__CUDACC__) || defined(__HIPCC__)
17 #define HOST_DEVICE __host__ __device__
18 #else
19 #define HOST_DEVICE
20 #endif
21
22 // integral power in pytorch allows for negative exponents, giving truncated integral results.
23 // e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the
24 // only non-zero result.
25 template <class T,
26 typename std::enable_if<std::is_integral<T>::value, T>::type* = nullptr>
powi_impl(T a,T b)27 inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
28 T result = 1;
29 while (b) {
30 if (b & 1) {
31 result *= a;
32 }
33 b /= 2;
34 a *= a;
35 }
36 return result;
37 }
38
39 template <class T,
40 typename std::enable_if<std::is_integral<T>::value && !std::is_signed<T>::value, T>::type* = nullptr>
powi(T a,T b)41 inline HOST_DEVICE T powi(T a, T b) {
42 return powi_impl(a, b);
43 }
44
45 template <class T,
46 typename std::enable_if<std::is_integral<T>::value && std::is_signed<T>::value, T>::type* = nullptr>
powi(T a,T b)47 inline HOST_DEVICE T powi(T a, T b) {
48 if ( b < 0 ) {
49 if ( a == 1 ) {
50 return 1;
51 } else if ( a == -1 ) {
52 auto negative = (-b) % static_cast<T>(2);
53 return negative ? -1 : 1;
54 } else {
55 return 0;
56 }
57 }
58 return powi_impl(a, b);
59 }
60
61 using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&);
62 using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
63
64 DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub);
65 DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub);
66
67 } // namespace native
68
69 } // namespace at
70