xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Concat.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include "ParserPrototxtFixture.hpp"
8 #include "OnnxParserTestUtils.hpp"
9 
10 TEST_SUITE("OnnxParser_Concat")
11 {
12 
13 struct ConcatFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14 {
ConcatFixtureConcatFixture15     ConcatFixture(const std::string& axis,
16                   const std::vector<int>& input0Shape,
17                   const std::vector<int>& input1Shape,
18                   const std::vector<int>& outputShape)
19     {
20         m_Prototext = R"(
21                     ir_version: 8
22                     producer_name: "onnx-example"
23                     graph {
24                       node {
25                         input: "Input0"
26                         input: "Input1"
27                         output: "Output"
28                         op_type: "Concat"
29                         attribute {
30                           name: "axis"
31                           i: )" + axis + R"(
32                           type: INT
33                         }
34                       }
35                       name: "concat-model"
36                       input {
37                         name: "Input0"
38                         type {
39                           tensor_type {
40                             elem_type: 1
41                             shape {
42                               )" + armnnUtils::ConstructTensorShapeString(input0Shape) + R"(
43                             }
44                           }
45                         }
46                       }
47                       input {
48                         name: "Input1"
49                         type {
50                           tensor_type {
51                             elem_type: 1
52                             shape {
53                               )" + armnnUtils::ConstructTensorShapeString(input1Shape) + R"(
54                             }
55                           }
56                         }
57                       }
58                       output {
59                         name: "Output"
60                         type {
61                           tensor_type {
62                             elem_type: 1
63                             shape {
64                               )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
65                             }
66                           }
67                         }
68                       }
69                     })";
70         Setup();
71     }
72 };
73 
74 struct ConcatAxis0Fixture : ConcatFixture
75 {
ConcatAxis0FixtureConcatAxis0Fixture76     ConcatAxis0Fixture() : ConcatFixture("0", { 1, 3, 2, 5 }, { 1, 3, 2, 5 }, { 2, 3, 2, 5 }) {}
77 };
78 
79 struct ConcatAxis1Fixture : ConcatFixture
80 {
ConcatAxis1FixtureConcatAxis1Fixture81     ConcatAxis1Fixture() : ConcatFixture("1", { 2, 2, 1, 3 }, { 2, 1, 1, 3 }, { 2, 3, 1, 3 }) {}
82 };
83 
84 struct ConcatAxis2Fixture : ConcatFixture
85 {
ConcatAxis2FixtureConcatAxis2Fixture86     ConcatAxis2Fixture() : ConcatFixture("2", { 2, 3, 1, 1 }, { 2, 3, 2, 1 }, { 2, 3, 3, 1 }) {}
87 };
88 
89 struct ConcatAxis3Fixture : ConcatFixture
90 {
ConcatAxis3FixtureConcatAxis3Fixture91     ConcatAxis3Fixture() : ConcatFixture("3", { 1, 3, 2, 2 }, { 1, 3, 2, 2 }, { 1, 3, 2, 4 }) {}
92 };
93 
94 struct ConcatNegativeAxisFixture : ConcatFixture
95 {
ConcatNegativeAxisFixtureConcatNegativeAxisFixture96     ConcatNegativeAxisFixture() : ConcatFixture("-1", { 1, 2, 5 }, { 1, 2, 3 }, { 1, 2, 8 }) {}
97 };
98 
99 TEST_CASE_FIXTURE(ConcatAxis0Fixture, "ConcatAxis0Test")
100 {
101     RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
102                                     6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
103                                     11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
104                                     16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
105                                     21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
106                                     26.0f, 27.0f, 28.0f, 29.0f, 30.0f }},
107                        {"Input1", { 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
108                                     36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
109                                     41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
110                                     46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
111                                     51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
112                                     56.0f, 57.0f, 58.0f, 59.0f, 60.0f }}},
113                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
114                                     6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
115                                     11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
116                                     16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
117                                     21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
118                                     26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
119                                     31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
120                                     36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
121                                     41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
122                                     46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
123                                     51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
124                                     56.0f, 57.0f, 58.0f, 59.0f, 60.0f }}});
125 }
126 
127 TEST_CASE_FIXTURE(ConcatAxis1Fixture, "ConcatAxis1est")
128 {
129     RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }},
130                        {"Input1", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}},
131                       {{"Output", { 1.0f, 2.0f, 3.0f,
132                                     4.0f, 5.0f, 6.0f,
133                                     13.0f, 14.0f, 15.0f,
134                                     7.0f, 8.0f, 9.0f,
135                                     10.0f, 11.0f, 12.0f,
136                                     16.0f, 17.0f, 18.0f }}});
137 }
138 
139 TEST_CASE_FIXTURE(ConcatAxis2Fixture, "ConcatAxis2Test")
140 {
141     RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }},
142                        {"Input1", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}},
143                       {{"Output", { 1.0f, 7.0f, 8.0f,
144                                     2.0f, 9.0f, 10.0f,
145                                     3.0f, 11.0f, 12.0f,
146                                     4.0f, 13.0f, 14.0f,
147                                     5.0f, 15.0f, 16.0f,
148                                     6.0f, 17.0f, 18.0f }}});
149 }
150 
151 TEST_CASE_FIXTURE(ConcatAxis3Fixture, "ConcatAxis3Test")
152 {
153     RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
154                                     7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }},
155                        {"Input1", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
156                                     19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}},
157                       {{"Output", { 1.0f, 2.0f, 13.0f, 14.0f,
158                                     3.0f, 4.0f, 15.0f, 16.0f,
159                                     5.0f, 6.0f, 17.0f, 18.0f,
160                                     7.0f, 8.0f, 19.0f, 20.0f,
161                                     9.0f, 10.0f, 21.0f, 22.0f,
162                                     11.0f, 12.0f, 23.0f, 24.0f }}});
163 }
164 
165 TEST_CASE_FIXTURE(ConcatNegativeAxisFixture, "ConcatNegativeAxisTest")
166 {
167     RunTest<3, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
168                                     6.0f, 7.0f, 8.0f, 9.0f, 10.0f }},
169                        {"Input1", { 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f }}},
170                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f, 12.0f, 13.0f,
171                                     6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 14.0f, 15.0f, 16.0f }}});
172 }
173 
174 struct ConcatMultipleInputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
175 {
ConcatMultipleInputsFixtureConcatMultipleInputsFixture176     ConcatMultipleInputsFixture()
177     {
178         m_Prototext = R"(
179                     ir_version: 8
180                     producer_name: "onnx-example"
181                     graph {
182                       node {
183                         input: "Input0"
184                         input: "Input1"
185                         input: "Input2"
186                         output: "Output"
187                         op_type: "Concat"
188                         attribute {
189                           name: "axis"
190                           i: 1
191                           type: INT
192                         }
193                       }
194                       name: "concat-model"
195                       input {
196                         name: "Input0"
197                         type {
198                           tensor_type {
199                             elem_type: 1
200                             shape {
201                               dim {
202                                 dim_value: 3
203                               }
204                               dim {
205                                 dim_value: 2
206                               }
207                             }
208                           }
209                         }
210                       }
211                       input {
212                         name: "Input1"
213                         type {
214                           tensor_type {
215                             elem_type: 1
216                             shape {
217                               dim {
218                                 dim_value: 3
219                               }
220                               dim {
221                                 dim_value: 3
222                               }
223                             }
224                           }
225                         }
226                       }
227                       input {
228                         name: "Input2"
229                         type {
230                           tensor_type {
231                             elem_type: 1
232                             shape {
233                               dim {
234                                 dim_value: 3
235                               }
236                               dim {
237                                 dim_value: 1
238                               }
239                             }
240                           }
241                         }
242                       }
243                       output {
244                         name: "Output"
245                         type {
246                           tensor_type {
247                             elem_type: 1
248                             shape {
249                               dim {
250                                 dim_value: 3
251                               }
252                               dim {
253                                 dim_value: 6
254                               }
255                             }
256                           }
257                         }
258                       }
259                     })";
260         Setup();
261     }
262 };
263 
264 TEST_CASE_FIXTURE(ConcatMultipleInputsFixture, "ConcatMultipleInputsTest")
265 {
266     RunTest<2, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }},
267                        {"Input1", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f }},
268                        {"Input2", { 16.0f, 17.0f, 18.0f }}},
269                       {{"Output", { 1.0f, 2.0f, 7.0f, 8.0f, 9.0f, 16.0f,
270                                     3.0f, 4.0f, 10.0f, 11.0f, 12.0f, 17.0f,
271                                     5.0f, 6.0f, 13.0f, 14.0f, 15.0f, 18.0f }}});
272 }
273 
274 }
275