xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Pow.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 
5 namespace c10 {
6 class Scalar;
7 }
8 
9 namespace at {
10 
11 struct TensorIterator;
12 struct TensorIteratorBase;
13 
14 namespace native {
15 
16 #if defined(__CUDACC__) || defined(__HIPCC__)
17 #define HOST_DEVICE __host__ __device__
18 #else
19 #define HOST_DEVICE
20 #endif
21 
22 // integral power in pytorch allows for negative exponents, giving truncated integral results.
23 // e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the
24 // only non-zero result.
25 template <class T,
26   typename std::enable_if<std::is_integral<T>::value, T>::type* = nullptr>
powi_impl(T a,T b)27 inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
28   T result = 1;
29   while (b) {
30     if (b & 1) {
31        result *= a;
32     }
33     b /= 2;
34     a *= a;
35   }
36   return result;
37 }
38 
39 template <class T,
40   typename std::enable_if<std::is_integral<T>::value && !std::is_signed<T>::value, T>::type* = nullptr>
powi(T a,T b)41 inline HOST_DEVICE T powi(T a, T b) {
42   return powi_impl(a, b);
43 }
44 
45 template <class T,
46   typename std::enable_if<std::is_integral<T>::value && std::is_signed<T>::value, T>::type* = nullptr>
powi(T a,T b)47 inline HOST_DEVICE T powi(T a, T b) {
48   if ( b < 0 ) {
49       if ( a == 1 ) {
50           return 1;
51       } else if ( a == -1 ) {
52           auto negative = (-b) % static_cast<T>(2);
53           return negative ? -1 : 1;
54       } else {
55           return 0;
56       }
57   }
58   return powi_impl(a, b);
59 }
60 
61 using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&);
62 using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
63 
64 DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub);
65 DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub);
66 
67 } // namespace native
68 
69 } // namespace at
70