xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/scalar_tensor_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/Utils.h>
5 #include <c10/util/accumulate.h>
6 
7 #include <algorithm>
8 #include <iostream>
9 #include <numeric>
10 
11 using namespace at;
12 
13 #define TRY_CATCH_ELSE(fn, catc, els)                          \
14   {                                                            \
15     /* avoid mistakenly passing if els code throws exception*/ \
16     bool _passed = false;                                      \
17     try {                                                      \
18       fn;                                                      \
19       _passed = true;                                          \
20       els;                                                     \
21     } catch (std::exception & e) {                             \
22       ASSERT_FALSE(_passed);                                   \
23       catc;                                                    \
24     }                                                          \
25   }
26 
require_equal_size_dim(const Tensor & lhs,const Tensor & rhs)27 void require_equal_size_dim(const Tensor &lhs, const Tensor &rhs) {
28   ASSERT_EQ(lhs.dim(), rhs.dim());
29   ASSERT_TRUE(lhs.sizes().equals(rhs.sizes()));
30 }
31 
should_expand(const IntArrayRef & from_size,const IntArrayRef & to_size)32 bool should_expand(const IntArrayRef &from_size, const IntArrayRef &to_size) {
33   if (from_size.size() > to_size.size()) {
34     return false;
35   }
36   for (auto from_dim_it = from_size.rbegin(); from_dim_it != from_size.rend();
37        ++from_dim_it) {
38     for (auto to_dim_it = to_size.rbegin(); to_dim_it != to_size.rend();
39          ++to_dim_it) {
40       if (*from_dim_it != 1 && *from_dim_it != *to_dim_it) {
41         return false;
42       }
43     }
44   }
45   return true;
46 }
47 
test(DeprecatedTypeProperties & T)48 void test(DeprecatedTypeProperties &T) {
49   std::vector<std::vector<int64_t>> sizes = {{}, {0}, {1}, {1, 1}, {2}};
50 
51   // single-tensor/size tests
52   for (auto s = sizes.begin(); s != sizes.end(); ++s) {
53     // verify that the dim, sizes, strides, etc match what was requested.
54     auto t = ones(*s, T);
55     ASSERT_EQ((size_t)t.dim(), s->size());
56     ASSERT_EQ((size_t)t.ndimension(), s->size());
57     ASSERT_TRUE(t.sizes().equals(*s));
58     ASSERT_EQ(t.strides().size(), s->size());
59     const auto numel = c10::multiply_integers(s->begin(), s->end());
60     ASSERT_EQ(t.numel(), numel);
61     // verify we can output
62     std::stringstream ss;
63     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
64     ASSERT_NO_THROW(ss << t << std::endl);
65 
66     // set_
67     auto t2 = ones(*s, T);
68     t2.set_();
69     require_equal_size_dim(t2, ones({0}, T));
70 
71     // unsqueeze
72     ASSERT_EQ(t.unsqueeze(0).dim(), t.dim() + 1);
73 
74     // unsqueeze_
75     {
76       auto t2 = ones(*s, T);
77       auto r = t2.unsqueeze_(0);
78       ASSERT_EQ(r.dim(), t.dim() + 1);
79     }
80 
81     // squeeze (with dimension argument)
82     if (t.dim() == 0 || t.sizes()[0] == 1) {
83       ASSERT_EQ(t.squeeze(0).dim(), std::max<int64_t>(t.dim() - 1, 0));
84     } else {
85       // In PyTorch, it is a no-op to try to squeeze a dimension that has size
86       // != 1; in NumPy this is an error.
87       ASSERT_EQ(t.squeeze(0).dim(), t.dim());
88     }
89 
90     // squeeze (with no dimension argument)
91     {
92       std::vector<int64_t> size_without_ones;
93       for (auto size : *s) {
94         if (size != 1) {
95           size_without_ones.push_back(size);
96         }
97       }
98       auto result = t.squeeze();
99       require_equal_size_dim(result, ones(size_without_ones, T));
100     }
101 
102     {
103       // squeeze_ (with dimension argument)
104       auto t2 = ones(*s, T);
105       if (t2.dim() == 0 || t2.sizes()[0] == 1) {
106         ASSERT_EQ(t2.squeeze_(0).dim(), std::max<int64_t>(t.dim() - 1, 0));
107       } else {
108         // In PyTorch, it is a no-op to try to squeeze a dimension that has size
109         // != 1; in NumPy this is an error.
110         ASSERT_EQ(t2.squeeze_(0).dim(), t.dim());
111       }
112     }
113 
114     // squeeze_ (with no dimension argument)
115     {
116       auto t2 = ones(*s, T);
117       std::vector<int64_t> size_without_ones;
118       for (auto size : *s) {
119         if (size != 1) {
120           size_without_ones.push_back(size);
121         }
122       }
123       auto r = t2.squeeze_();
124       require_equal_size_dim(t2, ones(size_without_ones, T));
125     }
126 
127     // reduce (with dimension argument and with 1 return argument)
128     if (t.numel() != 0) {
129       ASSERT_EQ(t.sum(0).dim(), std::max<int64_t>(t.dim() - 1, 0));
130     } else {
131       ASSERT_TRUE(t.sum(0).equal(at::zeros({}, T)));
132     }
133 
134     // reduce (with dimension argument and with 2 return arguments)
135     if (t.numel() != 0) {
136       auto ret = t.min(0);
137       ASSERT_EQ(std::get<0>(ret).dim(), std::max<int64_t>(t.dim() - 1, 0));
138       ASSERT_EQ(std::get<1>(ret).dim(), std::max<int64_t>(t.dim() - 1, 0));
139     } else {
140       // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
141       ASSERT_ANY_THROW(t.min(0));
142     }
143 
144     // simple indexing
145     if (t.dim() > 0 && t.numel() != 0) {
146       ASSERT_EQ(t[0].dim(), std::max<int64_t>(t.dim() - 1, 0));
147     } else {
148       // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
149       ASSERT_ANY_THROW(t[0]);
150     }
151 
152     // fill_ (argument to fill_ can only be a 0-dim tensor)
153     TRY_CATCH_ELSE(
154         t.fill_(t.sum(0)), ASSERT_GT(t.dim(), 1), ASSERT_LE(t.dim(), 1));
155   }
156 
157   for (auto lhs_it = sizes.begin(); lhs_it != sizes.end(); ++lhs_it) {
158     // NOLINTNEXTLINE(modernize-loop-convert)
159     for (auto rhs_it = sizes.begin(); rhs_it != sizes.end(); ++rhs_it) {
160       // is_same_size should only match if they are the same shape
161       {
162         auto lhs = ones(*lhs_it, T);
163         auto rhs = ones(*rhs_it, T);
164         if (*lhs_it != *rhs_it) {
165           ASSERT_FALSE(lhs.is_same_size(rhs));
166           ASSERT_FALSE(rhs.is_same_size(lhs));
167         }
168       }
169       // forced size functions (resize_, resize_as, set_)
170       // resize_
171       {
172         {
173          auto lhs = ones(*lhs_it, T);
174          auto rhs = ones(*rhs_it, T);
175          lhs.resize_(*rhs_it);
176          require_equal_size_dim(lhs, rhs);
177         }
178         // resize_as_
179         {
180           auto lhs = ones(*lhs_it, T);
181           auto rhs = ones(*rhs_it, T);
182           lhs.resize_as_(rhs);
183           require_equal_size_dim(lhs, rhs);
184         }
185         // set_
186         {
187           {
188             // with tensor
189             auto lhs = ones(*lhs_it, T);
190             auto rhs = ones(*rhs_it, T);
191             lhs.set_(rhs);
192             require_equal_size_dim(lhs, rhs);
193           }
194           {
195             // with storage
196             auto lhs = ones(*lhs_it, T);
197             auto rhs = ones(*rhs_it, T);
198             lhs.set_(rhs.storage());
199             // should not be dim 0 because an empty storage is dim 1; all other
200             // storages aren't scalars
201             ASSERT_NE(lhs.dim(), 0);
202           }
203           {
204             // with storage, offset, sizes, strides
205             auto lhs = ones(*lhs_it, T);
206             auto rhs = ones(*rhs_it, T);
207             lhs.set_(rhs.storage(), rhs.storage_offset(), rhs.sizes(), rhs.strides());
208             require_equal_size_dim(lhs, rhs);
209           }
210         }
211       }
212 
213       // view
214       {
215         auto lhs = ones(*lhs_it, T);
216         auto rhs = ones(*rhs_it, T);
217         auto rhs_size = *rhs_it;
218         TRY_CATCH_ELSE(auto result = lhs.view(rhs_size),
219                        ASSERT_NE(lhs.numel(), rhs.numel()),
220                        ASSERT_EQ(lhs.numel(), rhs.numel());
221                        require_equal_size_dim(result, rhs););
222       }
223 
224       // take
225       {
226         auto lhs = ones(*lhs_it, T);
227         auto rhs = zeros(*rhs_it, T).toType(ScalarType::Long);
228         TRY_CATCH_ELSE(auto result = lhs.take(rhs),
229                        ASSERT_EQ(lhs.numel(), 0); ASSERT_NE(rhs.numel(), 0),
230                        require_equal_size_dim(result, rhs));
231       }
232 
233       // put
234       {
235         auto lhs = ones(*lhs_it, T);
236         auto rhs1 = zeros(*rhs_it, T).toType(ScalarType::Long);
237         auto rhs2 = zeros(*rhs_it, T);
238         TRY_CATCH_ELSE(auto result = lhs.put(rhs1, rhs2),
239                        ASSERT_EQ(lhs.numel(), 0); ASSERT_NE(rhs1.numel(), 0),
240                        require_equal_size_dim(result, lhs));
241       }
242 
243       // ger
244       {
245         auto lhs = ones(*lhs_it, T);
246         auto rhs = ones(*rhs_it, T);
247         TRY_CATCH_ELSE(auto result = lhs.ger(rhs),
248                        ASSERT_TRUE(
249                            (lhs.numel() == 0 || rhs.numel() == 0 ||
250                             lhs.dim() != 1 || rhs.dim() != 1)),
251                        [&]() {
252                          int64_t dim0 = lhs.dim() == 0 ? 1 : lhs.size(0);
253                          int64_t dim1 = rhs.dim() == 0 ? 1 : rhs.size(0);
254                          require_equal_size_dim(
255                              result, at::empty({dim0, dim1}, result.options()));
256                        }(););
257       }
258 
259       // expand
260       {
261         auto lhs = ones(*lhs_it, T);
262         auto lhs_size = *lhs_it;
263         auto rhs = ones(*rhs_it, T);
264         auto rhs_size = *rhs_it;
265         bool should_pass = should_expand(lhs_size, rhs_size);
266         TRY_CATCH_ELSE(auto result = lhs.expand(rhs_size),
267                        ASSERT_FALSE(should_pass),
268                        ASSERT_TRUE(should_pass);
269                        require_equal_size_dim(result, rhs););
270 
271         // in-place functions (would be good if we can also do a non-broadcasting
272         // one, b/c broadcasting functions will always end up operating on tensors
273         // of same size; is there an example of this outside of assign_ ?)
274         {
275           bool should_pass_inplace = should_expand(rhs_size, lhs_size);
276           TRY_CATCH_ELSE(lhs.add_(rhs),
277                          ASSERT_FALSE(should_pass_inplace),
278                          ASSERT_TRUE(should_pass_inplace);
279                          require_equal_size_dim(lhs, ones(*lhs_it, T)););
280         }
281       }
282     }
283   }
284 }
285 
TEST(TestScalarTensor,TestScalarTensorCPU)286 TEST(TestScalarTensor, TestScalarTensorCPU) {
287   manual_seed(123);
288   test(CPU(kFloat));
289 }
290 
TEST(TestScalarTensor,TestScalarTensorCUDA)291 TEST(TestScalarTensor, TestScalarTensorCUDA) {
292   manual_seed(123);
293 
294   if (at::hasCUDA()) {
295     test(CUDA(kFloat));
296   }
297 }
298 
TEST(TestScalarTensor,TestScalarTensorMPS)299 TEST(TestScalarTensor, TestScalarTensorMPS) {
300   manual_seed(123);
301 
302   if (at::hasMPS()) {
303     test(MPS(kFloat));
304   }
305 }
306