1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 #include "OnnxParserTestUtils.hpp" 9 10 TEST_SUITE("OnnxParser_Concat") 11 { 12 13 struct ConcatFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 14 { ConcatFixtureConcatFixture15 ConcatFixture(const std::string& axis, 16 const std::vector<int>& input0Shape, 17 const std::vector<int>& input1Shape, 18 const std::vector<int>& outputShape) 19 { 20 m_Prototext = R"( 21 ir_version: 8 22 producer_name: "onnx-example" 23 graph { 24 node { 25 input: "Input0" 26 input: "Input1" 27 output: "Output" 28 op_type: "Concat" 29 attribute { 30 name: "axis" 31 i: )" + axis + R"( 32 type: INT 33 } 34 } 35 name: "concat-model" 36 input { 37 name: "Input0" 38 type { 39 tensor_type { 40 elem_type: 1 41 shape { 42 )" + armnnUtils::ConstructTensorShapeString(input0Shape) + R"( 43 } 44 } 45 } 46 } 47 input { 48 name: "Input1" 49 type { 50 tensor_type { 51 elem_type: 1 52 shape { 53 )" + armnnUtils::ConstructTensorShapeString(input1Shape) + R"( 54 } 55 } 56 } 57 } 58 output { 59 name: "Output" 60 type { 61 tensor_type { 62 elem_type: 1 63 shape { 64 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 65 } 66 } 67 } 68 } 69 })"; 70 Setup(); 71 } 72 }; 73 74 struct ConcatAxis0Fixture : ConcatFixture 75 { ConcatAxis0FixtureConcatAxis0Fixture76 ConcatAxis0Fixture() : ConcatFixture("0", { 1, 3, 2, 5 }, { 1, 3, 2, 5 }, { 2, 3, 2, 5 }) {} 77 }; 78 79 struct ConcatAxis1Fixture : ConcatFixture 80 { ConcatAxis1FixtureConcatAxis1Fixture81 ConcatAxis1Fixture() : ConcatFixture("1", { 2, 2, 1, 3 }, { 2, 1, 1, 3 }, { 2, 3, 1, 3 }) {} 82 }; 83 84 struct ConcatAxis2Fixture : ConcatFixture 85 { ConcatAxis2FixtureConcatAxis2Fixture86 ConcatAxis2Fixture() : ConcatFixture("2", { 2, 3, 1, 1 }, { 2, 3, 2, 1 }, { 2, 3, 3, 1 }) {} 87 }; 88 89 struct ConcatAxis3Fixture : ConcatFixture 90 { ConcatAxis3FixtureConcatAxis3Fixture91 ConcatAxis3Fixture() : ConcatFixture("3", { 1, 3, 2, 2 }, { 1, 3, 2, 2 }, { 1, 3, 2, 4 }) {} 92 }; 93 94 struct ConcatNegativeAxisFixture : ConcatFixture 95 { ConcatNegativeAxisFixtureConcatNegativeAxisFixture96 ConcatNegativeAxisFixture() : ConcatFixture("-1", { 1, 2, 5 }, { 1, 2, 3 }, { 1, 2, 8 }) {} 97 }; 98 99 TEST_CASE_FIXTURE(ConcatAxis0Fixture, "ConcatAxis0Test") 100 { 101 RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 102 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 103 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 104 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 105 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 106 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}, 107 {"Input1", { 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 108 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 109 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 110 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 111 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 112 56.0f, 57.0f, 58.0f, 59.0f, 60.0f }}}, 113 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 114 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 115 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 116 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 117 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 118 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 119 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 120 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 121 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 122 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 123 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 124 56.0f, 57.0f, 58.0f, 59.0f, 60.0f }}}); 125 } 126 127 TEST_CASE_FIXTURE(ConcatAxis1Fixture, "ConcatAxis1est") 128 { 129 RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}, 130 {"Input1", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}}, 131 {{"Output", { 1.0f, 2.0f, 3.0f, 132 4.0f, 5.0f, 6.0f, 133 13.0f, 14.0f, 15.0f, 134 7.0f, 8.0f, 9.0f, 135 10.0f, 11.0f, 12.0f, 136 16.0f, 17.0f, 18.0f }}}); 137 } 138 139 TEST_CASE_FIXTURE(ConcatAxis2Fixture, "ConcatAxis2Test") 140 { 141 RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}, 142 {"Input1", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}}, 143 {{"Output", { 1.0f, 7.0f, 8.0f, 144 2.0f, 9.0f, 10.0f, 145 3.0f, 11.0f, 12.0f, 146 4.0f, 13.0f, 14.0f, 147 5.0f, 15.0f, 16.0f, 148 6.0f, 17.0f, 18.0f }}}); 149 } 150 151 TEST_CASE_FIXTURE(ConcatAxis3Fixture, "ConcatAxis3Test") 152 { 153 RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 154 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}, 155 {"Input1", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 156 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}}, 157 {{"Output", { 1.0f, 2.0f, 13.0f, 14.0f, 158 3.0f, 4.0f, 15.0f, 16.0f, 159 5.0f, 6.0f, 17.0f, 18.0f, 160 7.0f, 8.0f, 19.0f, 20.0f, 161 9.0f, 10.0f, 21.0f, 22.0f, 162 11.0f, 12.0f, 23.0f, 24.0f }}}); 163 } 164 165 TEST_CASE_FIXTURE(ConcatNegativeAxisFixture, "ConcatNegativeAxisTest") 166 { 167 RunTest<3, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 168 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}, 169 {"Input1", { 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f }}}, 170 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f, 12.0f, 13.0f, 171 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 14.0f, 15.0f, 16.0f }}}); 172 } 173 174 struct ConcatMultipleInputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 175 { ConcatMultipleInputsFixtureConcatMultipleInputsFixture176 ConcatMultipleInputsFixture() 177 { 178 m_Prototext = R"( 179 ir_version: 8 180 producer_name: "onnx-example" 181 graph { 182 node { 183 input: "Input0" 184 input: "Input1" 185 input: "Input2" 186 output: "Output" 187 op_type: "Concat" 188 attribute { 189 name: "axis" 190 i: 1 191 type: INT 192 } 193 } 194 name: "concat-model" 195 input { 196 name: "Input0" 197 type { 198 tensor_type { 199 elem_type: 1 200 shape { 201 dim { 202 dim_value: 3 203 } 204 dim { 205 dim_value: 2 206 } 207 } 208 } 209 } 210 } 211 input { 212 name: "Input1" 213 type { 214 tensor_type { 215 elem_type: 1 216 shape { 217 dim { 218 dim_value: 3 219 } 220 dim { 221 dim_value: 3 222 } 223 } 224 } 225 } 226 } 227 input { 228 name: "Input2" 229 type { 230 tensor_type { 231 elem_type: 1 232 shape { 233 dim { 234 dim_value: 3 235 } 236 dim { 237 dim_value: 1 238 } 239 } 240 } 241 } 242 } 243 output { 244 name: "Output" 245 type { 246 tensor_type { 247 elem_type: 1 248 shape { 249 dim { 250 dim_value: 3 251 } 252 dim { 253 dim_value: 6 254 } 255 } 256 } 257 } 258 } 259 })"; 260 Setup(); 261 } 262 }; 263 264 TEST_CASE_FIXTURE(ConcatMultipleInputsFixture, "ConcatMultipleInputsTest") 265 { 266 RunTest<2, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}, 267 {"Input1", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f }}, 268 {"Input2", { 16.0f, 17.0f, 18.0f }}}, 269 {{"Output", { 1.0f, 2.0f, 7.0f, 8.0f, 9.0f, 16.0f, 270 3.0f, 4.0f, 10.0f, 11.0f, 12.0f, 17.0f, 271 5.0f, 6.0f, 13.0f, 14.0f, 15.0f, 18.0f }}}); 272 } 273 274 } 275