xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/atest.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <c10/util/irange.h>
5 
6 #include <iostream>
7 using namespace std;
8 using namespace at;
9 
10 class atest : public ::testing::Test {
11  protected:
SetUp()12   void SetUp() override {
13     x_tensor = tensor({10, -1, 0, 1, -10});
14     y_tensor = tensor({-10, 1, 0, -1, 10});
15     x_logical = tensor({1, 1, 0, 1, 0});
16     y_logical = tensor({0, 1, 0, 1, 1});
17     x_float = tensor({2.0, 2.4, 5.6, 7.0, 36.0});
18     y_float = tensor({1.0, 1.1, 8.7, 10.0, 24.0});
19   }
20 
21   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
22   Tensor x_tensor, y_tensor;
23   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
24   Tensor x_logical, y_logical;
25   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
26   Tensor x_float, y_float;
27   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
28   const int INT = 1;
29   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
30   const int FLOAT = 2;
31   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
32   const int INTFLOAT = 3;
33   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
34   const int INTBOOL = 5;
35   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
36   const int INTBOOLFLOAT = 7;
37 };
38 
39 namespace BinaryOpsKernel {
40 const int IntMask = 1; // test dtype = kInt
41 const int FloatMask = 2; // test dtype = kFloat
42 const int BoolMask = 4; // test dtype = kBool
43 } // namespace BinaryOpsKernel
44 
45 template <typename T, typename... Args>
unit_binary_ops_test(T func,const Tensor & x_tensor,const Tensor & y_tensor,const Tensor & exp,ScalarType dtype,Args...args)46 void unit_binary_ops_test(
47     T func,
48     const Tensor& x_tensor,
49     const Tensor& y_tensor,
50     const Tensor& exp,
51     ScalarType dtype,
52     Args... args) {
53   auto out_tensor = empty({5}, dtype);
54   func(out_tensor, x_tensor.to(dtype), y_tensor.to(dtype), args...);
55   ASSERT_EQ(out_tensor.dtype(), dtype);
56   if (dtype == kFloat) {
57     ASSERT_TRUE(exp.to(dtype).allclose(out_tensor));
58   } else {
59     ASSERT_TRUE(exp.to(dtype).equal(out_tensor));
60   }
61 }
62 
63 /*
64   template function for running binary operator test
65   - exp: expected output
66   - func: function to be tested
67   - option: 3 bits,
68     - 1st bit: Test op over integer tensors
69     - 2nd bit: Test op over float tensors
70     - 3rd bit: Test op over boolean tensors
71     For example, if function should be tested over integer/boolean but not for
72     float, option will be 1 * 1 + 0 * 2 + 1 * 4 = 5. If tested over all the
73     type, option should be 7.
74 */
75 template <typename T, typename... Args>
run_binary_ops_test(T func,const Tensor & x_tensor,const Tensor & y_tensor,const Tensor & exp,int option,Args...args)76 void run_binary_ops_test(
77     T func,
78     const Tensor& x_tensor,
79     const Tensor& y_tensor,
80     const Tensor& exp,
81     int option,
82     Args... args) {
83   // Test op over integer tensors
84   if (option & BinaryOpsKernel::IntMask) {
85     unit_binary_ops_test(func, x_tensor, y_tensor, exp, kInt, args...);
86   }
87 
88   // Test op over float tensors
89   if (option & BinaryOpsKernel::FloatMask) {
90     unit_binary_ops_test(func, x_tensor, y_tensor, exp, kFloat, args...);
91   }
92 
93   // Test op over boolean tensors
94   if (option & BinaryOpsKernel::BoolMask) {
95     unit_binary_ops_test(func, x_tensor, y_tensor, exp, kBool, args...);
96   }
97 }
98 
trace()99 void trace() {
100   Tensor foo = rand({12, 12});
101 
102   // ASSERT foo is 2-dimensional and holds floats.
103   auto foo_a = foo.accessor<float, 2>();
104   float trace = 0;
105 
106   for (const auto i : c10::irange(foo_a.size(0))) {
107     trace += foo_a[i][i];
108   }
109 
110   ASSERT_FLOAT_EQ(foo.trace().item<float>(), trace);
111 }
112 
TEST_F(atest,operators)113 TEST_F(atest, operators) {
114   int a = 0b10101011;
115   int b = 0b01111011;
116 
117   auto a_tensor = tensor({a});
118   auto b_tensor = tensor({b});
119 
120   ASSERT_TRUE(tensor({~a}).equal(~a_tensor));
121   ASSERT_TRUE(tensor({a | b}).equal(a_tensor | b_tensor));
122   ASSERT_TRUE(tensor({a & b}).equal(a_tensor & b_tensor));
123   ASSERT_TRUE(tensor({a ^ b}).equal(a_tensor ^ b_tensor));
124 }
125 
TEST_F(atest,logical_and_operators)126 TEST_F(atest, logical_and_operators) {
127   auto exp_tensor = tensor({0, 1, 0, 1, 0});
128   run_binary_ops_test(
129       logical_and_out, x_logical, y_logical, exp_tensor, INTBOOL);
130 }
131 
TEST_F(atest,logical_or_operators)132 TEST_F(atest, logical_or_operators) {
133   auto exp_tensor = tensor({1, 1, 0, 1, 1});
134   run_binary_ops_test(
135       logical_or_out, x_logical, y_logical, exp_tensor, INTBOOL);
136 }
137 
TEST_F(atest,logical_xor_operators)138 TEST_F(atest, logical_xor_operators) {
139   auto exp_tensor = tensor({1, 0, 0, 0, 1});
140   run_binary_ops_test(
141       logical_xor_out, x_logical, y_logical, exp_tensor, INTBOOL);
142 }
143 
TEST_F(atest,lt_operators)144 TEST_F(atest, lt_operators) {
145   auto exp_tensor = tensor({0, 0, 0, 0, 1});
146   run_binary_ops_test<
147       at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&)>(
148       lt_out, x_logical, y_logical, exp_tensor, INTBOOL);
149 }
150 
TEST_F(atest,le_operators)151 TEST_F(atest, le_operators) {
152   auto exp_tensor = tensor({0, 1, 1, 1, 1});
153   run_binary_ops_test<
154       at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&)>(
155       le_out, x_logical, y_logical, exp_tensor, INTBOOL);
156 }
157 
TEST_F(atest,gt_operators)158 TEST_F(atest, gt_operators) {
159   auto exp_tensor = tensor({1, 0, 0, 0, 0});
160   run_binary_ops_test<
161       at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&)>(
162       gt_out, x_logical, y_logical, exp_tensor, INTBOOL);
163 }
164 
TEST_F(atest,ge_operators)165 TEST_F(atest, ge_operators) {
166   auto exp_tensor = tensor({1, 1, 1, 1, 0});
167   run_binary_ops_test<
168       at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&)>(
169       ge_out, x_logical, y_logical, exp_tensor, INTBOOL);
170 }
171 
TEST_F(atest,eq_operators)172 TEST_F(atest, eq_operators) {
173   auto exp_tensor = tensor({0, 1, 1, 1, 0});
174   run_binary_ops_test<
175       at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&)>(
176       eq_out, x_logical, y_logical, exp_tensor, INTBOOL);
177 }
178 
TEST_F(atest,ne_operators)179 TEST_F(atest, ne_operators) {
180   auto exp_tensor = tensor({1, 0, 0, 0, 1});
181   run_binary_ops_test<
182       at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&)>(
183       ne_out, x_logical, y_logical, exp_tensor, INTBOOL);
184 }
185 
TEST_F(atest,add_operators)186 TEST_F(atest, add_operators) {
187   auto exp_tensor = tensor({-10, 1, 0, -1, 10});
188   run_binary_ops_test<
189       at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Scalar&)>(
190       add_out, x_tensor, y_tensor, exp_tensor, INTBOOL, 2);
191 }
192 
TEST_F(atest,max_operators)193 TEST_F(atest, max_operators) {
194   auto exp_tensor = tensor({10, 1, 0, 1, 10});
195   run_binary_ops_test<
196       at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&)>(
197       max_out, x_tensor, y_tensor, exp_tensor, INTBOOLFLOAT);
198 }
199 
TEST_F(atest,min_operators)200 TEST_F(atest, min_operators) {
201   auto exp_tensor = tensor({-10, -1, 0, -1, -10});
202   run_binary_ops_test<
203       at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&)>(
204       min_out, x_tensor, y_tensor, exp_tensor, INTBOOLFLOAT);
205 }
206 
TEST_F(atest,sigmoid_backward_operator)207 TEST_F(atest, sigmoid_backward_operator) {
208   auto exp_tensor = tensor({-1100, 0, 0, -2, 900});
209   // only test with type Float
210   run_binary_ops_test<
211       at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&)>(
212       sigmoid_backward_out, x_tensor, y_tensor, exp_tensor, FLOAT);
213 }
214 
TEST_F(atest,fmod_tensor_operators)215 TEST_F(atest, fmod_tensor_operators) {
216   auto exp_tensor = tensor({0.0, 0.2, 5.6, 7.0, 12.0});
217   run_binary_ops_test<
218       at::Tensor& (*)(at::Tensor&, const at::Tensor&, const at::Tensor&)>(
219       fmod_out, x_float, y_float, exp_tensor, INTFLOAT);
220 }
221 
222 // TEST_CASE( "atest", "[]" ) {
TEST_F(atest,atest)223 TEST_F(atest, atest) {
224   manual_seed(123);
225 
226   auto foo = rand({12, 6});
227 
228   ASSERT_EQ(foo.size(0), 12);
229   ASSERT_EQ(foo.size(1), 6);
230 
231   foo = foo + foo * 3;
232   foo -= 4;
233 
234   Scalar a = 4;
235   float b = a.to<float>();
236   ASSERT_EQ(b, 4);
237 
238   foo = ((foo * foo) == (foo.pow(3))).to(kByte);
239   foo = 2 + (foo + 1);
240   // foo = foo[3];
241   auto foo_v = foo.accessor<uint8_t, 2>();
242 
243   for (const auto i : c10::irange(foo_v.size(0))) {
244     for (const auto j : c10::irange(foo_v.size(1))) {
245       foo_v[i][j]++;
246     }
247   }
248 
249   ASSERT_TRUE(foo.equal(4 * ones({12, 6}, kByte)));
250 
251   trace();
252 
253   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
254   float data[] = {1, 2, 3, 4, 5, 6};
255 
256   auto f = from_blob(data, {1, 2, 3});
257   auto f_a = f.accessor<float, 3>();
258 
259   ASSERT_EQ(f_a[0][0][0], 1.0);
260   ASSERT_EQ(f_a[0][1][1], 5.0);
261 
262   ASSERT_EQ(f.strides()[0], 6);
263   ASSERT_EQ(f.strides()[1], 3);
264   ASSERT_EQ(f.strides()[2], 1);
265   ASSERT_EQ(f.sizes()[0], 1);
266   ASSERT_EQ(f.sizes()[1], 2);
267   ASSERT_EQ(f.sizes()[2], 3);
268 
269   // TODO(ezyang): maybe do a more precise exception type.
270   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
271   ASSERT_THROW(f.resize_({3, 4, 5}), std::exception);
272   {
273     int isgone = 0;
274     {
275       auto f2 = from_blob(data, {1, 2, 3}, [&](void*) { isgone++; });
276     }
277     ASSERT_EQ(isgone, 1);
278   }
279   {
280     int isgone = 0;
281     Tensor a_view;
282     {
283       auto f2 = from_blob(data, {1, 2, 3}, [&](void*) { isgone++; });
284       a_view = f2.view({3, 2, 1});
285     }
286     ASSERT_EQ(isgone, 0);
287     a_view.reset();
288     ASSERT_EQ(isgone, 1);
289   }
290 
291   if (at::hasCUDA()) {
292     int isgone = 0;
293     {
294       auto base = at::empty({1, 2, 3}, TensorOptions(kCUDA));
295       auto f2 = from_blob(base.mutable_data_ptr(), {1, 2, 3}, [&](void*) { isgone++; });
296     }
297     ASSERT_EQ(isgone, 1);
298 
299     // Attempt to specify the wrong device in from_blob
300     auto t = at::empty({1, 2, 3}, TensorOptions(kCUDA, 0));
301     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
302     EXPECT_ANY_THROW(from_blob(t.data_ptr(), {1, 2, 3}, at::Device(kCUDA, 1)));
303 
304     // Infers the correct device
305     auto t_ = from_blob(t.data_ptr(), {1, 2, 3}, kCUDA);
306     ASSERT_EQ(t_.device(), at::Device(kCUDA, 0));
307   }
308 }
309