1 // 2 // Copyright © 2020 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 9 TEST_SUITE("OnnxParser_Flatter") 10 { 11 struct FlattenMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 12 { FlattenMainFixtureFlattenMainFixture13 FlattenMainFixture(const std::string& dataType) 14 { 15 m_Prototext = R"( 16 ir_version: 3 17 producer_name: "CNTK" 18 producer_version: "2.5.1" 19 domain: "ai.cntk" 20 model_version: 1 21 graph { 22 name: "CNTKGraph" 23 input { 24 name: "Input" 25 type { 26 tensor_type { 27 elem_type: )" + dataType + R"( 28 shape { 29 dim { 30 dim_value: 2 31 } 32 dim { 33 dim_value: 2 34 } 35 dim { 36 dim_value: 3 37 } 38 dim { 39 dim_value: 3 40 } 41 } 42 } 43 } 44 } 45 node { 46 input: "Input" 47 output: "Output" 48 name: "flatten" 49 op_type: "Flatten" 50 attribute { 51 name: "axis" 52 i: 2 53 type: INT 54 } 55 } 56 output { 57 name: "Output" 58 type { 59 tensor_type { 60 elem_type: 1 61 shape { 62 dim { 63 dim_value: 4 64 } 65 dim { 66 dim_value: 9 67 } 68 } 69 } 70 } 71 } 72 } 73 opset_import { 74 version: 7 75 })"; 76 } 77 }; 78 79 struct FlattenDefaultAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 80 { FlattenDefaultAxisFixtureFlattenDefaultAxisFixture81 FlattenDefaultAxisFixture(const std::string& dataType) 82 { 83 m_Prototext = R"( 84 ir_version: 3 85 producer_name: "CNTK" 86 producer_version: "2.5.1" 87 domain: "ai.cntk" 88 model_version: 1 89 graph { 90 name: "CNTKGraph" 91 input { 92 name: "Input" 93 type { 94 tensor_type { 95 elem_type: )" + dataType + R"( 96 shape { 97 dim { 98 dim_value: 2 99 } 100 dim { 101 dim_value: 2 102 } 103 dim { 104 dim_value: 3 105 } 106 dim { 107 dim_value: 3 108 } 109 } 110 } 111 } 112 } 113 node { 114 input: "Input" 115 output: "Output" 116 name: "flatten" 117 op_type: "Flatten" 118 } 119 output { 120 name: "Output" 121 type { 122 tensor_type { 123 elem_type: 1 124 shape { 125 dim { 126 dim_value: 2 127 } 128 dim { 129 dim_value: 18 130 } 131 } 132 } 133 } 134 } 135 } 136 opset_import { 137 version: 7 138 })"; 139 } 140 }; 141 142 struct FlattenAxisZeroFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 143 { FlattenAxisZeroFixtureFlattenAxisZeroFixture144 FlattenAxisZeroFixture(const std::string& dataType) 145 { 146 m_Prototext = R"( 147 ir_version: 3 148 producer_name: "CNTK" 149 producer_version: "2.5.1" 150 domain: "ai.cntk" 151 model_version: 1 152 graph { 153 name: "CNTKGraph" 154 input { 155 name: "Input" 156 type { 157 tensor_type { 158 elem_type: )" + dataType + R"( 159 shape { 160 dim { 161 dim_value: 2 162 } 163 dim { 164 dim_value: 2 165 } 166 dim { 167 dim_value: 3 168 } 169 dim { 170 dim_value: 3 171 } 172 } 173 } 174 } 175 } 176 node { 177 input: "Input" 178 output: "Output" 179 name: "flatten" 180 op_type: "Flatten" 181 attribute { 182 name: "axis" 183 i: 0 184 type: INT 185 } 186 } 187 output { 188 name: "Output" 189 type { 190 tensor_type { 191 elem_type: 1 192 shape { 193 dim { 194 dim_value: 1 195 } 196 dim { 197 dim_value: 36 198 } 199 } 200 } 201 } 202 } 203 } 204 opset_import { 205 version: 7 206 })"; 207 } 208 }; 209 210 struct FlattenNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 211 { FlattenNegativeAxisFixtureFlattenNegativeAxisFixture212 FlattenNegativeAxisFixture(const std::string& dataType) 213 { 214 m_Prototext = R"( 215 ir_version: 3 216 producer_name: "CNTK" 217 producer_version: "2.5.1" 218 domain: "ai.cntk" 219 model_version: 1 220 graph { 221 name: "CNTKGraph" 222 input { 223 name: "Input" 224 type { 225 tensor_type { 226 elem_type: )" + dataType + R"( 227 shape { 228 dim { 229 dim_value: 2 230 } 231 dim { 232 dim_value: 2 233 } 234 dim { 235 dim_value: 3 236 } 237 dim { 238 dim_value: 3 239 } 240 } 241 } 242 } 243 } 244 node { 245 input: "Input" 246 output: "Output" 247 name: "flatten" 248 op_type: "Flatten" 249 attribute { 250 name: "axis" 251 i: -1 252 type: INT 253 } 254 } 255 output { 256 name: "Output" 257 type { 258 tensor_type { 259 elem_type: 1 260 shape { 261 dim { 262 dim_value: 12 263 } 264 dim { 265 dim_value: 3 266 } 267 } 268 } 269 } 270 } 271 } 272 opset_import { 273 version: 7 274 })"; 275 } 276 }; 277 278 struct FlattenInvalidNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 279 { FlattenInvalidNegativeAxisFixtureFlattenInvalidNegativeAxisFixture280 FlattenInvalidNegativeAxisFixture(const std::string& dataType) 281 { 282 m_Prototext = R"( 283 ir_version: 3 284 producer_name: "CNTK" 285 producer_version: "2.5.1" 286 domain: "ai.cntk" 287 model_version: 1 288 graph { 289 name: "CNTKGraph" 290 input { 291 name: "Input" 292 type { 293 tensor_type { 294 elem_type: )" + dataType + R"( 295 shape { 296 dim { 297 dim_value: 2 298 } 299 dim { 300 dim_value: 2 301 } 302 dim { 303 dim_value: 3 304 } 305 dim { 306 dim_value: 3 307 } 308 } 309 } 310 } 311 } 312 node { 313 input: "Input" 314 output: "Output" 315 name: "flatten" 316 op_type: "Flatten" 317 attribute { 318 name: "axis" 319 i: -5 320 type: INT 321 } 322 } 323 output { 324 name: "Output" 325 type { 326 tensor_type { 327 elem_type: 1 328 shape { 329 dim { 330 dim_value: 12 331 } 332 dim { 333 dim_value: 3 334 } 335 } 336 } 337 } 338 } 339 } 340 opset_import { 341 version: 7 342 })"; 343 } 344 }; 345 346 struct FlattenValidFixture : FlattenMainFixture 347 { FlattenValidFixtureFlattenValidFixture348 FlattenValidFixture() : FlattenMainFixture("1") { 349 Setup(); 350 } 351 }; 352 353 struct FlattenDefaultValidFixture : FlattenDefaultAxisFixture 354 { FlattenDefaultValidFixtureFlattenDefaultValidFixture355 FlattenDefaultValidFixture() : FlattenDefaultAxisFixture("1") { 356 Setup(); 357 } 358 }; 359 360 struct FlattenAxisZeroValidFixture : FlattenAxisZeroFixture 361 { FlattenAxisZeroValidFixtureFlattenAxisZeroValidFixture362 FlattenAxisZeroValidFixture() : FlattenAxisZeroFixture("1") { 363 Setup(); 364 } 365 }; 366 367 struct FlattenNegativeAxisValidFixture : FlattenNegativeAxisFixture 368 { FlattenNegativeAxisValidFixtureFlattenNegativeAxisValidFixture369 FlattenNegativeAxisValidFixture() : FlattenNegativeAxisFixture("1") { 370 Setup(); 371 } 372 }; 373 374 struct FlattenInvalidFixture : FlattenMainFixture 375 { FlattenInvalidFixtureFlattenInvalidFixture376 FlattenInvalidFixture() : FlattenMainFixture("10") { } 377 }; 378 379 struct FlattenInvalidAxisFixture : FlattenInvalidNegativeAxisFixture 380 { FlattenInvalidAxisFixtureFlattenInvalidAxisFixture381 FlattenInvalidAxisFixture() : FlattenInvalidNegativeAxisFixture("1") { } 382 }; 383 384 TEST_CASE_FIXTURE(FlattenValidFixture, "ValidFlattenTest") 385 { 386 RunTest<2>({{"Input", 387 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 388 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 389 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, 390 {{"Output", 391 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 392 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 393 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}); 394 } 395 396 TEST_CASE_FIXTURE(FlattenDefaultValidFixture, "ValidFlattenDefaultTest") 397 { 398 RunTest<2>({{"Input", 399 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 400 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 401 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, 402 {{"Output", 403 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 404 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 405 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}); 406 } 407 408 TEST_CASE_FIXTURE(FlattenAxisZeroValidFixture, "ValidFlattenAxisZeroTest") 409 { 410 RunTest<2>({{"Input", 411 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 412 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 413 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, 414 {{"Output", 415 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 416 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 417 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}); 418 } 419 420 TEST_CASE_FIXTURE(FlattenNegativeAxisValidFixture, "ValidFlattenNegativeAxisTest") 421 { 422 RunTest<2>({{"Input", 423 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 424 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 425 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, 426 {{"Output", 427 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 428 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 429 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}); 430 } 431 432 TEST_CASE_FIXTURE(FlattenInvalidFixture, "IncorrectDataTypeFlatten") 433 { 434 CHECK_THROWS_AS(Setup(), armnn::ParseException); 435 } 436 437 TEST_CASE_FIXTURE(FlattenInvalidAxisFixture, "IncorrectAxisFlatten") 438 { 439 CHECK_THROWS_AS(Setup(), armnn::ParseException); 440 } 441 442 } 443