1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 #include "OnnxParserTestUtils.hpp" 9 10 TEST_SUITE("OnnxParser_Reshape") 11 { 12 struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 13 { ReshapeMainFixtureReshapeMainFixture14 ReshapeMainFixture(const std::string& dataType) 15 { 16 m_Prototext = R"( 17 ir_version: 3 18 producer_name: "CNTK" 19 producer_version: "2.5.1" 20 domain: "ai.cntk" 21 model_version: 1 22 graph { 23 name: "CNTKGraph" 24 input { 25 name: "Input" 26 type { 27 tensor_type { 28 elem_type: )" + dataType + R"( 29 shape { 30 dim { 31 dim_value: 4 32 } 33 } 34 } 35 } 36 } 37 input { 38 name: "Shape" 39 type { 40 tensor_type { 41 elem_type: 7 42 shape { 43 dim { 44 dim_value: 2 45 } 46 } 47 } 48 } 49 } 50 node { 51 input: "Input" 52 input: "Shape" 53 output: "Output" 54 name: "reshape" 55 op_type: "Reshape" 56 57 } 58 initializer { 59 dims: 2 60 data_type: 7 61 int64_data: 2 62 int64_data: 2 63 name: "Shape" 64 } 65 output { 66 name: "Output" 67 type { 68 tensor_type { 69 elem_type: 1 70 shape { 71 dim { 72 dim_value: 2 73 } 74 dim { 75 dim_value: 2 76 } 77 } 78 } 79 } 80 } 81 } 82 opset_import { 83 version: 7 84 })"; 85 } 86 }; 87 88 struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 89 { ReshapeRank4FixtureReshapeRank4Fixture90 ReshapeRank4Fixture(const std::string& dataType) 91 { 92 m_Prototext = R"( 93 ir_version: 3 94 producer_name: "CNTK" 95 producer_version: "2.5.1" 96 domain: "ai.cntk" 97 model_version: 1 98 graph { 99 name: "CNTKGraph" 100 input { 101 name: "Input" 102 type { 103 tensor_type { 104 elem_type: )" + dataType + R"( 105 shape { 106 dim { 107 dim_value: 2 108 } 109 dim { 110 dim_value: 2 111 } 112 dim { 113 dim_value: 3 114 } 115 dim { 116 dim_value: 3 117 } 118 } 119 } 120 } 121 } 122 input { 123 name: "Shape" 124 type { 125 tensor_type { 126 elem_type: 7 127 shape { 128 dim { 129 dim_value: 2 130 } 131 } 132 } 133 } 134 } 135 node { 136 input: "Input" 137 input: "Shape" 138 output: "Output" 139 name: "reshape" 140 op_type: "Reshape" 141 142 } 143 initializer { 144 dims: 2 145 data_type: 7 146 int64_data: 2 147 int64_data: 2 148 name: "Shape" 149 } 150 output { 151 name: "Output" 152 type { 153 tensor_type { 154 elem_type: 1 155 shape { 156 dim { 157 dim_value: 6 158 } 159 dim { 160 dim_value: 6 161 } 162 } 163 } 164 } 165 } 166 } 167 opset_import { 168 version: 7 169 })"; 170 } 171 }; 172 173 struct ReshapeValidFixture : ReshapeMainFixture 174 { ReshapeValidFixtureReshapeValidFixture175 ReshapeValidFixture() : ReshapeMainFixture("1") { 176 Setup(); 177 } 178 }; 179 180 struct ReshapeValidRank4Fixture : ReshapeRank4Fixture 181 { ReshapeValidRank4FixtureReshapeValidRank4Fixture182 ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") { 183 Setup(); 184 } 185 }; 186 187 struct ReshapeInvalidFixture : ReshapeMainFixture 188 { ReshapeInvalidFixtureReshapeInvalidFixture189 ReshapeInvalidFixture() : ReshapeMainFixture("10") { } 190 }; 191 192 TEST_CASE_FIXTURE(ReshapeValidFixture, "ValidReshapeTest") 193 { 194 RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}}); 195 } 196 197 TEST_CASE_FIXTURE(ReshapeValidRank4Fixture, "ValidRank4ReshapeTest") 198 { 199 RunTest<2>( 200 {{"Input", 201 {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 202 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 203 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}}, 204 {{"Output", 205 {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 206 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 207 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}}); 208 } 209 210 TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape") 211 { 212 CHECK_THROWS_AS(Setup(), armnn::ParseException); 213 } 214 215 struct ReshapeNegativeReshapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 216 { ReshapeNegativeReshapeFixtureReshapeNegativeReshapeFixture217 ReshapeNegativeReshapeFixture(const std::vector<int>& inputShape, 218 const std::vector<int>& shapeInputShape, 219 const std::vector<int>& outputShape, 220 const std::string& shape) 221 { 222 m_Prototext = R"( 223 ir_version: 3 224 producer_name: "onnx-example" 225 graph { 226 name: "ReshapeGrapn" 227 input { 228 name: "Input" 229 type { 230 tensor_type { 231 elem_type: 1 232 shape { 233 )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( 234 } 235 } 236 } 237 } 238 input { 239 name: "Shape" 240 type { 241 tensor_type { 242 elem_type: 7 243 shape { 244 )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"( 245 } 246 } 247 } 248 } 249 node { 250 input: "Input" 251 input: "Shape" 252 output: "Output" 253 name: "reshape" 254 op_type: "Reshape" 255 } 256 initializer { 257 dims: 2 258 data_type: 7 259 )" + shape + R"( 260 name: "Shape" 261 } 262 output { 263 name: "Output" 264 type { 265 tensor_type { 266 elem_type: 1 267 shape { 268 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 269 } 270 } 271 } 272 } 273 } 274 opset_import { 275 version: 7 276 })"; 277 } 278 }; 279 280 struct ReshapeNegativeReshape1DFixture : ReshapeNegativeReshapeFixture 281 { ReshapeNegativeReshape1DFixtureReshapeNegativeReshape1DFixture282 ReshapeNegativeReshape1DFixture() : ReshapeNegativeReshapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }, "int64_data: -1") 283 { 284 Setup(); 285 } 286 }; 287 288 struct ReshapeNegativeReshape2DFixture : ReshapeNegativeReshapeFixture 289 { ReshapeNegativeReshape2DFixtureReshapeNegativeReshape2DFixture290 ReshapeNegativeReshape2DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 }, 291 { 2 }, 292 { 2, 6 }, 293 "int64_data: -1 int64_data: 6") 294 { 295 Setup(); 296 } 297 }; 298 299 struct ReshapeNegativeReshape3DFixture : ReshapeNegativeReshapeFixture 300 { ReshapeNegativeReshape3DFixtureReshapeNegativeReshape3DFixture301 ReshapeNegativeReshape3DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 }, 302 { 3 }, 303 { 3, 1, 4 }, 304 "int64_data: 3 int64_data: -1 int64_data: 4") 305 { 306 Setup(); 307 } 308 }; 309 310 struct ReshapeNegativeReshape4DFixture : ReshapeNegativeReshapeFixture 311 { ReshapeNegativeReshape4DFixtureReshapeNegativeReshape4DFixture312 ReshapeNegativeReshape4DFixture() : ReshapeNegativeReshapeFixture( 313 { 2, 3, 1, 2 }, 314 { 4 }, 315 { 3, 1, 2, 2 }, 316 "int64_data: 3 int64_data: 1 int64_data: 2 int64_data: -1") 317 { 318 Setup(); 319 } 320 }; 321 322 TEST_CASE_FIXTURE(ReshapeNegativeReshape1DFixture, "ReshapeNegativeReshape1DTest") 323 { 324 RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}, 325 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}); 326 } 327 328 TEST_CASE_FIXTURE(ReshapeNegativeReshape2DFixture, "ReshapeNegativeReshape2DTest") 329 { 330 RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 331 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, 332 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 333 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); 334 } 335 336 TEST_CASE_FIXTURE(ReshapeNegativeReshape3DFixture, "ReshapeNegativeReshape3DTest") 337 { 338 RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 339 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, 340 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 341 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); 342 } 343 344 TEST_CASE_FIXTURE(ReshapeNegativeReshape4DFixture, "ReshapeNegativeReshape4DTest") 345 { 346 RunTest<4, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 347 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}, 348 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 349 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}}); 350 } 351 352 struct ReshapeNonConstShapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 353 { ReshapeNonConstShapeFixtureReshapeNonConstShapeFixture354 ReshapeNonConstShapeFixture(const std::vector<int>& inputShape, 355 const std::vector<int>& shapeInputShape, 356 const std::vector<int>& outputShape) 357 { 358 m_Prototext = R"( 359 ir_version: 3 360 producer_name: "onnx-example" 361 graph { 362 name: "ReshapeGrapn" 363 input { 364 name: "Input" 365 type { 366 tensor_type { 367 elem_type: 1 368 shape { 369 )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( 370 } 371 } 372 } 373 } 374 input { 375 name: "Shape" 376 type { 377 tensor_type { 378 elem_type: 7 379 shape { 380 )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"( 381 } 382 } 383 } 384 } 385 node { 386 input: "Input" 387 input: "Shape" 388 output: "Output" 389 name: "reshape" 390 op_type: "Reshape" 391 } 392 output { 393 name: "Output" 394 type { 395 tensor_type { 396 elem_type: 1 397 shape { 398 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 399 } 400 } 401 } 402 } 403 } 404 opset_import { 405 version: 7 406 })"; 407 } 408 }; 409 410 struct ReshapeNonConst1DShapeFixture : ReshapeNonConstShapeFixture 411 { ReshapeNonConst1DShapeFixtureReshapeNonConst1DShapeFixture412 ReshapeNonConst1DShapeFixture() : ReshapeNonConstShapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }) 413 { 414 Setup(); 415 } 416 }; 417 418 struct ReshapeNonConst2DShapeFixture : ReshapeNonConstShapeFixture 419 { ReshapeNonConst2DShapeFixtureReshapeNonConst2DShapeFixture420 ReshapeNonConst2DShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 2 }, { 2, 12 }) 421 { 422 Setup(); 423 } 424 }; 425 426 struct ReshapeInvalidNonConstShapeFixture : ReshapeNonConstShapeFixture 427 { ReshapeInvalidNonConstShapeFixtureReshapeInvalidNonConstShapeFixture428 ReshapeInvalidNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 3 }, { 2, 3, 4 }) 429 { 430 } 431 }; 432 433 struct ReshapeInvalidDimNonConstShapeFixture : ReshapeNonConstShapeFixture 434 { ReshapeInvalidDimNonConstShapeFixtureReshapeInvalidDimNonConstShapeFixture435 ReshapeInvalidDimNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 1, 2 }, { 2, 3, 4 }) 436 { 437 } 438 }; 439 440 TEST_CASE_FIXTURE(ReshapeNonConst1DShapeFixture, "ReshapeNonConst1DShapeTest") 441 { 442 RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}, 443 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}); 444 } 445 446 TEST_CASE_FIXTURE(ReshapeNonConst2DShapeFixture, "ReshapeNonConst2DShapeTest") 447 { 448 RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 449 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 450 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 451 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}}, 452 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 453 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 454 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 455 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}}); 456 } 457 458 TEST_CASE_FIXTURE(ReshapeInvalidNonConstShapeFixture, "ReshapeInvalidNonConstShapeTest") 459 { 460 CHECK_THROWS_AS(Setup(), armnn::ParseException); 461 } 462 463 TEST_CASE_FIXTURE(ReshapeInvalidDimNonConstShapeFixture, "ReshapeInvalidDimNonConstShapeTest") 464 { 465 CHECK_THROWS_AS(Setup(), armnn::ParseException); 466 } 467 468 } 469