1 #include <gtest/gtest.h>
2
3 #include <ATen/native/Pow.h>
4 #include <c10/util/irange.h>
5
6 #include <torch/types.h>
7 #include <torch/utils.h>
8
9 #include <iostream>
10 #include <vector>
11 #include <type_traits>
12
13 using namespace at;
14
15 namespace {
16
17 const auto int_min = std::numeric_limits<int>::min();
18 const auto int_max = std::numeric_limits<int>::max();
19 const auto long_min = std::numeric_limits<int64_t>::min();
20 const auto long_max = std::numeric_limits<int64_t>::max();
21 const auto float_lowest = std::numeric_limits<float>::lowest();
22 const auto float_min = std::numeric_limits<float>::min();
23 const auto float_max = std::numeric_limits<float>::max();
24 const auto double_lowest = std::numeric_limits<double>::lowest();
25 const auto double_min = std::numeric_limits<double>::min();
26 const auto double_max = std::numeric_limits<double>::max();
27
28 const std::vector<int> ints {
29 int_min,
30 int_min + 1,
31 int_min + 2,
32 static_cast<int>(-sqrt(static_cast<double>(int_max))),
33 -3, -2, -1, 0, 1, 2, 3,
34 static_cast<int>(sqrt(static_cast<double>(int_max))),
35 int_max - 2,
36 int_max - 1,
37 int_max
38 };
39 const std::vector<int> non_neg_ints {
40 0, 1, 2, 3,
41 static_cast<int>(sqrt(static_cast<double>(int_max))),
42 int_max - 2,
43 int_max - 1,
44 int_max
45 };
46 const std::vector<int64_t> longs {
47 long_min,
48 long_min + 1,
49 long_min + 2,
50 static_cast<int64_t>(-sqrt(static_cast<double>(long_max))),
51 -3, -2, -1, 0, 1, 2, 3,
52 static_cast<int64_t>(sqrt(static_cast<double>(long_max))),
53 long_max - 2,
54 long_max - 1,
55 long_max
56 };
57 const std::vector<int64_t> non_neg_longs {
58 0, 1, 2, 3,
59 static_cast<int64_t>(sqrt(static_cast<double>(long_max))),
60 long_max - 2,
61 long_max - 1,
62 long_max
63 };
64 const std::vector<float> floats {
65 float_lowest,
66 -3.0f, -2.0f, -1.0f, -1.0f/2.0f, -1.0f/3.0f,
67 -float_min,
68 0.0,
69 float_min,
70 1.0f/3.0f, 1.0f/2.0f, 1.0f, 2.0f, 3.0f,
71 float_max,
72 };
73 const std::vector<double> doubles {
74 double_lowest,
75 -3.0, -2.0, -1.0, -1.0/2.0, -1.0/3.0,
76 -double_min,
77 0.0,
78 double_min,
79 1.0/3.0, 1.0/2.0, 1.0, 2.0, 3.0,
80 double_max,
81 };
82
83 template <class T,
84 typename std::enable_if_t<std::is_floating_point_v<T>, T>* = nullptr>
assert_eq(T val,T act,T exp)85 void assert_eq(T val, T act, T exp) {
86 if (std::isnan(act) || std::isnan(exp)) {
87 return;
88 }
89 ASSERT_FLOAT_EQ(act, exp);
90 }
91
92 template <class T,
93 typename std::enable_if_t<std::is_integral_v<T>, T>* = nullptr>
assert_eq(T val,T act,T exp)94 void assert_eq(T val, T act, T exp) {
95 if (val != 0 && act == 0) {
96 return;
97 }
98 if (val != 0 && exp == 0) {
99 return;
100 }
101 const auto min = std::numeric_limits<T>::min();
102 if (exp == min && val != min) {
103 return;
104 }
105 ASSERT_EQ(act, exp);
106 }
107
108 template <class T,
109 typename std::enable_if_t<std::is_floating_point_v<T>, T>* = nullptr>
typed_pow(T base,T exp)110 T typed_pow(T base, T exp) {
111 return std::pow(base, exp);
112 }
113 template <class T,
114 typename std::enable_if_t<std::is_integral_v<T>, T>* = nullptr>
typed_pow(T base,T exp)115 T typed_pow(T base, T exp) {
116 return native::powi(base, exp);
117 }
118
119 template<typename Vals, typename Pows>
tensor_pow_scalar(const Vals vals,const Pows pows,const torch::ScalarType valsDtype,const torch::ScalarType dtype)120 void tensor_pow_scalar(const Vals vals, const Pows pows, const torch::ScalarType valsDtype, const torch::ScalarType dtype) {
121 const auto tensor = torch::tensor(vals, valsDtype);
122
123 for (const auto pow : pows) {
124 // NOLINTNEXTLINE(clang-diagnostic-implicit-const-int-float-conversion)
125 if ( dtype == kInt && pow > static_cast<float>(std::numeric_limits<int>::max())) {
126 // value cannot be converted to type int without overflow
127 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
128 EXPECT_THROW(tensor.pow(pow), std::runtime_error);
129 continue;
130 }
131 auto actual_pow = tensor.pow(pow);
132
133 auto actual_pow_ = torch::empty_like(actual_pow);
134 actual_pow_.copy_(tensor);
135 actual_pow_.pow_(pow);
136
137 auto actual_pow_out = torch::empty_like(actual_pow);
138 torch::pow_out(actual_pow_out, tensor, pow);
139
140 auto actual_torch_pow = torch::pow(tensor, pow);
141
142 int i = 0;
143 for (const auto val : vals) {
144 const auto exp = torch::pow(torch::tensor({val}, dtype), torch::tensor(pow, dtype)).template item<double>();
145
146 const auto act_pow = actual_pow[i].to(at::kDouble).template item<double>();
147 assert_eq<long double>(val, act_pow, exp);
148
149 const auto act_pow_ = actual_pow_[i].to(at::kDouble).template item<double>();
150 assert_eq<long double>(val, act_pow_, exp);
151
152 const auto act_pow_out = actual_pow_out[i].to(at::kDouble).template item<double>();
153 assert_eq<long double>(val, act_pow_out, exp);
154
155 const auto act_torch_pow = actual_torch_pow[i].to(at::kDouble).template item<double>();
156 assert_eq<long double>(val, act_torch_pow, exp);
157
158 i++;
159 }
160 }
161 }
162
163 template<typename Vals, typename Pows>
scalar_pow_tensor(const Vals vals,c10::ScalarType vals_dtype,const Pows pows,c10::ScalarType pows_dtype)164 void scalar_pow_tensor(const Vals vals, c10::ScalarType vals_dtype, const Pows pows, c10::ScalarType pows_dtype) {
165 using T = typename Pows::value_type;
166
167 const auto pow_tensor = torch::tensor(pows, pows_dtype);
168
169 for (const auto val : vals) {
170 const auto actual_pow = torch::pow(val, pow_tensor);
171 auto actual_pow_out1 = torch::empty_like(actual_pow);
172 const auto actual_pow_out2 =
173 torch::pow_out(actual_pow_out1, val, pow_tensor);
174
175 int i = 0;
176 for (const auto pow : pows) {
177 const auto exp = typed_pow(static_cast<T>(val), T(pow));
178
179 const auto act_pow = actual_pow[i].template item<T>();
180 assert_eq<T>(val, act_pow, exp);
181
182 const auto act_pow_out1 = actual_pow_out1[i].template item<T>();
183 assert_eq<T>(val, act_pow_out1, exp);
184
185 const auto act_pow_out2 = actual_pow_out2[i].template item<T>();
186 assert_eq<T>(val, act_pow_out2, exp);
187
188 i++;
189 }
190 }
191 }
192
193 template<typename Vals, typename Pows>
tensor_pow_tensor(const Vals vals,c10::ScalarType vals_dtype,Pows pows,c10::ScalarType pows_dtype)194 void tensor_pow_tensor(const Vals vals, c10::ScalarType vals_dtype, Pows pows, c10::ScalarType pows_dtype) {
195 using T = typename Vals::value_type;
196
197 typedef std::numeric_limits< double > dbl;
198 std::cout.precision(dbl::max_digits10);
199
200 const auto vals_tensor = torch::tensor(vals, vals_dtype);
201 for ([[maybe_unused]] const auto shirt : c10::irange(pows.size())) {
202 const auto pows_tensor = torch::tensor(pows, pows_dtype);
203
204 const auto actual_pow = vals_tensor.pow(pows_tensor);
205
206 auto actual_pow_ = vals_tensor.clone();
207 actual_pow_.pow_(pows_tensor);
208
209 auto actual_pow_out = torch::empty_like(vals_tensor);
210 torch::pow_out(actual_pow_out, vals_tensor, pows_tensor);
211
212 auto actual_torch_pow = torch::pow(vals_tensor, pows_tensor);
213
214 int i = 0;
215 for (const auto val : vals) {
216 const auto pow = pows[i];
217 const auto exp = typed_pow(T(val), T(pow));
218
219 const auto act_pow = actual_pow[i].template item<T>();
220 assert_eq(val, act_pow, exp);
221
222 const auto act_pow_ = actual_pow_[i].template item<T>();
223 assert_eq(val, act_pow_, exp);
224
225 const auto act_pow_out = actual_pow_out[i].template item<T>();
226 assert_eq(val, act_pow_out, exp);
227
228 const auto act_torch_pow = actual_torch_pow[i].template item<T>();
229 assert_eq(val, act_torch_pow, exp);
230
231 i++;
232 }
233
234 std::rotate(pows.begin(), pows.begin() + 1, pows.end());
235 }
236 }
237
238 template<typename T>
test_pow_one(const std::vector<T> vals)239 void test_pow_one(const std::vector<T> vals) {
240 for (const auto val : vals) {
241 ASSERT_EQ(native::powi(val, T(1)), val);
242 }
243 }
244
245 template<typename T>
test_squared(const std::vector<T> vals)246 void test_squared(const std::vector<T> vals) {
247 for (const auto val : vals) {
248 ASSERT_EQ(native::powi(val, T(2)), val * val);
249 }
250 }
251
252 template<typename T>
test_cubed(const std::vector<T> vals)253 void test_cubed(const std::vector<T> vals) {
254 for (const auto val : vals) {
255 ASSERT_EQ(native::powi(val, T(3)), val * val * val);
256 }
257 }
258 template<typename T>
test_inverse(const std::vector<T> vals)259 void test_inverse(const std::vector<T> vals) {
260 for (const auto val : vals) {
261 // 1 has special checks below
262 if ( val != 1 && val != -1) {
263 ASSERT_EQ(native::powi(val, T(-4)), 0);
264 ASSERT_EQ(native::powi(val, T(-1)), val==1);
265 }
266 }
267 T neg1 = -1;
268 ASSERT_EQ(native::powi(neg1, T(0)), 1);
269 ASSERT_EQ(native::powi(neg1, T(-1)), -1);
270 ASSERT_EQ(native::powi(neg1, T(-2)), 1);
271 ASSERT_EQ(native::powi(neg1, T(-3)), -1);
272 ASSERT_EQ(native::powi(neg1, T(-4)), 1);
273
274 T one = 1;
275 ASSERT_EQ(native::powi(one, T(0)), 1);
276 ASSERT_EQ(native::powi(one, T(-1)), 1);
277 ASSERT_EQ(native::powi(one, T(-2)), 1);
278 ASSERT_EQ(native::powi(one, T(-3)), 1);
279 ASSERT_EQ(native::powi(one, T(-4)), 1);
280
281 }
282
283 }
284
TEST(PowTest,IntTensorPowAllScalars)285 TEST(PowTest, IntTensorPowAllScalars) {
286 tensor_pow_scalar(ints, non_neg_ints, kInt, kInt);
287 tensor_pow_scalar(ints, non_neg_longs, kInt, kInt);
288 tensor_pow_scalar(ints, floats, kInt, kFloat);
289 tensor_pow_scalar(ints, doubles, kInt, kDouble);
290 }
291
TEST(PowTest,LongTensorPowAllScalars)292 TEST(PowTest, LongTensorPowAllScalars) {
293 tensor_pow_scalar(longs, non_neg_ints, kLong, kLong);
294 tensor_pow_scalar(longs, non_neg_longs, kLong, kLong);
295 tensor_pow_scalar(longs, floats, kLong, kFloat);
296 tensor_pow_scalar(longs, doubles, kLong, kDouble);
297 }
298
TEST(PowTest,FloatTensorPowAllScalars)299 TEST(PowTest, FloatTensorPowAllScalars) {
300 tensor_pow_scalar(floats, ints, kFloat, kDouble);
301 tensor_pow_scalar(floats, longs, kFloat, kDouble);
302 tensor_pow_scalar(floats, floats, kFloat, kFloat);
303 tensor_pow_scalar(floats, doubles, kFloat, kDouble);
304 }
305
TEST(PowTest,DoubleTensorPowAllScalars)306 TEST(PowTest, DoubleTensorPowAllScalars) {
307 tensor_pow_scalar(doubles, ints, kDouble, kDouble);
308 tensor_pow_scalar(doubles, longs, kDouble, kDouble);
309 tensor_pow_scalar(doubles, floats, kDouble, kDouble);
310 tensor_pow_scalar(doubles, doubles, kDouble, kDouble);
311 }
312
TEST(PowTest,IntScalarPowAllTensors)313 TEST(PowTest, IntScalarPowAllTensors) {
314 scalar_pow_tensor(ints, c10::kInt, ints, c10::kInt);
315 scalar_pow_tensor(ints, c10::kInt, longs, c10::kLong);
316 scalar_pow_tensor(ints, c10::kInt, floats, c10::kFloat);
317 scalar_pow_tensor(ints, c10::kInt, doubles, c10::kDouble);
318 }
319
TEST(PowTest,LongScalarPowAllTensors)320 TEST(PowTest, LongScalarPowAllTensors) {
321 scalar_pow_tensor(longs, c10::kLong, longs, c10::kLong);
322 scalar_pow_tensor(longs, c10::kLong, floats, c10::kFloat);
323 scalar_pow_tensor(longs, c10::kLong, doubles, c10::kDouble);
324 }
325
TEST(PowTest,FloatScalarPowAllTensors)326 TEST(PowTest, FloatScalarPowAllTensors) {
327 scalar_pow_tensor(floats, c10::kFloat, floats, c10::kFloat);
328 scalar_pow_tensor(floats, c10::kFloat, doubles, c10::kDouble);
329 }
330
TEST(PowTest,DoubleScalarPowAllTensors)331 TEST(PowTest, DoubleScalarPowAllTensors) {
332 scalar_pow_tensor(doubles, c10::kDouble, doubles, c10::kDouble);
333 }
334
TEST(PowTest,IntTensorPowIntTensor)335 TEST(PowTest, IntTensorPowIntTensor) {
336 tensor_pow_tensor(ints, c10::kInt, ints, c10::kInt);
337 }
338
TEST(PowTest,LongTensorPowLongTensor)339 TEST(PowTest, LongTensorPowLongTensor) {
340 tensor_pow_tensor(longs, c10::kLong, longs, c10::kLong);
341 }
342
TEST(PowTest,FloatTensorPowFloatTensor)343 TEST(PowTest, FloatTensorPowFloatTensor) {
344 tensor_pow_tensor(floats, c10::kFloat, floats, c10::kFloat);
345 }
346
TEST(PowTest,DoubleTensorPowDoubleTensor)347 TEST(PowTest, DoubleTensorPowDoubleTensor) {
348 tensor_pow_tensor(doubles, c10::kDouble, doubles, c10::kDouble);
349 }
350
TEST(PowTest,TestIntegralPow)351 TEST(PowTest, TestIntegralPow) {
352 test_pow_one(longs);
353 test_pow_one(ints);
354
355 test_squared(longs);
356 test_squared(ints);
357
358 test_cubed(longs);
359 test_cubed(ints);
360
361 test_inverse(longs);
362 test_inverse(ints);
363 }
364