xref: /aosp_15_r20/external/pytorch/test/cpp/api/dispatch.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/native/Pow.h>
4 #include <c10/util/irange.h>
5 #include <test/cpp/api/support.h>
6 #include <torch/torch.h>
7 #include <torch/types.h>
8 #include <torch/utils.h>
9 #include <cstdlib>
10 #include <iostream>
11 #include <type_traits>
12 #include <vector>
13 
14 struct DispatchTest : torch::test::SeedingFixture {};
15 
TEST_F(DispatchTest,TestAVX2)16 TEST_F(DispatchTest, TestAVX2) {
17   const std::vector<int> ints{1, 2, 3, 4};
18   const std::vector<int> result{1, 4, 27, 256};
19   const auto vals_tensor = torch::tensor(ints);
20   const auto pows_tensor = torch::tensor(ints);
21 #ifdef _WIN32
22   _putenv("ATEN_CPU_CAPABILITY=avx2");
23 #else
24   setenv("ATEN_CPU_CAPABILITY", "avx2", 1);
25 #endif
26   const auto actual_pow_avx2 = vals_tensor.pow(pows_tensor);
27   for (const auto i : c10::irange(4)) {
28     ASSERT_EQ(result[i], actual_pow_avx2[i].item<int>());
29   }
30 }
31 
TEST_F(DispatchTest,TestAVX512)32 TEST_F(DispatchTest, TestAVX512) {
33   const std::vector<int> ints{1, 2, 3, 4};
34   const std::vector<int> result{1, 4, 27, 256};
35   const auto vals_tensor = torch::tensor(ints);
36   const auto pows_tensor = torch::tensor(ints);
37 #ifdef _WIN32
38   _putenv("ATEN_CPU_CAPABILITY=avx512");
39 #else
40   setenv("ATEN_CPU_CAPABILITY", "avx512", 1);
41 #endif
42   const auto actual_pow_avx512 = vals_tensor.pow(pows_tensor);
43   for (const auto i : c10::irange(4)) {
44     ASSERT_EQ(result[i], actual_pow_avx512[i].item<int>());
45   }
46 }
47 
TEST_F(DispatchTest,TestDefault)48 TEST_F(DispatchTest, TestDefault) {
49   const std::vector<int> ints{1, 2, 3, 4};
50   const std::vector<int> result{1, 4, 27, 256};
51   const auto vals_tensor = torch::tensor(ints);
52   const auto pows_tensor = torch::tensor(ints);
53 #ifdef _WIN32
54   _putenv("ATEN_CPU_CAPABILITY=default");
55 #else
56   setenv("ATEN_CPU_CAPABILITY", "default", 1);
57 #endif
58   const auto actual_pow_default = vals_tensor.pow(pows_tensor);
59   for (const auto i : c10::irange(4)) {
60     ASSERT_EQ(result[i], actual_pow_default[i].item<int>());
61   }
62 }
63