xref: /aosp_15_r20/external/eigen/unsupported/test/cxx11_tensor_broadcasting.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 
16 template <int DataLayout>
test_simple_broadcasting()17 static void test_simple_broadcasting()
18 {
19   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
20   tensor.setRandom();
21   array<ptrdiff_t, 4> broadcasts;
22   broadcasts[0] = 1;
23   broadcasts[1] = 1;
24   broadcasts[2] = 1;
25   broadcasts[3] = 1;
26 
27   Tensor<float, 4, DataLayout> no_broadcast;
28   no_broadcast = tensor.broadcast(broadcasts);
29 
30   VERIFY_IS_EQUAL(no_broadcast.dimension(0), 2);
31   VERIFY_IS_EQUAL(no_broadcast.dimension(1), 3);
32   VERIFY_IS_EQUAL(no_broadcast.dimension(2), 5);
33   VERIFY_IS_EQUAL(no_broadcast.dimension(3), 7);
34 
35   for (int i = 0; i < 2; ++i) {
36     for (int j = 0; j < 3; ++j) {
37       for (int k = 0; k < 5; ++k) {
38         for (int l = 0; l < 7; ++l) {
39           VERIFY_IS_EQUAL(tensor(i,j,k,l), no_broadcast(i,j,k,l));
40         }
41       }
42     }
43   }
44 
45   broadcasts[0] = 2;
46   broadcasts[1] = 3;
47   broadcasts[2] = 1;
48   broadcasts[3] = 4;
49   Tensor<float, 4, DataLayout> broadcast;
50   broadcast = tensor.broadcast(broadcasts);
51 
52   VERIFY_IS_EQUAL(broadcast.dimension(0), 4);
53   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
54   VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
55   VERIFY_IS_EQUAL(broadcast.dimension(3), 28);
56 
57   for (int i = 0; i < 4; ++i) {
58     for (int j = 0; j < 9; ++j) {
59       for (int k = 0; k < 5; ++k) {
60         for (int l = 0; l < 28; ++l) {
61           VERIFY_IS_EQUAL(tensor(i%2,j%3,k%5,l%7), broadcast(i,j,k,l));
62         }
63       }
64     }
65   }
66 }
67 
68 
69 template <int DataLayout>
test_vectorized_broadcasting()70 static void test_vectorized_broadcasting()
71 {
72   Tensor<float, 3, DataLayout> tensor(8,3,5);
73   tensor.setRandom();
74   array<ptrdiff_t, 3> broadcasts;
75   broadcasts[0] = 2;
76   broadcasts[1] = 3;
77   broadcasts[2] = 4;
78 
79   Tensor<float, 3, DataLayout> broadcast;
80   broadcast = tensor.broadcast(broadcasts);
81 
82   VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
83   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
84   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
85 
86   for (int i = 0; i < 16; ++i) {
87     for (int j = 0; j < 9; ++j) {
88       for (int k = 0; k < 20; ++k) {
89         VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
90       }
91     }
92   }
93 
94 #if EIGEN_HAS_VARIADIC_TEMPLATES
95   tensor.resize(11,3,5);
96 #else
97   array<Index, 3> new_dims;
98   new_dims[0] = 11;
99   new_dims[1] = 3;
100   new_dims[2] = 5;
101   tensor.resize(new_dims);
102 #endif
103 
104   tensor.setRandom();
105   broadcast = tensor.broadcast(broadcasts);
106 
107   VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
108   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
109   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
110 
111   for (int i = 0; i < 22; ++i) {
112     for (int j = 0; j < 9; ++j) {
113       for (int k = 0; k < 20; ++k) {
114         VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
115       }
116     }
117   }
118 }
119 
120 
121 template <int DataLayout>
test_static_broadcasting()122 static void test_static_broadcasting()
123 {
124   Tensor<float, 3, DataLayout> tensor(8,3,5);
125   tensor.setRandom();
126 
127 #if defined(EIGEN_HAS_INDEX_LIST)
128   Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts;
129 #else
130   Eigen::array<int, 3> broadcasts;
131   broadcasts[0] = 2;
132   broadcasts[1] = 3;
133   broadcasts[2] = 4;
134 #endif
135 
136   Tensor<float, 3, DataLayout> broadcast;
137   broadcast = tensor.broadcast(broadcasts);
138 
139   VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
140   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
141   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
142 
143   for (int i = 0; i < 16; ++i) {
144     for (int j = 0; j < 9; ++j) {
145       for (int k = 0; k < 20; ++k) {
146         VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
147       }
148     }
149   }
150 
151 #if EIGEN_HAS_VARIADIC_TEMPLATES
152   tensor.resize(11,3,5);
153 #else
154   array<Index, 3> new_dims;
155   new_dims[0] = 11;
156   new_dims[1] = 3;
157   new_dims[2] = 5;
158   tensor.resize(new_dims);
159 #endif
160 
161   tensor.setRandom();
162   broadcast = tensor.broadcast(broadcasts);
163 
164   VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
165   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
166   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
167 
168   for (int i = 0; i < 22; ++i) {
169     for (int j = 0; j < 9; ++j) {
170       for (int k = 0; k < 20; ++k) {
171         VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
172       }
173     }
174   }
175 }
176 
177 
178 template <int DataLayout>
test_fixed_size_broadcasting()179 static void test_fixed_size_broadcasting()
180 {
181   // Need to add a [] operator to the Size class for this to work
182 #if 0
183   Tensor<float, 1, DataLayout> t1(10);
184   t1.setRandom();
185   TensorFixedSize<float, Sizes<1>, DataLayout> t2;
186   t2 = t2.constant(20.0f);
187 
188   Tensor<float, 1, DataLayout> t3 = t1 + t2.broadcast(Eigen::array<int, 1>{{10}});
189   for (int i = 0; i < 10; ++i) {
190     VERIFY_IS_APPROX(t3(i), t1(i) + t2(0));
191   }
192 
193   TensorMap<TensorFixedSize<float, Sizes<1>, DataLayout> > t4(t2.data(), {{1}});
194   Tensor<float, 1, DataLayout> t5 = t1 + t4.broadcast(Eigen::array<int, 1>{{10}});
195   for (int i = 0; i < 10; ++i) {
196     VERIFY_IS_APPROX(t5(i), t1(i) + t2(0));
197   }
198 #endif
199 }
200 
201 template <int DataLayout>
test_simple_broadcasting_one_by_n()202 static void test_simple_broadcasting_one_by_n()
203 {
204   Tensor<float, 4, DataLayout> tensor(1,13,5,7);
205   tensor.setRandom();
206   array<ptrdiff_t, 4> broadcasts;
207   broadcasts[0] = 9;
208   broadcasts[1] = 1;
209   broadcasts[2] = 1;
210   broadcasts[3] = 1;
211   Tensor<float, 4, DataLayout> broadcast;
212   broadcast = tensor.broadcast(broadcasts);
213 
214   VERIFY_IS_EQUAL(broadcast.dimension(0), 9);
215   VERIFY_IS_EQUAL(broadcast.dimension(1), 13);
216   VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
217   VERIFY_IS_EQUAL(broadcast.dimension(3), 7);
218 
219   for (int i = 0; i < 9; ++i) {
220     for (int j = 0; j < 13; ++j) {
221       for (int k = 0; k < 5; ++k) {
222         for (int l = 0; l < 7; ++l) {
223           VERIFY_IS_EQUAL(tensor(i%1,j%13,k%5,l%7), broadcast(i,j,k,l));
224         }
225       }
226     }
227   }
228 }
229 
230 template <int DataLayout>
test_simple_broadcasting_n_by_one()231 static void test_simple_broadcasting_n_by_one()
232 {
233   Tensor<float, 4, DataLayout> tensor(7,3,5,1);
234   tensor.setRandom();
235   array<ptrdiff_t, 4> broadcasts;
236   broadcasts[0] = 1;
237   broadcasts[1] = 1;
238   broadcasts[2] = 1;
239   broadcasts[3] = 19;
240   Tensor<float, 4, DataLayout> broadcast;
241   broadcast = tensor.broadcast(broadcasts);
242 
243   VERIFY_IS_EQUAL(broadcast.dimension(0), 7);
244   VERIFY_IS_EQUAL(broadcast.dimension(1), 3);
245   VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
246   VERIFY_IS_EQUAL(broadcast.dimension(3), 19);
247 
248   for (int i = 0; i < 7; ++i) {
249     for (int j = 0; j < 3; ++j) {
250       for (int k = 0; k < 5; ++k) {
251         for (int l = 0; l < 19; ++l) {
252           VERIFY_IS_EQUAL(tensor(i%7,j%3,k%5,l%1), broadcast(i,j,k,l));
253         }
254       }
255     }
256   }
257 }
258 
259 template <int DataLayout>
test_simple_broadcasting_one_by_n_by_one_1d()260 static void test_simple_broadcasting_one_by_n_by_one_1d()
261 {
262   Tensor<float, 3, DataLayout> tensor(1,7,1);
263   tensor.setRandom();
264   array<ptrdiff_t, 3> broadcasts;
265   broadcasts[0] = 5;
266   broadcasts[1] = 1;
267   broadcasts[2] = 13;
268   Tensor<float, 3, DataLayout> broadcasted;
269   broadcasted = tensor.broadcast(broadcasts);
270 
271   VERIFY_IS_EQUAL(broadcasted.dimension(0), 5);
272   VERIFY_IS_EQUAL(broadcasted.dimension(1), 7);
273   VERIFY_IS_EQUAL(broadcasted.dimension(2), 13);
274 
275   for (int i = 0; i < 5; ++i) {
276     for (int j = 0; j < 7; ++j) {
277       for (int k = 0; k < 13; ++k) {
278         VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k));
279       }
280     }
281   }
282 }
283 
284 template <int DataLayout>
test_simple_broadcasting_one_by_n_by_one_2d()285 static void test_simple_broadcasting_one_by_n_by_one_2d()
286 {
287   Tensor<float, 4, DataLayout> tensor(1,7,13,1);
288   tensor.setRandom();
289   array<ptrdiff_t, 4> broadcasts;
290   broadcasts[0] = 5;
291   broadcasts[1] = 1;
292   broadcasts[2] = 1;
293   broadcasts[3] = 19;
294   Tensor<float, 4, DataLayout> broadcast;
295   broadcast = tensor.broadcast(broadcasts);
296 
297   VERIFY_IS_EQUAL(broadcast.dimension(0), 5);
298   VERIFY_IS_EQUAL(broadcast.dimension(1), 7);
299   VERIFY_IS_EQUAL(broadcast.dimension(2), 13);
300   VERIFY_IS_EQUAL(broadcast.dimension(3), 19);
301 
302   for (int i = 0; i < 5; ++i) {
303     for (int j = 0; j < 7; ++j) {
304       for (int k = 0; k < 13; ++k) {
305         for (int l = 0; l < 19; ++l) {
306           VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l));
307         }
308       }
309     }
310   }
311 }
312 
EIGEN_DECLARE_TEST(cxx11_tensor_broadcasting)313 EIGEN_DECLARE_TEST(cxx11_tensor_broadcasting)
314 {
315   CALL_SUBTEST(test_simple_broadcasting<ColMajor>());
316   CALL_SUBTEST(test_simple_broadcasting<RowMajor>());
317   CALL_SUBTEST(test_vectorized_broadcasting<ColMajor>());
318   CALL_SUBTEST(test_vectorized_broadcasting<RowMajor>());
319   CALL_SUBTEST(test_static_broadcasting<ColMajor>());
320   CALL_SUBTEST(test_static_broadcasting<RowMajor>());
321   CALL_SUBTEST(test_fixed_size_broadcasting<ColMajor>());
322   CALL_SUBTEST(test_fixed_size_broadcasting<RowMajor>());
323   CALL_SUBTEST(test_simple_broadcasting_one_by_n<RowMajor>());
324   CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>());
325   CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>());
326   CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>());
327   CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<ColMajor>());
328   CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<ColMajor>());
329   CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<RowMajor>());
330   CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<RowMajor>());
331 }
332