1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/native/Pow.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
5*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
6*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
7*da0073e9SAndroid Build Coastguard Worker #include <torch/types.h>
8*da0073e9SAndroid Build Coastguard Worker #include <torch/utils.h>
9*da0073e9SAndroid Build Coastguard Worker #include <cstdlib>
10*da0073e9SAndroid Build Coastguard Worker #include <iostream>
11*da0073e9SAndroid Build Coastguard Worker #include <type_traits>
12*da0073e9SAndroid Build Coastguard Worker #include <vector>
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker struct DispatchTest : torch::test::SeedingFixture {};
15*da0073e9SAndroid Build Coastguard Worker
TEST_F(DispatchTest,TestAVX2)16*da0073e9SAndroid Build Coastguard Worker TEST_F(DispatchTest, TestAVX2) {
17*da0073e9SAndroid Build Coastguard Worker const std::vector<int> ints{1, 2, 3, 4};
18*da0073e9SAndroid Build Coastguard Worker const std::vector<int> result{1, 4, 27, 256};
19*da0073e9SAndroid Build Coastguard Worker const auto vals_tensor = torch::tensor(ints);
20*da0073e9SAndroid Build Coastguard Worker const auto pows_tensor = torch::tensor(ints);
21*da0073e9SAndroid Build Coastguard Worker #ifdef _WIN32
22*da0073e9SAndroid Build Coastguard Worker _putenv("ATEN_CPU_CAPABILITY=avx2");
23*da0073e9SAndroid Build Coastguard Worker #else
24*da0073e9SAndroid Build Coastguard Worker setenv("ATEN_CPU_CAPABILITY", "avx2", 1);
25*da0073e9SAndroid Build Coastguard Worker #endif
26*da0073e9SAndroid Build Coastguard Worker const auto actual_pow_avx2 = vals_tensor.pow(pows_tensor);
27*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(4)) {
28*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(result[i], actual_pow_avx2[i].item<int>());
29*da0073e9SAndroid Build Coastguard Worker }
30*da0073e9SAndroid Build Coastguard Worker }
31*da0073e9SAndroid Build Coastguard Worker
TEST_F(DispatchTest,TestAVX512)32*da0073e9SAndroid Build Coastguard Worker TEST_F(DispatchTest, TestAVX512) {
33*da0073e9SAndroid Build Coastguard Worker const std::vector<int> ints{1, 2, 3, 4};
34*da0073e9SAndroid Build Coastguard Worker const std::vector<int> result{1, 4, 27, 256};
35*da0073e9SAndroid Build Coastguard Worker const auto vals_tensor = torch::tensor(ints);
36*da0073e9SAndroid Build Coastguard Worker const auto pows_tensor = torch::tensor(ints);
37*da0073e9SAndroid Build Coastguard Worker #ifdef _WIN32
38*da0073e9SAndroid Build Coastguard Worker _putenv("ATEN_CPU_CAPABILITY=avx512");
39*da0073e9SAndroid Build Coastguard Worker #else
40*da0073e9SAndroid Build Coastguard Worker setenv("ATEN_CPU_CAPABILITY", "avx512", 1);
41*da0073e9SAndroid Build Coastguard Worker #endif
42*da0073e9SAndroid Build Coastguard Worker const auto actual_pow_avx512 = vals_tensor.pow(pows_tensor);
43*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(4)) {
44*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(result[i], actual_pow_avx512[i].item<int>());
45*da0073e9SAndroid Build Coastguard Worker }
46*da0073e9SAndroid Build Coastguard Worker }
47*da0073e9SAndroid Build Coastguard Worker
TEST_F(DispatchTest,TestDefault)48*da0073e9SAndroid Build Coastguard Worker TEST_F(DispatchTest, TestDefault) {
49*da0073e9SAndroid Build Coastguard Worker const std::vector<int> ints{1, 2, 3, 4};
50*da0073e9SAndroid Build Coastguard Worker const std::vector<int> result{1, 4, 27, 256};
51*da0073e9SAndroid Build Coastguard Worker const auto vals_tensor = torch::tensor(ints);
52*da0073e9SAndroid Build Coastguard Worker const auto pows_tensor = torch::tensor(ints);
53*da0073e9SAndroid Build Coastguard Worker #ifdef _WIN32
54*da0073e9SAndroid Build Coastguard Worker _putenv("ATEN_CPU_CAPABILITY=default");
55*da0073e9SAndroid Build Coastguard Worker #else
56*da0073e9SAndroid Build Coastguard Worker setenv("ATEN_CPU_CAPABILITY", "default", 1);
57*da0073e9SAndroid Build Coastguard Worker #endif
58*da0073e9SAndroid Build Coastguard Worker const auto actual_pow_default = vals_tensor.pow(pows_tensor);
59*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(4)) {
60*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(result[i], actual_pow_default[i].item<int>());
61*da0073e9SAndroid Build Coastguard Worker }
62*da0073e9SAndroid Build Coastguard Worker }
63