xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Concatenation.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 
9 TEST_SUITE("TensorflowLiteParser_Concatenation")
10 {
11 struct ConcatenationFixture : public ParserFlatbuffersFixture
12 {
ConcatenationFixtureConcatenationFixture13     explicit ConcatenationFixture(const std::string & inputShape1,
14                                   const std::string & inputShape2,
15                                   const std::string & outputShape,
16                                   const std::string & axis,
17                                   const std::string & activation="NONE")
18     {
19         m_JsonString = R"(
20             {
21                 "version": 3,
22                 "operator_codes": [ { "builtin_code": "CONCATENATION" } ],
23                 "subgraphs": [ {
24                     "tensors": [
25                         {
26                             "shape": )" + inputShape1 + R"(,
27                             "type": "UINT8",
28                             "buffer": 0,
29                             "name": "inputTensor1",
30                             "quantization": {
31                                 "min": [ 0.0 ],
32                                 "max": [ 255.0 ],
33                                 "scale": [ 1.0 ],
34                                 "zero_point": [ 0 ],
35                             }
36                         },
37                         {
38                             "shape": )" + inputShape2 + R"(,
39                             "type": "UINT8",
40                             "buffer": 1,
41                             "name": "inputTensor2",
42                             "quantization": {
43                                 "min": [ 0.0 ],
44                                 "max": [ 255.0 ],
45                                 "scale": [ 1.0 ],
46                                 "zero_point": [ 0 ],
47                             }
48                         },
49                         {
50                             "shape": )" + outputShape + R"( ,
51                             "type": "UINT8",
52                             "buffer": 2,
53                             "name": "outputTensor",
54                             "quantization": {
55                                 "min": [ 0.0 ],
56                                 "max": [ 255.0 ],
57                                 "scale": [ 1.0 ],
58                                 "zero_point": [ 0 ],
59                             }
60                         }
61                     ],
62                     "inputs": [ 0, 1 ],
63                     "outputs": [ 2 ],
64                     "operators": [
65                         {
66                             "opcode_index": 0,
67                             "inputs": [ 0, 1 ],
68                             "outputs": [ 2 ],
69                             "builtin_options_type": "ConcatenationOptions",
70                             "builtin_options": {
71                                 "axis": )" + axis + R"(,
72                                 "fused_activation_function": )" + activation + R"(
73                             },
74                             "custom_options_format": "FLEXBUFFERS"
75                         }
76                     ],
77                 } ],
78                 "buffers" : [
79                     { },
80                     { }
81                 ]
82             }
83         )";
84         Setup();
85     }
86 };
87 
88 
89 struct ConcatenationFixtureNegativeDim : ConcatenationFixture
90 {
ConcatenationFixtureNegativeDimConcatenationFixtureNegativeDim91     ConcatenationFixtureNegativeDim() : ConcatenationFixture("[ 1, 1, 2, 2 ]",
92                                                              "[ 1, 1, 2, 2 ]",
93                                                              "[ 1, 2, 2, 2 ]",
94                                                              "-3" ) {}
95 };
96 
97 TEST_CASE_FIXTURE(ConcatenationFixtureNegativeDim, "ParseConcatenationNegativeDim")
98 {
99     RunTest<4, armnn::DataType::QAsymmU8>(
100         0,
101         {{"inputTensor1", { 0, 1, 2, 3 }},
102         {"inputTensor2", { 4, 5, 6, 7 }}},
103         {{"outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7 }}});
104 }
105 
106 struct ConcatenationFixtureNCHW : ConcatenationFixture
107 {
ConcatenationFixtureNCHWConcatenationFixtureNCHW108     ConcatenationFixtureNCHW() : ConcatenationFixture("[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 2 ]", "[ 1, 2, 2, 2 ]", "1" ) {}
109 };
110 
111 TEST_CASE_FIXTURE(ConcatenationFixtureNCHW, "ParseConcatenationNCHW")
112 {
113     RunTest<4, armnn::DataType::QAsymmU8>(
114         0,
115         {{"inputTensor1", { 0, 1, 2, 3 }},
116         {"inputTensor2", { 4, 5, 6, 7 }}},
117         {{"outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7 }}});
118 }
119 
120 struct ConcatenationFixtureNHWC : ConcatenationFixture
121 {
ConcatenationFixtureNHWCConcatenationFixtureNHWC122     ConcatenationFixtureNHWC() : ConcatenationFixture("[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 4 ]", "3" ) {}
123 };
124 
125 TEST_CASE_FIXTURE(ConcatenationFixtureNHWC, "ParseConcatenationNHWC")
126 {
127     RunTest<4, armnn::DataType::QAsymmU8>(
128         0,
129         {{"inputTensor1", { 0, 1, 2, 3 }},
130         {"inputTensor2", { 4, 5, 6, 7 }}},
131         {{"outputTensor", { 0, 1, 4, 5, 2, 3, 6, 7 }}});
132 }
133 
134 struct ConcatenationFixtureDim1 : ConcatenationFixture
135 {
ConcatenationFixtureDim1ConcatenationFixtureDim1136     ConcatenationFixtureDim1() : ConcatenationFixture("[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 4 ]", "[ 1, 4, 3, 4 ]", "1" ) {}
137 };
138 
139 TEST_CASE_FIXTURE(ConcatenationFixtureDim1, "ParseConcatenationDim1")
140 {
141     RunTest<4, armnn::DataType::QAsymmU8>(
142         0,
143         { { "inputTensor1", {  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
144                                12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 } },
145         { "inputTensor2", {  50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
146                              62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73 } } },
147         { { "outputTensor", {  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
148                                12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
149                                50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
150                                62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73 } } });
151 }
152 
153 struct ConcatenationFixtureDim3 : ConcatenationFixture
154 {
ConcatenationFixtureDim3ConcatenationFixtureDim3155     ConcatenationFixtureDim3() : ConcatenationFixture("[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 8 ]", "3" ) {}
156 };
157 
158 TEST_CASE_FIXTURE(ConcatenationFixtureDim3, "ParseConcatenationDim3")
159 {
160     RunTest<4, armnn::DataType::QAsymmU8>(
161         0,
162         { { "inputTensor1", {  0,  1,  2,  3,
163                                4,  5,  6,  7,
164                                8,  9, 10, 11,
165                                12, 13, 14, 15,
166                                16, 17, 18, 19,
167                                20, 21, 22, 23 } },
168         { "inputTensor2", {  50, 51, 52, 53,
169                              54, 55, 56, 57,
170                              58, 59, 60, 61,
171                              62, 63, 64, 65,
172                              66, 67, 68, 69,
173                              70, 71, 72, 73 } } },
174         { { "outputTensor", {  0,  1,  2,  3,
175                                50, 51, 52, 53,
176                                4,  5,  6,  7,
177                                54, 55, 56, 57,
178                                8,  9,  10, 11,
179                                58, 59, 60, 61,
180                                12, 13, 14, 15,
181                                62, 63, 64, 65,
182                                16, 17, 18, 19,
183                                66, 67, 68, 69,
184                                20, 21, 22, 23,
185                                70, 71, 72, 73 } } });
186 }
187 
188 struct ConcatenationFixture3DDim0 : ConcatenationFixture
189 {
ConcatenationFixture3DDim0ConcatenationFixture3DDim0190     ConcatenationFixture3DDim0() : ConcatenationFixture("[ 1, 2, 3]", "[ 2, 2, 3]", "[ 3, 2, 3]", "0" ) {}
191 };
192 
193 TEST_CASE_FIXTURE(ConcatenationFixture3DDim0, "ParseConcatenation3DDim0")
194 {
195     RunTest<3, armnn::DataType::QAsymmU8>(
196         0,
197         { { "inputTensor1", { 0,  1,  2,  3,  4,  5 } },
198           { "inputTensor2", { 6,  7,  8,  9, 10, 11,
199                              12, 13, 14, 15, 16, 17 } } },
200         { { "outputTensor", { 0,  1,  2,  3,  4,  5,
201                               6,  7,  8,  9, 10, 11,
202                              12, 13, 14, 15, 16, 17 } } });
203 }
204 
205 struct ConcatenationFixture3DDim1 : ConcatenationFixture
206 {
ConcatenationFixture3DDim1ConcatenationFixture3DDim1207     ConcatenationFixture3DDim1() : ConcatenationFixture("[ 1, 2, 3]", "[ 1, 4, 3]", "[ 1, 6, 3]", "1" ) {}
208 };
209 
210 TEST_CASE_FIXTURE(ConcatenationFixture3DDim1, "ParseConcatenation3DDim1")
211 {
212     RunTest<3, armnn::DataType::QAsymmU8>(
213         0,
214         { { "inputTensor1", { 0,  1,  2,  3,  4,  5 } },
215           { "inputTensor2", { 6,  7,  8,  9, 10, 11,
216                              12, 13, 14, 15, 16, 17 } } },
217         { { "outputTensor", { 0,  1,  2,  3,  4,  5,
218                               6,  7,  8,  9, 10, 11,
219                              12, 13, 14, 15, 16, 17 } } });
220 }
221 
222 struct ConcatenationFixture3DDim2 : ConcatenationFixture
223 {
ConcatenationFixture3DDim2ConcatenationFixture3DDim2224     ConcatenationFixture3DDim2() : ConcatenationFixture("[ 1, 2, 3]", "[ 1, 2, 6]", "[ 1, 2, 9]", "2" ) {}
225 };
226 
227 TEST_CASE_FIXTURE(ConcatenationFixture3DDim2, "ParseConcatenation3DDim2")
228 {
229     RunTest<3, armnn::DataType::QAsymmU8>(
230         0,
231         { { "inputTensor1", { 0,  1,  2,
232                               3,  4,  5 } },
233           { "inputTensor2", { 6,  7,  8,  9, 10, 11,
234                              12, 13, 14, 15, 16, 17 } } },
235         { { "outputTensor", { 0,  1,  2,  6,  7,  8,  9, 10, 11,
236                               3,  4,  5, 12, 13, 14, 15, 16, 17 } } });
237 }
238 
239 }
240