xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/pow_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/native/Pow.h>
4 #include <c10/util/irange.h>
5 
6 #include <torch/types.h>
7 #include <torch/utils.h>
8 
9 #include <iostream>
10 #include <vector>
11 #include <type_traits>
12 
13 using namespace at;
14 
15 namespace {
16 
17 const auto int_min = std::numeric_limits<int>::min();
18 const auto int_max = std::numeric_limits<int>::max();
19 const auto long_min = std::numeric_limits<int64_t>::min();
20 const auto long_max = std::numeric_limits<int64_t>::max();
21 const auto float_lowest = std::numeric_limits<float>::lowest();
22 const auto float_min = std::numeric_limits<float>::min();
23 const auto float_max = std::numeric_limits<float>::max();
24 const auto double_lowest = std::numeric_limits<double>::lowest();
25 const auto double_min = std::numeric_limits<double>::min();
26 const auto double_max = std::numeric_limits<double>::max();
27 
28 const std::vector<int> ints {
29   int_min,
30   int_min + 1,
31   int_min + 2,
32   static_cast<int>(-sqrt(static_cast<double>(int_max))),
33   -3, -2, -1, 0, 1, 2, 3,
34   static_cast<int>(sqrt(static_cast<double>(int_max))),
35   int_max - 2,
36   int_max - 1,
37   int_max
38 };
39 const std::vector<int> non_neg_ints {
40   0, 1, 2, 3,
41   static_cast<int>(sqrt(static_cast<double>(int_max))),
42   int_max - 2,
43   int_max - 1,
44   int_max
45 };
46 const std::vector<int64_t> longs {
47   long_min,
48   long_min + 1,
49   long_min + 2,
50   static_cast<int64_t>(-sqrt(static_cast<double>(long_max))),
51   -3, -2, -1, 0, 1, 2, 3,
52   static_cast<int64_t>(sqrt(static_cast<double>(long_max))),
53   long_max - 2,
54   long_max - 1,
55   long_max
56 };
57 const std::vector<int64_t> non_neg_longs {
58   0, 1, 2, 3,
59   static_cast<int64_t>(sqrt(static_cast<double>(long_max))),
60   long_max - 2,
61   long_max - 1,
62   long_max
63 };
64 const std::vector<float> floats {
65   float_lowest,
66   -3.0f, -2.0f, -1.0f, -1.0f/2.0f, -1.0f/3.0f,
67   -float_min,
68   0.0,
69   float_min,
70   1.0f/3.0f, 1.0f/2.0f, 1.0f, 2.0f, 3.0f,
71   float_max,
72 };
73 const std::vector<double> doubles {
74   double_lowest,
75   -3.0, -2.0, -1.0, -1.0/2.0, -1.0/3.0,
76   -double_min,
77   0.0,
78   double_min,
79   1.0/3.0, 1.0/2.0, 1.0, 2.0, 3.0,
80   double_max,
81 };
82 
83 template <class T,
84   typename std::enable_if_t<std::is_floating_point_v<T>, T>* = nullptr>
assert_eq(T val,T act,T exp)85 void assert_eq(T val, T act, T exp) {
86   if (std::isnan(act) || std::isnan(exp)) {
87     return;
88   }
89   ASSERT_FLOAT_EQ(act, exp);
90 }
91 
92 template <class T,
93   typename std::enable_if_t<std::is_integral_v<T>, T>* = nullptr>
assert_eq(T val,T act,T exp)94 void assert_eq(T val, T act, T exp) {
95   if (val != 0 && act == 0) {
96     return;
97   }
98   if (val != 0 && exp == 0) {
99     return;
100   }
101   const auto min = std::numeric_limits<T>::min();
102   if (exp == min && val != min) {
103     return;
104   }
105   ASSERT_EQ(act, exp);
106 }
107 
108 template <class T,
109   typename std::enable_if_t<std::is_floating_point_v<T>, T>* = nullptr>
typed_pow(T base,T exp)110 T typed_pow(T base, T exp) {
111   return std::pow(base, exp);
112 }
113 template <class T,
114   typename std::enable_if_t<std::is_integral_v<T>, T>* = nullptr>
typed_pow(T base,T exp)115 T typed_pow(T base, T exp) {
116   return native::powi(base, exp);
117 }
118 
119 template<typename Vals, typename Pows>
tensor_pow_scalar(const Vals vals,const Pows pows,const torch::ScalarType valsDtype,const torch::ScalarType dtype)120 void tensor_pow_scalar(const Vals vals, const Pows pows, const torch::ScalarType valsDtype, const torch::ScalarType dtype) {
121   const auto tensor = torch::tensor(vals, valsDtype);
122 
123   for (const auto pow : pows) {
124     // NOLINTNEXTLINE(clang-diagnostic-implicit-const-int-float-conversion)
125     if ( dtype == kInt && pow > static_cast<float>(std::numeric_limits<int>::max())) {
126       // value cannot be converted to type int without overflow
127       // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
128       EXPECT_THROW(tensor.pow(pow), std::runtime_error);
129       continue;
130     }
131     auto actual_pow = tensor.pow(pow);
132 
133     auto actual_pow_ = torch::empty_like(actual_pow);
134     actual_pow_.copy_(tensor);
135     actual_pow_.pow_(pow);
136 
137     auto actual_pow_out = torch::empty_like(actual_pow);
138     torch::pow_out(actual_pow_out, tensor, pow);
139 
140     auto actual_torch_pow = torch::pow(tensor, pow);
141 
142     int i = 0;
143     for (const auto val : vals) {
144       const auto exp = torch::pow(torch::tensor({val}, dtype), torch::tensor(pow, dtype)).template item<double>();
145 
146       const auto act_pow = actual_pow[i].to(at::kDouble).template item<double>();
147       assert_eq<long double>(val, act_pow, exp);
148 
149       const auto act_pow_ = actual_pow_[i].to(at::kDouble).template item<double>();
150       assert_eq<long double>(val, act_pow_, exp);
151 
152       const auto act_pow_out = actual_pow_out[i].to(at::kDouble).template item<double>();
153       assert_eq<long double>(val, act_pow_out, exp);
154 
155       const auto act_torch_pow = actual_torch_pow[i].to(at::kDouble).template item<double>();
156       assert_eq<long double>(val, act_torch_pow, exp);
157 
158       i++;
159     }
160   }
161 }
162 
163 template<typename Vals, typename Pows>
scalar_pow_tensor(const Vals vals,c10::ScalarType vals_dtype,const Pows pows,c10::ScalarType pows_dtype)164 void scalar_pow_tensor(const Vals vals, c10::ScalarType vals_dtype, const Pows pows, c10::ScalarType pows_dtype) {
165   using T = typename Pows::value_type;
166 
167   const auto pow_tensor = torch::tensor(pows, pows_dtype);
168 
169   for (const auto val : vals) {
170     const auto actual_pow = torch::pow(val, pow_tensor);
171     auto actual_pow_out1 = torch::empty_like(actual_pow);
172     const auto actual_pow_out2 =
173       torch::pow_out(actual_pow_out1, val, pow_tensor);
174 
175     int i = 0;
176     for (const auto pow : pows) {
177       const auto exp = typed_pow(static_cast<T>(val), T(pow));
178 
179       const auto act_pow = actual_pow[i].template item<T>();
180       assert_eq<T>(val, act_pow, exp);
181 
182       const auto act_pow_out1 = actual_pow_out1[i].template item<T>();
183       assert_eq<T>(val, act_pow_out1, exp);
184 
185       const auto act_pow_out2 = actual_pow_out2[i].template item<T>();
186       assert_eq<T>(val, act_pow_out2, exp);
187 
188       i++;
189     }
190   }
191 }
192 
193 template<typename Vals, typename Pows>
tensor_pow_tensor(const Vals vals,c10::ScalarType vals_dtype,Pows pows,c10::ScalarType pows_dtype)194 void tensor_pow_tensor(const Vals vals, c10::ScalarType vals_dtype, Pows pows, c10::ScalarType pows_dtype) {
195   using T = typename Vals::value_type;
196 
197   typedef std::numeric_limits< double > dbl;
198   std::cout.precision(dbl::max_digits10);
199 
200   const auto vals_tensor = torch::tensor(vals, vals_dtype);
201   for ([[maybe_unused]] const auto shirt : c10::irange(pows.size())) {
202     const auto pows_tensor = torch::tensor(pows, pows_dtype);
203 
204     const auto actual_pow = vals_tensor.pow(pows_tensor);
205 
206     auto actual_pow_ = vals_tensor.clone();
207     actual_pow_.pow_(pows_tensor);
208 
209     auto actual_pow_out = torch::empty_like(vals_tensor);
210     torch::pow_out(actual_pow_out, vals_tensor, pows_tensor);
211 
212     auto actual_torch_pow = torch::pow(vals_tensor, pows_tensor);
213 
214     int i = 0;
215     for (const auto val : vals) {
216       const auto pow = pows[i];
217       const auto exp = typed_pow(T(val), T(pow));
218 
219       const auto act_pow = actual_pow[i].template item<T>();
220       assert_eq(val, act_pow, exp);
221 
222       const auto act_pow_ = actual_pow_[i].template item<T>();
223       assert_eq(val, act_pow_, exp);
224 
225       const auto act_pow_out = actual_pow_out[i].template item<T>();
226       assert_eq(val, act_pow_out, exp);
227 
228       const auto act_torch_pow = actual_torch_pow[i].template item<T>();
229       assert_eq(val, act_torch_pow, exp);
230 
231       i++;
232     }
233 
234     std::rotate(pows.begin(), pows.begin() + 1, pows.end());
235   }
236 }
237 
238 template<typename T>
test_pow_one(const std::vector<T> vals)239 void test_pow_one(const std::vector<T> vals) {
240   for (const auto val : vals) {
241     ASSERT_EQ(native::powi(val, T(1)), val);
242   }
243 }
244 
245 template<typename T>
test_squared(const std::vector<T> vals)246 void test_squared(const std::vector<T> vals) {
247   for (const auto val : vals) {
248     ASSERT_EQ(native::powi(val, T(2)), val * val);
249   }
250 }
251 
252 template<typename T>
test_cubed(const std::vector<T> vals)253 void test_cubed(const std::vector<T> vals) {
254   for (const auto val : vals) {
255     ASSERT_EQ(native::powi(val, T(3)), val * val * val);
256   }
257 }
258 template<typename T>
test_inverse(const std::vector<T> vals)259 void test_inverse(const std::vector<T> vals) {
260   for (const auto val : vals) {
261     // 1 has special checks below
262     if ( val != 1 && val != -1) {
263       ASSERT_EQ(native::powi(val, T(-4)), 0);
264       ASSERT_EQ(native::powi(val, T(-1)), val==1);
265     }
266   }
267   T neg1 = -1;
268   ASSERT_EQ(native::powi(neg1, T(0)), 1);
269   ASSERT_EQ(native::powi(neg1, T(-1)), -1);
270   ASSERT_EQ(native::powi(neg1, T(-2)), 1);
271   ASSERT_EQ(native::powi(neg1, T(-3)), -1);
272   ASSERT_EQ(native::powi(neg1, T(-4)), 1);
273 
274   T one = 1;
275   ASSERT_EQ(native::powi(one, T(0)), 1);
276   ASSERT_EQ(native::powi(one, T(-1)), 1);
277   ASSERT_EQ(native::powi(one, T(-2)), 1);
278   ASSERT_EQ(native::powi(one, T(-3)), 1);
279   ASSERT_EQ(native::powi(one, T(-4)), 1);
280 
281 }
282 
283 }
284 
TEST(PowTest,IntTensorPowAllScalars)285 TEST(PowTest, IntTensorPowAllScalars) {
286   tensor_pow_scalar(ints, non_neg_ints, kInt, kInt);
287   tensor_pow_scalar(ints, non_neg_longs, kInt, kInt);
288   tensor_pow_scalar(ints, floats, kInt, kFloat);
289   tensor_pow_scalar(ints, doubles, kInt, kDouble);
290 }
291 
TEST(PowTest,LongTensorPowAllScalars)292 TEST(PowTest, LongTensorPowAllScalars) {
293   tensor_pow_scalar(longs, non_neg_ints, kLong, kLong);
294   tensor_pow_scalar(longs, non_neg_longs, kLong, kLong);
295   tensor_pow_scalar(longs, floats, kLong, kFloat);
296   tensor_pow_scalar(longs, doubles, kLong, kDouble);
297 }
298 
TEST(PowTest,FloatTensorPowAllScalars)299 TEST(PowTest, FloatTensorPowAllScalars) {
300   tensor_pow_scalar(floats, ints, kFloat, kDouble);
301   tensor_pow_scalar(floats, longs, kFloat, kDouble);
302   tensor_pow_scalar(floats, floats, kFloat, kFloat);
303   tensor_pow_scalar(floats, doubles, kFloat, kDouble);
304 }
305 
TEST(PowTest,DoubleTensorPowAllScalars)306 TEST(PowTest, DoubleTensorPowAllScalars) {
307   tensor_pow_scalar(doubles, ints, kDouble, kDouble);
308   tensor_pow_scalar(doubles, longs, kDouble, kDouble);
309   tensor_pow_scalar(doubles, floats, kDouble, kDouble);
310   tensor_pow_scalar(doubles, doubles, kDouble, kDouble);
311 }
312 
TEST(PowTest,IntScalarPowAllTensors)313 TEST(PowTest, IntScalarPowAllTensors) {
314   scalar_pow_tensor(ints, c10::kInt, ints, c10::kInt);
315   scalar_pow_tensor(ints, c10::kInt, longs, c10::kLong);
316   scalar_pow_tensor(ints, c10::kInt, floats, c10::kFloat);
317   scalar_pow_tensor(ints, c10::kInt, doubles, c10::kDouble);
318 }
319 
TEST(PowTest,LongScalarPowAllTensors)320 TEST(PowTest, LongScalarPowAllTensors) {
321   scalar_pow_tensor(longs, c10::kLong, longs, c10::kLong);
322   scalar_pow_tensor(longs, c10::kLong, floats, c10::kFloat);
323   scalar_pow_tensor(longs, c10::kLong, doubles, c10::kDouble);
324 }
325 
TEST(PowTest,FloatScalarPowAllTensors)326 TEST(PowTest, FloatScalarPowAllTensors) {
327   scalar_pow_tensor(floats, c10::kFloat, floats, c10::kFloat);
328   scalar_pow_tensor(floats, c10::kFloat, doubles, c10::kDouble);
329 }
330 
TEST(PowTest,DoubleScalarPowAllTensors)331 TEST(PowTest, DoubleScalarPowAllTensors) {
332   scalar_pow_tensor(doubles, c10::kDouble, doubles, c10::kDouble);
333 }
334 
TEST(PowTest,IntTensorPowIntTensor)335 TEST(PowTest, IntTensorPowIntTensor) {
336   tensor_pow_tensor(ints, c10::kInt, ints, c10::kInt);
337 }
338 
TEST(PowTest,LongTensorPowLongTensor)339 TEST(PowTest, LongTensorPowLongTensor) {
340   tensor_pow_tensor(longs, c10::kLong, longs, c10::kLong);
341 }
342 
TEST(PowTest,FloatTensorPowFloatTensor)343 TEST(PowTest, FloatTensorPowFloatTensor) {
344   tensor_pow_tensor(floats, c10::kFloat, floats, c10::kFloat);
345 }
346 
TEST(PowTest,DoubleTensorPowDoubleTensor)347 TEST(PowTest, DoubleTensorPowDoubleTensor) {
348   tensor_pow_tensor(doubles, c10::kDouble, doubles, c10::kDouble);
349 }
350 
TEST(PowTest,TestIntegralPow)351 TEST(PowTest, TestIntegralPow) {
352   test_pow_one(longs);
353   test_pow_one(ints);
354 
355   test_squared(longs);
356   test_squared(ints);
357 
358   test_cubed(longs);
359   test_cubed(ints);
360 
361   test_inverse(longs);
362   test_inverse(ints);
363 }
364