xref: /aosp_15_r20/external/eigen/unsupported/test/cxx11_tensor_concatenation.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 
16 template<int DataLayout>
test_dimension_failures()17 static void test_dimension_failures()
18 {
19   Tensor<int, 3, DataLayout> left(2, 3, 1);
20   Tensor<int, 3, DataLayout> right(3, 3, 1);
21   left.setRandom();
22   right.setRandom();
23 
24   // Okay; other dimensions are equal.
25   Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
26 
27   // Dimension mismatches.
28   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1));
29   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2));
30 
31   // Axis > NumDims or < 0.
32   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3));
33   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1));
34 }
35 
36 template<int DataLayout>
test_static_dimension_failure()37 static void test_static_dimension_failure()
38 {
39   Tensor<int, 2, DataLayout> left(2, 3);
40   Tensor<int, 3, DataLayout> right(2, 3, 1);
41 
42 #ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE
43   // Technically compatible, but we static assert that the inputs have same
44   // NumDims.
45   Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
46 #endif
47 
48   // This can be worked around in this case.
49   Tensor<int, 3, DataLayout> concatenation = left
50       .reshape(Tensor<int, 3>::Dimensions(2, 3, 1))
51       .concatenate(right, 0);
52   Tensor<int, 2, DataLayout> alternative = left
53    // Clang compiler break with {{{}}} with an ambiguous error on copy constructor
54   // the variadic DSize constructor added for #ifndef EIGEN_EMULATE_CXX11_META_H.
55   // Solution:
56   // either the code should change to
57   //  Tensor<int, 2>::Dimensions{{2, 3}}
58   // or Tensor<int, 2>::Dimensions{Tensor<int, 2>::Dimensions{{2, 3}}}
59       .concatenate(right.reshape(Tensor<int, 2>::Dimensions(2, 3)), 0);
60 }
61 
62 template<int DataLayout>
test_simple_concatenation()63 static void test_simple_concatenation()
64 {
65   Tensor<int, 3, DataLayout> left(2, 3, 1);
66   Tensor<int, 3, DataLayout> right(2, 3, 1);
67   left.setRandom();
68   right.setRandom();
69 
70   Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
71   VERIFY_IS_EQUAL(concatenation.dimension(0), 4);
72   VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
73   VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
74   for (int j = 0; j < 3; ++j) {
75     for (int i = 0; i < 2; ++i) {
76       VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
77     }
78     for (int i = 2; i < 4; ++i) {
79       VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0));
80     }
81   }
82 
83   concatenation = left.concatenate(right, 1);
84   VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
85   VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
86   VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
87   for (int i = 0; i < 2; ++i) {
88     for (int j = 0; j < 3; ++j) {
89       VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
90     }
91     for (int j = 3; j < 6; ++j) {
92       VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0));
93     }
94   }
95 
96   concatenation = left.concatenate(right, 2);
97   VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
98   VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
99   VERIFY_IS_EQUAL(concatenation.dimension(2), 2);
100   for (int i = 0; i < 2; ++i) {
101     for (int j = 0; j < 3; ++j) {
102       VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
103       VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0));
104     }
105   }
106 }
107 
108 
109 // TODO(phli): Add test once we have a real vectorized implementation.
110 // static void test_vectorized_concatenation() {}
111 
test_concatenation_as_lvalue()112 static void test_concatenation_as_lvalue()
113 {
114   Tensor<int, 2> t1(2, 3);
115   Tensor<int, 2> t2(2, 3);
116   t1.setRandom();
117   t2.setRandom();
118 
119   Tensor<int, 2> result(4, 3);
120   result.setRandom();
121   t1.concatenate(t2, 0) = result;
122 
123   for (int i = 0; i < 2; ++i) {
124     for (int j = 0; j < 3; ++j) {
125       VERIFY_IS_EQUAL(t1(i, j), result(i, j));
126       VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
127     }
128   }
129 }
130 
131 
EIGEN_DECLARE_TEST(cxx11_tensor_concatenation)132 EIGEN_DECLARE_TEST(cxx11_tensor_concatenation)
133 {
134    CALL_SUBTEST(test_dimension_failures<ColMajor>());
135    CALL_SUBTEST(test_dimension_failures<RowMajor>());
136    CALL_SUBTEST(test_static_dimension_failure<ColMajor>());
137    CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
138    CALL_SUBTEST(test_simple_concatenation<ColMajor>());
139    CALL_SUBTEST(test_simple_concatenation<RowMajor>());
140    // CALL_SUBTEST(test_vectorized_concatenation());
141    CALL_SUBTEST(test_concatenation_as_lvalue());
142 
143 }
144