1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/Utils.h>
5 #include <c10/util/accumulate.h>
6
7 #include <algorithm>
8 #include <iostream>
9 #include <numeric>
10
11 using namespace at;
12
13 #define TRY_CATCH_ELSE(fn, catc, els) \
14 { \
15 /* avoid mistakenly passing if els code throws exception*/ \
16 bool _passed = false; \
17 try { \
18 fn; \
19 _passed = true; \
20 els; \
21 } catch (std::exception & e) { \
22 ASSERT_FALSE(_passed); \
23 catc; \
24 } \
25 }
26
require_equal_size_dim(const Tensor & lhs,const Tensor & rhs)27 void require_equal_size_dim(const Tensor &lhs, const Tensor &rhs) {
28 ASSERT_EQ(lhs.dim(), rhs.dim());
29 ASSERT_TRUE(lhs.sizes().equals(rhs.sizes()));
30 }
31
should_expand(const IntArrayRef & from_size,const IntArrayRef & to_size)32 bool should_expand(const IntArrayRef &from_size, const IntArrayRef &to_size) {
33 if (from_size.size() > to_size.size()) {
34 return false;
35 }
36 for (auto from_dim_it = from_size.rbegin(); from_dim_it != from_size.rend();
37 ++from_dim_it) {
38 for (auto to_dim_it = to_size.rbegin(); to_dim_it != to_size.rend();
39 ++to_dim_it) {
40 if (*from_dim_it != 1 && *from_dim_it != *to_dim_it) {
41 return false;
42 }
43 }
44 }
45 return true;
46 }
47
test(DeprecatedTypeProperties & T)48 void test(DeprecatedTypeProperties &T) {
49 std::vector<std::vector<int64_t>> sizes = {{}, {0}, {1}, {1, 1}, {2}};
50
51 // single-tensor/size tests
52 for (auto s = sizes.begin(); s != sizes.end(); ++s) {
53 // verify that the dim, sizes, strides, etc match what was requested.
54 auto t = ones(*s, T);
55 ASSERT_EQ((size_t)t.dim(), s->size());
56 ASSERT_EQ((size_t)t.ndimension(), s->size());
57 ASSERT_TRUE(t.sizes().equals(*s));
58 ASSERT_EQ(t.strides().size(), s->size());
59 const auto numel = c10::multiply_integers(s->begin(), s->end());
60 ASSERT_EQ(t.numel(), numel);
61 // verify we can output
62 std::stringstream ss;
63 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
64 ASSERT_NO_THROW(ss << t << std::endl);
65
66 // set_
67 auto t2 = ones(*s, T);
68 t2.set_();
69 require_equal_size_dim(t2, ones({0}, T));
70
71 // unsqueeze
72 ASSERT_EQ(t.unsqueeze(0).dim(), t.dim() + 1);
73
74 // unsqueeze_
75 {
76 auto t2 = ones(*s, T);
77 auto r = t2.unsqueeze_(0);
78 ASSERT_EQ(r.dim(), t.dim() + 1);
79 }
80
81 // squeeze (with dimension argument)
82 if (t.dim() == 0 || t.sizes()[0] == 1) {
83 ASSERT_EQ(t.squeeze(0).dim(), std::max<int64_t>(t.dim() - 1, 0));
84 } else {
85 // In PyTorch, it is a no-op to try to squeeze a dimension that has size
86 // != 1; in NumPy this is an error.
87 ASSERT_EQ(t.squeeze(0).dim(), t.dim());
88 }
89
90 // squeeze (with no dimension argument)
91 {
92 std::vector<int64_t> size_without_ones;
93 for (auto size : *s) {
94 if (size != 1) {
95 size_without_ones.push_back(size);
96 }
97 }
98 auto result = t.squeeze();
99 require_equal_size_dim(result, ones(size_without_ones, T));
100 }
101
102 {
103 // squeeze_ (with dimension argument)
104 auto t2 = ones(*s, T);
105 if (t2.dim() == 0 || t2.sizes()[0] == 1) {
106 ASSERT_EQ(t2.squeeze_(0).dim(), std::max<int64_t>(t.dim() - 1, 0));
107 } else {
108 // In PyTorch, it is a no-op to try to squeeze a dimension that has size
109 // != 1; in NumPy this is an error.
110 ASSERT_EQ(t2.squeeze_(0).dim(), t.dim());
111 }
112 }
113
114 // squeeze_ (with no dimension argument)
115 {
116 auto t2 = ones(*s, T);
117 std::vector<int64_t> size_without_ones;
118 for (auto size : *s) {
119 if (size != 1) {
120 size_without_ones.push_back(size);
121 }
122 }
123 auto r = t2.squeeze_();
124 require_equal_size_dim(t2, ones(size_without_ones, T));
125 }
126
127 // reduce (with dimension argument and with 1 return argument)
128 if (t.numel() != 0) {
129 ASSERT_EQ(t.sum(0).dim(), std::max<int64_t>(t.dim() - 1, 0));
130 } else {
131 ASSERT_TRUE(t.sum(0).equal(at::zeros({}, T)));
132 }
133
134 // reduce (with dimension argument and with 2 return arguments)
135 if (t.numel() != 0) {
136 auto ret = t.min(0);
137 ASSERT_EQ(std::get<0>(ret).dim(), std::max<int64_t>(t.dim() - 1, 0));
138 ASSERT_EQ(std::get<1>(ret).dim(), std::max<int64_t>(t.dim() - 1, 0));
139 } else {
140 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
141 ASSERT_ANY_THROW(t.min(0));
142 }
143
144 // simple indexing
145 if (t.dim() > 0 && t.numel() != 0) {
146 ASSERT_EQ(t[0].dim(), std::max<int64_t>(t.dim() - 1, 0));
147 } else {
148 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
149 ASSERT_ANY_THROW(t[0]);
150 }
151
152 // fill_ (argument to fill_ can only be a 0-dim tensor)
153 TRY_CATCH_ELSE(
154 t.fill_(t.sum(0)), ASSERT_GT(t.dim(), 1), ASSERT_LE(t.dim(), 1));
155 }
156
157 for (auto lhs_it = sizes.begin(); lhs_it != sizes.end(); ++lhs_it) {
158 // NOLINTNEXTLINE(modernize-loop-convert)
159 for (auto rhs_it = sizes.begin(); rhs_it != sizes.end(); ++rhs_it) {
160 // is_same_size should only match if they are the same shape
161 {
162 auto lhs = ones(*lhs_it, T);
163 auto rhs = ones(*rhs_it, T);
164 if (*lhs_it != *rhs_it) {
165 ASSERT_FALSE(lhs.is_same_size(rhs));
166 ASSERT_FALSE(rhs.is_same_size(lhs));
167 }
168 }
169 // forced size functions (resize_, resize_as, set_)
170 // resize_
171 {
172 {
173 auto lhs = ones(*lhs_it, T);
174 auto rhs = ones(*rhs_it, T);
175 lhs.resize_(*rhs_it);
176 require_equal_size_dim(lhs, rhs);
177 }
178 // resize_as_
179 {
180 auto lhs = ones(*lhs_it, T);
181 auto rhs = ones(*rhs_it, T);
182 lhs.resize_as_(rhs);
183 require_equal_size_dim(lhs, rhs);
184 }
185 // set_
186 {
187 {
188 // with tensor
189 auto lhs = ones(*lhs_it, T);
190 auto rhs = ones(*rhs_it, T);
191 lhs.set_(rhs);
192 require_equal_size_dim(lhs, rhs);
193 }
194 {
195 // with storage
196 auto lhs = ones(*lhs_it, T);
197 auto rhs = ones(*rhs_it, T);
198 lhs.set_(rhs.storage());
199 // should not be dim 0 because an empty storage is dim 1; all other
200 // storages aren't scalars
201 ASSERT_NE(lhs.dim(), 0);
202 }
203 {
204 // with storage, offset, sizes, strides
205 auto lhs = ones(*lhs_it, T);
206 auto rhs = ones(*rhs_it, T);
207 lhs.set_(rhs.storage(), rhs.storage_offset(), rhs.sizes(), rhs.strides());
208 require_equal_size_dim(lhs, rhs);
209 }
210 }
211 }
212
213 // view
214 {
215 auto lhs = ones(*lhs_it, T);
216 auto rhs = ones(*rhs_it, T);
217 auto rhs_size = *rhs_it;
218 TRY_CATCH_ELSE(auto result = lhs.view(rhs_size),
219 ASSERT_NE(lhs.numel(), rhs.numel()),
220 ASSERT_EQ(lhs.numel(), rhs.numel());
221 require_equal_size_dim(result, rhs););
222 }
223
224 // take
225 {
226 auto lhs = ones(*lhs_it, T);
227 auto rhs = zeros(*rhs_it, T).toType(ScalarType::Long);
228 TRY_CATCH_ELSE(auto result = lhs.take(rhs),
229 ASSERT_EQ(lhs.numel(), 0); ASSERT_NE(rhs.numel(), 0),
230 require_equal_size_dim(result, rhs));
231 }
232
233 // put
234 {
235 auto lhs = ones(*lhs_it, T);
236 auto rhs1 = zeros(*rhs_it, T).toType(ScalarType::Long);
237 auto rhs2 = zeros(*rhs_it, T);
238 TRY_CATCH_ELSE(auto result = lhs.put(rhs1, rhs2),
239 ASSERT_EQ(lhs.numel(), 0); ASSERT_NE(rhs1.numel(), 0),
240 require_equal_size_dim(result, lhs));
241 }
242
243 // ger
244 {
245 auto lhs = ones(*lhs_it, T);
246 auto rhs = ones(*rhs_it, T);
247 TRY_CATCH_ELSE(auto result = lhs.ger(rhs),
248 ASSERT_TRUE(
249 (lhs.numel() == 0 || rhs.numel() == 0 ||
250 lhs.dim() != 1 || rhs.dim() != 1)),
251 [&]() {
252 int64_t dim0 = lhs.dim() == 0 ? 1 : lhs.size(0);
253 int64_t dim1 = rhs.dim() == 0 ? 1 : rhs.size(0);
254 require_equal_size_dim(
255 result, at::empty({dim0, dim1}, result.options()));
256 }(););
257 }
258
259 // expand
260 {
261 auto lhs = ones(*lhs_it, T);
262 auto lhs_size = *lhs_it;
263 auto rhs = ones(*rhs_it, T);
264 auto rhs_size = *rhs_it;
265 bool should_pass = should_expand(lhs_size, rhs_size);
266 TRY_CATCH_ELSE(auto result = lhs.expand(rhs_size),
267 ASSERT_FALSE(should_pass),
268 ASSERT_TRUE(should_pass);
269 require_equal_size_dim(result, rhs););
270
271 // in-place functions (would be good if we can also do a non-broadcasting
272 // one, b/c broadcasting functions will always end up operating on tensors
273 // of same size; is there an example of this outside of assign_ ?)
274 {
275 bool should_pass_inplace = should_expand(rhs_size, lhs_size);
276 TRY_CATCH_ELSE(lhs.add_(rhs),
277 ASSERT_FALSE(should_pass_inplace),
278 ASSERT_TRUE(should_pass_inplace);
279 require_equal_size_dim(lhs, ones(*lhs_it, T)););
280 }
281 }
282 }
283 }
284 }
285
TEST(TestScalarTensor,TestScalarTensorCPU)286 TEST(TestScalarTensor, TestScalarTensorCPU) {
287 manual_seed(123);
288 test(CPU(kFloat));
289 }
290
TEST(TestScalarTensor,TestScalarTensorCUDA)291 TEST(TestScalarTensor, TestScalarTensorCUDA) {
292 manual_seed(123);
293
294 if (at::hasCUDA()) {
295 test(CUDA(kFloat));
296 }
297 }
298
TEST(TestScalarTensor,TestScalarTensorMPS)299 TEST(TestScalarTensor, TestScalarTensorMPS) {
300 manual_seed(123);
301
302 if (at::hasMPS()) {
303 test(MPS(kFloat));
304 }
305 }
306