1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #include "main.h"
11
12 #include <Eigen/CXX11/Tensor>
13
14 using Eigen::Tensor;
15
16 template <int DataLayout>
test_simple_broadcasting()17 static void test_simple_broadcasting()
18 {
19 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
20 tensor.setRandom();
21 array<ptrdiff_t, 4> broadcasts;
22 broadcasts[0] = 1;
23 broadcasts[1] = 1;
24 broadcasts[2] = 1;
25 broadcasts[3] = 1;
26
27 Tensor<float, 4, DataLayout> no_broadcast;
28 no_broadcast = tensor.broadcast(broadcasts);
29
30 VERIFY_IS_EQUAL(no_broadcast.dimension(0), 2);
31 VERIFY_IS_EQUAL(no_broadcast.dimension(1), 3);
32 VERIFY_IS_EQUAL(no_broadcast.dimension(2), 5);
33 VERIFY_IS_EQUAL(no_broadcast.dimension(3), 7);
34
35 for (int i = 0; i < 2; ++i) {
36 for (int j = 0; j < 3; ++j) {
37 for (int k = 0; k < 5; ++k) {
38 for (int l = 0; l < 7; ++l) {
39 VERIFY_IS_EQUAL(tensor(i,j,k,l), no_broadcast(i,j,k,l));
40 }
41 }
42 }
43 }
44
45 broadcasts[0] = 2;
46 broadcasts[1] = 3;
47 broadcasts[2] = 1;
48 broadcasts[3] = 4;
49 Tensor<float, 4, DataLayout> broadcast;
50 broadcast = tensor.broadcast(broadcasts);
51
52 VERIFY_IS_EQUAL(broadcast.dimension(0), 4);
53 VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
54 VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
55 VERIFY_IS_EQUAL(broadcast.dimension(3), 28);
56
57 for (int i = 0; i < 4; ++i) {
58 for (int j = 0; j < 9; ++j) {
59 for (int k = 0; k < 5; ++k) {
60 for (int l = 0; l < 28; ++l) {
61 VERIFY_IS_EQUAL(tensor(i%2,j%3,k%5,l%7), broadcast(i,j,k,l));
62 }
63 }
64 }
65 }
66 }
67
68
69 template <int DataLayout>
test_vectorized_broadcasting()70 static void test_vectorized_broadcasting()
71 {
72 Tensor<float, 3, DataLayout> tensor(8,3,5);
73 tensor.setRandom();
74 array<ptrdiff_t, 3> broadcasts;
75 broadcasts[0] = 2;
76 broadcasts[1] = 3;
77 broadcasts[2] = 4;
78
79 Tensor<float, 3, DataLayout> broadcast;
80 broadcast = tensor.broadcast(broadcasts);
81
82 VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
83 VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
84 VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
85
86 for (int i = 0; i < 16; ++i) {
87 for (int j = 0; j < 9; ++j) {
88 for (int k = 0; k < 20; ++k) {
89 VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
90 }
91 }
92 }
93
94 #if EIGEN_HAS_VARIADIC_TEMPLATES
95 tensor.resize(11,3,5);
96 #else
97 array<Index, 3> new_dims;
98 new_dims[0] = 11;
99 new_dims[1] = 3;
100 new_dims[2] = 5;
101 tensor.resize(new_dims);
102 #endif
103
104 tensor.setRandom();
105 broadcast = tensor.broadcast(broadcasts);
106
107 VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
108 VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
109 VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
110
111 for (int i = 0; i < 22; ++i) {
112 for (int j = 0; j < 9; ++j) {
113 for (int k = 0; k < 20; ++k) {
114 VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
115 }
116 }
117 }
118 }
119
120
121 template <int DataLayout>
test_static_broadcasting()122 static void test_static_broadcasting()
123 {
124 Tensor<float, 3, DataLayout> tensor(8,3,5);
125 tensor.setRandom();
126
127 #if defined(EIGEN_HAS_INDEX_LIST)
128 Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts;
129 #else
130 Eigen::array<int, 3> broadcasts;
131 broadcasts[0] = 2;
132 broadcasts[1] = 3;
133 broadcasts[2] = 4;
134 #endif
135
136 Tensor<float, 3, DataLayout> broadcast;
137 broadcast = tensor.broadcast(broadcasts);
138
139 VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
140 VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
141 VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
142
143 for (int i = 0; i < 16; ++i) {
144 for (int j = 0; j < 9; ++j) {
145 for (int k = 0; k < 20; ++k) {
146 VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
147 }
148 }
149 }
150
151 #if EIGEN_HAS_VARIADIC_TEMPLATES
152 tensor.resize(11,3,5);
153 #else
154 array<Index, 3> new_dims;
155 new_dims[0] = 11;
156 new_dims[1] = 3;
157 new_dims[2] = 5;
158 tensor.resize(new_dims);
159 #endif
160
161 tensor.setRandom();
162 broadcast = tensor.broadcast(broadcasts);
163
164 VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
165 VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
166 VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
167
168 for (int i = 0; i < 22; ++i) {
169 for (int j = 0; j < 9; ++j) {
170 for (int k = 0; k < 20; ++k) {
171 VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
172 }
173 }
174 }
175 }
176
177
178 template <int DataLayout>
test_fixed_size_broadcasting()179 static void test_fixed_size_broadcasting()
180 {
181 // Need to add a [] operator to the Size class for this to work
182 #if 0
183 Tensor<float, 1, DataLayout> t1(10);
184 t1.setRandom();
185 TensorFixedSize<float, Sizes<1>, DataLayout> t2;
186 t2 = t2.constant(20.0f);
187
188 Tensor<float, 1, DataLayout> t3 = t1 + t2.broadcast(Eigen::array<int, 1>{{10}});
189 for (int i = 0; i < 10; ++i) {
190 VERIFY_IS_APPROX(t3(i), t1(i) + t2(0));
191 }
192
193 TensorMap<TensorFixedSize<float, Sizes<1>, DataLayout> > t4(t2.data(), {{1}});
194 Tensor<float, 1, DataLayout> t5 = t1 + t4.broadcast(Eigen::array<int, 1>{{10}});
195 for (int i = 0; i < 10; ++i) {
196 VERIFY_IS_APPROX(t5(i), t1(i) + t2(0));
197 }
198 #endif
199 }
200
201 template <int DataLayout>
test_simple_broadcasting_one_by_n()202 static void test_simple_broadcasting_one_by_n()
203 {
204 Tensor<float, 4, DataLayout> tensor(1,13,5,7);
205 tensor.setRandom();
206 array<ptrdiff_t, 4> broadcasts;
207 broadcasts[0] = 9;
208 broadcasts[1] = 1;
209 broadcasts[2] = 1;
210 broadcasts[3] = 1;
211 Tensor<float, 4, DataLayout> broadcast;
212 broadcast = tensor.broadcast(broadcasts);
213
214 VERIFY_IS_EQUAL(broadcast.dimension(0), 9);
215 VERIFY_IS_EQUAL(broadcast.dimension(1), 13);
216 VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
217 VERIFY_IS_EQUAL(broadcast.dimension(3), 7);
218
219 for (int i = 0; i < 9; ++i) {
220 for (int j = 0; j < 13; ++j) {
221 for (int k = 0; k < 5; ++k) {
222 for (int l = 0; l < 7; ++l) {
223 VERIFY_IS_EQUAL(tensor(i%1,j%13,k%5,l%7), broadcast(i,j,k,l));
224 }
225 }
226 }
227 }
228 }
229
230 template <int DataLayout>
test_simple_broadcasting_n_by_one()231 static void test_simple_broadcasting_n_by_one()
232 {
233 Tensor<float, 4, DataLayout> tensor(7,3,5,1);
234 tensor.setRandom();
235 array<ptrdiff_t, 4> broadcasts;
236 broadcasts[0] = 1;
237 broadcasts[1] = 1;
238 broadcasts[2] = 1;
239 broadcasts[3] = 19;
240 Tensor<float, 4, DataLayout> broadcast;
241 broadcast = tensor.broadcast(broadcasts);
242
243 VERIFY_IS_EQUAL(broadcast.dimension(0), 7);
244 VERIFY_IS_EQUAL(broadcast.dimension(1), 3);
245 VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
246 VERIFY_IS_EQUAL(broadcast.dimension(3), 19);
247
248 for (int i = 0; i < 7; ++i) {
249 for (int j = 0; j < 3; ++j) {
250 for (int k = 0; k < 5; ++k) {
251 for (int l = 0; l < 19; ++l) {
252 VERIFY_IS_EQUAL(tensor(i%7,j%3,k%5,l%1), broadcast(i,j,k,l));
253 }
254 }
255 }
256 }
257 }
258
259 template <int DataLayout>
test_simple_broadcasting_one_by_n_by_one_1d()260 static void test_simple_broadcasting_one_by_n_by_one_1d()
261 {
262 Tensor<float, 3, DataLayout> tensor(1,7,1);
263 tensor.setRandom();
264 array<ptrdiff_t, 3> broadcasts;
265 broadcasts[0] = 5;
266 broadcasts[1] = 1;
267 broadcasts[2] = 13;
268 Tensor<float, 3, DataLayout> broadcasted;
269 broadcasted = tensor.broadcast(broadcasts);
270
271 VERIFY_IS_EQUAL(broadcasted.dimension(0), 5);
272 VERIFY_IS_EQUAL(broadcasted.dimension(1), 7);
273 VERIFY_IS_EQUAL(broadcasted.dimension(2), 13);
274
275 for (int i = 0; i < 5; ++i) {
276 for (int j = 0; j < 7; ++j) {
277 for (int k = 0; k < 13; ++k) {
278 VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k));
279 }
280 }
281 }
282 }
283
284 template <int DataLayout>
test_simple_broadcasting_one_by_n_by_one_2d()285 static void test_simple_broadcasting_one_by_n_by_one_2d()
286 {
287 Tensor<float, 4, DataLayout> tensor(1,7,13,1);
288 tensor.setRandom();
289 array<ptrdiff_t, 4> broadcasts;
290 broadcasts[0] = 5;
291 broadcasts[1] = 1;
292 broadcasts[2] = 1;
293 broadcasts[3] = 19;
294 Tensor<float, 4, DataLayout> broadcast;
295 broadcast = tensor.broadcast(broadcasts);
296
297 VERIFY_IS_EQUAL(broadcast.dimension(0), 5);
298 VERIFY_IS_EQUAL(broadcast.dimension(1), 7);
299 VERIFY_IS_EQUAL(broadcast.dimension(2), 13);
300 VERIFY_IS_EQUAL(broadcast.dimension(3), 19);
301
302 for (int i = 0; i < 5; ++i) {
303 for (int j = 0; j < 7; ++j) {
304 for (int k = 0; k < 13; ++k) {
305 for (int l = 0; l < 19; ++l) {
306 VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l));
307 }
308 }
309 }
310 }
311 }
312
EIGEN_DECLARE_TEST(cxx11_tensor_broadcasting)313 EIGEN_DECLARE_TEST(cxx11_tensor_broadcasting)
314 {
315 CALL_SUBTEST(test_simple_broadcasting<ColMajor>());
316 CALL_SUBTEST(test_simple_broadcasting<RowMajor>());
317 CALL_SUBTEST(test_vectorized_broadcasting<ColMajor>());
318 CALL_SUBTEST(test_vectorized_broadcasting<RowMajor>());
319 CALL_SUBTEST(test_static_broadcasting<ColMajor>());
320 CALL_SUBTEST(test_static_broadcasting<RowMajor>());
321 CALL_SUBTEST(test_fixed_size_broadcasting<ColMajor>());
322 CALL_SUBTEST(test_fixed_size_broadcasting<RowMajor>());
323 CALL_SUBTEST(test_simple_broadcasting_one_by_n<RowMajor>());
324 CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>());
325 CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>());
326 CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>());
327 CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<ColMajor>());
328 CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<ColMajor>());
329 CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<RowMajor>());
330 CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<RowMajor>());
331 }
332