1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #include "main.h"
11
12 #include <Eigen/CXX11/Tensor>
13
14 using Eigen::Tensor;
15
16 template<int DataLayout>
test_dimension_failures()17 static void test_dimension_failures()
18 {
19 Tensor<int, 3, DataLayout> left(2, 3, 1);
20 Tensor<int, 3, DataLayout> right(3, 3, 1);
21 left.setRandom();
22 right.setRandom();
23
24 // Okay; other dimensions are equal.
25 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
26
27 // Dimension mismatches.
28 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1));
29 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2));
30
31 // Axis > NumDims or < 0.
32 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3));
33 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1));
34 }
35
36 template<int DataLayout>
test_static_dimension_failure()37 static void test_static_dimension_failure()
38 {
39 Tensor<int, 2, DataLayout> left(2, 3);
40 Tensor<int, 3, DataLayout> right(2, 3, 1);
41
42 #ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE
43 // Technically compatible, but we static assert that the inputs have same
44 // NumDims.
45 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
46 #endif
47
48 // This can be worked around in this case.
49 Tensor<int, 3, DataLayout> concatenation = left
50 .reshape(Tensor<int, 3>::Dimensions(2, 3, 1))
51 .concatenate(right, 0);
52 Tensor<int, 2, DataLayout> alternative = left
53 // Clang compiler break with {{{}}} with an ambiguous error on copy constructor
54 // the variadic DSize constructor added for #ifndef EIGEN_EMULATE_CXX11_META_H.
55 // Solution:
56 // either the code should change to
57 // Tensor<int, 2>::Dimensions{{2, 3}}
58 // or Tensor<int, 2>::Dimensions{Tensor<int, 2>::Dimensions{{2, 3}}}
59 .concatenate(right.reshape(Tensor<int, 2>::Dimensions(2, 3)), 0);
60 }
61
62 template<int DataLayout>
test_simple_concatenation()63 static void test_simple_concatenation()
64 {
65 Tensor<int, 3, DataLayout> left(2, 3, 1);
66 Tensor<int, 3, DataLayout> right(2, 3, 1);
67 left.setRandom();
68 right.setRandom();
69
70 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
71 VERIFY_IS_EQUAL(concatenation.dimension(0), 4);
72 VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
73 VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
74 for (int j = 0; j < 3; ++j) {
75 for (int i = 0; i < 2; ++i) {
76 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
77 }
78 for (int i = 2; i < 4; ++i) {
79 VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0));
80 }
81 }
82
83 concatenation = left.concatenate(right, 1);
84 VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
85 VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
86 VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
87 for (int i = 0; i < 2; ++i) {
88 for (int j = 0; j < 3; ++j) {
89 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
90 }
91 for (int j = 3; j < 6; ++j) {
92 VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0));
93 }
94 }
95
96 concatenation = left.concatenate(right, 2);
97 VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
98 VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
99 VERIFY_IS_EQUAL(concatenation.dimension(2), 2);
100 for (int i = 0; i < 2; ++i) {
101 for (int j = 0; j < 3; ++j) {
102 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
103 VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0));
104 }
105 }
106 }
107
108
109 // TODO(phli): Add test once we have a real vectorized implementation.
110 // static void test_vectorized_concatenation() {}
111
test_concatenation_as_lvalue()112 static void test_concatenation_as_lvalue()
113 {
114 Tensor<int, 2> t1(2, 3);
115 Tensor<int, 2> t2(2, 3);
116 t1.setRandom();
117 t2.setRandom();
118
119 Tensor<int, 2> result(4, 3);
120 result.setRandom();
121 t1.concatenate(t2, 0) = result;
122
123 for (int i = 0; i < 2; ++i) {
124 for (int j = 0; j < 3; ++j) {
125 VERIFY_IS_EQUAL(t1(i, j), result(i, j));
126 VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
127 }
128 }
129 }
130
131
EIGEN_DECLARE_TEST(cxx11_tensor_concatenation)132 EIGEN_DECLARE_TEST(cxx11_tensor_concatenation)
133 {
134 CALL_SUBTEST(test_dimension_failures<ColMajor>());
135 CALL_SUBTEST(test_dimension_failures<RowMajor>());
136 CALL_SUBTEST(test_static_dimension_failure<ColMajor>());
137 CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
138 CALL_SUBTEST(test_simple_concatenation<ColMajor>());
139 CALL_SUBTEST(test_simple_concatenation<RowMajor>());
140 // CALL_SUBTEST(test_vectorized_concatenation());
141 CALL_SUBTEST(test_concatenation_as_lvalue());
142
143 }
144